class composer.algorithms.EMA(half_life, update_interval=None, train_with_ema_weights=False)[source]#

Maintains a shadow model with weights that follow the exponential moving average of the trained model weights.

Weights are updated according to

\[W_{ema_model}^{(t+1)} = smoothing\times W_{ema_model}^{(t)}+(1-smoothing)\times W_{model}^{(t)} \]

Where the smoothing is determined from half_life according to

\[smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right] \]

Model evaluation is done with the moving average weights, which can result in better generalization. Because of the shadow models, EMA triples the modelโ€™s memory consumption. Note that this does not mean that the total memory required doubles, since stored activations and the optimizer state are not duplicated. EMA also uses a small amount of extra compute to update the moving average weights.

See the Method Card for more details.

  • half_life (str) โ€“ The time string specifying the half life for terms in the average. A longer half life means old information is remembered longer, a shorter half life means old information is discared sooner. A half life of 0 means no averaging is done, an infinite half life means no update is done. Currently only units of epoch (โ€˜epโ€™) and batch (โ€˜baโ€™). Value must be an integer.

  • update_interval (str, optional) โ€“ The time string specifying the period at which updates are done. For example, an update_interval='1ep' means updates are done every epoch, while update_interval='10ba' means updates are done once every ten batches. Units must match the units used to specify half_life. If not specified, update_interval will default to 1 in the units of half_life. Value must be an integer. Default: None.

  • train_with_ema_weights (bool, optional) โ€“ An experimental feature that uses the ema weights as the training weights. In most cases should be left as False. Default False.


from composer.algorithms import EMA
algorithm = EMA(half_life='50ba', update_interval='1ba')
trainer = Trainer(

Copies ema model parameters and buffers to the input model and returns it.


model (Module) โ€“ the model to convert into the ema model.


torch.nn.Module โ€“ The input model with parameters and buffers replaced with the averaged parameters and buffers.