# Copyright 2021 MosaicML. All Rights Reserved.

"""Core MixUp classes and functions."""

from __future__ import annotations

import logging
from typing import Optional, Tuple

import numpy as np
import torch

from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.loss.utils import ensure_targets_one_hot

log = logging.getLogger(__name__)

__all__ = ["MixUp", "mixup_batch"]

[docs]def mixup_batch(input: torch.Tensor, target: torch.Tensor, mixing: Optional[float] = None, alpha: float = 0.2, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, float]: """Create new samples using convex combinations of pairs of samples. This is done by taking a convex combination of ``input`` with a randomly permuted copy of ``input``. The permutation takes place along the sample axis (dim 0). The relative weight of the original ``input`` versus the permuted copy is defined by the ``mixing`` parameter. This parameter should be chosen from a ``Beta(alpha, alpha)`` distribution for some parameter ``alpha > 0``. Note that the same ``mixing`` is used for the whole batch. Args: input (torch.Tensor): input tensor of shape ``(minibatch, ...)``, where ``...`` indicates zero or more dimensions. target (torch.Tensor): target tensor of shape ``(minibatch, ...)``, where ``...`` indicates zero or more dimensions. mixing (float, optional): coefficient used to interpolate between the two examples. If provided, must be in :math:`[0, 1]`. If ``None``, value is drawn from a ``Beta(alpha, alpha)`` distribution. Default: ``None``. alpha (float, optional): parameter for the Beta distribution over ``mixing``. Ignored if ``mixing`` is provided. Default: ``0.2``. indices (Tensor, optional): Permutation of the samples to use. Default: ``None``. Returns: input_mixed (torch.Tensor): batch of inputs after mixup has been applied target_perm (torch.Tensor): The labels of the mixed-in examples mixing (torch.Tensor): the amount of mixing used Example: .. testcode:: import torch from composer.functional import mixup_batch N, C, H, W = 2, 3, 4, 5 X = torch.randn(N, C, H, W) y = torch.randint(num_classes, size=(N,)) X_mixed, y_perm, mixing = mixup_batch( X, y, alpha=0.2) """ if mixing is None: mixing = _gen_mixing_coef(alpha) # Create permuted versions of x and y in preparation for interpolation # Use given indices if there are any. if indices is None: permuted_idx = _gen_indices(input.shape[0]) else: permuted_idx = indices x_permuted = input[permuted_idx] permuted_target = target[permuted_idx] # Interpolate between the inputs x_mixup = (1 - mixing) * input + mixing * x_permuted return x_mixup, permuted_target, mixing
[docs]class MixUp(Algorithm): """`MixUp <>`_ trains the network on convex combinations of pairs of examples and targets rather than individual examples and targets. This is done by taking a convex combination of a given batch X with a randomly permuted copy of X. The mixing coefficient is drawn from a ``Beta(alpha, alpha)`` distribution. Training in this fashion sometimes reduces generalization error. Args: alpha (float, optional): the psuedocount for the Beta distribution used to sample mixing parameters. As ``alpha`` grows, the two samples in each pair tend to be weighted more equally. As ``alpha`` approaches 0 from above, the combination approaches only using one element of the pair. Default: ``0.2``. interpolate_loss (bool, optional): Interpolates the loss rather than the labels. A useful trick when using a cross entropy loss. Will produce incorrect behavior if the loss is not a linear function of the targets. Default: ``False`` Example: .. testcode:: from composer.algorithms import MixUp algorithm = MixUp(alpha=0.2) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", algorithms=[algorithm], optimizers=[optimizer] ) """ def __init__(self, alpha: float = 0.2, interpolate_loss: bool = False): self.alpha = alpha self.interpolate_loss = interpolate_loss self.mixing = 0.0 self.indices = torch.Tensor() self.permuted_target = torch.Tensor() def match(self, event: Event, state: State) -> bool: if self.interpolate_loss: return event in [Event.BEFORE_FORWARD, Event.BEFORE_BACKWARD] else: return event in [Event.BEFORE_FORWARD, Event.BEFORE_LOSS] def apply(self, event: Event, state: State, logger: Logger) -> None: input, target = state.batch_pair if event == Event.BEFORE_FORWARD: if not isinstance(input, torch.Tensor): raise NotImplementedError("Multiple tensors for inputs not supported yet.") if not isinstance(target, torch.Tensor): raise NotImplementedError("Multiple tensors for targets not supported yet.") self.mixing = _gen_mixing_coef(self.alpha) self.indices = _gen_indices(input.shape[0]) new_input, self.permuted_target, _ = mixup_batch( input, target, mixing=self.mixing, indices=self.indices, ) state.batch = (new_input, target) if not self.interpolate_loss and event == Event.BEFORE_LOSS: # Interpolate the targets if not isinstance(state.outputs, torch.Tensor): raise NotImplementedError("Multiple output tensors not supported yet") if not isinstance(target, torch.Tensor): raise NotImplementedError("Multiple target tensors not supported yet") # Make sure that the targets are dense/one-hot target = ensure_targets_one_hot(state.outputs, target) permuted_target = ensure_targets_one_hot(state.outputs, self.permuted_target) # Interpolate to get the new target mixed_up_target = (1 - self.mixing) * target + self.mixing * permuted_target # Create the new batch state.batch = (input, mixed_up_target) if self.interpolate_loss and event == Event.BEFORE_BACKWARD: # Grab the loss function if hasattr(state.model, "loss"): loss_fn = state.model.loss elif hasattr(state.model, "module") and hasattr(state.model.module, "loss"): if isinstance(state.model.module, torch.nn.Module): loss_fn = state.model.module.loss else: raise TypeError("state.model.module must be a torch module") else: raise AttributeError("Loss must be accesable via model.loss or model.module.loss") # Verify that the loss is callable if not callable(loss_fn): raise TypeError("Loss must be callable") # Interpolate the loss new_loss = loss_fn(state.outputs, (input, self.permuted_target)) if not isinstance(state.loss, torch.Tensor): raise NotImplementedError("Multiple losses not supported yet") if not isinstance(new_loss, torch.Tensor): raise NotImplementedError("Multiple losses not supported yet") state.loss = (1 - self.mixing) * state.loss + self.mixing * new_loss
def _gen_mixing_coef(alpha: float) -> float: """Samples ``max(z, 1-z), z ~ Beta(alpha, alpha)``.""" # First check if alpha is positive. assert alpha >= 0 # Draw the mixing parameter from a beta distribution. # Check here is needed because beta distribution requires alpha > 0 # but alpha = 0 is fine for mixup. if alpha == 0: mixing_lambda = 0 else: mixing_lambda = np.random.beta(alpha, alpha) # for symmetric beta distribution, can always use 0 <= lambda <= .5; # this way the "main" label is always the original one, which keeps # the training accuracy meaningful return min(mixing_lambda, 1. - mixing_lambda) def _gen_indices(num_samples: int) -> torch.Tensor: """Generates a random permutation of the batch indices.""" return torch.randperm(num_samples)