# Copyright 2021 MosaicML. All Rights Reserved.
"""Helper functions to perform augmentations on a :class:`PIL.Image.Image`.
Augmentation that take an intensity value are normalized on a scale of 1-10,
where 10 is the strongest and maximum value an augmentation
function will accept.
Adapted from
`AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty <https://github.com/google-research/augmix/blob/master/augmentations.py>`_.
Attributes:
AugmentationFn ((PIL.Image.Image, float) -> PIL.Image.Image): The type annotation for describing an
augmentation function.
Each augmentation takes a :class:`~PIL.Image.Image` and an intensity level on the range ``[0; 10]``,
and returns an augmented image.
augmentation_sets (Dict[str, List[AugmentationFn]]): The collection of all augmentations.
This dictionary has the following entries:
* ``augmentation_sets["safe"]`` contains augmentations that do not overlap with ImageNet-C/CIFAR10-C test sets.
* ``augmentation_sets["original"]`` contains augmentations that use the original implementations of
enhancing color, contrast, brightness, and sharpness.
* ``augmentation_sets["all"]`` contains all augmentations.
"""
from typing import Callable
import numpy as np
from PIL import Image, ImageEnhance, ImageOps
AugmentationFn = Callable[[Image.Image, float], Image.Image]
__all__ = [
"AugmentationFn",
"autocontrast",
"equalize",
"posterize",
"rotate",
"solarize",
"shear_x",
"shear_y",
"translate_x",
"translate_y",
"color",
"color_original",
"contrast",
"contrast_original",
"brightness",
"brightness_original",
"sharpness",
"sharpness_original",
"augmentation_sets",
]
def _int_parameter(level: float, maxval: float):
"""Helper function to scale a value between ``0`` and ``maxval`` and return as an int.
Args:
level (float): Level of the operation that will be between ``[0, 10]``.
maxval (float): Maximum value that the operation can have. This will be scaled to
``level/10``.
Returns:
int: The result from scaling ``maxval`` according to ``level``.
"""
return int(level * maxval / 10)
def _float_parameter(level: float, maxval: float):
"""Helper function to scale a value between ``0`` and ``maxval`` and return as a float.
Args:
level (float): Level of the operation that will be between [0, 10].
maxval (float): Maximum value that the operation can have. This will be scaled to
level/10.
Returns:
float: The result from scaling ``maxval`` according to ``level``.
"""
return float(level) * maxval / 10.
def _sample_level(n: float):
"""Helper function to sample from a uniform distribution between ``0.1`` and some value ``n``."""
return np.random.uniform(low=0.1, high=n)
def _symmetric_sample(level: float):
"""Helper function to sample from a distribution over the domain [0.1, 10] with median == 1 and uniform probability
of x | 0.1 โค x โค 1, and x | 1 โค x โค 10.
Used for sampling transforms that can range from intensity 0 to infinity, and for which an intensity of 1 == no
change.
"""
if np.random.uniform() > 0.5:
return np.random.uniform(1, level)
else:
return np.random.uniform(1 - (0.09 * level), 1)
[docs]def autocontrast(pil_img: Image.Image, level: float = 0.0):
"""Autocontrast an image.
.. seealso:: :func:`PIL.ImageOps.autocontrast`.
Args:
pil_img (Image.Image): The image
"""
del level # unused
return ImageOps.autocontrast(pil_img)
[docs]def equalize(pil_img: Image.Image, level: float):
"""Equalize an image.
.. seealso:: :func:`PIL.ImageOps.equalize`.
Args:
pil_img (Image.Image): The image
"""
del level # unused
return ImageOps.equalize(pil_img)
[docs]def posterize(pil_img: Image.Image, level: float):
"""Posterize an image.
.. seealso:: :func:`PIL.ImageOps.posterize`.
Args:
pil_img (Image.Image): The image
level (float): The intensity, which should be on ``[0, 10]``
"""
level = _int_parameter(_sample_level(level), 4)
return ImageOps.posterize(pil_img, 4 - level)
[docs]def rotate(pil_img: Image.Image, level: float):
"""Rotate an image.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
degrees = _int_parameter(_sample_level(level), 30)
if np.random.uniform() > 0.5:
degrees = -degrees
return pil_img.rotate(degrees, resample=Image.BILINEAR)
[docs]def solarize(pil_img: Image.Image, level: float):
"""Solarize an image.
.. seealso:: :func:`PIL.ImageOps.solarize`.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _int_parameter(_sample_level(level), 256)
return ImageOps.solarize(pil_img, 256 - level)
[docs]def shear_x(pil_img: Image.Image, level: float):
"""Shear an image horizontally.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 0.3)
if np.random.uniform() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR)
[docs]def shear_y(pil_img: Image.Image, level: float):
"""Shear an image vertically.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 0.3)
if np.random.uniform() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR)
[docs]def translate_x(pil_img: Image.Image, level: float):
"""Shear an image horizontally.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _int_parameter(_sample_level(level), pil_img.size[0] / 3)
if np.random.random() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR)
[docs]def translate_y(pil_img: Image.Image, level: float):
"""Shear an image vertically.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _int_parameter(_sample_level(level), pil_img.size[1] / 3)
if np.random.random() > 0.5:
level = -level
return pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR)
# The following augmentations overlap with corruptions in the ImageNet-C/CIFAR10-C test
# sets. Their original implementations also have an intensity sampling scheme that
# samples a value bounded by 0.118 at a minimum, and a maximum value of intensity*0.18+
# 0.1, which ranged from 0.28 (intensity = 1) to 1.9 (intensity 10). These augmentations
# have different effects depending on whether they are < 0 or > 0, so the original
# sampling scheme does not make sense to me. Accordingly, I replaced it with the
# _symmetric_sample() above.
[docs]def color(pil_img: Image.Image, level: float):
"""Enhance color on an image.
.. seealso:: :class:`PIL.ImageEnhance.Color`.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _symmetric_sample(level)
return ImageEnhance.Color(pil_img).enhance(level)
[docs]def color_original(pil_img: Image.Image, level: float):
"""Enhance color on an image, following the corruptions in the ImageNet-C/CIFAR10-C test sets.
.. seealso :class:`PIL.ImageEnhance.Color`.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 1.8) + 0.1
return ImageEnhance.Color(pil_img).enhance(level)
[docs]def contrast(pil_img: Image.Image, level: float):
"""Enhance contrast on an image.
.. seealso:: :class:`PIL.ImageEnhance.Contrast`.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _symmetric_sample(level)
return ImageEnhance.Contrast(pil_img).enhance(level)
[docs]def contrast_original(pil_img: Image.Image, level: float):
"""Enhance contrast on an image, following the corruptions in the ImageNet-C/CIFAR10-C test sets.
.. seealso:: :class:`PIL.ImageEnhance.Contrast`.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 1.8) + 0.1
return ImageEnhance.Contrast(pil_img).enhance(level)
[docs]def brightness(pil_img: Image.Image, level: float):
"""Enhance brightness on an image.
.. seealso:: :class:`PIL.ImageEnhance.Brightness`.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _symmetric_sample(level)
# Reduce intensity of brightness increases
if level > 1:
level = level * .75
return ImageEnhance.Brightness(pil_img).enhance(level)
[docs]def brightness_original(pil_img: Image.Image, level: float):
"""Enhance brightness on an image, following the corruptions in the ImageNet-C/CIFAR10-C test sets.
.. seealso:: :class:`PIL.ImageEnhance.Brightness`.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 1.8) + 0.1
return ImageEnhance.Brightness(pil_img).enhance(level)
[docs]def sharpness(pil_img: Image.Image, level: float):
"""Enhance sharpness on an image.
.. seealso:: :class:`PIL.ImageEnhance.Sharpness`.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _symmetric_sample(level)
return ImageEnhance.Sharpness(pil_img).enhance(level)
[docs]def sharpness_original(pil_img: Image.Image, level: float):
"""Enhance sharpness on an image, following the corruptions in the ImageNet-C/CIFAR10-C test sets.
.. seealso:: :class:`PIL.ImageEnhance.Sharpness`.
Args:
pil_img (Image.Image): The image.
level (float): The intensity, which should be on ``[0, 10]``.
"""
level = _float_parameter(_sample_level(level), 1.8) + 0.1
return ImageEnhance.Sharpness(pil_img).enhance(level)
augmentation_sets = {
"all": [
autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y, color,
contrast, brightness, sharpness
],
# Augmentations that don't overlap with ImageNet-C/CIFAR10-C test sets
"safe": [autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y],
# Augmentations that use original implementations of color, contrast, brightness, and sharpness
"original": [
autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y, color_original,
contrast_original, brightness_original, sharpness_original
],
}