compute_ema#

composer.functional.compute_ema(model, ema_model, smoothing=0.99)[source]#

Updates the weights of ema_model to be closer to the weights of model according to an exponential weighted average. Weights are updated according to

Wemamodel(t+1)=smoothingร—Wemamodel(t)+(1โˆ’smoothing)ร—Wmodel(t)W_{ema_model}^{(t+1)} = smoothing\times W_{ema_model}^{(t)}+(1-smoothing)\times W_{model}^{(t)}

The update to ema_model happens in place.

The half life of the weights for terms in the average is given by

t1/2=โˆ’logโก(2)logโก(smoothing)t_{1/2} = -\frac{\log(2)}{\log(smoothing)}

Therefore, to set smoothing to obtain a target half life, set smoothing according to

smoothing=expโก[โˆ’logโก(2)t1/2]smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right]
Parameters
  • model (Module) โ€“ the model containing the latest weights to use to update the moving average weights.

  • ema_model (Module, EMAParameters) โ€“ the model containing the moving average weights to be updated.

  • smoothing (float, optional) โ€“ the coefficient representing the degree to which older observations are kept. Must be in the interval (0,1)(0, 1). Default: 0.99.

Example

import composer.functional as cf
from torchvision import models
model = models.resnet50()
ema_model = models.resnet50()
cf.compute_ema(model, ema_model, smoothing=0.9)