Source code for composer.optim.scheduler

# Copyright 2021 MosaicML. All Rights Reserved.

import logging
import re
from abc import ABC
from dataclasses import asdict, dataclass, fields
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import torch
import yahp as hp
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, MultiStepLR, StepLR, _LRScheduler

from composer.core.types import Optimizer, Scheduler
from composer.optim.pytorch_future import WarmUpLR

log = logging.getLogger(__name__)

Time = str
"""
Time: For scheduler hparams, we support providing time (e.g. milestones) as
both integers, which will be interpreted as epochs, or as a string in format:
* '12ep' -- 12 epochs
* '1024ba' -- 1024 batches
* '12ep32ba' -- 12 epochs and 32 batches

The provided time is converted and represented internally.
"""

_interval_doc = 'frequency of step() calls, either "batch" or "epoch". Default: "epoch"'

STR_REGEX = re.compile(r'^(?:([0-9]*)(ep))?(?:([0-9]*)(ba))?$', flags=re.IGNORECASE)

# Allow (batch, batches) or (epoch, epochs). Also accept "step" ~ "batch"
INTERVAL_MAP = {
    'batch': 'batch',
    'batches': 'batch',
    'epoch': 'epoch',
    'epochs': 'epoch',
    'step': 'batch',
    'steps': 'batch'
}


def _parse_time_string(timestring: str) -> Tuple[int, int]:
    """Parse timestring to (epoch, batches).

    Args:
        timestring (str): String in the format 'XXepYYba'.

    Returns:
        tuple: (epochs, batches)

    Raises:
        ValueError: The timestring is invalid

    Examples:
        >>> _parse_time_string('32ep173ba')
        (32, 173)
        >>> _parse_time_string('12ep')
        (12, 0)
        >>> _parse_time_string('1024ba')
        (0, 1024)
    """

    match = STR_REGEX.findall(timestring)
    if len(match) != 1:
        raise ValueError(f'Invalid timestring: {timestring}. Should be of format 32ep15ba, or 99ba or 7ep')
    match = match[0]

    epochs = 0 if 'ep' not in match else int(match[match.index('ep') - 1])
    batches = 0 if 'ba' not in match else int(match[match.index('ba') - 1])

    return epochs, batches


def _convert_time(time: Time, steps_per_epoch: Optional[int] = None, interval: str = 'epoch') -> int:
    """Convert time to either batches or epochs (based on interval argument)."""
    if isinstance(time, int):
        return time
    if steps_per_epoch is None:
        raise ValueError('steps_per_epoch must be provided to parse time string.')

    epochs, batches = _parse_time_string(time)
    if interval in ('batches', 'batch', 'steps', 'step'):
        log.info(f'Converting {time}, {interval} to {batches + epochs * steps_per_epoch}')
        return batches + epochs * steps_per_epoch
    elif interval in ('epochs', 'epoch'):
        epochs = epochs + batches // steps_per_epoch
        batches = batches % steps_per_epoch
        if batches != 0:
            log.warning('Scheduler is stepping every epoch, but provided timestring '
                        f'{time} had batches. Ignoring the batches term.')
        log.info(f'Converting {time}, {interval} to {epochs}')
        return epochs
    else:
        raise ValueError('interval must be one of (batch, epoch)')


