Source code for composer.algorithms.label_smoothing.label_smoothing

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Core Label Smoothing classes and functions."""

from __future__ import annotations

from typing import Any, Callable, Optional, Union

import torch

from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.loss.utils import ensure_targets_one_hot

__all__ = ['LabelSmoothing', 'smooth_labels']


[docs]def smooth_labels(logits: torch.Tensor, target: torch.Tensor, smoothing: float = 0.1): """Shrink targets towards a uniform distribution as in `Szegedy et al <https://arxiv.org/abs/1512.00567>`_. The smoothed labels are computed as ``(1 - smoothing) * targets + smoothing * unif`` where ``unif`` is a vector with elements all equal to ``1 / num_classes``. Args: logits (torch.Tensor): predicted value for ``target``, or any other tensor with the same shape. Shape must be ``(N, num_classes, ...)`` for ``N`` examples and ``num_classes`` classes with any number of optional extra dimensions. target (torch.Tensor): target tensor of either shape ``N`` or ``(N, num_classes, ...)``. In the former case, elements of ``targets`` must be integer class ids in the range ``0..num_classes``. In the latter case, ``targets`` must have the same shape as ``logits``. smoothing (float, optional): strength of the label smoothing, in :math:`[0, 1]`. ``smoothing=0`` means no label smoothing, and ``smoothing=1`` means maximal smoothing (targets are ignored). Default: ``0.1``. Returns: torch.Tensor: The smoothed targets. Example: .. testcode:: import torch num_classes = 10 targets = torch.randint(num_classes, size=(100,)) from composer.algorithms.label_smoothing import smooth_labels new_targets = smooth_labels(logits=logits, target=targets, smoothing=0.1) """ target = ensure_targets_one_hot(logits, target) n_classes = logits.shape[1] return (target * (1. - smoothing)) + (smoothing / n_classes)
[docs]class LabelSmoothing(Algorithm): """Shrink targets towards a uniform distribution as in `Szegedy et al <https://arxiv.org/abs/1512.00567>`_. The smoothed labels are computed as ``(1 - smoothing) * targets + smoothing * unif`` where ``unif`` is a vector with elements all equal to ``1 / num_classes``. Args: smoothing: Strength of the label smoothing, in :math:`[0, 1]`. ``smoothing=0`` means no label smoothing, and ``smoothing=1`` means maximal smoothing (targets are ignored). Default: ``0.1``. target_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the target from the batch. Can also be a pair of get and set functions, where the getter is assumed to be first in the pair. The default is 1, which corresponds to any sequence, where the second element is the target. Default: ``1``. Example: .. testcode:: from composer.algorithms import LabelSmoothing algorithm = LabelSmoothing(smoothing=0.1) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", algorithms=[algorithm], optimizers=[optimizer] ) """ def __init__( self, smoothing: float = 0.1, target_key: Union[str, int, tuple[Callable, Callable], Any] = 1, ): self.smoothing = smoothing self.original_labels = torch.Tensor() self.target_key = target_key def match(self, event: Event, state: State) -> bool: return event in [Event.BEFORE_LOSS, Event.AFTER_LOSS] def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: labels = state.batch_get_item(self.target_key) if event == Event.BEFORE_LOSS: assert isinstance(state.outputs, torch.Tensor), 'Multiple tensors not supported yet' assert isinstance(labels, torch.Tensor), 'Multiple tensors not supported yet' self.original_labels = labels.clone() smoothed_labels = smooth_labels( state.outputs, labels, smoothing=self.smoothing, ) state.batch_set_item(self.target_key, smoothed_labels) elif event == Event.AFTER_LOSS: # restore the target to the non-smoothed version state.batch_set_item(self.target_key, self.original_labels)