Source code for composer.metrics.metrics

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""A collection of common torchmetrics."""
from __future__ import annotations

from typing import Callable

import torch
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.utilities.data import to_categorical

from composer.loss import soft_cross_entropy

__all__ = ['MIoU', 'Dice', 'CrossEntropy', 'LossMetric']


[docs]class MIoU(Metric): """Torchmetrics mean Intersection-over-Union (mIoU) implementation. IoU calculates the intersection area between the predicted class mask and the label class mask. The intersection is then divided by the area of the union of the predicted and label masks. This measures the quality of predicted class mask with respect to the label. The IoU for each class is then averaged and the final result is the mIoU score. Implementation is primarily based on `mmsegmentation <https://github.com/open-mmlab/mmsegmentation/blob/aa50358c71fe9c4cccdd2abe42433bdf702e757b/mmseg/core/evaluation/metrics.py#L132>`_ Args: num_classes (int): the number of classes in the segmentation task. ignore_index (int, optional): the index to ignore when computing mIoU. Default: ``-1``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, num_classes: int, ignore_index: int = -1): super().__init__(dist_sync_on_step=True) self.num_classes = num_classes self.ignore_index = ignore_index self.add_state('total_intersect', default=torch.zeros(num_classes, dtype=torch.float64), dist_reduce_fx='sum') self.add_state('total_union', default=torch.zeros(num_classes, dtype=torch.float64), dist_reduce_fx='sum')
[docs] def update(self, logits: Tensor, targets: Tensor): """Update the state with new predictions and targets.""" preds = logits.argmax(dim=1) for pred, target in zip(preds, targets): mask = (target != self.ignore_index) pred = pred[mask] target = target[mask] intersect = pred[pred == target] area_intersect = torch.histc(intersect.float(), bins=self.num_classes, min=0, max=self.num_classes - 1) area_prediction = torch.histc(pred.float(), bins=self.num_classes, min=0, max=self.num_classes - 1) area_target = torch.histc(target.float(), bins=self.num_classes, min=0, max=self.num_classes - 1) self.total_intersect += area_intersect self.total_union += area_prediction + area_target - area_intersect
[docs] def compute(self): """Aggregate state across all processes and compute final metric.""" total_intersect = self.total_intersect[self.total_union != 0] # type: ignore (third-party) total_union = self.total_union[self.total_union != 0] # type: ignore (third-party) return 100 * (total_intersect / total_union).mean()
[docs]class Dice(Metric): """The Dice Coefficient for evaluating image segmentation. The Dice Coefficient measures how similar predictions and targets are. More concretely, it is computed as 2 * the Area of Overlap divided by the total number of pixels in both images. Args: num_classes (int): the number of classes in the segmentation task. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, num_classes: int): super().__init__(dist_sync_on_step=True) self.add_state('n_updates', default=torch.zeros(1), dist_reduce_fx='sum') self.add_state('dice', default=torch.zeros((num_classes,)), dist_reduce_fx='sum')
[docs] def update(self, preds: Tensor, targets: Tensor): """Update the state based on new predictions and targets.""" self.n_updates += 1 # type: ignore self.dice += self.compute_stats(preds, targets)
[docs] def compute(self): """Aggregate the state over all processes to compute the metric.""" dice = 100 * self.dice / self.n_updates # type: ignore best_sum_dice = dice[:] top_dice = round(torch.mean(best_sum_dice).item(), 2) return top_dice
@staticmethod def compute_stats(preds: Tensor, targets: Tensor): num_classes = preds.shape[1] scores = torch.zeros(num_classes - 1, device=preds.device, dtype=torch.float32) for i in range(1, num_classes): if (targets != i).all(): # no foreground class _, _pred = torch.max(preds, 1) scores[i - 1] += 1 if (_pred != i).all() else 0 continue _tp, _fp, _tn, _fn, _ = _stat_scores(preds, targets, class_index=i) # type: ignore denom = (2 * _tp + _fp + _fn).to(torch.float) score_cls = (2 * _tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0 scores[i - 1] += score_cls return scores
def _stat_scores( preds: Tensor, targets: Tensor, class_index: int, argmax_dim: int = 1, ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: if preds.ndim == targets.ndim + 1: preds = to_categorical(preds, argmax_dim=argmax_dim) tp = ((preds == class_index) * (targets == class_index)).to(torch.long).sum() fp = ((preds == class_index) * (targets != class_index)).to(torch.long).sum() tn = ((preds != class_index) * (targets != class_index)).to(torch.long).sum() fn = ((preds != class_index) * (targets == class_index)).to(torch.long).sum() sup = (targets == class_index).to(torch.long).sum() return tp, fp, tn, fn, sup
[docs]class CrossEntropy(Metric): """Torchmetrics cross entropy loss implementation. This class implements cross entropy loss as a :class:`torchmetrics.Metric` so that it can be returned by the :meth:`~.ComposerModel.metrics`. Args: ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. ``ignore_index`` is only applicable when the target contains class indices. Default: ``-100``. dist_sync_on_step (bool, optional): sync distributed metrics every step. Default: ``False``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, ignore_index: int = -100, dist_sync_on_step: bool = False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')
[docs] def update(self, preds: Tensor, targets: Tensor) -> None: """Update the state with new predictions and targets.""" # Loss calculated over samples/batch, accumulate loss over all batches self.sum_loss += soft_cross_entropy(preds, targets, ignore_index=self.ignore_index) assert isinstance(self.total_batches, Tensor) self.total_batches += 1
[docs] def compute(self) -> Tensor: """Aggregate state over all processes and compute the metric.""" # Return average loss over entire validation dataset assert isinstance(self.total_batches, Tensor) assert isinstance(self.sum_loss, Tensor) return self.sum_loss / self.total_batches
[docs]class LossMetric(Metric): """Turns a torch.nn Loss Module into distributed torchmetrics Metric. Args: loss_function (callable): loss function to compute and track. dist_sync_on_step (bool, optional): sync distributed metrics every step. Default: ``False``. """ # Make torchmetrics call update only once full_state_update = False def __init__(self, loss_function: Callable, dist_sync_on_step: bool = False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.loss_function = loss_function self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_batches', default=torch.tensor(0), dist_reduce_fx='sum')
[docs] def update(self, preds: Tensor, targets: Tensor) -> None: """Update the state with new predictions and targets.""" # Loss calculated over samples/batch, accumulate loss over all batches self.sum_loss += self.loss_function(preds, targets) self.total_batches += 1 # type: ignore
[docs] def compute(self): """Aggregate state over all processes and compute the metric.""" # Return average loss over entire validation dataset assert isinstance(self.total_batches, Tensor) assert isinstance(self.sum_loss, Tensor) return self.sum_loss / self.total_batches