@dataclass
class SchedulerHparams(hp.Hparams, ABC):

    scheduler_object = None  # type: Optional[Callable[..., Scheduler]]
    interval = 'epochs'  # type: str

    def convert_time_fields(self, steps_per_epoch: Optional[int] = None) -> None:
        """Convert time fields into integers.

        Converts all fields that were provided as timestrings (e.g. "32ep11ba") into
        integers, representing either epochs or batches, depending on the
        scheduler's interval attribute.

        Examples:
            >>> hp = StepLRHparams(step_size='32ep77ba', interval='batch')
            >>> hp.convert_time_fields(steps_per_epoch=100)
            >>> hp.step_size
            3277
            >>> hp = StepLRHparams(step_size='32ep77ba', interval='epoch')
            >>> hp.convert_time_fields(steps_per_epoch=100)
            >>> hp.step_size
            32
            >>> hp = StepLRHparams(step_size=5, interval='epoch')
            >>> hp.convert_time_fields()  # steps_per_epoch not needed
            >>> hp.step_size
            5
            >>> hp = MultiStepLRHParams(milestones=['50ep', '8050ba'], interval='batch')
            >>> hp.convert_time_fields(steps_per_epoch=100)
            >>> hp.milestones
            [5000, 8050]
            >>> hp = MultiStepLRHParams(milestones=['50ep', '8050ba'], interval='epoch')
            >>> hp.convert_time_fields(steps_per_epoch=100)
            >>> hp.milestones
            [50, 80]

        Args:
            steps_per_epoch (int): used to convert between epochs <-> batches. Need not be
                                   provided if all fields are provided as integers.
        """
        assert hasattr(self, 'interval'), "Scheduler Hparams needs an interval (str) parameter."

        for field in fields(self):
            # TODO: switch Time back to Union[int, str]
            if field.name not in ('interval', 'warmup_method') and field.type == Time or field.type == List[Time]:
                time = getattr(self, field.name)
                if isinstance(time, list):
                    result = [_convert_time(t, steps_per_epoch, self.interval) for t in time]
                else:
                    result = _convert_time(time, steps_per_epoch, self.interval)

                setattr(self, field.name, result)

    def initialize_object(  # type: ignore
            self,
            optimizer: Optimizer,
            steps_per_epoch: Optional[int] = None,
    ) -> Tuple[Scheduler, str]:
        """Create the scheduler object from the current hparams.

        Args:
            optimizer (Optimizer): the optimizer associated with this scheduler
            steps_per_epoch (Optional[int], optional): number of steps per epoch. Default: ``None``.

        Returns:
            (Scheduler, str): (The parametrized scheduler instance, schedule step interval)
        """

        assert self.scheduler_object is not None, "Scheduler Hparams needs scheduler_object to initialize."
        assert hasattr(self, 'interval'), "Scheduler Hparams needs an interval (str) parameter."
        self.convert_time_fields(steps_per_epoch)

        # we pass the interval to the trainer directly
        kwargs = {k: v for k, v in asdict(self).items() if k not in ['interval']}
        obj = self.scheduler_object(optimizer, **kwargs)
        obj.interval = self.interval  # type: ignore
        obj.steps_per_epoch = steps_per_epoch  # type: ignore
        return obj, self.interval


