# Copyright 2022 MosaicML Composer authors# SPDX-License-Identifier: Apache-2.0"""A collection of common torchmetrics."""from__future__importannotationsfromtypingimportCallableimporttorchfromtorchimportTensorfromtorchmetricsimportMetricfromtorchmetrics.utilities.dataimportto_categoricalfromcomposer.lossimportsoft_cross_entropy__all__=['MIoU','Dice','CrossEntropy','LossMetric']
[docs]classMIoU(Metric):"""Torchmetrics mean Intersection-over-Union (mIoU) implementation. IoU calculates the intersection area between the predicted class mask and the label class mask. The intersection is then divided by the area of the union of the predicted and label masks. This measures the quality of predicted class mask with respect to the label. The IoU for each class is then averaged and the final result is the mIoU score. Implementation is primarily based on `mmsegmentation <https://github.com/open-mmlab/mmsegmentation/blob/aa50358c71fe9c4cccdd2abe42433bdf702e757b/mmseg/core/evaluation/metrics.py#L132>`_ Args: num_classes (int): the number of classes in the segmentation task. ignore_index (int, optional): the index to ignore when computing mIoU. Default: ``-1``. """# Make torchmetrics call update only oncefull_state_update=Falsedef__init__(self,num_classes:int,ignore_index:int=-1):super().__init__(dist_sync_on_step=True)self.num_classes=num_classesself.ignore_index=ignore_indexself.add_state('total_intersect',default=torch.zeros(num_classes,dtype=torch.float64),dist_reduce_fx='sum')self.add_state('total_union',default=torch.zeros(num_classes,dtype=torch.float64),dist_reduce_fx='sum')
[docs]defupdate(self,logits:Tensor,targets:Tensor):"""Update the state with new predictions and targets."""preds=logits.argmax(dim=1)forpred,targetinzip(preds,targets):mask=(target!=self.ignore_index)pred=pred[mask]target=target[mask]intersect=pred[pred==target]area_intersect=torch.histc(intersect.float(),bins=self.num_classes,min=0,max=self.num_classes-1)area_prediction=torch.histc(pred.float(),bins=self.num_classes,min=0,max=self.num_classes-1)area_target=torch.histc(target.float(),bins=self.num_classes,min=0,max=self.num_classes-1)assertisinstance(self.total_intersect,Tensor)assertisinstance(self.total_union,Tensor)self.total_intersect+=area_intersectself.total_union+=area_prediction+area_target-area_intersect
[docs]defcompute(self):"""Aggregate state across all processes and compute final metric."""total_intersect=self.total_intersect[self.total_union!=0]# type: ignore (third-party)total_union=self.total_union[self.total_union!=0]# type: ignore (third-party)return100*(total_intersect/total_union).mean()
[docs]classDice(Metric):"""The Dice Coefficient for evaluating image segmentation. The Dice Coefficient measures how similar predictions and targets are. More concretely, it is computed as 2 * the Area of Overlap divided by the total number of pixels in both images. Args: num_classes (int): the number of classes in the segmentation task. """# Make torchmetrics call update only oncefull_state_update=Falsedef__init__(self,num_classes:int):super().__init__(dist_sync_on_step=True)self.add_state('n_updates',default=torch.zeros(1),dist_reduce_fx='sum')self.add_state('dice',default=torch.zeros((num_classes,)),dist_reduce_fx='sum')
[docs]defupdate(self,preds:Tensor,targets:Tensor):"""Update the state based on new predictions and targets."""self.n_updates+=1# type: ignoreassertisinstance(self.dice,Tensor)self.dice+=self.compute_stats(preds,targets)
[docs]defcompute(self):"""Aggregate the state over all processes to compute the metric."""dice=100*self.dice/self.n_updates# type: ignorebest_sum_dice=dice[:]top_dice=round(torch.mean(best_sum_dice).item(),2)returntop_dice
@staticmethoddefcompute_stats(preds:Tensor,targets:Tensor):num_classes=preds.shape[1]scores=torch.zeros(num_classes-1,device=preds.device,dtype=torch.float32)foriinrange(1,num_classes):if(targets!=i).all():# no foreground class_,_pred=torch.max(preds,1)scores[i-1]+=1if(_pred!=i).all()else0continue_tp,_fp,_tn,_fn,_=_stat_scores(preds,targets,class_index=i)# type: ignoredenom=(2*_tp+_fp+_fn).to(torch.float)score_cls=(2*_tp).to(torch.float)/denomiftorch.is_nonzero(denom)else0.0scores[i-1]+=score_clsreturnscores
[docs]classCrossEntropy(Metric):"""Torchmetrics cross entropy loss implementation. This class implements cross entropy loss as a :class:`torchmetrics.Metric` so that it can be returned by the :meth:`~.ComposerModel.metrics`. Args: ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. ``ignore_index`` is only applicable when the target contains class indices. Default: ``-100``. dist_sync_on_step (bool, optional): sync distributed metrics every step. Default: ``False``. """# Make torchmetrics call update only oncefull_state_update=Falsedef__init__(self,ignore_index:int=-100,dist_sync_on_step:bool=False):super().__init__(dist_sync_on_step=dist_sync_on_step)self.ignore_index=ignore_indexself.add_state('sum_loss',default=torch.tensor(0.),dist_reduce_fx='sum')self.add_state('total_batches',default=torch.tensor(0),dist_reduce_fx='sum')
[docs]defupdate(self,preds:Tensor,targets:Tensor)->None:"""Update the state with new predictions and targets."""# Loss calculated over samples/batch, accumulate loss over all batchesassertisinstance(self.sum_loss,Tensor)self.sum_loss+=soft_cross_entropy(preds,targets,ignore_index=self.ignore_index)assertisinstance(self.total_batches,Tensor)self.total_batches+=1
[docs]defcompute(self)->Tensor:"""Aggregate state over all processes and compute the metric."""# Return average loss over entire validation datasetassertisinstance(self.total_batches,Tensor)assertisinstance(self.sum_loss,Tensor)returnself.sum_loss/self.total_batches
[docs]classLossMetric(Metric):"""Turns a torch.nn Loss Module into distributed torchmetrics Metric. Args: loss_function (callable): loss function to compute and track. dist_sync_on_step (bool, optional): sync distributed metrics every step. Default: ``False``. """# Make torchmetrics call update only oncefull_state_update=Falsedef__init__(self,loss_function:Callable,dist_sync_on_step:bool=False):super().__init__(dist_sync_on_step=dist_sync_on_step)self.loss_function=loss_functionself.add_state('sum_loss',default=torch.tensor(0.),dist_reduce_fx='sum')self.add_state('total_batches',default=torch.tensor(0),dist_reduce_fx='sum')
[docs]defupdate(self,preds:Tensor,targets:Tensor)->None:"""Update the state with new predictions and targets."""# Loss calculated over samples/batch, accumulate loss over all batchesself.sum_loss+=self.loss_function(preds,targets)self.total_batches+=1# type: ignore
[docs]defcompute(self):"""Aggregate state over all processes and compute the metric."""# Return average loss over entire validation datasetassertisinstance(self.total_batches,Tensor)assertisinstance(self.sum_loss,Tensor)returnself.sum_loss/self.total_batches