Modify model architectures.

Algorithms, such as BlurPool, replace model parameters in-place. This module contains helper functions to replace parameters in Module and Optimizer instances.


Surgery replacement function protocol.

The function is provided with a torch.nn.Module and a counter for the number of instances of the module type have been seen. The function should return a replacement torch.nn.Module if the module type should be replaced, or None otherwise.

  • module (Module) โ€“ Source module

  • module_index (int) โ€“ The i-th instance of module class.


Optional[torch.nn.Module] โ€“ The replacement module, or None to indicate no modification.


(Module, int) -> Optional[Module]



Counts the number of instances of module_class in module, recursively.


Modify model in-place by recursively applying replacement policies.


Fully replaces an optimizer's parameters.


Remove old_params from the optimizers and insert new_params.


composer.utils.module_surgery.count_module_instances(module, module_class)[source]#

Counts the number of instances of module_class in module, recursively.


>>> from torch import nn
>>> module = nn.Sequential(nn.Linear(16, 32), nn.Linear(32, 64), nn.ReLU())
>>> count_module_instances(module, nn.Linear)
>>> count_module_instances(module, (nn.Linear, nn.ReLU))
  • module (Module) โ€“ The source module.

  • module_class (Type[Module] | Tuple[Type[Module], ...]) โ€“ The module type (or tuple of module types) to count.


int โ€“ The number of instances of module_class in module

composer.utils.module_surgery.replace_module_classes(module, policies, optimizers=None, recurse_on_replacements=False, indices=None)[source]#

Modify model in-place by recursively applying replacement policies.


The following example replaces all convolution layers with linear layers, and linear layers will be replaced if there are 16 input features. Recursion occurs on replacement.

  • The first replacement policy replaces the nn.Conv2d(1, 32, 3, 1) layer with a nn.Linear(16, 32) layer.

  • The second replacement policy recurses on this replaced layer. Because in_features == 16, this policy replaces the layer with a nn.Linear(32, 64).

  • This policy is invoked again on this new layer. However, since in_features == 32, no replacement occurs and this policy returns None.

  • Since all policies do not match or now return None on all layers, surgery is finished.

  • All replacements, including intermediate replacements, are returned.

>>> from torch import nn
>>> module = nn.Sequential(
...     nn.Conv2d(1, 32, 3, 1),
...     nn.ReLU(),
...     nn.MaxPool2d(2),
...     nn.Flatten(),
...     nn.Linear(5408, 128),
...     nn.ReLU(),
...     nn.LogSoftmax(dim=1),
... )
>>> policies = {
...     nn.Conv2d: lambda x, idx: nn.Linear(16, 32),
...     nn.Linear: lambda x, idx: nn.Linear(32, 64) if x.in_features == 16 else None
... }
>>> replace_module_classes(module, policies, recurse_on_replacements=True)
{Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)): Linear(in_features=16, out_features=32, bias=True), Linear(in_features=16, out_features=32, bias=True): Linear(in_features=32, out_features=64, bias=True)}


When a module is replaced, any tensor values within the module are not copied over to the new module even when the shape is identical. For example, if model weights are initialized prior to calling this function, the initialized weights will not be preserved in any replacements.

  • module (Module) โ€“ Model to modify.

  • policies (Mapping[Module, ReplacementFunction]) โ€“ Mapping of source module class to a replacement function. Matching policies are applied in the iteration order of the dictionary, so if order is important, an OrderedDict should be used. The replacement function may return either another Module or None. If the latter, the source module is not replaced.

  • recurse_on_replacements (bool) โ€“ If true, policies will be applied to any module returned by another policy. For example, if one policy replaces a Conv2d with a module containing another Conv2d, the replacement function will be invoked with this new child Conv2d instance. If the replacement policies are not conditioned on module properties that change during replacement, infinite recursion is possible.

  • indices (Dict[Any, int], optional) โ€“

    A dictionary mapping module types to the number of times theyโ€™ve occurred so far in the recursive traversal of module and its child modules. The value is provided to replacement functions, so they may switch behaviors depending on the number of replacements that occurred for a given module type.


    These indices may not correspond to the order in which modules get called in the forward pass.

  • optimizers (Optimizer | Sequence[Optimizer], optional) โ€“ One or more Optimizer objects. If provided, this function will attempt to remove parameters in replaced modules from these optimizers, and add parameters from the newly-created modules. See update_params_in_optimizer() for more information.


Dict[torch.nn.Module, torch.nn.Module] โ€“ A dictionary of {original_module: replacement_module} reflecting the replacements applied to module and its children.

composer.utils.module_surgery.replace_params_in_optimizer(old_params, new_params, optimizers)[source]#

Fully replaces an optimizerโ€™s parameters.

This differs from update_params_in_optimizer() in that this method is capable of replacing parameters spanning multiple param groups. To accomplish this, this function assumes that parameters in new_params should inherit the param group of the corresponding parameter from old_params. Thus, this function also assumes that old_params and new_params have the same length.

  • old_params (Iterator[Parameter]) โ€“ Current parameters of the optimizer.

  • new_params (Iterator[Parameter]) โ€“ New parameters of the optimizer, given in the same order as old_params. Must be the same length as old_params.

  • optimizers (Optimizer | Sequence[Optimizer]) โ€“ One or more torch.optim.Optimizer objects.

  • NotImplementedError โ€“ If optimizers contains more than one optimizer.

  • RuntimeError โ€“ If old_params and new_params have different lengths, or if a param from old_params cannot be found.

composer.utils.module_surgery.update_params_in_optimizer(old_params, new_params, optimizers)[source]#

Remove old_params from the optimizers and insert new_params.

Newly added parameters will be added to the same param_group as the removed parameters. A RuntimeError will be raised if old_params is split across multiple parameter groups.

This function differs from replace_params_in_optimizer() in that len(old_params) need not equal len(new_params). However, this function does not support replacing parameters accross multiple optimizer groups.


Dynamically removing parameters from a Optimizer and adding parameters to an existing param_groups are not officially supported, so this function may fail when PyTorch is updated. The recommended practice is to instead recreate the optimizer when the parameter set changes To simply add new parameters without replacing existing ones, use add_param_group().

  • old_params (Iterable[Parameter]) โ€“ Parameters in this iterable should be removed if they are not present in new_params.

  • new_params โ€“ Parameters in this iterable should be added if they are not present in old_params.

  • optimizers (Optimizer | Sequence[Optimizer]) โ€“ One or more Optimizer objects

  • NotImplementedError โ€“ If optimizers contains more than one optimizer.

  • RuntimeError โ€“ If not all removed parameters are found in the same parameter group, or if any of them are not found at all.