composer.algorithms.squeeze_excite.squeeze_excite#
composer.algorithms.squeeze_excite.squeeze_excite
Functions
Adds Squeeze-and-Excitation blocks (Hu et al, 2019) after |
Classes
Base class for algorithms. |
|
Enum to represent events in the training loop. |
|
An interface to record training data. |
|
|
Base class for all optimizers. |
Adds Squeeze-and-Excitation blocks (Hu et al, 2019) after the |
|
Squeeze-and-Excitation block from (Hu et al, 2019) |
|
Helper class used to add a |
|
The state of the trainer. |
Attributes
Optional
Sequence
Union
annotations
log
- class composer.algorithms.squeeze_excite.squeeze_excite.SqueezeExcite(latent_channels=64, min_channels=128)[source]#
Bases:
composer.core.algorithm.Algorithm
Adds Squeeze-and-Excitation blocks (Hu et al, 2019) after the
Conv2d
modules in a neural network.Runs on
INIT
. SeeSqueezeExcite2d
for more information.- Parameters
latent_channels (float, optional) โ Dimensionality of the hidden layer within the added MLP. If less than 1, interpreted as a fraction of the number of output channels in the
Conv2d
immediately preceding each Squeeze-and-Excitation block. Default:64
.min_channels (int, optional) โ An SE block is added after a
Conv2d
moduleconv
only ifmin(conv.in_channels, conv.out_channels) >= min_channels
. For models that reduce spatial size and increase channel count deeper in the network, this parameter can be used to only add SE blocks deeper in the network. This may be desirable because SE blocks add less overhead when their inputs have smaller spatial size. Default:128
.
- class composer.algorithms.squeeze_excite.squeeze_excite.SqueezeExcite2d(num_features, latent_channels=0.125)[source]#
Bases:
torch.nn.modules.module.Module
Squeeze-and-Excitation block from (Hu et al, 2019)
This block applies global average pooling to the input, feeds the resulting vector to a single-hidden-layer fully-connected network (MLP), and uses the output of this MLP as attention coefficients to rescale the input. This allows the network to take into account global information about each input, as opposed to only local receptive fields like in a convolutional layer.
- class composer.algorithms.squeeze_excite.squeeze_excite.SqueezeExciteConv2d(*args, latent_channels=0.125, conv=None, **kwargs)[source]#
Bases:
torch.nn.modules.module.Module
Helper class used to add a
SqueezeExcite2d
module after aConv2d
.
- composer.algorithms.squeeze_excite.squeeze_excite.apply_squeeze_excite(model, latent_channels=64, min_channels=128, optimizers=None)[source]#
Adds Squeeze-and-Excitation blocks (Hu et al, 2019) after
Conv2d
layers.A Squeeze-and-Excitation block applies global average pooling to the input, feeds the resulting vector to a single-hidden-layer fully-connected network (MLP), and uses the output of this MLP as attention coefficients to rescale the input. This allows the network to take into account global information about each input, as opposed to only local receptive fields like in a convolutional layer.
- Parameters
model (Module) โ The module to apply squeeze excite replacement.
latent_channels (float, optional) โ Dimensionality of the hidden layer within the added MLP. If less than 1, interpreted as a fraction of the number of output channels in the
Conv2d
immediately preceding each Squeeze-and-Excitation block. Default:64
.min_channels (int, optional) โ An SE block is added after a
Conv2d
moduleconv
only if one of the layerโs input or output channels is greater than this threshold. Default:128
.optimizers (Optimizer | Sequence[Optimizer], optional) โ
Existing optimizers bound to
model.parameters()
. All optimizers that have already been constructed withmodel.parameters()
must be specified here so they will optimize the correct parameters.If the optimizer(s) are constructed after calling this function, then it is safe to omit this parameter. These optimizers will see the correct model parameters.
- Returns
The modified model
Example
import composer.functional as cf from torchvision import models model = models.resnet50() cf.apply_stochastic_depth(model, target_layer_name='ResNetBottleneck')