EMA#
- class composer.algorithms.EMA(half_life='1000ba', smoothing=None, ema_start='0.0dur', update_interval=None)[source]#
Maintains a set of weights that follow the exponential moving average of the training model weights.
Weights are updated according to
\[W_{ema_model}^{(t+1)} = smoothing\times W_{ema_model}^{(t)}+(1-smoothing)\times W_{model}^{(t)} \]Where the smoothing is determined from
half_life
according to\[smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right] \]Model evaluation is done with the moving average weights, which can result in better generalization. Because of the ema weights, EMA can double the modelโs memory consumption. Note that this does not mean that the total memory required doubles, since stored activations and the optimizer state are not duplicated. EMA also uses a small amount of extra compute to update the moving average weights.
See the Method Card for more details.
- Parameters
half_life (str, optional) โ The time string specifying the half life for terms in the average. A longer half life means old information is remembered longer, a shorter half life means old information is discared sooner. A half life of
0
means no averaging is done, an infinite half life means no update is done. Currently only units of epoch (โepโ) and batch (โbaโ). Time must be an integer value in the units specified. Cannot be used ifsmoothing
is also specified. Default:"1000ba"
.smoothing (float, optional) โ The coefficient representing the degree to which older observations are kept. Must be in the interval \((0, 1)\). Cannot be used if
half_life
also specified. This value will not be adjusted ifupdate_interval
is changed. Default:None
.ema_start (str, optional) โ The time string denoting the amount of training completed before EMA begins. Currently only units of duration (โdurโ), batch (โbaโ) and epoch (โepโ) are supported. Default:
'0.0dur'
.update_interval (str, optional) โ The time string specifying the period at which updates are done. For example, an
update_interval='1ep'
means updates are done every epoch, whileupdate_interval='10ba'
means updates are done once every ten batches. Units must match the units used to specifyhalf_life
if not usingsmoothing
. If not specified,update_interval
will default to1
in the units ofhalf_life
, or"1ba"
ifsmoothing
is specified. Time must be an integer value in the units specified. Default:None
.
Example
from composer.algorithms import EMA algorithm = EMA(half_life='1000ba', update_interval='1ba') trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", algorithms=[algorithm], optimizers=[optimizer] )
- ensure_compatible_state_dict(state)[source]#
Ensure state dicts created prior to Composer 0.13.0 are compatible with later versions.