compute_ema#
- composer.functional.compute_ema(model, ema_model, smoothing=0.99)[source]#
Updates the weights of
ema_model
to be closer to the weights ofmodel
according to an exponential weighted average. Weights are updated according to\[W_{ema_model}^{(t+1)} = smoothing\times W_{ema_model}^{(t)}+(1-smoothing)\times W_{model}^{(t)} \]The update to
ema_model
happens in place.The half life of the weights for terms in the average is given by
\[t_{1/2} = -\frac{\log(2)}{\log(smoothing)} \]Therefore, to set smoothing to obtain a target half life, set smoothing according to
\[smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right] \]- Parameters
model (Module) โ the model containing the latest weights to use to update the moving average weights.
ema_model (Module, EMAParameters) โ the model containing the moving average weights to be updated.
smoothing (float, optional) โ the coefficient representing the degree to which older observations are kept. Must be in the interval \((0, 1)\). Default:
0.99
.
Example
import composer.functional as cf from torchvision import models model = models.resnet50() ema_model = models.resnet50() cf.compute_ema(model, ema_model, smoothing=0.9)