# Copyright 2021 MosaicML. All Rights Reserved.
"""Core CutOut classes and functions."""
from __future__ import annotations
import logging
from typing import Optional, TypeVar
import numpy as np
import torch
from PIL.Image import Image as PillowImage
from torch import Tensor
from composer.algorithms.utils.augmentation_common import image_as_type
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
log = logging.getLogger(__name__)
__all__ = ["CutOut", "cutout_batch"]
ImgT = TypeVar("ImgT", torch.Tensor, PillowImage)
[docs]def cutout_batch(input: ImgT, num_holes: int = 1, length: float = 0.5, uniform_sampling: bool = False) -> ImgT:
"""See :class:`CutOut`.
Args:
input (PIL.Image.Image or torch.Tensor): Image or batch of images. If
a :class:`torch.Tensor`, must be a single image of shape ``(C, H, W)``
or a batch of images of shape ``(N, C, H, W)``.
num_holes: Integer number of holes to cut out. Default: ``1``.
length (float, optional): Relative side length of the masked region.
If specified, ``length`` is interpreted as a fraction of ``H`` and
``W``, and the resulting box is a square with side length
``length * min(H, W)``. Must be in the interval :math:`(0, 1)`.
Default: ``0.5``.
uniform_sampling (bool, optional): If ``True``, sample the bounding
box such that each pixel has an equal probability of being masked.
If ``False``, defaults to the sampling used in the original paper
implementation. Default: ``False``.
Returns:
X_cutout: Batch of images with ``num_holes`` square holes with
dimension determined by ``length`` replaced with zeros.
Example:
.. testcode::
from composer.algorithms.cutout import cutout_batch
new_input_batch = cutout_batch(X_example, num_holes=1, length=0.25)
"""
X_tensor = image_as_type(input, torch.Tensor)
h = X_tensor.shape[-2]
w = X_tensor.shape[-1]
length = int(min(h, w) * length)
mask = torch.ones_like(X_tensor)
for _ in range(num_holes):
if uniform_sampling is True:
y = np.random.randint(-length // 2, high=h + length // 2)
x = np.random.randint(-length // 2, high=w + length // 2)
else:
y = np.random.randint(h)
x = np.random.randint(w)
mask = _generate_mask(mask, w, h, x, y, length)
X_cutout = X_tensor * mask
X_out = image_as_type(X_cutout, input.__class__) # pyright struggling with unions
return X_out
[docs]class CutOut(Algorithm):
"""`Cutout <https://arxiv.org/abs/1708.04552>`_ is a data augmentation technique that works by masking out one or
more square regions of an input image.
This implementation cuts out the same square from all images in a batch.
Example:
.. testcode::
from composer.algorithms import CutOut
from composer.trainer import Trainer
cutout_algorithm = CutOut(num_holes=1, length=0.25)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration="1ep",
algorithms=[cutout_algorithm],
optimizers=[optimizer]
)
Args:
num_holes (int, optional): Integer number of holes to cut out.
Default: ``1``.
length (float, optional): Relative side length of the masked region.
If specified, ``length`` is interpreted as a fraction of ``H`` and
``W``, and the resulting box is a square with side length
``length * min(H, W)``. Must be in the interval :math:`(0, 1)`.
Default: ``0.5``.
"""
def __init__(self, num_holes: int = 1, length: float = 0.5, uniform_sampling: bool = False):
self.num_holes = num_holes
self.length = length
self.uniform_sampling = uniform_sampling
[docs] def match(self, event: Event, state: State) -> bool:
"""Runs on Event.AFTER_DATALOADER."""
return event == Event.AFTER_DATALOADER
[docs] def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
"""Apply cutout on input images."""
x, y = state.batch_pair
assert isinstance(x, Tensor), "Multiple tensors not supported for Cutout."
new_x = cutout_batch(x, num_holes=self.num_holes, length=self.length, uniform_sampling=self.uniform_sampling)
state.batch = (new_x, y)
def _generate_mask(mask: Tensor, width: int, height: int, x: int, y: int, cutout_length: int) -> Tensor:
y1 = np.clip(y - cutout_length // 2, 0, height)
y2 = np.clip(y + cutout_length // 2, 0, height)
x1 = np.clip(x - cutout_length // 2, 0, width)
x2 = np.clip(x + cutout_length // 2, 0, width)
mask[..., y1:y2, x1:x2] = 0.
return mask