compute_ema#

composer.functional.compute_ema(model, ema_model, smoothing=0.99)[source]#

Updates the weights of ema_model to be closer to the weights of model according to an exponential weighted average. Weights are updated according to

\[W_{ema_model}^{(t+1)} = smoothing\times W_{ema_model}^{(t)}+(1-smoothing)\times W_{model}^{(t)} \]

The update to ema_model happens in place.

The half life of the weights for terms in the average is given by

\[t_{1/2} = -\frac{\log(2)}{\log(smoothing)} \]

Therefore, to set smoothing to obtain a target half life, set smoothing according to

\[smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right] \]
Parameters
  • model (Module) โ€“ the model containing the latest weights to use to update the moving average weights.

  • ema_model (Module, EMAParameters) โ€“ the model containing the moving average weights to be updated.

  • smoothing (float, optional) โ€“ the coefficient representing the degree to which older observations are kept. Must be in the interval \((0, 1)\). Default: 0.99.

Example

import composer.functional as cf
from torchvision import models
model = models.resnet50()
ema_model = models.resnet50()
cf.compute_ema(model, ema_model, smoothing=0.9)