Source code for composer.algorithms.utils.augmentation_common

# Copyright 2021 MosaicML. All Rights Reserved.
from typing import Callable, Iterable, Type, TypeVar, cast

import torch
import torchvision.transforms.functional
from PIL.Image import Image as PillowImage

_InputImgT = TypeVar("_InputImgT", torch.Tensor, PillowImage)
_OutputImgT = TypeVar("_OutputImgT", torch.Tensor, PillowImage)


[docs]def image_as_type(image: _InputImgT, typ: Type[_OutputImgT]) -> _OutputImgT: """Converts between :class:`torch.Tensor` and :class:`PIL.Image.Image` image representations Args: image: a single image represented as a :class:`PIL.Image.Image` or a rank 2 or rank 3 :class:`torch.Tensor` in ``HW`` or ``CHW`` format. A rank 4 or higher tensor can also be provided as long as no type conversion is needed; in this case, the input tensor will be returned. This case is allowed so that functions that natively operate on batch tensors can safely call ``image_as_type(image, torch.Tensor)`` without additional error and type checking. typ: type of the copied image. Must be :class:`PIL.Image.Image` or :class:`torch.Tensor` Returns: A copy of ``image`` with type ``typ`` Raises: TypeError: if ``typ`` is not one of :class:`torch.Tensor` or :class:`PIL.Image.Image` ValueError: if ``image`` cannot be converted to the ``typ``, such as when requesting conversion of a rank 4 tensor to :class:`PIL.Image.Image`. """ if isinstance(image, typ): return image if not typ in (torch.Tensor, PillowImage): raise TypeError(f"Only typ={{torch.Tensor, Image}} is supported; got {typ}") if typ is torch.Tensor: return cast(_OutputImgT, torchvision.transforms.functional.to_tensor(image)) # PIL -> Tensor return cast(_OutputImgT, torchvision.transforms.functional.to_pil_image(image)) # Tensor -> PIL
[docs]def map_pillow_function(f_pil: Callable[[PillowImage], PillowImage], imgs: _OutputImgT) -> _OutputImgT: """Lifts a function that requires pillow images to also work on tensors. Args: f_pil: a callable that takes maps :class:`PIL.Image.Image` objects to other :class:`PIL.Image.Image` objects. imgs: a :class:`PIL.Image.Image` or a :class:`torch.Tensor` in ``HW``, ``CHW`` or ``NCHW`` format. Returns: The result of applying ``f_pil`` to each image in ``imgs``, converted back to the same type and (if applicable) tensor layout as ``imgs``. """ single_image_input = not isinstance(imgs, Iterable) single_image_input |= isinstance(imgs, torch.Tensor) and imgs.ndim == 3 imgs_as_iterable = [imgs] if single_image_input else imgs imgs_as_iterable = cast(type(imgs_as_iterable), imgs_as_iterable) imgs_pil = [image_as_type(img, PillowImage) for img in imgs_as_iterable] imgs_out_pil = [f_pil(img_pil) for img_pil in imgs_pil] imgs_out = [image_as_type(img_pil, type(imgs_as_iterable[0])) for img_pil in imgs_out_pil] if isinstance(imgs, torch.Tensor) and imgs.ndim == 4: # batch of imgs imgs_out = [torch.unsqueeze(cast(torch.Tensor, img), 0) for img in imgs_out] imgs_out = torch.cat(imgs_out, dim=0) if single_image_input: imgs_out = imgs_out[0] imgs_out = cast(_OutputImgT, imgs_out) return imgs_out