# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Core code for Stochastic Weight Averaging."""
from __future__ import annotations
import logging
import warnings
from typing import Any, Optional
import torch
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.swa_utils import SWALR, AveragedModel
from composer.core import Algorithm, Event, State, Time, TimeUnit
from composer.loggers import Logger
log = logging.getLogger(__name__)
__all__ = ['SWA']
def _assert_valid_duration(time: Time):
if time.unit == TimeUnit.DURATION and (time < 0 or time > 1):
raise ValueError(f'time in duration units must be [0, 1], got {time}')
[docs]class SWA(Algorithm):
"""Applies Stochastic Weight Averaging (`Izmailov et al, 2018 <https://arxiv.org/abs/1803.05407>`_).
Stochastic Weight Averaging (SWA) averages model weights sampled at
different times near the end of training. This leads to better
generalization than just using the final trained weights.
Because this algorithm needs to maintain both the current value of the
weights and the average of all of the sampled weights, it doubles the
model's memory consumption. Note that this does not mean that the total
memory required doubles, however, since stored activations and the
optimizer state are not doubled.
.. note::
The AveragedModel is currently stored on the CPU device, which may
cause slow training if the model weights are large.
Uses PyTorch's `torch.optim.swa_util
<https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging>`_
under the hood.
See the :doc:`Method Card </method_cards/swa>` for more details.
Example:
.. testcode::
from composer.algorithms import SWA
from composer.trainer import Trainer
swa_algorithm = SWA(
swa_start="6ep",
swa_end="8ep"
)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration="10ep",
algorithms=[swa_algorithm],
optimizers=[optimizer]
)
Args:
swa_start (str, optional): The time string denoting the amount of training
completed before stochastic weight averaging begins. Currently only units of
duration ('dur') and epoch ('ep') are supported. Default: ``'0.7dur'``.
swa_end (str, optional): The time string denoting the amount of training
completed before the baseline (non-averaged) model is replaced with the
stochastic weight averaged model. It's important to have at least one epoch
of training after the baseline model is replaced by the SWA model so that the
SWA model can have its buffers (most importantly its batch norm statistics)
updated. If ``swa_end`` occurs during the final epoch of training (e.g.
``swa_end = 0.9dur`` and ``max_duration = "5ep"``, or ``swa_end = 1.0dur``),
the SWA model will not have its buffers updated, which can negatively impact
accuracy, so ensure ``swa_end`` < :math:`\\frac{N_{epochs}-1}{N_{epochs}}`.
Currently only units of duration ('dur') and epoch ('ep') are supported.
Default: ``'0.97dur'``.
update_interval (str, optional): Time string denoting how often the averaged
model is updated. For example, ``'1ep'`` means the averaged model will be
updated once per epoch and ``'5ba'`` means the averaged model will be updated
every 5 batches. Note that for single-epoch training runs (e.g. many NLP
training runs), ``update_interval`` must be specified in units of ``'ba'``,
otherwise SWA won't happen. Also note that very small update intervals (e.g.
``"1ba"``) can substantially slow down training. Default: ``'1ep'``.
schedule_swa_lr (bool, optional): Flag to determine whether to apply an
SWA-specific LR schedule during the period in which SWA is active. Default:
``False``.
anneal_strategy (str, optional): SWA learning rate annealing schedule strategy.
``"linear"`` for linear annealing, ``"cos"`` for cosine annealing. Default:
``"linear"``.
anneal_steps (int, optional): Number of SWA model updates over which to
anneal SWA learning rate. Note that updates are determined by the
``update_interval`` argument. For example, if ``anneal_steps = 10`` and
``update_interval = '1ep'``, then the SWA LR will be annealed once per epoch
for 10 epochs; if ``anneal_steps = 20`` and ``update_interval = '8ba'``, then
the SWA LR will be annealed once every 8 batches over the course of 160
batches (20 steps * 8 batches/step). Default: ``10``.
swa_lr (float, optional): The final learning rate to anneal towards with the SWA
LR scheduler. Set to ``None`` for no annealing. Default: ``None``.
"""
def __init__(
self,
swa_start: str = '0.7dur',
swa_end: str = '0.97dur',
update_interval: str = '1ep',
schedule_swa_lr: bool = False,
anneal_strategy: str = 'linear',
anneal_steps: int = 10,
swa_lr: Optional[float] = None,
):
warnings.warn(
'SWA has known issues when resuming from a checkpoint on multiple GPUs, which will cause an error when resuming without `load_weights_only=True`.',
)
self.schedule_swa_lr = schedule_swa_lr
self.anneal_strategy = anneal_strategy
self.anneal_steps = anneal_steps
self.swa_lr = swa_lr
self.swa_model: Optional[torch.nn.Module] = None
self.swa_completed = False
self.swa_started = False
# Check timestrings are parsable and convert into time objects
self.swa_start = Time.from_timestring(swa_start)
self.swa_end = Time.from_timestring(swa_end)
self.update_interval = Time.from_timestring(update_interval)
self._validate_time()
if anneal_steps <= 0:
raise ValueError('anneal_steps must be greater than 0')
# Check annealing_strategy string
if self.anneal_strategy.lower() in ['linear', 'lin']:
self.anneal_strategy = 'linear'
elif self.anneal_strategy.lower() in ['cos', 'cosine']:
self.anneal_strategy = 'cos'
else:
raise ValueError("anneal_strategy must be one of {'linear', 'cos'}.")
self.swa_scheduler = None
self.swa_model = None
# Keeps track of # steps so that we can know when to update averaged model
self.step_counter = 0
# Check units for update_interval and set match event accordingly
if self.update_interval.unit == TimeUnit.BATCH:
self.match_event = Event.BATCH_END
elif self.update_interval.unit == TimeUnit.EPOCH:
self.match_event = Event.EPOCH_END
def _validate_time(self):
# validate time units
if self.swa_start.unit != self.swa_end.unit:
raise ValueError(f'swa_start and swa_end must have same units, got {self.swa_start} and {self.swa_end}')
if self.swa_start.unit not in [TimeUnit.DURATION, TimeUnit.EPOCH]:
raise ValueError(f'swa_start must be DURATION or EPOCH, got {self.swa_start.unit}')
if self.update_interval.unit not in [TimeUnit.BATCH, TimeUnit.EPOCH]:
raise ValueError(f'update_iterval must be BATCH or EPOCH, got {self.update_interval.unit}')
# validate time
if self.swa_start >= self.swa_end:
raise ValueError('swa_end must be > swa_start.')
if self.swa_end.unit == TimeUnit.DURATION and self.swa_end == 1:
log.warning(
"'swa_end' = '1dur'. Batch norm statistics of averaged model "
'will not be updated. This will negatively impact accuracy. '
'See the documentation for the `swa_end` parameter for details.',
)
_assert_valid_duration(self.swa_start)
_assert_valid_duration(self.swa_end)
def _get_time(self, state: State):
"""helper function to retrieve either the epoch or the duration depending on the units"""
unit = self.swa_start.unit
if unit == TimeUnit.EPOCH:
return state.timestamp.epoch
elif unit == TimeUnit.DURATION:
time_elapsed = state.get_elapsed_duration()
assert time_elapsed is not None, 'Time should have been set on BATCH_END or EPOCH_END.'
return time_elapsed
else:
raise ValueError('units must be in epoch or duration.')
def _get_last_lr(self, schedulers: list[LRScheduler]):
""" retrieves the last lr from current schedulers. """
if len(schedulers) == 0:
return 1.0
if len(schedulers) != 1:
raise RuntimeError(f'SWA supports only one scheduler, got {len(schedulers)}')
scheduler = schedulers[0]
last_lr = scheduler.get_last_lr()
if len(last_lr) != 1:
raise RuntimeError(f'SWA supports only one LR; instead found {len(last_lr)}')
return last_lr[0]
def match(self, event: Event, state: State) -> bool:
if event == Event.INIT:
return True
# only match on BATCH_END or EPOCH_END, depending on the setting
if event != self.match_event or self.swa_completed:
return False
return self._get_time(state) >= self.swa_start
def _initialize_swa(self, state: State) -> None:
if self.schedule_swa_lr:
self.swa_lr = self._get_last_lr(state.schedulers)
if len(state.optimizers) != 1:
raise RuntimeError('SWA supports only one optimizer')
self.swa_scheduler = SWALR(
state.optimizers[0],
swa_lr=self.swa_lr,
anneal_epochs=self.anneal_steps,
anneal_strategy=self.anneal_strategy,
)
self.swa_model = AveragedModel(state.model, device=torch.device('cpu'))
def apply(self, event: Event, state: State, logger: Logger) -> None:
if event == event.INIT:
# on trainer init, we create the schedulers and models
# so that the checkpoints can be loaded
self._initialize_swa(state)
return
if not self.swa_started:
# re-initialize swa once time > swa_start
self._initialize_swa(state)
self.swa_started = True
if self.step_counter % self.update_interval.value == 0:
assert self.swa_model is not None
self.swa_model.update_parameters(state.model) # type: ignore
if self.schedule_swa_lr:
assert self.swa_scheduler is not None
self.swa_scheduler.step()
self.step_counter += 1
# Determine whether it's time to end SWA
if self._get_time(state) >= self.swa_end:
self.swa_completed = True
if state.get_elapsed_duration() == 1:
log.warning((
'The baseline model was replaced with the SWA model after the end of '
'training. This means that SWA model will not have its batch norm '
'statistics updated. This will negatively impact accuracy. See the '
'documentation for the `swa_end` parameter for details.'
))
state.model.load_state_dict(self.swa_model.module.state_dict()) # type: ignore
log.info('Set model to the averaged model')
def state_dict(self) -> dict[str, Any]:
state_dict = super().state_dict()
# we pop the anneal_func from the SWALR state
# since it is set in the SWALR __init__
swa_scheduler_state = None
if self.swa_scheduler:
swa_scheduler_state = self.swa_scheduler.state_dict()
swa_scheduler_state.pop('anneal_func')
state_dict = {
'swa_model': self.swa_model.state_dict() if self.swa_model else None,
'swa_completed': self.swa_completed,
'swa_started': self.swa_started,
'swa_scheduler': swa_scheduler_state,
'step_counter': self.step_counter,
**state_dict,
}
return state_dict
def load_state_dict(self, state: dict[str, Any]) -> None:
self.swa_completed = state['swa_completed']
self.step_counter = state['step_counter']
self.swa_started = state['swa_started']
if self.swa_scheduler and state['swa_scheduler']:
self.swa_scheduler.load_state_dict(state['swa_scheduler'])
if self.swa_model and state['swa_model']:
self.swa_model.load_state_dict(state['swa_model'])