# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Core CutOut classes and functions."""
from __future__ import annotations
import logging
from typing import Any, Callable, Optional, TypeVar, Union
import numpy as np
import torch
from PIL.Image import Image as PillowImage
from torch import Tensor
from composer.algorithms.utils.augmentation_common import image_as_type
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
log = logging.getLogger(__name__)
__all__ = ['CutOut', 'cutout_batch']
ImgT = TypeVar('ImgT', torch.Tensor, PillowImage)
[docs]def cutout_batch(input: ImgT, num_holes: int = 1, length: float = 0.5, uniform_sampling: bool = False) -> ImgT:
"""See :class:`.CutOut`.
Args:
input (PIL.Image.Image | torch.Tensor): Image or batch of images. If
a :class:`torch.Tensor`, must be a single image of shape ``(C, H, W)``
or a batch of images of shape ``(N, C, H, W)``.
num_holes: Integer number of holes to cut out. Default: ``1``.
length (float, optional): Relative side length of the masked region.
If specified, ``length`` is interpreted as a fraction of ``H`` and
``W``, and the resulting box is a square with side length
``length * min(H, W)``. Must be in the interval :math:`(0, 1)`.
Default: ``0.5``.
uniform_sampling (bool, optional): If ``True``, sample the bounding
box such that each pixel has an equal probability of being masked.
If ``False``, defaults to the sampling used in the original paper
implementation. Default: ``False``.
Returns:
X_cutout: Batch of images with ``num_holes`` square holes with
dimension determined by ``length`` replaced with zeros.
Example:
.. testcode::
from composer.algorithms.cutout import cutout_batch
new_input_batch = cutout_batch(X_example, num_holes=1, length=0.25)
"""
X_tensor = image_as_type(input, torch.Tensor)
h = X_tensor.shape[-2]
w = X_tensor.shape[-1]
length = int(min(h, w) * length)
mask = torch.ones_like(X_tensor)
for _ in range(num_holes):
if uniform_sampling is True:
y = np.random.randint(-length // 2, high=h + length // 2)
x = np.random.randint(-length // 2, high=w + length // 2)
else:
y = np.random.randint(h)
x = np.random.randint(w)
mask = _generate_mask(mask, w, h, x, y, length)
X_cutout = X_tensor * mask
X_out = image_as_type(X_cutout, input.__class__) # pyright struggling with unions
return X_out
[docs]class CutOut(Algorithm):
"""`CutOut <https://arxiv.org/abs/1708.04552>`_ is a data augmentation technique
that works by masking out one or more square regions of an input image.
This implementation cuts out the same square from all images in a batch.
Example:
.. testcode::
from composer.algorithms import CutOut
from composer.trainer import Trainer
cutout_algorithm = CutOut(num_holes=1, length=0.25)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration="1ep",
algorithms=[cutout_algorithm],
optimizers=[optimizer]
)
Args:
num_holes (int, optional): Integer number of holes to cut out.
Default: ``1``.
length (float, optional): Relative side length of the masked region.
If specified, ``length`` is interpreted as a fraction of ``H`` and
``W``, and the resulting box is a square with side length
``length * min(H, W)``. Must be in the interval :math:`(0, 1)`.
Default: ``0.5``.
uniform_sampling (bool, optional): If ``True``, sample the bounding
box such that each pixel has an equal probability of being masked.
If ``False``, defaults to the sampling used in the original paper
implementation. Default: ``False``.
input_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the input
from the batch. Can also be a pair of get and set functions, where the getter
is assumed to be first in the pair. The default is 0, which corresponds to any sequence, where the first element
is the input. Default: ``0``.
"""
def __init__(
self,
num_holes: int = 1,
length: float = 0.5,
uniform_sampling: bool = False,
input_key: Union[str, int, Callable, Any] = 0,
):
self.num_holes = num_holes
self.length = length
self.uniform_sampling = uniform_sampling
self.input_key = input_key
def match(self, event: Event, state: State) -> bool:
return event == Event.AFTER_DATALOADER
def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
x = state.batch_get_item(self.input_key)
assert isinstance(x, Tensor), 'Multiple tensors not supported for Cutout.'
new_x = cutout_batch(x, num_holes=self.num_holes, length=self.length, uniform_sampling=self.uniform_sampling)
state.batch_set_item(self.input_key, new_x)
def _generate_mask(mask: Tensor, width: int, height: int, x: int, y: int, cutout_length: int) -> Tensor:
y1 = np.clip(y - cutout_length // 2, 0, height)
y2 = np.clip(y + cutout_length // 2, 0, height)
x1 = np.clip(x - cutout_length // 2, 0, width)
x2 = np.clip(x + cutout_length // 2, 0, width)
mask[..., y1:y2, x1:x2] = 0.
return mask