Source code for composer.algorithms.swa.swa

# Copyright 2021 MosaicML. All Rights Reserved.

"""Core code for Stochastic Weight Averaging."""

from __future__ import annotations

import logging
from typing import Optional

import torch
from torch.optim.swa_utils import SWALR, AveragedModel

from composer.core import Algorithm, Event, State, Time, TimeUnit
from composer.loggers import Logger

log = logging.getLogger(__name__)

__all__ = ['SWA']


[docs]class SWA(Algorithm): """Apply Stochastic Weight Averaging (`Izmailov et al, 2018 <https://arxiv.org/abs/1803.05407>`_) Stochastic Weight Averaging (SWA) averages model weights sampled at different times near the end of training. This leads to better generalization than just using the final trained weights. Because this algorithm needs to maintain both the current value of the weights and the average of all of the sampled weights, it doubles the model's memory consumption. Note that this does not mean that the total memory required doubles, however, since stored activations and the optimizer state are not doubled. Uses PyTorch's `torch.optim.swa_util <https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging>`_ under the hood. See the :doc:`Method Card </method_cards/swa>` for more details. Example: .. testcode:: from composer.algorithms import SWA from composer.trainer import Trainer swa_algorithm = SWA( swa_start="6ep", swa_end="8ep" ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="10ep", algorithms=[swa_algorithm], optimizers=[optimizer] ) Args: swa_start (str, optional): The time string denoting the amount of training completed before stochastic weight averaging begins. Currently only units of duration ('dur') and epoch ('ep') are supported. Defalt = ``'0.7dur'``. swa_end (str, optional): The time string denoting the amount of training completed before the baseline (non-averaged) model is replaced with the stochastic weight averaged model. It's important to have at least one epoch of training after the baseline model is replaced by the SWA model so that the SWA model can have its buffers (most importantly its batch norm statistics) updated. If ``swa_end`` occurs during the final epoch of training (e.g. ``swa_end = 0.9dur`` and ``max_duration = "5ep"``, or ``swa_end = 1.0dur``), the SWA model will not have its buffers updated, which can negatively impact accuracy, so ensure ``swa_end`` < :math:`\\frac{N_{epochs}-1}{N_{epochs}}`. Currently only units of duration ('dur') and epoch ('ep') are supported. Default = ``'0.97dur'``. update_interval (str, optional): Time string denoting how often the averaged model is updated. For example, ``'1ep'`` means the averaged model will be updated once per epoch, and ``'5ba'`` means the averaged model will be updated every 5 batches. Note that for single-epoch training runs (e.g. many NLP training runs) ``update_interval`` must be specified in units of ``'ba'``, otherwise SWA won't happen. Also note that very small update intervals (e.g. ``"1ba"``) can substantially slow down training. Default = ``'1ep'``. schedule_swa_lr (bool, optional): Flag to determine whether to apply an SWA-specific LR schedule during the period in which SWA is active. Default = ``False``. anneal_strategy (str, optional): SWA learning rate annealing schedule strategy. "linear" for linear annealing, "cos" for cosine annealing. Default = ``"linear"``. anneal_steps (int, optional): Number of SWA model updates over which to anneal SWA learning rate. Note that updates are determined by the ``update_interval`` argument. For example, if ``anneal_steps = 10`` and ``update_interval = '1ep'``, then the SWA LR will be annealed once per epoch for 10 epochs; if ``anneal_steps = 20`` and ``update_interval = '8ba'``, then the SWA LR will be annealed once every 8 batches over the course of 160 batches (20 steps * 8 batches/step). Default = ``10``. swa_lr (float, optional): The final learning rate to anneal towards with the SWA LR scheduler. Set to ``None`` for no annealing. Default = ``None``. """ def __init__(self, swa_start: str = "0.7dur", swa_end: str = "0.97dur", update_interval: str = "1ep", schedule_swa_lr: bool = False, anneal_strategy: str = "linear", anneal_steps: int = 10, swa_lr: Optional[float] = None): self.schedule_swa_lr = schedule_swa_lr self.anneal_strategy = anneal_strategy self.anneal_steps = anneal_steps self.swa_lr = swa_lr self.swa_model: Optional[torch.nn.Module] = None self.swa_completed = False # Check timestrings are parsable and convert into time objects try: self.swa_start = Time.from_timestring(swa_start) except ValueError as error: raise ValueError(f"Invalid time string for parameter swa_start") from error try: self.swa_end = Time.from_timestring(swa_end) except ValueError as error: raise ValueError(f"Invalid time string for parameter swa_end") from error try: self.update_interval = Time.from_timestring(update_interval) except ValueError as error: raise ValueError(f"Invalid time string for parameter update_interval") from error # Check time objects have supported units for time_attr in ["swa_start", "swa_end"]: time_obj = getattr(self, time_attr) if time_obj.unit not in [TimeUnit.DURATION, TimeUnit.EPOCH]: raise ValueError(f"Invalid unit string for parameter {time_attr}: {time_obj.unit}") if self.update_interval.unit not in [TimeUnit.BATCH, TimeUnit.EPOCH]: raise ValueError(f"Invalid unit string for parameter update_interval: " f"{self.update_interval.unit}") # Check time objects have valid values if self.swa_start.unit == TimeUnit.DURATION: if self.swa_start < 0 or self.swa_start >= 1: raise ValueError("If swa_start is specified in units of 'dur', it must " "be in the interval [0, 1).") if self.swa_end.unit == TimeUnit.DURATION: if self.swa_end == 1: log.warning("'swa_end' = '1dur'. Batch norm statistics of averaged model " "will not be updated. This will negatively impact accuracy. " "See the documentation for the `swa_end` parameter for details.") if self.swa_end > 1: raise ValueError("If swa_end is specified in units of 'dur', it must be โ‰ค1.") if self.update_interval < 1: raise ValueError("update_interval must be โ‰ฅ 1.") if anneal_steps <= 0: raise ValueError("anneal_steps must be greater than 0") # Check annealing_strategy string if self.anneal_strategy.lower() in ["linear", "lin"]: self.anneal_strategy = "linear" elif self.anneal_strategy.lower() in ["cos", "cosine"]: self.anneal_strategy = "cos" else: raise ValueError("Parameter 'anneal_strategy' must have an argument that is one of {'linear', 'cos'}.") self.swa_scheduler = None self.swa_model = None # Keeps track of # steps so that we can know when to update averaged model self.step_counter = None # Check units for update_interval and set match event accordingly if self.update_interval.unit == TimeUnit.BATCH: self.match_event = Event.BATCH_END elif self.update_interval.unit == TimeUnit.EPOCH: self.match_event = Event.EPOCH_END def match(self, event: Event, state: State) -> bool: if self.swa_start.unit == TimeUnit.DURATION: should_start_swa = state.get_elapsed_duration() >= self.swa_start and not self.swa_completed elif self.swa_start.unit == TimeUnit.EPOCH: should_start_swa = state.timer.get("ep") >= self.swa_start and not self.swa_completed else: should_start_swa = False return event == self.match_event and should_start_swa def apply(self, event: Event, state: State, logger: Logger) -> None: if self.step_counter is None: self.step_counter = 0 if self.swa_scheduler is None and self.schedule_swa_lr: if self.swa_lr is None: if len(state.schedulers) != 1: raise RuntimeError("SWA supports only one scheduler") scheduler = state.schedulers[0] last_lr = scheduler.get_last_lr() if len(last_lr) != 1: raise RuntimeError(f"SWA supports only one LR; instead found {len(last_lr)}") log.info(f'Setting SWA LR to {last_lr}') self.swa_lr = last_lr[0] if len(state.optimizers) != 1: raise RuntimeError("SWA supports one and only one optimizer") self.swa_scheduler = SWALR( state.optimizers[0], swa_lr=self.swa_lr, anneal_epochs=self.anneal_steps, anneal_strategy=self.anneal_strategy, ) if self.step_counter % self.update_interval.value == 0: if self.swa_model is None: self.swa_model = AveragedModel(state.model) self.swa_model.update_parameters(state.model) # type: ignore if self.schedule_swa_lr: if self.swa_scheduler is None: raise ValueError('SWA LR scheduler was not set.') self.swa_scheduler.step() self.step_counter += 1 # Determine whether it's time to end SWA if self.swa_end.unit == TimeUnit.DURATION and (state.get_elapsed_duration() >= self.swa_end): self.swa_completed = True if self.swa_end.unit == TimeUnit.EPOCH and (state.timer.get("ep") >= self.swa_end): self.swa_completed = True if self.swa_completed: if state.get_elapsed_duration() == 1: log.warning("The baseline model was replaced with the SWA model after the end of " "training. This means that SWA model will not have its batch norm " "statistics updated. This will negatively impact accuracy. See the " "documentation for the `swa_end` parameter for details.") state.model.load_state_dict(self.swa_model.module.state_dict()) # type: ignore log.info('Set model to the averaged model')