EMA#

class composer.algorithms.EMA(half_life='1000ba', smoothing=None, ema_start='0.0dur', update_interval=None)[source]#

Maintains a set of weights that follow the exponential moving average of the training model weights.

Weights are updated according to

\[W_{ema_model}^{(t+1)} = smoothing\times W_{ema_model}^{(t)}+(1-smoothing)\times W_{model}^{(t)} \]

Where the smoothing is determined from half_life according to

\[smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right] \]

Model evaluation is done with the moving average weights, which can result in better generalization. Because of the ema weights, EMA can double the modelโ€™s memory consumption. Note that this does not mean that the total memory required doubles, since stored activations and the optimizer state are not duplicated. EMA also uses a small amount of extra compute to update the moving average weights.

See the Method Card for more details.

Parameters
  • half_life (str, optional) โ€“ The time string specifying the half life for terms in the average. A longer half life means old information is remembered longer, a shorter half life means old information is discared sooner. A half life of 0 means no averaging is done, an infinite half life means no update is done. Currently only units of epoch (โ€˜epโ€™) and batch (โ€˜baโ€™). Time must be an integer value in the units specified. Cannot be used if smoothing is also specified. Default: "1000ba".

  • smoothing (float, optional) โ€“ The coefficient representing the degree to which older observations are kept. Must be in the interval \((0, 1)\). Cannot be used if half_life also specified. This value will not be adjusted if update_interval is changed. Default: None.

  • ema_start (str, optional) โ€“ The time string denoting the amount of training completed before EMA begins. Currently only units of duration (โ€˜durโ€™), batch (โ€˜baโ€™) and epoch (โ€˜epโ€™) are supported. Default: '0.0dur'.

  • update_interval (str, optional) โ€“ The time string specifying the period at which updates are done. For example, an update_interval='1ep' means updates are done every epoch, while update_interval='10ba' means updates are done once every ten batches. Units must match the units used to specify half_life if not using smoothing. If not specified, update_interval will default to 1 in the units of half_life, or "1ba" if smoothing is specified. Time must be an integer value in the units specified. Default: None.

Example

from composer.algorithms import EMA
algorithm = EMA(half_life='1000ba', update_interval='1ba')
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration="1ep",
    algorithms=[algorithm],
    optimizers=[optimizer]
)
ensure_compatible_state_dict(state)[source]#

Ensure state dicts created prior to Composer 0.13.0 are compatible with later versions.

get_ema_model(model)[source]#

Replaces the parameters of the supplied model with the ema parameters if they are not already active.

Parameters

model (Module) โ€“ The model to replace the parameters of.

Returns

torch.nn.Module โ€“ The model with the ema parameters.

get_training_model(model)[source]#

Replaces the parameters of the supplied model with the training parameters if they are not already active.

Parameters

model (Module) โ€“ The model to replace the parameters of.

Returns

torch.nn.Module โ€“ The model with the training parameters.