SWA#
- class composer.algorithms.SWA(swa_start='0.7dur', swa_end='0.97dur', update_interval='1ep', schedule_swa_lr=False, anneal_strategy='linear', anneal_steps=10, swa_lr=None)[source]#
Applies Stochastic Weight Averaging (Izmailov et al, 2018).
Stochastic Weight Averaging (SWA) averages model weights sampled at different times near the end of training. This leads to better generalization than just using the final trained weights.
Because this algorithm needs to maintain both the current value of the weights and the average of all of the sampled weights, it doubles the modelโs memory consumption. Note that this does not mean that the total memory required doubles, however, since stored activations and the optimizer state are not doubled.
Note
The AveragedModel is currently stored on the CPU device, which may cause slow training if the model weights are large.
Uses PyTorchโs torch.optim.swa_util under the hood.
See the Method Card for more details.
Example
from composer.algorithms import SWA from composer.trainer import Trainer swa_algorithm = SWA( swa_start="6ep", swa_end="8ep" ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="10ep", algorithms=[swa_algorithm], optimizers=[optimizer] )
- Parameters
swa_start (str, optional) โ The time string denoting the amount of training completed before stochastic weight averaging begins. Currently only units of duration (โdurโ) and epoch (โepโ) are supported. Default:
'0.7dur'
.swa_end (str, optional) โ The time string denoting the amount of training completed before the baseline (non-averaged) model is replaced with the stochastic weight averaged model. Itโs important to have at least one epoch of training after the baseline model is replaced by the SWA model so that the SWA model can have its buffers (most importantly its batch norm statistics) updated. If
swa_end
occurs during the final epoch of training (e.g.swa_end = 0.9dur
andmax_duration = "5ep"
, orswa_end = 1.0dur
), the SWA model will not have its buffers updated, which can negatively impact accuracy, so ensureswa_end
< \(\frac{N_{epochs}-1}{N_{epochs}}\). Currently only units of duration (โdurโ) and epoch (โepโ) are supported. Default:'0.97dur'
.update_interval (str, optional) โ Time string denoting how often the averaged model is updated. For example,
'1ep'
means the averaged model will be updated once per epoch and'5ba'
means the averaged model will be updated every 5 batches. Note that for single-epoch training runs (e.g. many NLP training runs),update_interval
must be specified in units of'ba'
, otherwise SWA wonโt happen. Also note that very small update intervals (e.g."1ba"
) can substantially slow down training. Default:'1ep'
.schedule_swa_lr (bool, optional) โ Flag to determine whether to apply an SWA-specific LR schedule during the period in which SWA is active. Default:
False
.anneal_strategy (str, optional) โ SWA learning rate annealing schedule strategy.
"linear"
for linear annealing,"cos"
for cosine annealing. Default:"linear"
.anneal_steps (int, optional) โ Number of SWA model updates over which to anneal SWA learning rate. Note that updates are determined by the
update_interval
argument. For example, ifanneal_steps = 10
andupdate_interval = '1ep'
, then the SWA LR will be annealed once per epoch for 10 epochs; ifanneal_steps = 20
andupdate_interval = '8ba'
, then the SWA LR will be annealed once every 8 batches over the course of 160 batches (20 steps * 8 batches/step). Default:10
.swa_lr (float, optional) โ The final learning rate to anneal towards with the SWA LR scheduler. Set to
None
for no annealing. Default:None
.