CutMix#
- class composer.algorithms.CutMix(alpha=1.0, interpolate_loss=False, uniform_sampling=False, input_key=0, target_key=1)[source]#
CutMix trains the network on non-overlapping combinations of pairs of examples and interpolated targets rather than individual examples and targets.
This is done by taking a non-overlapping combination of a given batch X with a randomly permuted copy of X. The area is drawn from a
Beta(alpha, alpha)
distribution.Training in this fashion sometimes reduces generalization error.
- Parameters
alpha (float, optional) โ the psuedocount for the Beta distribution used to sample area parameters. As
alpha
grows, the two samples in each pair tend to be weighted more equally. Asalpha
approaches 0 from above, the combination approaches only using one element of the pair. Default:1
.interpolate_loss (bool, optional) โ Interpolates the loss rather than the labels. A useful trick when using a cross entropy loss. Will produce incorrect behavior if the loss is not a linear function of the targets. Default:
False
uniform_sampling (bool, optional) โ If
True
, sample the bounding box such that each pixel has an equal probability of being mixed. IfFalse
, defaults to the sampling used in the original paper implementation. Default:False
.input_key (str | int | Tuple[Callable, Callable] | Any, optional) โ A key that indexes to the input from the batch. Can also be a pair of get and set functions, where the getter is assumed to be first in the pair. The default is 0, which corresponds to any sequence, where the first element is the input. Default:
0
.target_key (str | int | Tuple[Callable, Callable] | Any, optional) โ A key that indexes to the target from the batch. Can also be a pair of get and set functions, where the getter is assumed to be first in the pair. The default is 1, which corresponds to any sequence, where the second element is the target. Default:
1
.
Example
from composer.algorithms import CutMix algorithm = CutMix(alpha=0.2) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", algorithms=[algorithm], optimizers=[optimizer] )