[docs]class ConstantLR(_LRScheduler): """Scheduler that does not change the optimizer's learning rate. Args: optimizer (Optimizer): the optimizer associated with this scheduler. last_epoch (int, optional): The index of the last epoch. Can be used to restore the state of the learning rate schedule. Default: ``-1``. verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``. """ def __init__(self, optimizer: Optimizer, last_epoch: int = -1, verbose: int = False): self.optimizer = optimizer super(ConstantLR, self).__init__(optimizer, last_epoch, verbose) # type: ignore def get_lr(self): """ Get the current learning rate for each parameter group. Returns: List of float: The current learning rate for each parameter group. """ return self.base_lrs # type: ignore def _get_closed_form_lr(self): """ Get the current learning rate for each parameter group. Returns: List of float: The current learning rate for each parameter group. """ return [base_lr for base_lr in self.base_lrs] # type: ignore
@dataclass class ConstantLRHparams(SchedulerHparams): """Hyperparameters for the :class:`ConstantLR` scheduler.""" verbose: bool = hp.optional(default=False, doc='prints message to stdout') interval: str = hp.optional(default='epoch', doc=_interval_doc) scheduler_object = ConstantLR @dataclass class StepLRHparams(SchedulerHparams): """Hyperparameters for the `StepLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html#torch.optim.lr_scheduler.StepLR>`_ scheduler. """ step_size: Time = hp.required(doc='Period of learning rate decay') gamma: float = hp.optional(default=0.1, doc='multiplicative factor of decay') verbose: bool = hp.optional(default=False, doc='prints message to stdout') interval: str = hp.optional(default='epoch', doc=_interval_doc) scheduler_object = torch.optim.lr_scheduler.StepLR @dataclass class MultiStepLRHparams(SchedulerHparams): """Hyperparameters for the `MultiStepLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html#torch.optim.lr_scheduler.MultiStepLR>`_ scheduler. """ milestones: List[Time] = hp.required(doc='List of epoch indices') gamma: float = hp.optional(default=0.1, doc='multiplicative factor of decay') verbose: bool = hp.optional(default=False, doc='prints message to stdout') interval: str = hp.optional(default='epoch', doc=_interval_doc) scheduler_object = torch.optim.lr_scheduler.MultiStepLR @dataclass class ExponentialLRHparams(SchedulerHparams): """Hyperparameters for the `ExponentialLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html#torch.optim.lr_scheduler.ExponentialLR>`_ scheduler. """ gamma: float = hp.required(doc='multiplicative factor of decay') verbose: bool = hp.optional(default=False, doc='prints message to stdout') interval: str = hp.optional(default='epoch', doc=_interval_doc) scheduler_object = torch.optim.lr_scheduler.ExponentialLR @dataclass class CosineAnnealingLRHparams(SchedulerHparams): """Hyperparameters for the `CosineAnnealingLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR>`_ scheduler. """ T_max: Time = hp.required(doc="Maximum number of iterations.") eta_min: float = hp.optional(default=0.0, doc='minimum learning rate.') verbose: bool = hp.optional(default=False, doc='prints message to stdout') interval: str = hp.optional(default='epoch', doc=_interval_doc) scheduler_object = torch.optim.lr_scheduler.CosineAnnealingLR def initialize_object(self, optimizer: Optimizer, steps_per_epoch: Optional[int] = None): self.convert_time_fields(steps_per_epoch) return super().initialize_object(optimizer, steps_per_epoch) @dataclass class CosineAnnealingWarmRestartsHparams(SchedulerHparams): """Hyperparameters for the ``CosineAnnealingWarmRestarts` <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html#torch.optim.lr_scheduler.CosineAnnealingWarmRestarts>`_ scheduler. """ T_0: Time = hp.required("Number of iterations for the first restart.") eta_min: float = hp.optional(default=0.0, doc='minimum learning rate.') verbose: bool = hp.optional(default=False, doc='prints message to stdout') interval: str = hp.optional(default='epoch', doc=_interval_doc) T_mult: int = hp.optional("A factor increases :math:`T_{i}` after a restart. Default: 1.", default=1) scheduler_object = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts def initialize_object(self, optimizer: Optimizer, steps_per_epoch: Optional[int] = None): self.convert_time_fields(steps_per_epoch) return super().initialize_object(optimizer, steps_per_epoch) @dataclass class WarmUpLRHparams(SchedulerHparams): """Hyperparameters for the :class:`~composer.optim.pytorch_future.WarmUpLR` scheduler. See the documentation for :class:`~composer.optim.pytorch_future.WarmUpLR`. """ warmup_factor: float = hp.optional("Number to multiply learning rate at start.", default=1.0 / 3) warmup_iters: Time = hp.optional("Number of warmup step. Default: 5 iterations.", default="5ba") warmup_method: str = hp.optional("Warmup method (linear or constant)", default='linear') verbose: bool = hp.optional('Prints message to stdout', default=False) interval: str = hp.optional('Warmup the LR every step or epoch. Default: epoch', default='epoch') scheduler_object = WarmUpLR def ensure_warmup_last(schedulers: List[SchedulerHparams]) -> List[SchedulerHparams]: """Ensure that WarmUp-based schedulers appear last in the provided list. Args: schedulers (List[SchedulerHparams]): List of schedulers. Returns: List[SchedulerHparams]: A sorted list of schedulers with WarmUp-based schedulers at the end. """ return sorted(schedulers, key=lambda x: isinstance(x, (WarmUpLR, WarmUpLRHparams))) def get_num_warmup_batches(scheduler_hparams: Sequence[SchedulerHparams], steps_per_epoch: Optional[int] = None) -> int: """Gets the number of warmup steps declared by a list of schedulers. Args: scheduler_hparams (Sequence[SchedulerHparams]): List of schedulers steps_per_epoch (Optional[int], optional): Number of steps in a single epoch. Default: ``None``. Returns: int: Number of warmup steps """ warmup_scheduler_hparams = [scheduler for scheduler in scheduler_hparams if isinstance(scheduler, WarmUpLRHparams)] if len(warmup_scheduler_hparams): warmup_iters = warmup_scheduler_hparams[0].warmup_iters if isinstance(warmup_iters, str): return _convert_time( time=warmup_iters, steps_per_epoch=steps_per_epoch, interval=warmup_scheduler_hparams[0].interval, ) else: return warmup_iters return 0
[docs]class ComposedScheduler(_LRScheduler): """Handles warmup for a chained list of schedulers. With one call, will run each scheduler's ``step()``. If :class:`WarmUpLR` is in the list, will delay the stepping of schedulers that need to be silent during warmup. ``ComposedScheduler`` handles warmups, where as `ChainedScheduler <https://pytorch.org/docs/1.10./generated/torch.optim.lr_scheduler.ChainedScheduler.html?highlight=chained#torch.optim.lr_scheduler.ChainedScheduler>`_ only combines schedulers. `CosineAnnealingLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR>`_ and `ExponentialLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html#torch.optim.lr_scheduler.ExponentialLR>`_ are not stepped during the warmup period. Other schedulers, such as `MultiStepLR <https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html#torch.optim.lr_scheduler.MultiStepLR>`_ are still stepped, to keep their milestones unchanged. Handles running the :class:`WarmUpLR` at every step if :attr:`WarmUpLR.interval='batch'`, and other schedulers at every epoch. Args: schedulers (list): List of chained schedulers. Example: >>> # Assuming optimizer uses lr = 1. for all groups >>> # lr = 0.1 if epoch == 0 >>> # lr = 0.1 if epoch == 1 >>> # lr = 0.9 if epoch == 2 # ExponentialLR effect starts here >>> # lr = 0.81 if epoch == 3 >>> # lr = 0.729 if epoch == 4 >>> scheduler1 = WarmUpLR(self.opt, warmup_factor=0.1, warmup_iters=2, warmup_method="constant") >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) >>> scheduler = ComposedScheduler(zip([scheduler1, scheduler2], ["epoch", "epoch"])) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() >>> # Assuming optimizer uses lr = 1. for all groups >>> # lr = 0.1 if epoch == 0 >>> # lr = 0.1 if epoch == 1 >>> # lr = 1.0 if epoch == 2 >>> # lr = 1.0 if epoch == 3 >>> # lr = 0.2 if epoch == 4 . # MultiStepLR effect starts here >>> scheduler1 = WarmUpLR(self.opt, warmup_factor=0.1, warmup_iters=2, warmup_method="constant") >>> scheduler2 = MultiStepLR(optimizer, milestones=[4], gamma=0.2) >>> scheduler = ComposedScheduler(zip([scheduler1, scheduler2], ["epoch", "epoch"])) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__(self, schedulers): # check for tuple if not all(isinstance(scheduler, tuple) for scheduler in schedulers): raise ValueError('Schedulers must be a tuple of (Scheduler, interval), ' 'where interval is one of "epoch" or "batch".') self._validate_same_optimizers(schedulers) self.schedulers, self.intervals = list(zip(*schedulers)) # unpack (scheduler, interval) # generous with spelling (batch, batches)/(step, steps) and (epoch, epochs) self.intervals = [INTERVAL_MAP[interval] for interval in self.intervals] warmup = [(scheduler, interval) for scheduler, interval in zip(self.schedulers, self.intervals) if isinstance(scheduler, WarmUpLR)] if warmup: assert len(warmup) == 1, "ComposedScheduler only supports one WarmUpLR " \ f"in the provided list, found {len(warmup)}." warmup, interval = warmup[0] self.warmup_iters = warmup.warmup_iters log.info(f'Setting LR Warmup to {self.warmup_iters} {interval}') else: self.warmup_iters = 0 # these schedulers need to be silent during warmup self.delay_schedulers = [CosineAnnealingLR, ExponentialLR] self._warmup_counter = 0 # counter to track warmups def step(self, interval: str = 'epoch'): """Step all applicable schedulers. Args: interval (str, optional): The interval of the current step. Must be either ``'step'`` or ``'epoch'``. Default: ``epoch``. """ for scheduler, scheduler_interval in zip(self.schedulers, self.intervals): if self._warmup_counter < self.warmup_iters and \ any(isinstance(scheduler, delay) for delay in self.delay_schedulers): continue if interval == scheduler_interval: scheduler.step() if isinstance(scheduler, WarmUpLR): self._warmup_counter += 1 def _validate_schedulers(self, warmup_epochs: int) -> None: """Verify that any stepwise schedulers do not change the LR during the desired warmup period. Args: warmup_epochs (int): Number of epochs for warmup. """ # since WarmUpLR is non-chainable form, step LR milestones must # occur after warmup is completed lr_step_schedulers = [ scheduler for scheduler in self.schedulers if isinstance(scheduler, (StepLR, MultiStepLR)) ] for scheduler in lr_step_schedulers: if isinstance(scheduler, StepLR) and scheduler.step_size <= warmup_epochs: # type: ignore raise ValueError(f'StepLR step_size {scheduler.step_size} must ' # type: ignore 'be greater than warmup_iters {self.warmup_iters}') elif isinstance(scheduler, MultiStepLR): if any(ms <= warmup_epochs for ms in scheduler.milestones.elements()): #type: ignore raise ValueError(f'MultiStepLR milestones must be greater than warmup_iters {warmup_epochs}') def state_dict(self) -> Dict[str, Any]: """Returns a dictionary containing the state of all composed schedulers. Returns: Dict: the state dictionary """ state_dict = { "schedulers": {scheduler.__class__.__qualname__: scheduler.state_dict() for scheduler in self.schedulers}, "_warmup_counter": self._warmup_counter, } return state_dict def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """Load the state of all composed schedulers from the provided dictionary. Args: state_dict (Dict[str, Any]): A dict containing the state of all composed schedulers. Should be an object returned from a call to :meth:`state_dict()`. """ for scheduler in self.schedulers: scheduler.load_state_dict(state_dict["schedulers"][scheduler.__class__.__qualname__]) self._warmup_counter = state_dict["_warmup_counter"] def _validate_same_optimizers(self, schedulers): """Verify that all schedulers correspond to the same optimizer.""" for scheduler_idx in range(1, len(schedulers)): if (schedulers[scheduler_idx][0].optimizer != schedulers[0][0].optimizer): # type: ignore raise ValueError("ComposedScheduler expects all schedulers to belong to the same optimizer, but " "got schedulers at index {} and {} to be different".format(0, scheduler_idx))