Source code for composer.algorithms.blurpool.blurpool

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import functools
import logging
import warnings
from typing import Optional, Sequence, Union

import numpy as np
import torch
from torch.optim import Optimizer

from composer.algorithms.blurpool.blurpool_layers import BlurConv2d, BlurMaxPool2d
from composer.algorithms.warnings import NoEffectWarning
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import module_surgery

log = logging.getLogger(__name__)


[docs]def apply_blurpool( model: torch.nn.Module, replace_convs: bool = True, replace_maxpools: bool = True, blur_first: bool = True, min_channels: int = 16, optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None, ) -> None: """Add anti-aliasing filters to strided :class:`torch.nn.Conv2d` and/or :class:`torch.nn.MaxPool2d` modules. These filters increase invariance to small spatial shifts in the input (`Zhang 2019 <http://proceedings.mlr.press/v97/zhang19a.html>`_). Args: model (:class:`torch.nn.Module`): the model to modify in-place replace_convs (bool, optional): replace strided :class:`torch.nn.Conv2d` modules with :class:`.BlurConv2d` modules. Default: ``True``. replace_maxpools (bool, optional): replace eligible :class:`torch.nn.MaxPool2d` modules with :class:`.BlurMaxPool2d` modules. Default: ``True``. blur_first (bool, optional): for ``replace_convs``, blur input before the associated convolution. When set to ``False``, the convolution is applied with a stride of 1 before the blurring, resulting in significant overhead (though more closely matching `the paper <http://proceedings.mlr.press/v97/zhang19a.html>`_). See :class:`.BlurConv2d` for further discussion. Default: ``True``. min_channels (int, optional): Skip replacing layers with in_channels < min_channels. Commonly used to prevent the blurring of the first layer. Default: 16. optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): Existing optimizers bound to ``model.parameters()``. All optimizers that have already been constructed with ``model.parameters()`` must be specified here so they will optimize the correct parameters. If the optimizer(s) are constructed *after* calling this function, then it is safe to omit this parameter. These optimizers will see the correct model parameters. Example: .. testcode:: import composer.functional as cf from torchvision import models model = models.resnet50() cf.apply_blurpool(model) """ transforms = {} if replace_maxpools: transforms[torch.nn.MaxPool2d] = BlurMaxPool2d.from_maxpool2d if replace_convs: transforms[torch.nn.Conv2d] = functools.partial( _maybe_replace_strided_conv2d, blur_first=blur_first, min_channels=min_channels, ) module_surgery.replace_module_classes(model, optimizers=optimizers, policies=transforms) _log_surgery_result(model)
[docs]class BlurPool(Algorithm): """`BlurPool <http://proceedings.mlr.press/v97/zhang19a.html>`_ adds anti-aliasing filters to convolutional layers. This algorithm increases accuracy and invariance to small shifts in the input. It runs on :attr:`.Event.INIT`. Args: replace_convs (bool): replace strided :class:`torch.nn.Conv2d` modules with :class:`.BlurConv2d` modules. Default: ``True``. replace_maxpools (bool): replace eligible :class:`torch.nn.MaxPool2d` modules with :class:`.BlurMaxPool2d` modules. Default: ``True``. blur_first (bool): when ``replace_convs`` is ``True``, blur input before the associated convolution. When set to ``False``, the convolution is applied with a stride of 1 before the blurring, resulting in significant overhead (though more closely matching the paper). See :class:`.BlurConv2d` for further discussion. Default: ``True``. min_channels (int, optional): Skip replacing layers with in_channels < min_channels. Commonly used to prevent the blurring of the first layer. Default: 16. """ def __init__( self, replace_convs: bool = True, replace_maxpools: bool = True, blur_first: bool = True, min_channels: int = 16, ) -> None: self.replace_convs = replace_convs self.replace_maxpools = replace_maxpools self.blur_first = blur_first self.min_channels = min_channels if self.replace_maxpools is False and self.replace_convs is False: raise ValueError( 'Both replace_maxpool and replace_convs are set to False. BlurPool will not be modifying the model.', ) def __repr__(self) -> str: return f'{self.__class__.__name__}(replace_convs={self.replace_convs},replace_maxpools={self.replace_maxpools},blur_first={self.blur_first},min_channels={self.min_channels})' @staticmethod def required_on_load() -> bool: return True def match(self, event: Event, state: State) -> bool: return event == Event.INIT def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: assert state.model is not None apply_blurpool( state.model, optimizers=state.optimizers, replace_convs=self.replace_convs, replace_maxpools=self.replace_maxpools, blur_first=self.blur_first, min_channels=self.min_channels, ) self._log_results(event, state, logger) def _log_results(self, event: Event, state: State, logger: Logger) -> None: """Logs the result of BlurPool application, including the number of layers that have been replaced.""" assert state.model is not None num_blurpool_layers = module_surgery.count_module_instances(state.model, BlurMaxPool2d) num_blurconv_layers = module_surgery.count_module_instances(state.model, BlurConv2d) # python logger log.info( f'Applied BlurPool to model {state.model.__class__.__name__} ' f'with replace_maxpools={self.replace_maxpools}, ' f'replace_convs={self.replace_convs}. ' f'Model now has {num_blurpool_layers} BlurMaxPool2d ' f'and {num_blurconv_layers} BlurConv2D layers.', ) logger.log_hyperparameters({ 'blurpool/num_blurpool_layers': num_blurpool_layers, 'blurpool/num_blurconv_layers': num_blurconv_layers, })
def _log_surgery_result(model: torch.nn.Module): num_blurpool_layers = module_surgery.count_module_instances(model, BlurMaxPool2d) num_blurconv_layers = module_surgery.count_module_instances(model, BlurConv2d) if num_blurconv_layers == 0 and num_blurpool_layers == 0: warnings.warn( NoEffectWarning( 'Applying BlurPool did not change any layers. ' 'No strided Conv2d or Pool2d layers were found.', ), ) log.info( f'Applied BlurPool to model {model.__class__.__name__}. ' f'Model now has {num_blurpool_layers} BlurMaxPool2d ' f'and {num_blurconv_layers} BlurConv2D layers.', ) def _maybe_replace_strided_conv2d( module: torch.nn.Conv2d, module_index: int, blur_first: bool, min_channels: int = 16, ): already_blurpooled = hasattr(module, '_already_blurpooled') and module._already_blurpooled if np.max(module.stride) > 1 and module.in_channels >= min_channels and not already_blurpooled: return BlurConv2d.from_conv2d(module, module_index, blur_first=blur_first) return None