# Copyright 2021 MosaicML. All Rights Reserved.
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.utilities.data import to_categorical
if TYPE_CHECKING:
from composer.core.types import Tensor
[docs]class Dice(Metric):
"""The Dice Coefficient for evaluating image segmentation.
The Dice Coefficient measures how similar predictions and targets are.
More concretely, it is computed as 2 * the Area of Overlap divided by
the total number of pixels in both images.
"""
def __init__(self, nclass):
super().__init__(dist_sync_on_step=True)
self.add_state("n_updates", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("dice", default=torch.zeros((nclass,)), dist_reduce_fx="sum")
[docs] def update(self, pred, target):
"""Update the state based on new predictions and targets."""
self.n_updates += 1 # type: ignore
self.dice += self.compute_stats(pred, target)
[docs] def compute(self):
"""Aggregate the state over all processes to compute the metric."""
dice = 100 * self.dice / self.n_updates # type: ignore
best_sum_dice = dice[:]
top_dice = round(torch.mean(best_sum_dice).item(), 2)
return top_dice
@staticmethod
def compute_stats(pred, target):
num_classes = pred.shape[1]
scores = torch.zeros(num_classes - 1, device=pred.device, dtype=torch.float32)
for i in range(1, num_classes):
if (target != i).all():
# no foreground class
_, _pred = torch.max(pred, 1)
scores[i - 1] += 1 if (_pred != i).all() else 0
continue
_tp, _fp, _tn, _fn, _ = _stat_scores(pred, target, class_index=i) # type: ignore
denom = (2 * _tp + _fp + _fn).to(torch.float)
score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
scores[i - 1] += score_cls
return scores
def _stat_scores(
preds: Tensor,
target: Tensor,
class_index: int,
argmax_dim: int = 1,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
if preds.ndim == target.ndim + 1:
preds = to_categorical(preds, argmax_dim=argmax_dim)
tp = ((preds == class_index) * (target == class_index)).to(torch.long).sum()
fp = ((preds == class_index) * (target != class_index)).to(torch.long).sum()
tn = ((preds != class_index) * (target != class_index)).to(torch.long).sum()
fn = ((preds != class_index) * (target == class_index)).to(torch.long).sum()
sup = (target == class_index).to(torch.long).sum()
return tp, fp, tn, fn, sup
def _infer_target_type(input: Tensor, targets: Tensor) -> str:
"""Infers whether the target is in indices format or one_hot format.
Example indices format: [1, 4, 7]
Example one_hot format [[0, 1, 0], [1, 0, 0], ...]
"""
if input.shape == targets.shape:
return 'one_hot'
elif input.ndim == targets.ndim + 1:
return 'indices'
else:
raise RuntimeError(f'Unable to infer indices or one_hot. Targets has shape {targets.shape}'
f' and the inputs to cross entropy has shape {input.shape}. For one_hot, '
'expect targets.shape == inputs.shape. For indices, expect '
'inputs.ndim == targets.ndim + 1')
def ensure_targets_one_hot(input: Tensor, targets: Tensor) -> Tensor:
if _infer_target_type(input, targets) == 'indices':
targets = F.one_hot(targets, num_classes=input.shape[1])
return targets
def check_for_index_targets(targets: Tensor) -> bool:
"""Checks if a given set of targets are indices by looking at the type"""
index_types = ['torch.LongTensor', 'torch.cuda.LongTensor']
return targets.type() in index_types
[docs]def soft_cross_entropy(input: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = 'mean'):
"""Drop-in replacement for ``torch.CrossEntropy`` that can handle dense labels.
This function will be obsolete with
`this update <https://github.com/pytorch/pytorch/pull/61044>`_.
"""
target_type = _infer_target_type(input, target)
if target_type == 'indices':
return F.cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
elif target_type == 'one_hot':
assert reduction in ['sum', 'mean', 'none'], f"{reduction} reduction not supported."
assert size_average is None, "size_average is deprecated"
assert reduce is None, "reduce is deprecated"
assert ignore_index == -100, "ignore_index not supported."
probs = -1 * (target * F.log_softmax(input, dim=1))
if weight is not None:
probs *= weight / weight.sum() # allow broadcast along batch dim
probs = probs.sum(dim=1)
if reduction == 'sum':
probs = probs.sum(dim=0)
elif reduction == 'mean':
probs = probs.mean(dim=0)
return probs
else:
raise ValueError(f"Unrecognized target type {target_type}")
[docs]class CrossEntropyLoss(Metric):
"""Torchmetric cross entropy loss implementation.
This class implements cross entropy loss as a `torchmetric` so that
it can be returned by the :meth:`~composer.models.BaseMosaicModel.metric`
function in :class:`BaseMosaicModel`.
"""
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("sum_loss", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("total_batches", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, preds: Tensor, target: Tensor) -> None:
"""Update the state with new predictions and targets.
"""
# Loss calculated over samples/batch, accumulate loss over all batches
self.sum_loss += soft_cross_entropy(preds, target)
assert isinstance(self.total_batches, Tensor)
self.total_batches += 1
[docs] def compute(self) -> Tensor:
"""Aggregate state over all processes and compute the metric.
"""
# Return average loss over entire validation dataset
assert isinstance(self.total_batches, Tensor)
assert isinstance(self.sum_loss, Tensor)
return self.sum_loss / self.total_batches