Source code for composer.algorithms.cutout.cutout

# Copyright 2021 MosaicML. All Rights Reserved.

"""Core CutOut classes and functions."""

from __future__ import annotations

import logging
from typing import Optional, TypeVar

import numpy as np
import torch
from PIL.Image import Image as PillowImage
from torch import Tensor

from composer.algorithms.utils.augmentation_common import image_as_type
from composer.core import Algorithm, Event, State
from composer.loggers import Logger

log = logging.getLogger(__name__)

__all__ = ["CutOut", "cutout_batch"]

ImgT = TypeVar("ImgT", torch.Tensor, PillowImage)


[docs]def cutout_batch(input: ImgT, num_holes: int = 1, length: float = 0.5, uniform_sampling: bool = False) -> ImgT: """See :class:`CutOut`. Args: input (PIL.Image.Image or torch.Tensor): Image or batch of images. If a :class:`torch.Tensor`, must be a single image of shape ``(C, H, W)`` or a batch of images of shape ``(N, C, H, W)``. num_holes: Integer number of holes to cut out. Default: ``1``. length (float, optional): Relative side length of the masked region. If specified, ``length`` is interpreted as a fraction of ``H`` and ``W``, and the resulting box is a square with side length ``length * min(H, W)``. Must be in the interval :math:`(0, 1)`. Default: ``0.5``. uniform_sampling (bool, optional): If ``True``, sample the bounding box such that each pixel has an equal probability of being masked. If ``False``, defaults to the sampling used in the original paper implementation. Default: ``False``. Returns: X_cutout: Batch of images with ``num_holes`` square holes with dimension determined by ``length`` replaced with zeros. Example: .. testcode:: from composer.algorithms.cutout import cutout_batch new_input_batch = cutout_batch(X_example, num_holes=1, length=0.25) """ X_tensor = image_as_type(input, torch.Tensor) h = X_tensor.shape[-2] w = X_tensor.shape[-1] length = int(min(h, w) * length) mask = torch.ones_like(X_tensor) for _ in range(num_holes): if uniform_sampling is True: y = np.random.randint(-length // 2, high=h + length // 2) x = np.random.randint(-length // 2, high=w + length // 2) else: y = np.random.randint(h) x = np.random.randint(w) mask = _generate_mask(mask, w, h, x, y, length) X_cutout = X_tensor * mask X_out = image_as_type(X_cutout, input.__class__) # pyright struggling with unions return X_out
[docs]class CutOut(Algorithm): """`Cutout <https://arxiv.org/abs/1708.04552>`_ is a data augmentation technique that works by masking out one or more square regions of an input image. This implementation cuts out the same square from all images in a batch. Example: .. testcode:: from composer.algorithms import CutOut from composer.trainer import Trainer cutout_algorithm = CutOut(num_holes=1, length=0.25) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", algorithms=[cutout_algorithm], optimizers=[optimizer] ) Args: num_holes (int, optional): Integer number of holes to cut out. Default: ``1``. length (float, optional): Relative side length of the masked region. If specified, ``length`` is interpreted as a fraction of ``H`` and ``W``, and the resulting box is a square with side length ``length * min(H, W)``. Must be in the interval :math:`(0, 1)`. Default: ``0.5``. """ def __init__(self, num_holes: int = 1, length: float = 0.5, uniform_sampling: bool = False): self.num_holes = num_holes self.length = length self.uniform_sampling = uniform_sampling
[docs] def match(self, event: Event, state: State) -> bool: """Runs on Event.AFTER_DATALOADER.""" return event == Event.AFTER_DATALOADER
[docs] def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: """Apply cutout on input images.""" x, y = state.batch_pair assert isinstance(x, Tensor), "Multiple tensors not supported for Cutout." new_x = cutout_batch(x, num_holes=self.num_holes, length=self.length, uniform_sampling=self.uniform_sampling) state.batch = (new_x, y)
def _generate_mask(mask: Tensor, width: int, height: int, x: int, y: int, cutout_length: int) -> Tensor: y1 = np.clip(y - cutout_length // 2, 0, height) y2 = np.clip(y + cutout_length // 2, 0, height) x1 = np.clip(x - cutout_length // 2, 0, width) x2 = np.clip(x + cutout_length // 2, 0, width) mask[..., y1:y2, x1:x2] = 0. return mask