Source code for composer.algorithms.sam.sam

# Copyright 2021 MosaicML. All Rights Reserved.

from __future__ import annotations

import logging
from typing import Optional

import torch

from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import ensure_tuple

log = logging.getLogger(__name__)


[docs]class SAMOptimizer(torch.optim.Optimizer): """Wraps an optimizer with sharpness-aware minimization (`Foret et al, 2020 <https://arxiv.org/abs/2010.01412>`_). See :class:`SAM` for details. Implementation based on https://github.com/davda54/sam """ def __init__(self, base_optimizer, rho: float = 0.05, epsilon: float = 1.0e-12, interval: int = 1, **kwargs): assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" self.base_optimizer = base_optimizer self.global_step = 0 self.interval = interval self._step_supports_amp_closure = True # Flag for Composer trainer defaults = dict(rho=rho, epsilon=epsilon, **kwargs) super(SAMOptimizer, self).__init__(self.base_optimizer.param_groups, defaults) @torch.no_grad() def sub_e_w(self): for group in self.param_groups: for p in group["params"]: if "e_w" not in self.state[p]: continue e_w = self.state[p]["e_w"] # retrieve stale e(w) p.sub_(e_w) # get back to "w" from "w + e(w)" @torch.no_grad() def first_step(self): grad_norm = self._grad_norm() for group in self.param_groups: scale = group["rho"] / (grad_norm + group["epsilon"]) for p in group["params"]: if p.grad is None: continue e_w = p.grad * scale.to(p) p.add_(e_w) # climb to the local maximum "w + e(w)" self.state[p]["e_w"] = e_w @torch.no_grad() def second_step(self): for group in self.param_groups: for p in group["params"]: if p.grad is None or "e_w" not in self.state[p]: continue p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)" self.base_optimizer.step() # do the actual "sharpness-aware" update @torch.no_grad() def step(self, closure=None): assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass loss = None if (self.global_step + 1) % self.interval == 0: # Compute gradient at (w) per-GPU, and do not sync loss = closure(ddp_sync=False) # type: ignore if loss: self.first_step() # Compute e(w) and set weights to (w + (e(w)) separately per-GPU if closure(): # Compute gradient at (w + e(w)) self.second_step() # Reset weights to (w) and step base optimizer else: self.sub_e_w() # If second forward-backward closure fails, reset weights to (w) else: loss = closure() if loss: self.base_optimizer.step() self.global_step += 1 return loss def _grad_norm(self): norm = torch.norm(torch.stack( [p.grad.norm(p=2) for group in self.param_groups for p in group["params"] if p.grad is not None]), p="fro") return norm
[docs]class SAM(Algorithm): """Adds sharpness-aware minimization (`Foret et al, 2020 <https://arxiv.org/abs/2010.01412>`_) by wrapping an existing optimizer with a :class:`SAMOptimizer`. Args: rho (float, optional): The neighborhood size parameter of SAM. Must be greater than 0. Default: ``0.05``. epsilon (float, optional): A small value added to the gradient norm for numerical stability. Default: ``1e-12``. interval (int, optional): SAM will run once per ``interval`` steps. A value of 1 will cause SAM to run every step. Steps on which SAM runs take roughly twice as much time to complete. Default: ``1``. """ def __init__( self, rho: float = 0.05, epsilon: float = 1.0e-12, interval: int = 1, ): """__init__ is constructed from the same fields as in hparams.""" self.rho = rho self.epsilon = epsilon self.interval = interval
[docs] def match(self, event: Event, state: State) -> bool: """Run on Event.INIT. Args: event (:class:`Event`): The current event. state (:class:`State`): The current state. Returns: bool: True if this algorithm should run now """ return event == Event.INIT
[docs] def apply(self, event: Event, state: State, logger: Optional[Logger]) -> Optional[int]: """Applies SAM by wrapping the base optimizer with the SAM optimizer. Args: event (Event): the current event state (State): the current trainer state logger (Logger): the training logger """ assert state.optimizers is not None state.optimizers = tuple( SAMOptimizer( base_optimizer=optimizer, rho=self.rho, epsilon=self.epsilon, interval=self.interval, ) for optimizer in ensure_tuple(state.optimizers))