# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Core CutMix classes and functions."""
from __future__ import annotations
import logging
from typing import Any, Callable, Optional, Union
import numpy as np
import torch
from torch import Tensor
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.loss.utils import ensure_targets_one_hot
log = logging.getLogger(__name__)
__all__ = ['CutMix', 'cutmix_batch']
[docs]def cutmix_batch(
input: Tensor,
target: Tensor,
length: Optional[float] = None,
alpha: float = 1.,
bbox: Optional[tuple] = None,
indices: Optional[torch.Tensor] = None,
uniform_sampling: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, float, tuple]:
"""Create new samples using combinations of pairs of samples.
This is done by masking a region of each image in ``input`` and filling
the masked region with the corresponding content from a random different
image in``input``.
The position of the masked region is determined by drawing a center point
uniformly at random from all spatial positions.
The area of the masked region is computed using either ``length`` or
``alpha``. If ``length`` is provided, it directly determines the size
of the masked region. If it is not provided, the fraction of the input
area to mask is drawn from a ``Beta(alpha, alpha)`` distribution.
The original paper uses a fixed value of ``alpha = 1``.
Alternatively, one may provide a bounding box to mask directly, in
which case ``alpha`` is ignored and ``length`` must not be provided.
The same masked region is used for the whole batch.
.. note::
The masked region is clipped at the spatial boundaries of the inputs.
This means that there is no padding required, but the actual region
used may be smaller than the nominal size computed using ``length``
or ``alpha``.
Args:
input (torch.Tensor): input tensor of shape ``(N, C, H, W)``.
target (torch.Tensor): target tensor of either shape ``N`` or
``(N, num_classes)``. In the former case, elements of ``target``
must be integer class ids in the range ``0..num_classes``. In the
latter case, rows of ``target`` may be arbitrary vectors of targets,
including, e.g., one-hot encoded class labels, smoothed class
labels, or multi-output regression targets.
length (float, optional): Relative side length of the masked region.
If specified, ``length`` is interpreted as a fraction of ``H`` and
``W``, and the resulting box is of size ``(length * H, length * W)``.
Default: ``None``.
alpha (float, optional): parameter for the Beta distribution over
the fraction of the input to mask. Ignored if ``length`` is
provided. Default: ``1``.
bbox (tuple, optional): predetermined ``(x1, y1, x2, y2)``
coordinates of the bounding box. Default: ``None``.
indices (torch.Tensor, optional): Permutation of the samples to use.
Default: ``None``.
uniform_sampling (bool, optional): If ``True``, sample the bounding box
such that each pixel has an equal probability of being mixed.
If ``False``, defaults to the sampling used in the original paper
implementation. Default: ``False``.
Returns:
input_mixed (torch.Tensor): batch of inputs after cutmix has been
applied.
target_perm (torch.Tensor): The labels of the mixed-in examples
area (float): The fractional area of the unmixed region.
bounding_box (tuple): the ``(left, top, right, bottom)`` coordinates of
the bounding box that defines the mixed region.
Raises:
ValueError: If both ``length`` and ``bbox`` are provided.
Example:
.. testcode::
import torch
from composer.functional import cutmix_batch
N, C, H, W = 2, 3, 4, 5
num_classes = 10
X = torch.randn(N, C, H, W)
y = torch.randint(num_classes, size=(N,))
X_mixed, target_perm, area, _ = cutmix_batch(X, y, alpha=0.2)
"""
if bbox is not None and length is not None:
raise ValueError(f'Cannot provide both length and bbox; got {length} and {bbox}')
# Create shuffled indicies across the batch in preparation for cutting and mixing.
# Use given indices if there are any.
if indices is None:
shuffled_idx = _gen_indices(input)
else:
shuffled_idx = indices
H, W = input.shape[-2], input.shape[-1]
# figure out fraction of area to cut
if length is None:
cutmix_lambda = _gen_cutmix_coef(alpha)
else:
cut_w = int(length * W)
cut_h = int(length * H)
cutmix_lambda = (cut_w * cut_h) / (H * W)
# Create the new inputs.
X_cutmix = torch.clone(input)
# Sample a rectangular box using lambda. Use variable names from the paper.
if bbox:
rx, ry, rw, rh = bbox[0], bbox[1], bbox[2], bbox[3]
box_area = (rw - rx) * (rh - ry)
cutmix_lambda = box_area / (H * W)
else:
rx, ry, rw, rh = _rand_bbox(input.shape[2], input.shape[3], cutmix_lambda, uniform_sampling=uniform_sampling)
bbox = (rx, ry, rw, rh)
# Fill in the box with a part of a random image.
X_cutmix[:, :, rx:rw, ry:rh] = X_cutmix[shuffled_idx, :, rx:rw, ry:rh]
# adjust lambda to exactly match pixel ratio. This is an implementation detail taken from
# the original implementation, and implies lambda is not actually beta distributed.
adjusted_lambda = _adjust_lambda(input, bbox)
# Make a shuffled version of y for interpolation
y_shuffled = target[shuffled_idx]
return X_cutmix, y_shuffled, adjusted_lambda, bbox
[docs]class CutMix(Algorithm):
"""`CutMix <https://arxiv.org/abs/1905.04899>`_ trains the network on non-overlapping combinations
of pairs of examples and interpolated targets rather than individual examples and targets.
This is done by taking a non-overlapping combination of a given batch X with a
randomly permuted copy of X. The area is drawn from a ``Beta(alpha, alpha)``
distribution.
Training in this fashion sometimes reduces generalization error.
Args:
alpha (float, optional): the psuedocount for the Beta distribution
used to sample area parameters. As ``alpha`` grows, the two samples
in each pair tend to be weighted more equally. As ``alpha``
approaches 0 from above, the combination approaches only using
one element of the pair. Default: ``1``.
interpolate_loss (bool, optional): Interpolates the loss rather than the labels.
A useful trick when using a cross entropy loss. Will produce incorrect behavior
if the loss is not a linear function of the targets. Default: ``False``
uniform_sampling (bool, optional): If ``True``, sample the bounding
box such that each pixel has an equal probability of being mixed.
If ``False``, defaults to the sampling used in the original
paper implementation. Default: ``False``.
input_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the input
from the batch. Can also be a pair of get and set functions, where the getter
is assumed to be first in the pair. The default is 0, which corresponds to any sequence, where the first element
is the input. Default: ``0``.
target_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the target
from the batch. Can also be a pair of get and set functions, where the getter
is assumed to be first in the pair. The default is 1, which corresponds to any sequence, where the second element
is the target. Default: ``1``.
Example:
.. testcode::
from composer.algorithms import CutMix
algorithm = CutMix(alpha=0.2)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration="1ep",
algorithms=[algorithm],
optimizers=[optimizer]
)
"""
def __init__(
self,
alpha: float = 1.,
interpolate_loss: bool = False,
uniform_sampling: bool = False,
input_key: Union[str, int, tuple[Callable, Callable], Any] = 0,
target_key: Union[str, int, tuple[Callable, Callable], Any] = 1,
):
self.alpha = alpha
self.interpolate_loss = interpolate_loss
self._uniform_sampling = uniform_sampling
self._indices = torch.Tensor()
self._cutmix_lambda = 0.0
self._bbox: tuple[int, int, int, int] = (0, 0, 0, 0)
self._permuted_target = torch.Tensor()
self._adjusted_lambda = 0.0
self.input_key, self.target_key = input_key, target_key
def match(self, event: Event, state: State) -> bool:
if self.interpolate_loss:
return event in [Event.BEFORE_FORWARD, Event.BEFORE_BACKWARD]
else:
return event in [Event.BEFORE_FORWARD, Event.BEFORE_LOSS]
def apply(self, event: Event, state: State, logger: Logger) -> None:
input = state.batch_get_item(key=self.input_key)
target = state.batch_get_item(key=self.target_key)
if not isinstance(input, torch.Tensor):
raise NotImplementedError('Multiple tensors for inputs not supported yet.')
if not isinstance(target, torch.Tensor):
raise NotImplementedError('Multiple tensors for targets not supported yet.')
alpha = self.alpha
if event == Event.BEFORE_FORWARD:
# these are saved only for testing
self._indices = _gen_indices(input)
_cutmix_lambda = _gen_cutmix_coef(alpha)
self._bbox = _rand_bbox(
input.shape[2],
input.shape[3],
_cutmix_lambda,
uniform_sampling=self._uniform_sampling,
)
self._adjusted_lambda = _adjust_lambda(input, self._bbox)
new_input, self._permuted_target, _, _ = cutmix_batch(
input=input,
target=target,
alpha=self.alpha,
bbox=self._bbox,
indices=self._indices,
uniform_sampling=self._uniform_sampling,
)
state.batch_set_item(key=self.input_key, value=new_input)
if not self.interpolate_loss and event == Event.BEFORE_LOSS:
# Interpolate the targets
if not isinstance(state.outputs, torch.Tensor):
raise NotImplementedError('Multiple output tensors not supported yet')
if not isinstance(target, torch.Tensor):
raise NotImplementedError('Multiple target tensors not supported yet')
if self._permuted_target.ndim > 2 and self._permuted_target.shape[-2:] == input.shape[-2:]:
# Target has the same height and width as the input, no need to interpolate.
x1, y1, x2, y2 = self._bbox
target[..., x1:x2, y1:y2] = self._permuted_target[..., x1:x2, y1:y2]
else:
# Need to interpolate on dense/one-hot targets.
target = ensure_targets_one_hot(state.outputs, target)
permuted_target = ensure_targets_one_hot(state.outputs, self._permuted_target)
# Interpolate to get the new target
target = self._adjusted_lambda * target + (1 - self._adjusted_lambda) * permuted_target
# Create the new batch
state.batch_set_item(key=self.target_key, value=target)
if self.interpolate_loss and event == Event.BEFORE_BACKWARD:
if self._permuted_target.ndim > 2 and self._permuted_target.shape[-2:] == input.shape[-2:]:
raise ValueError("Can't interpolate loss when target has the same height and width as the input")
# Grab the loss function
if hasattr(state.model, 'loss'):
loss_fn = state.model.loss
elif hasattr(state.model, 'module') and hasattr(state.model.module, 'loss'):
if isinstance(state.model.module, torch.nn.Module):
loss_fn = state.model.module.loss
else:
raise TypeError('state.model.module must be a torch module')
else:
raise AttributeError('Loss must be accessible via model.loss or model.module.loss')
# Verify that the loss is callable
if not callable(loss_fn):
raise TypeError('Loss must be callable')
# Interpolate the loss
new_loss = loss_fn(state.outputs, (input, self._permuted_target))
if not isinstance(state.loss, torch.Tensor):
raise NotImplementedError('Multiple losses not supported yet')
if not isinstance(new_loss, torch.Tensor):
raise NotImplementedError('Multiple losses not supported yet')
state.loss = self._adjusted_lambda * state.loss + (1 - self._adjusted_lambda) * new_loss
def _gen_indices(x: Tensor) -> Tensor:
"""Generates indices of a random permutation of elements of a batch.
Args:
x (torch.Tensor): input tensor of shape ``(B, d1, d2, ..., dn)``,
B is batch size, d1-dn are feature dimensions.
Returns:
indices: A random permutation of the batch indices.
"""
return torch.randperm(x.shape[0])
def _gen_cutmix_coef(alpha: float) -> float:
"""Generates lambda from ``Beta(alpha, alpha)``.
Args:
alpha (float): Parameter for the ``Beta(alpha, alpha)`` distribution.
Returns:
cutmix_lambda: Lambda parameter for performing CutMix.
"""
# First check if alpha is positive.
assert alpha >= 0
# Draw the area parameter from a beta distribution.
# Check here is needed because beta distribution requires alpha > 0
# but alpha = 0 is fine for cutmix.
if alpha == 0:
cutmix_lambda = 0
else:
cutmix_lambda = np.random.beta(alpha, alpha)
return cutmix_lambda
def _rand_bbox(
W: int,
H: int,
cutmix_lambda: float,
cx: Optional[int] = None,
cy: Optional[int] = None,
uniform_sampling: bool = False,
) -> tuple[int, int, int, int]:
"""Randomly samples a bounding box with area determined by ``cutmix_lambda``.
Adapted from original implementation https://github.com/clovaai/CutMix-PyTorch
Args:
W (int): Width of the image
H (int): Height of the image
cutmix_lambda (float): Lambda param from cutmix, used to set the area of the
box if ``cut_w`` or ``cut_h`` is not provided.
cx (int, optional): Optional x coordinate of the center of the box.
cy (int, optional): Optional y coordinate of the center of the box.
uniform_sampling (bool, optional): If true, sample the bounding box such that each pixel
has an equal probability of being mixed. If false, defaults to the
sampling used in the original paper implementation. Default: ``False``.
Returns:
bbx1: Leftmost edge of the bounding box
bby1: Top edge of the bounding box
bbx2: Rightmost edge of the bounding box
bby2: Bottom edge of the bounding box
"""
cut_ratio = np.sqrt(1.0 - cutmix_lambda)
cut_w = int(W * cut_ratio)
cut_h = int(H * cut_ratio)
# uniform
if cx is None:
if uniform_sampling is True:
cx = np.random.randint(-cut_w // 2, high=W + cut_w // 2)
else:
cx = np.random.randint(W)
if cy is None:
if uniform_sampling is True:
cy = np.random.randint(-cut_h // 2, high=H + cut_h // 2)
else:
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def _adjust_lambda(x: Tensor, bbox: tuple) -> float:
"""Rescale the cutmix lambda according to the size of the clipped bounding box.
Args:
x (torch.Tensor): input tensor of shape ``(B, d1, d2, ..., dn)``, B is batch size, d1-dn
are feature dimensions.
bbox (tuple): (x1, y1, x2, y2) coordinates of the boundind box, obeying x2 > x1, y2 > y1.
Returns:
adjusted_lambda: Rescaled cutmix_lambda to account for part of the bounding box
being potentially out of bounds of the input.
"""
rx, ry, rw, rh = bbox[0], bbox[1], bbox[2], bbox[3]
adjusted_lambda = 1 - ((rw - rx) * (rh - ry) / (x.size()[-1] * x.size()[-2]))
return adjusted_lambda