# Copyright 2021 MosaicML. All Rights Reserved.
import logging
from dataclasses import asdict, dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import yahp as hp
from torch.nn import functional as F
from composer.algorithms import AlgorithmHparams
from composer.core.types import Algorithm, Event, Logger, State, Tensor
from composer.models.loss import check_for_index_targets
log = logging.getLogger(__name__)
[docs]def gen_indices(x: Tensor) -> Tensor:
"""Generates indices of a random permutation of elements of a batch.
Args:
x: input tensor of shape (B, d1, d2, ..., dn), B is batch size, d1-dn
are feature dimensions.
Returns:
indices: A random permutation of the batch indices.
"""
return torch.randperm(x.shape[0])
[docs]def gen_cutmix_lambda(alpha: float) -> float:
"""Generates lambda from ``Beta(alpha, alpha)``
Args:
alpha: Parameter for the Beta(alpha, alpha) distribution
Returns:
cutmix_lambda: Lambda parameter for performing cutmix.
"""
# First check if alpha is positive.
assert alpha >= 0
# Draw the area parameter from a beta distribution.
# Check here is needed because beta distribution requires alpha > 0
# but alpha = 0 is fine for cutmix.
if alpha == 0:
cutmix_lambda = 0
else:
cutmix_lambda = np.random.beta(alpha, alpha)
return cutmix_lambda
[docs]def rand_bbox(W: int,
H: int,
cutmix_lambda: float,
cx: Optional[int] = None,
cy: Optional[int] = None) -> Tuple[int, int, int, int]:
"""Randomly samples a bounding box with area determined by cutmix_lambda.
Adapted from original implementation https://github.com/clovaai/CutMix-PyTorch
Args:
W: Width of the image
H: Height of the image
cutmix_lambda: Lambda param from cutmix, used to set the area of the box.
cx: Optional x coordinate of the center of the box.
cy: Optional y coordinate of the center of the box.
Returns:
bbx1: Leftmost edge of the bounding box
bby1: Top edge of the bounding box
bbx2: Rightmost edge of the bounding box
bby2: Bottom edge of the bounding box
"""
cut_ratio = np.sqrt(1.0 - cutmix_lambda)
cut_w = int(W * cut_ratio)
cut_h = int(H * cut_ratio)
# uniform
if cx is None:
cx = np.random.randint(W)
if cy is None:
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
[docs]def adjust_lambda(cutmix_lambda: float, x: Tensor, bbox: Tuple) -> float:
"""Rescale the cutmix lambda according to the size of the clipped bounding box
Args:
cutmix_lambda: Lambda param from cutmix, used to set the area of the box.
x: input tensor of shape (B, d1, d2, ..., dn), B is batch size, d1-dn
are feature dimensions.
bbox: (x1, y1, x2, y2) coordinates of the boundind box, obeying x2 > x1, y2 > y1.
Returns:
adjusted_lambda: Rescaled cutmix_lambda to account for part of the bounding box
being potentially out of bounds of the input.
"""
rx, ry, rw, rh = bbox[0], bbox[1], bbox[2], bbox[3]
adjusted_lambda = 1 - ((rw - rx) * (rh - ry) / (x.size()[-1] * x.size()[-2]))
return adjusted_lambda
[docs]def cutmix(x: Tensor,
y: Tensor,
alpha: float,
n_classes: int,
cutmix_lambda: Optional[float] = None,
bbox: Optional[Tuple] = None,
indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Create new samples using combinations of pairs of samples.
This is done by masking a region of x, and filling the masked region with a
permuted copy of x. The cutmix parameter lambda should be chosen from
a ``Beta(alpha, alpha)`` distribution for some parameter alpha > 0. The area of
the masked region is determined by lambda, and so labels are interpolated accordingly.
Note that the same lambda is used for all examples within the batch. The original
paper used a fixed value of alpha = 1.
Both the original and shuffled labels are returned. This is done because
for many loss functions (such as cross entropy) the targets are given as
indices, so interpolation must be handled separately.
Args:
x: input tensor of shape (B, d1, d2, ..., dn), B is batch size, d1-dn
are feature dimensions.
y: target tensor of shape (B, f1, f2, ..., fm), B is batch size, f1-fn
are possible target dimensions.
alpha: parameter for the beta distribution of the cutmix region size.
n_classes: total number of classes.
cutmix_lambda: optional, fixed size of cutmix region.
bbox: optional, predetermined (rx1, ry1, rx2, ry2) coords of the bounding box.
indices: Permutation of the batch indices `1..B`. Used
for permuting without randomness.
Returns:
x_cutmix: batch of inputs after cutmix has been applied.
y_cutmix: labels after cutmix has been applied.
Example:
from composer import functional as CF
for X, y in dataloader:
X, y, _, _ ,_ = CF.cutmix(X, y, alpha, nclasses)
pred = model(X)
loss = loss_fun(pred, y) # loss_fun must accept dense labels (ie NOT indices)
"""
# Create shuffled indicies across the batch in preparation for cutting and mixing.
# Use given indices if there are any.
if indices is None:
shuffled_idx = gen_indices(x)
else:
shuffled_idx = indices
# Create the new inputs.
x_cutmix = torch.clone(x)
# Sample a rectangular box using lambda. Use variable names from the paper.
if cutmix_lambda is None:
cutmix_lambda = gen_cutmix_lambda(alpha)
if bbox:
rx, ry, rw, rh = bbox[0], bbox[1], bbox[2], bbox[3]
else:
rx, ry, rw, rh = rand_bbox(x.shape[2], x.shape[3], cutmix_lambda)
bbox = (rx, ry, rw, rh)
# Fill in the box with a part of a random image.
x_cutmix[:, :, rx:rw, ry:rh] = x_cutmix[shuffled_idx, :, rx:rw, ry:rh]
# adjust lambda to exactly match pixel ratio. This is an implementation detail taken from
# the original implementation, and implies lambda is not actually beta distributed.
adjusted_lambda = adjust_lambda(cutmix_lambda, x, bbox)
# Make a shuffled version of y for interpolation
y_shuffled = y[shuffled_idx]
# Interpolate between labels using the adjusted lambda
# First check if labels are indices. If so, convert them to onehots.
# This is under the assumption that the loss expects torch.LongTensor, which is true for pytorch cross_entropy
if check_for_index_targets(y):
y_onehot = F.one_hot(y, num_classes=n_classes)
y_shuffled_onehot = F.one_hot(y_shuffled, num_classes=n_classes)
y_cutmix = adjusted_lambda * y_onehot + (1 - adjusted_lambda) * y_shuffled_onehot
else:
y_cutmix = adjusted_lambda * y + (1 - adjusted_lambda) * y_shuffled
return x_cutmix, y_cutmix
@dataclass
class CutMixHparams(AlgorithmHparams):
"""See :class:`CutMix`"""
alpha: float = hp.required('Strength of interpolation, should be >= 0. No interpolation if alpha=0.',
template_default=1.0)
def initialize_object(self) -> "CutMix":
return CutMix(**asdict(self))
[docs]class CutMix(Algorithm):
"""`CutMix <https://arxiv.org/abs/1905.04899>`_ trains the network on
non-overlapping combinations of pairs of examples and iterpolated targets
rather than individual examples and targets.
This is done by taking a non-overlapping combination of a given batch X with a
randomly permuted copy of X. The area is drawn from a ``Beta(alpha, alpha)``
distribution.
Training in this fashion reduces generalization error.
Args:
alpha: the psuedocount for the Beta distribution used to sample
area parameters. As ``alpha`` grows, the two samples
in each pair tend to be weighted more equally. As ``alpha``
approaches 0 from above, the combination approaches only using
one element of the pair.
"""
def __init__(self, alpha: float):
self.hparams = CutMixHparams(alpha=alpha)
[docs] def match(self, event: Event, state: State) -> bool:
"""Runs on Event.INIT and Event.AFTER_DATALOADER
Args:
event (:class:`Event`): The current event.
state (:class:`State`): The current state.
Returns:
bool: True if this algorithm should run now.
"""
return event in (Event.AFTER_DATALOADER, Event.INIT)
@property
def indices(self) -> Tensor:
return self._indices
@indices.setter
def indices(self, new_indices: Tensor) -> None:
self._indices = new_indices
@property
def cutmix_lambda(self) -> float:
return self._cutmix_lambda
@cutmix_lambda.setter
def cutmix_lambda(self, new_lambda: float) -> None:
self._cutmix_lambda = new_lambda
@property
def bbox(self) -> tuple:
return self._bbox
@bbox.setter
def bbox(self, new_bbox: tuple) -> None:
self._bbox = new_bbox
[docs] def apply(self, event: Event, state: State, logger: Logger) -> None:
"""Applies CutMix augmentation on State input
Args:
event (Event): the current event
state (State): the current trainer state
logger (Logger): the training logger
"""
if event == Event.INIT:
self.num_classes: int = state.model.num_classes # type: ignore
return
input, target = state.batch_pair
assert isinstance(input, Tensor) and isinstance(target, Tensor), \
"Multiple tensors for inputs or targets not supported yet."
alpha = self.hparams.alpha
self.indices = gen_indices(input)
self.cutmix_lambda = gen_cutmix_lambda(alpha)
self.bbox = rand_bbox(input.shape[2], input.shape[3], self.cutmix_lambda)
self.cutmix_lambda = adjust_lambda(self.cutmix_lambda, input, self.bbox)
new_input, new_target = cutmix(
x=input,
y=target,
alpha=alpha,
n_classes=self.num_classes,
cutmix_lambda=self.cutmix_lambda,
bbox=self.bbox,
indices=self.indices,
)
state.batch = (new_input, new_target)