# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging
from typing import Optional, Sequence, Union
import torch
from torch.optim import Optimizer
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import module_surgery
log = logging.getLogger(__name__)
[docs]def apply_squeeze_excite(
model: torch.nn.Module,
latent_channels: float = 64,
min_channels: int = 128,
optimizers: Optional[Union[Optimizer, Sequence[Optimizer]]] = None,
) -> None:
"""Adds Squeeze-and-Excitation blocks (`Hu et al, 2019 <https://arxiv.org/abs/1709.01507>`_) after
:class:`torch.nn.Conv2d` layers.
A Squeeze-and-Excitation block applies global average pooling to the input,
feeds the resulting vector to a single-hidden-layer fully-connected
network (MLP), and uses the outputs of this MLP as attention coefficients
to rescale the input. This allows the network to take into account global
information about each input, as opposed to only local receptive fields
like in a convolutional layer.
Args:
model (torch.nn.Module): The module to apply squeeze excite replacement to.
latent_channels (float, optional): Dimensionality of the hidden layer within the added
MLP. If less than 1, interpreted as a fraction of the number of
output channels in the :class:`torch.nn.Conv2d` immediately
preceding each Squeeze-and-Excitation block. Default: ``64``.
min_channels (int, optional): An SE block is added after a :class:`torch.nn.Conv2d`
module ``conv`` only if one of the layer's input or output channels is greater than
this threshold. Default: ``128``.
optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional):
Existing optimizer(s) bound to ``model.parameters()``. All optimizers that have already been
constructed with ``model.parameters()`` must be specified here so that
they will optimize the correct parameters.
If the optimizer(s) are constructed *after* calling this function,
then it is safe to omit this parameter. These optimizers will see the correct
model parameters.
Example:
.. testcode::
import composer.functional as cf
from torchvision import models
model = models.resnet50()
cf.apply_stochastic_depth(
model,
target_layer_name='ResNetBottleneck'
)
"""
def convert_module(module: torch.nn.Module, module_index: int):
assert isinstance(module, torch.nn.Conv2d), 'should only be called with conv2d'
already_squeeze_excited = hasattr(module, '_already_squeeze_excited') and module._already_squeeze_excited
if min(module.in_channels, module.out_channels) >= min_channels and not already_squeeze_excited:
return SqueezeExciteConv2d.from_conv2d(module, module_index, latent_channels=latent_channels)
module_surgery.replace_module_classes(model, optimizers=optimizers, policies={torch.nn.Conv2d: convert_module})
[docs]class SqueezeExcite2d(torch.nn.Module):
"""Squeeze-and-Excitation block from (`Hu et al, 2019 <https://arxiv.org/abs/1709.01507>`_)
This block applies global average pooling to the input, feeds the resulting
vector to a single-hidden-layer fully-connected network (MLP), and uses the
outputs of this MLP as attention coefficients to rescale the input. This
allows the network to take into account global information about each input,
as opposed to only local receptive fields like in a convolutional layer.
Args:
num_features (int): Number of features or channels in the input.
latent_channels (float, optional): Dimensionality of the hidden layer within the added
MLP. If less than 1, interpreted as a fraction of ``num_features``. Default: ``0.125``.
"""
def __init__(self, num_features: int, latent_channels: float = .125):
super().__init__()
self.latent_channels = int(latent_channels if latent_channels >= 1 else latent_channels * num_features)
flattened_dims = num_features
self.pool_and_mlp = torch.nn.Sequential(
torch.nn.AdaptiveAvgPool2d(1),
torch.nn.Flatten(),
torch.nn.Linear(flattened_dims, self.latent_channels, bias=False),
torch.nn.ReLU(),
torch.nn.Linear(self.latent_channels, num_features, bias=False),
torch.nn.Sigmoid(),
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
n, c, _, _ = input.shape
attention_coeffs = self.pool_and_mlp(input)
return input * attention_coeffs.reshape(n, c, 1, 1)
[docs]class SqueezeExciteConv2d(torch.nn.Module):
"""Helper class used to add a :class:`.SqueezeExcite2d` module after a :class:`torch.nn.Conv2d`."""
def __init__(self, *args, latent_channels: float = 0.125, conv: Optional[torch.nn.Conv2d] = None, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(*args, **kwargs) if conv is None else conv
self.conv._already_squeeze_excited = True # Mark to avoid rewrapping on duplicate calls # pyright: ignore[reportGeneralTypeIssues]
self.se = SqueezeExcite2d(num_features=self.conv.out_channels, latent_channels=latent_channels)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.se(self.conv(input))
@staticmethod
def from_conv2d(module: torch.nn.Conv2d, module_index: int, latent_channels: float):
return SqueezeExciteConv2d(conv=module, latent_channels=latent_channels)
[docs]class SqueezeExcite(Algorithm):
"""Adds Squeeze-and-Excitation blocks (`Hu et al, 2019 <https://arxiv.org/abs/1709.01507>`_) after the
:class:`torch.nn.Conv2d` modules in a neural network.
Runs on :attr:`.Event.INIT`. See :class:`SqueezeExcite2d` for more information.
Args:
latent_channels (float, optional): Dimensionality of the hidden layer within the added
MLP. If less than 1, interpreted as a fraction of the number of
output channels in the :class:`torch.nn.Conv2d` immediately
preceding each Squeeze-and-Excitation block. Default: ``64``.
min_channels (int, optional): An SE block is added after a :class:`torch.nn.Conv2d`
module ``conv`` only if ``min(conv.in_channels, conv.out_channels) >= min_channels``.
For models that reduce spatial size and increase channel count
deeper in the network, this parameter can be used to only
add SE blocks deeper in the network. This may be desirable
because SE blocks add less overhead when their inputs have
smaller spatial size. Default: ``128``.
"""
def __init__(
self,
latent_channels: float = 64,
min_channels: int = 128,
):
self.latent_channels = latent_channels
self.min_channels = min_channels
def __repr__(self) -> str:
return f'{self.__class__.__name__}(latent_channels={self.latent_channels},min_channels={self.min_channels})'
@staticmethod
def required_on_load() -> bool:
return True
def match(self, event: Event, state: State) -> bool:
return event == Event.INIT
def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
apply_squeeze_excite(
state.model,
optimizers=state.optimizers,
latent_channels=self.latent_channels,
min_channels=self.min_channels,
)
layer_count = module_surgery.count_module_instances(state.model, SqueezeExciteConv2d)
log.info(
f'Applied SqueezeExcite to model {state.model.__class__.__name__} '
f'with latent_channels={self.latent_channels}, '
f'min_channels={self.min_channels}. '
f'Model now has {layer_count} SqueezeExcite layers.',
)
logger.log_hyperparameters({
'squeeze_excite/num_squeeze_excite_layers': layer_count,
})