SWA#

class composer.algorithms.SWA(swa_start='0.7dur', swa_end='0.97dur', update_interval='1ep', schedule_swa_lr=False, anneal_strategy='linear', anneal_steps=10, swa_lr=None)[source]#

Applies Stochastic Weight Averaging (Izmailov et al, 2018).

Stochastic Weight Averaging (SWA) averages model weights sampled at different times near the end of training. This leads to better generalization than just using the final trained weights.

Because this algorithm needs to maintain both the current value of the weights and the average of all of the sampled weights, it doubles the modelโ€™s memory consumption. Note that this does not mean that the total memory required doubles, however, since stored activations and the optimizer state are not doubled.

Note

The AveragedModel is currently stored on the CPU device, which may cause slow training if the model weights are large.

Uses PyTorchโ€™s torch.optim.swa_util under the hood.

See the Method Card for more details.

Example

from composer.algorithms import SWA
from composer.trainer import Trainer

swa_algorithm = SWA(
    swa_start="6ep",
    swa_end="8ep"
)
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration="10ep",
    algorithms=[swa_algorithm],
    optimizers=[optimizer]
)
Parameters
  • swa_start (str, optional) โ€“ The time string denoting the amount of training completed before stochastic weight averaging begins. Currently only units of duration (โ€˜durโ€™) and epoch (โ€˜epโ€™) are supported. Default: '0.7dur'.

  • swa_end (str, optional) โ€“ The time string denoting the amount of training completed before the baseline (non-averaged) model is replaced with the stochastic weight averaged model. Itโ€™s important to have at least one epoch of training after the baseline model is replaced by the SWA model so that the SWA model can have its buffers (most importantly its batch norm statistics) updated. If swa_end occurs during the final epoch of training (e.g. swa_end = 0.9dur and max_duration = "5ep", or swa_end = 1.0dur), the SWA model will not have its buffers updated, which can negatively impact accuracy, so ensure swa_end < \(\frac{N_{epochs}-1}{N_{epochs}}\). Currently only units of duration (โ€˜durโ€™) and epoch (โ€˜epโ€™) are supported. Default: '0.97dur'.

  • update_interval (str, optional) โ€“ Time string denoting how often the averaged model is updated. For example, '1ep' means the averaged model will be updated once per epoch and '5ba' means the averaged model will be updated every 5 batches. Note that for single-epoch training runs (e.g. many NLP training runs), update_interval must be specified in units of 'ba', otherwise SWA wonโ€™t happen. Also note that very small update intervals (e.g. "1ba") can substantially slow down training. Default: '1ep'.

  • schedule_swa_lr (bool, optional) โ€“ Flag to determine whether to apply an SWA-specific LR schedule during the period in which SWA is active. Default: False.

  • anneal_strategy (str, optional) โ€“ SWA learning rate annealing schedule strategy. "linear" for linear annealing, "cos" for cosine annealing. Default: "linear".

  • anneal_steps (int, optional) โ€“ Number of SWA model updates over which to anneal SWA learning rate. Note that updates are determined by the update_interval argument. For example, if anneal_steps = 10 and update_interval = '1ep', then the SWA LR will be annealed once per epoch for 10 epochs; if anneal_steps = 20 and update_interval = '8ba', then the SWA LR will be annealed once every 8 batches over the course of 160 batches (20 steps * 8 batches/step). Default: 10.

  • swa_lr (float, optional) โ€“ The final learning rate to anneal towards with the SWA LR scheduler. Set to None for no annealing. Default: None.