Source code for composer.algorithms.mixup.mixup

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Core MixUp classes and functions."""

from __future__ import annotations

import logging
from typing import Any, Callable, Optional, Union

import numpy as np
import torch

from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.loss.utils import ensure_targets_one_hot

log = logging.getLogger(__name__)

__all__ = ['MixUp', 'mixup_batch']


[docs]def mixup_batch( input: torch.Tensor, target: torch.Tensor, mixing: Optional[float] = None, alpha: float = 0.2, indices: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, float]: """Create new samples using convex combinations of pairs of samples. This is done by taking a convex combination of ``input`` with a randomly permuted copy of ``input``. The permutation takes place along the sample axis (``dim=0``). The relative weight of the original ``input`` versus the permuted copy is defined by the ``mixing`` parameter. This parameter should be chosen from a ``Beta(alpha, alpha)`` distribution for some parameter ``alpha > 0``. Note that the same ``mixing`` is used for the whole batch. Args: input (torch.Tensor): input tensor of shape ``(minibatch, ...)``, where ``...`` indicates zero or more dimensions. target (torch.Tensor): target tensor of shape ``(minibatch, ...)``, where ``...`` indicates zero or more dimensions. mixing (float, optional): coefficient used to interpolate between the two examples. If provided, must be in :math:`[0, 1]`. If ``None``, value is drawn from a ``Beta(alpha, alpha)`` distribution. Default: ``None``. alpha (float, optional): parameter for the Beta distribution over ``mixing``. Ignored if ``mixing`` is provided. Default: ``0.2``. indices (torch.Tensor, optional): Permutation of the samples to use. Default: ``None``. Returns: input_mixed (torch.Tensor): batch of inputs after mixup has been applied target_perm (torch.Tensor): The labels of the mixed-in examples mixing (torch.Tensor): the amount of mixing used Example: .. testcode:: import torch from composer.functional import mixup_batch N, C, H, W = 2, 3, 4, 5 X = torch.randn(N, C, H, W) y = torch.randint(num_classes, size=(N,)) X_mixed, y_perm, mixing = mixup_batch( X, y, alpha=0.2, ) """ if mixing is None: mixing = _gen_mixing_coef(alpha) # Create permuted versions of x and y in preparation for interpolation # Use given indices if there are any. if indices is None: permuted_idx = _gen_indices(input.shape[0]) else: permuted_idx = indices x_permuted = input[permuted_idx] permuted_target = target[permuted_idx] # Interpolate between the inputs x_mixup = (1 - mixing) * input + mixing * x_permuted return x_mixup, permuted_target, mixing
[docs]class MixUp(Algorithm): """`MixUp <https://arxiv.org/abs/1710.09412>`_ trains the network on convex batch combinations. The algorithm uses individual examples and targets to make a convex combination of a given batch X with a randomly permuted copy of X. The mixing coefficient is drawn from a ``Beta(alpha, alpha)`` distribution. Training in this fashion sometimes reduces generalization error. Args: alpha (float, optional): the psuedocount for the Beta distribution used to sample mixing parameters. As ``alpha`` grows, the two samples in each pair tend to be weighted more equally. As ``alpha`` approaches 0 from above, the combination approaches only using one element of the pair. Default: ``0.2``. interpolate_loss (bool, optional): Interpolates the loss rather than the labels. A useful trick when using a cross entropy loss. Will produce incorrect behavior if the loss is not a linear function of the targets. Default: ``False`` input_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the input from the batch. Can also be a pair of get and set functions, where the getter is assumed to be first in the pair. The default is 0, which corresponds to any sequence, where the first element is the input. Default: ``0``. target_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the target from the batch. Can also be a pair of get and set functions, where the getter is assumed to be first in the pair. The default is 1, which corresponds to any sequence, where the second element is the target. Default: ``1``. Example: .. testcode:: from composer.algorithms import MixUp algorithm = MixUp(alpha=0.2) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", algorithms=[algorithm], optimizers=[optimizer] ) """ def __init__( self, alpha: float = 0.2, interpolate_loss: bool = False, input_key: Union[str, int, tuple[Callable, Callable], Any] = 0, target_key: Union[str, int, tuple[Callable, Callable], Any] = 1, ): self.alpha = alpha self.interpolate_loss = interpolate_loss self.mixing = 0.0 self.indices = torch.Tensor() self.permuted_target = torch.Tensor() self.input_key, self.target_key = input_key, target_key def match(self, event: Event, state: State) -> bool: if self.interpolate_loss: return event in [Event.BEFORE_FORWARD, Event.BEFORE_BACKWARD] else: return event in [Event.BEFORE_FORWARD, Event.BEFORE_LOSS] def apply(self, event: Event, state: State, logger: Logger) -> None: input, target = state.batch_get_item(key=self.input_key), state.batch_get_item(key=self.target_key) if event == Event.BEFORE_FORWARD: if not isinstance(input, torch.Tensor): raise NotImplementedError('Multiple tensors for inputs not supported yet.') if not isinstance(target, torch.Tensor): raise NotImplementedError('Multiple tensors for targets not supported yet.') self.mixing = _gen_mixing_coef(self.alpha) self.indices = _gen_indices(input.shape[0]) new_input, self.permuted_target, _ = mixup_batch( input, target, mixing=self.mixing, indices=self.indices, ) state.batch_set_item(self.input_key, new_input) if not self.interpolate_loss and event == Event.BEFORE_LOSS: # Interpolate the targets if not isinstance(state.outputs, torch.Tensor): raise NotImplementedError('Multiple output tensors not supported yet') if not isinstance(target, torch.Tensor): raise NotImplementedError('Multiple target tensors not supported yet') # Make sure that the targets are dense/one-hot target = ensure_targets_one_hot(state.outputs, target) permuted_target = ensure_targets_one_hot(state.outputs, self.permuted_target) # Interpolate to get the new target mixed_up_target = (1 - self.mixing) * target + self.mixing * permuted_target # Create the new batch state.batch_set_item(self.target_key, mixed_up_target) if self.interpolate_loss and event == Event.BEFORE_BACKWARD: # Grab the loss function if hasattr(state.model, 'loss'): loss_fn = state.model.loss elif hasattr(state.model, 'module') and hasattr(state.model.module, 'loss'): if isinstance(state.model.module, torch.nn.Module): loss_fn = state.model.module.loss else: raise TypeError('state.model.module must be a torch module') else: raise AttributeError('Loss must be accesable via model.loss or model.module.loss') # Verify that the loss is callable if not callable(loss_fn): raise TypeError('Loss must be callable') # Interpolate the loss new_loss = loss_fn(state.outputs, (input, self.permuted_target)) if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') if not isinstance(new_loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') state.loss = (1 - self.mixing) * state.loss + self.mixing * new_loss
def _gen_mixing_coef(alpha: float) -> float: """Samples ``max(z, 1-z), z ~ Beta(alpha, alpha)``.""" # First check if alpha is positive. assert alpha >= 0 # Draw the mixing parameter from a beta distribution. # Check here is needed because beta distribution requires alpha > 0 # but alpha = 0 is fine for mixup. if alpha == 0: mixing_lambda = 0 else: mixing_lambda = np.random.beta(alpha, alpha) # for symmetric beta distribution, can always use 0 <= lambda <= .5; # this way the "main" label is always the original one, which keeps # the training accuracy meaningful return min(mixing_lambda, 1. - mixing_lambda) def _gen_indices(num_samples: int) -> torch.Tensor: """Generates a random permutation of the batch indices.""" return torch.randperm(num_samples)