# Copyright 2021 MosaicML. All Rights Reserved.
from __future__ import annotations
import functools
import logging
import warnings
from typing import Optional, Sequence, Union
import numpy as np
import torch
from torch.optim import Optimizer
from composer.algorithms.blurpool.blurpool_layers import BlurConv2d, BlurMaxPool2d
from composer.algorithms.warnings import NoEffectWarning
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import module_surgery
log = logging.getLogger(__name__)
[docs]def apply_blurpool(model: torch.nn.Module,
replace_convs: bool = True,
replace_maxpools: bool = True,
blur_first: bool = True,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None) -> torch.nn.Module:
"""Add anti-aliasing filters to the strided :class:`torch.nn.Conv2d` and/or :class:`torch.nn.MaxPool2d` modules
within `model`.
These filters increase invariance to small spatial shifts in the input
(`Zhang 2019 <http://proceedings.mlr.press/v97/zhang19a.html>`_).
Args:
model (torch.nn.Module): the model to modify in-place
replace_convs (bool, optional): replace strided :class:`torch.nn.Conv2d` modules with
:class:`.BlurConv2d` modules. Default: ``True``.
replace_maxpools (bool, optional): replace eligible :class:`torch.nn.MaxPool2d` modules
with :class:`.BlurMaxPool2d` modules. Default: ``True``.
blur_first (bool, optional): for ``replace_convs``, blur input before the associated
convolution. When set to ``False``, the convolution is applied with
a stride of 1 before the blurring, resulting in significant
overhead (though more closely matching
`the paper <http://proceedings.mlr.press/v97/zhang19a.html>`_).
See :class:`.BlurConv2d` for further discussion. Default: ``True``.
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional):
Existing optimizers bound to ``model.parameters()``. All optimizers that have already been
constructed with ``model.parameters()`` must be specified here so
they will optimize the correct parameters.
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see
the correct model parameters.
Returns:
The modified model
Example:
.. testcode::
import composer.functional as cf
from torchvision import models
model = models.resnet50()
cf.apply_blurpool(model)
"""
transforms = {}
if replace_maxpools:
transforms[torch.nn.MaxPool2d] = BlurMaxPool2d.from_maxpool2d
if replace_convs:
transforms[torch.nn.Conv2d] = functools.partial(
_maybe_replace_strided_conv2d,
blur_first=blur_first,
)
module_surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms)
_log_surgery_result(model)
return model
[docs]class BlurPool(Algorithm):
"""`BlurPool <http://proceedings.mlr.press/v97/zhang19a.html>`_ adds anti-aliasing filters to convolutional layers
to increase accuracy and invariance to small shifts in the input.
Runs on :attr:`~composer.core.event.Event.INIT`.
Args:
replace_convs (bool): replace strided :class:`torch.nn.Conv2d` modules with
:class:`.BlurConv2d` modules. Default: ``True``.
replace_maxpools (bool): replace eligible :class:`torch.nn.MaxPool2d` modules
with :class:`.BlurMaxPool2d` modules. Default: ``True``.
blur_first (bool): when ``replace_convs`` is ``True``, blur input before the
associated convolution. When set to ``False``, the convolution is
applied with a stride of 1 before the blurring, resulting in
significant overhead (though more closely matching the paper).
See :class:`.BlurConv2d` for further discussion. Default: ``True``.
"""
def __init__(self, replace_convs: bool = True, replace_maxpools: bool = True, blur_first: bool = True) -> None:
self.replace_convs = replace_convs
self.replace_maxpools = replace_maxpools
self.blur_first = blur_first
if self.replace_maxpools is False and \
self.replace_convs is False:
log.warning('Both replace_maxpool and replace_convs set to false '
'BlurPool will not be modifying the model.')
[docs] def match(self, event: Event, state: State) -> bool:
"""Runs on :attr:`~composer.core.event.Event.INIT`.
Args:
event (Event): The current event.
state (State): The current state.
Returns:
bool: True if this algorithm should run now.
"""
return event == Event.INIT
[docs] def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
"""Adds anti-aliasing filters to the maxpools and/or convolutions.
Args:
event (Event): the current event
state (State): the current trainer state
logger (Logger): the training logger
"""
assert state.model is not None
apply_blurpool(state.model,
optimizers=state.optimizers,
replace_convs=self.replace_convs,
replace_maxpools=self.replace_maxpools,
blur_first=self.blur_first)
self._log_results(event, state, logger)
def _log_results(self, event: Event, state: State, logger: Logger) -> None:
"""Logs the result of BlurPool application, including the number of layers that have been replaced."""
assert state.model is not None
num_blurpool_layers = module_surgery.count_module_instances(state.model, BlurMaxPool2d)
num_blurconv_layers = module_surgery.count_module_instances(state.model, BlurConv2d)
# python logger
log.info(f'Applied BlurPool to model {state.model.__class__.__name__} '
f'with replace_maxpools={self.replace_maxpools}, '
f'replace_convs={self.replace_convs}. '
f'Model now has {num_blurpool_layers} BlurMaxPool2d '
f'and {num_blurconv_layers} BlurConv2D layers.')
logger.data_fit({
'blurpool/num_blurpool_layers': num_blurpool_layers,
'blurpool/num_blurconv_layers': num_blurconv_layers,
})
def _log_surgery_result(model: torch.nn.Module):
num_blurpool_layers = module_surgery.count_module_instances(model, BlurMaxPool2d)
num_blurconv_layers = module_surgery.count_module_instances(model, BlurConv2d)
if num_blurconv_layers == 0 and num_blurpool_layers == 0:
warnings.warn(
NoEffectWarning("Applying BlurPool did not change any layers. "
"No strided Conv2d or Pool2d layers were found."))
log.info(f'Applied BlurPool to model {model.__class__.__name__}. '
f'Model now has {num_blurpool_layers} BlurMaxPool2d '
f'and {num_blurconv_layers} BlurConv2D layers.')
def _maybe_replace_strided_conv2d(module: torch.nn.Conv2d, module_index: int, blur_first: bool):
if (np.max(module.stride) > 1 and module.in_channels >= 16):
return BlurConv2d.from_conv2d(module, module_index, blur_first=blur_first)
return None