# Copyright 2021 MosaicML. All Rights Reserved.
"""Core CutMix classes and functions."""
from __future__ import annotations
import logging
from typing import Optional, Tuple
import numpy as np
import torch
from torch import Tensor
from torch.nn import functional as F
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.loss.utils import check_for_index_targets
log = logging.getLogger(__name__)
__all__ = ["CutMix", "cutmix_batch"]
[docs]def cutmix_batch(input: Tensor,
target: Tensor,
num_classes: int,
length: Optional[float] = None,
alpha: float = 1.,
bbox: Optional[Tuple] = None,
indices: Optional[torch.Tensor] = None,
uniform_sampling: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""Create new samples using combinations of pairs of samples.
This is done by masking a region of each image in ``input`` and filling
the masked region with the corresponding content from a random different
image in``input``.
The position of the masked region is determined by drawing a center point
uniformly at random from all spatial positions.
The area of the masked region is computed using either ``length`` or
``alpha``. If ``length`` is provided, it directly determines the size
of the masked region. If it is not provided, the fraction of the input
area to mask is drawn from a ``Beta(alpha, alpha)`` distribution.
The original paper used a fixed value of ``alpha = 1``.
Alternatively, one may provide a bounding box to mask directly, in
which case ``alpha`` is ignored and ``length`` must not be provided.
The same masked region is used for the whole batch.
.. note::
The masked region is clipped at the spatial boundaries of the inputs.
This means that there is no padding required, but the actual region
used may be smaller than the nominal size computed using ``length``
or ``alpha``.
Args:
input (torch.Tensor): input tensor of shape ``(N, C, H, W)``
target (torch.Tensor): target tensor of either shape ``N`` or
``(N, num_classes)``. In the former case, elements of ``target``
must be integer class ids in the range ``0..num_classes``. In the
latter case, rows of ``target`` may be arbitrary vectors of targets,
including, e.g., one-hot encoded class labels, smoothed class
labels, or multi-output regression targets.
num_classes (int): total number of classes or output variables
length (float, optional): Relative side length of the masked region.
If specified, ``length`` is interpreted as a fraction of ``H`` and
``W``, and the resulting box is of size ``(length * H, length * W)``.
Default: ``None``.
alpha (float, optional): parameter for the Beta distribution over
the fraction of the input to mask. Ignored if ``length`` is
provided. Default: ``1``.
bbox (tuple, optional): predetermined ``(x1, y1, x2, y2)``
coordinates of the bounding box. Default: ``None``.
indices (torch.Tensor, optional): Permutation of the samples to use.
Default: ``None``.
uniform_sampling (bool, optional): If ``True``, sample the bounding box
such that each pixel has an equal probability of being mixed.
If ``False``, defaults to the sampling used in the original paper
implementation. Default: ``False``.
Returns:
input_mixed (torch.Tensor): batch of inputs after cutmix has been
applied.
target_mixed (torch.Tensor): soft labels for mixed input samples.
These are a convex combination of the (possibly one-hot-encoded)
labels from the original samples and the samples chosen to fill
the masked regions, with the relative weighting equal to the
fraction of the spatial size that is cut.
E.g., if a sample was originally an image with label ``0`` and
40% of the image of was replaced with data from an image with label
``2``, the resulting labels, assuming only three classes, would be
``[1, 0, 0] * 0.6 + [0, 0, 1] * 0.4 = [0.6, 0, 0.4]``.
Raises:
ValueError: If both ``length`` and ``bbox`` are provided.
Example:
.. testcode::
import torch
from composer.functional import cutmix_batch
N, C, H, W = 2, 3, 4, 5
num_classes = 10
X = torch.randn(N, C, H, W)
y = torch.randint(num_classes, size=(N,))
X_mixed, y_mixed = cutmix_batch(
X, y, num_classes=num_classes, alpha=0.2)
"""
if bbox is not None and length is not None:
raise ValueError(f"Cannot provide both length and bbox; got {length} and {bbox}")
# Create shuffled indicies across the batch in preparation for cutting and mixing.
# Use given indices if there are any.
if indices is None:
shuffled_idx = _gen_indices(input)
else:
shuffled_idx = indices
H, W = input.shape[-2], input.shape[-1]
# figure out fraction of area to cut
if length is None:
cutmix_lambda = _gen_cutmix_coef(alpha)
else:
cut_w = int(length * W)
cut_h = int(length * H)
cutmix_lambda = (cut_w * cut_h) / (H * W)
# Create the new inputs.
X_cutmix = torch.clone(input)
# Sample a rectangular box using lambda. Use variable names from the paper.
if bbox:
rx, ry, rw, rh = bbox[0], bbox[1], bbox[2], bbox[3]
box_area = (rw - rx) * (rh - ry)
cutmix_lambda = box_area / (H * W)
else:
rx, ry, rw, rh = _rand_bbox(input.shape[2], input.shape[3], cutmix_lambda, uniform_sampling=uniform_sampling)
bbox = (rx, ry, rw, rh)
# Fill in the box with a part of a random image.
X_cutmix[:, :, rx:rw, ry:rh] = X_cutmix[shuffled_idx, :, rx:rw, ry:rh]
# adjust lambda to exactly match pixel ratio. This is an implementation detail taken from
# the original implementation, and implies lambda is not actually beta distributed.
adjusted_lambda = _adjust_lambda(cutmix_lambda, input, bbox)
# Make a shuffled version of y for interpolation
y_shuffled = target[shuffled_idx]
# Interpolate between labels using the adjusted lambda
# First check if labels are indices. If so, convert them to onehots.
# This is under the assumption that the loss expects torch.LongTensor, which is true for pytorch cross_entropy
if check_for_index_targets(target):
y_onehot = F.one_hot(target, num_classes=num_classes)
y_shuffled_onehot = F.one_hot(y_shuffled, num_classes=num_classes)
y_cutmix = adjusted_lambda * y_onehot + (1 - adjusted_lambda) * y_shuffled_onehot
else:
y_cutmix = adjusted_lambda * target + (1 - adjusted_lambda) * y_shuffled
return X_cutmix, y_cutmix
[docs]class CutMix(Algorithm):
"""`CutMix <https://arxiv.org/abs/1905.04899>`_ trains the network on non-overlapping combinations of pairs of
examples and interpolated targets rather than individual examples and targets.
This is done by taking a non-overlapping combination of a given batch X with a
randomly permuted copy of X. The area is drawn from a ``Beta(alpha, alpha)``
distribution.
Training in this fashion sometimes reduces generalization error.
Args:
num_classes (int): the number of classes in the task labels.
alpha (float, optional): the psuedocount for the Beta distribution
used to sample area parameters. As ``alpha`` grows, the two samples
in each pair tend to be weighted more equally. As ``alpha``
approaches 0 from above, the combination approaches only using
one element of the pair. Default: ``1``.
uniform_sampling (bool, optional): If ``True``, sample the bounding
box such that each pixel has an equal probability of being mixed.
If ``False``, defaults to the sampling used in the original
paper implementation. Default: ``False``.
Example:
.. testcode::
from composer.algorithms import CutMix
algorithm = CutMix(num_classes=10, alpha=0.2)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration="1ep",
algorithms=[algorithm],
optimizers=[optimizer]
)
"""
def __init__(self, num_classes: int, alpha: float = 1., uniform_sampling: bool = False):
self.num_classes = num_classes
self.alpha = alpha
self._uniform_sampling = uniform_sampling
self._indices = torch.Tensor()
self._cutmix_lambda = 0.0
self._bbox: Tuple[int, int, int, int] = (0, 0, 0, 0)
[docs] def match(self, event: Event, state: State) -> bool:
"""Runs on Event.INIT and Event.AFTER_DATALOADER.
Args:
event (:class:`Event`): The current event.
state (:class:`State`): The current state.
Returns:
bool: True if this algorithm should run now.
"""
return event == Event.AFTER_DATALOADER
[docs] def apply(self, event: Event, state: State, logger: Logger) -> None:
"""Applies CutMix augmentation on State input.
Args:
event (Event): the current event
state (State): the current trainer state
logger (Logger): the training logger
"""
input, target = state.batch_pair
assert isinstance(input, Tensor) and isinstance(target, Tensor), \
"Multiple tensors for inputs or targets not supported yet."
alpha = self.alpha
# these are saved only for testing
self._indices = _gen_indices(input)
_cutmix_lambda = _gen_cutmix_coef(alpha)
self._bbox = _rand_bbox(input.shape[2], input.shape[3], _cutmix_lambda, uniform_sampling=self._uniform_sampling)
self._cutmix_lambda = _adjust_lambda(_cutmix_lambda, input, self._bbox)
new_input, new_target = cutmix_batch(
input=input,
target=target,
num_classes=self.num_classes,
alpha=alpha,
bbox=self._bbox,
indices=self._indices,
)
state.batch = (new_input, new_target)
def _gen_indices(x: Tensor) -> Tensor:
"""Generates indices of a random permutation of elements of a batch.
Args:
x: input tensor of shape (B, d1, d2, ..., dn), B is batch size, d1-dn
are feature dimensions.
Returns:
indices: A random permutation of the batch indices.
"""
return torch.randperm(x.shape[0])
def _gen_cutmix_coef(alpha: float) -> float:
"""Generates lambda from ``Beta(alpha, alpha)``
Args:
alpha: Parameter for the ``Beta(alpha, alpha)`` distribution
Returns:
cutmix_lambda: Lambda parameter for performing cutmix.
"""
# First check if alpha is positive.
assert alpha >= 0
# Draw the area parameter from a beta distribution.
# Check here is needed because beta distribution requires alpha > 0
# but alpha = 0 is fine for cutmix.
if alpha == 0:
cutmix_lambda = 0
else:
cutmix_lambda = np.random.beta(alpha, alpha)
return cutmix_lambda
def _rand_bbox(W: int,
H: int,
cutmix_lambda: float,
cx: Optional[int] = None,
cy: Optional[int] = None,
uniform_sampling: bool = False) -> Tuple[int, int, int, int]:
"""Randomly samples a bounding box with area determined by cutmix_lambda.
Adapted from original implementation https://github.com/clovaai/CutMix-PyTorch
Args:
W: Width of the image
H: Height of the image
cutmix_lambda: Lambda param from cutmix, used to set the area of the
box if ``cut_w`` or ``cut_h`` is not provided.
cx: Optional x coordinate of the center of the box.
cy: Optional y coordinate of the center of the box.
cut_w: Optional width of the box
cut_h: Optional height of the box
uniform_sampling: If true, sample the bounding box such that each pixel
has an equal probability of being mixed. If false, defaults to the
sampling used in the original paper implementation.
Returns:
bbx1: Leftmost edge of the bounding box
bby1: Top edge of the bounding box
bbx2: Rightmost edge of the bounding box
bby2: Bottom edge of the bounding box
"""
cut_ratio = np.sqrt(1.0 - cutmix_lambda)
cut_w = int(W * cut_ratio)
cut_h = int(H * cut_ratio)
# uniform
if cx is None:
if uniform_sampling is True:
cx = np.random.randint(-cut_w // 2, high=W + cut_w // 2)
else:
cx = np.random.randint(W)
if cy is None:
if uniform_sampling is True:
cy = np.random.randint(-cut_h // 2, high=H + cut_h // 2)
else:
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def _adjust_lambda(cutmix_lambda: float, x: Tensor, bbox: Tuple) -> float:
"""Rescale the cutmix lambda according to the size of the clipped bounding box.
Args:
cutmix_lambda: Lambda param from cutmix, used to set the area of the box.
x: input tensor of shape (B, d1, d2, ..., dn), B is batch size, d1-dn
are feature dimensions.
bbox: (x1, y1, x2, y2) coordinates of the boundind box, obeying x2 > x1, y2 > y1.
Returns:
adjusted_lambda: Rescaled cutmix_lambda to account for part of the bounding box
being potentially out of bounds of the input.
"""
rx, ry, rw, rh = bbox[0], bbox[1], bbox[2], bbox[3]
adjusted_lambda = 1 - ((rw - rx) * (rh - ry) / (x.size()[-1] * x.size()[-2]))
return adjusted_lambda