Source code for composer.optim.optimizer_hparams

# Copyright 2021 MosaicML. All Rights Reserved.

"""Hyperparameters for optimizers."""

from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import Dict, Iterable, List, Type, Union

import torch
import torch_optimizer
import yahp as hp
from torch.optim import Optimizer

from composer.optim import DecoupledAdamW, DecoupledSGDW

# Optimizer parameters and defaults match those in torch.optim

__all__ = [
    "OptimizerHparams", "AdamHparams", "RAdamHparams", "AdamWHparams", "DecoupledAdamWHparams", "SGDHparams",
    "DecoupledSGDWHparams", "RMSpropHparams"
]


[docs]@dataclass class OptimizerHparams(hp.Hparams, ABC): """Base class for optimizer hyperparameter classes. Optimizer parameters that are added to :class:`~composer.trainer.trainer_hparams.TrainerHparams` (e.g. via YAML or the CLI) are initialized in the training loop. """ @property @abstractmethod def optimizer_object(cls) -> Type[Optimizer]: pass
[docs] def initialize_object(self, param_group: Union[Iterable[torch.Tensor], Iterable[Dict[str, torch.Tensor]]]) -> Optimizer: """Initializes the optimizer. Args: param_group (Iterable[torch.Tensor] | Iterable[Dict[str, torch.Tensor]]): Parameters for this optimizer to optimize. """ assert issubclass(self.optimizer_object, torch.optim.Optimizer) return self.optimizer_object(param_group, **asdict(self))
[docs]@dataclass class AdamHparams(OptimizerHparams): """Hyperparameters for the :class:`~torch.optim.Adam` optimizer. See :class:`~torch.optim.Adam` for documentation. Args: lr (float, optional): See :class:`~torch.optim.Adam`. betas (float, optional): See :class:`~torch.optim.Adam`. eps (float, optional): See :class:`~torch.optim.Adam`. weight_decay (float, optional): See :class:`~torch.optim.Adam`. amsgrad (bool, optional): See :class:`~torch.optim.Adam`. """ lr: float = hp.optional(default=0.001, doc="learning rate") betas: List[float] = hp.optional(default_factory=lambda: [0.9, 0.999], doc="coefficients used for computing running averages of gradient and its square.") eps: float = hp.optional(default=1e-8, doc="term for numerical stability") weight_decay: float = hp.optional(default=0.0, doc="weight decay (L2 penalty)") amsgrad: bool = hp.optional(default=False, doc="use AMSGrad variant") @property def optimizer_object(cls) -> Type[torch.optim.Adam]: return torch.optim.Adam
[docs]@dataclass class RAdamHparams(OptimizerHparams): """Hyperparameters for the :class:`~torch.optim.RAdam` optimizer. See :class:`~torch.optim.RAdam` for documentation. Args: lr (float, optional): See :class:`~torch.optim.RAdam`. betas (float, optional): See :class:`~torch.optim.RAdam`. eps (float, optional): See :class:`~torch.optim.RAdam`. weight_decay (float, optional): See :class:`~torch.optim.RAdam`. """ lr: float = hp.optional(default=0.001, doc="learning rate") betas: List[float] = hp.optional(default_factory=lambda: [0.9, 0.999], doc="coefficients used for computing running averages of gradient and its square.") eps: float = hp.optional(default=1e-8, doc="term for numerical stability") weight_decay: float = hp.optional(default=0.0, doc="weight decay (L2 penalty)") @property def optimizer_object(cls) -> Type[torch_optimizer.RAdam]: return torch_optimizer.RAdam
[docs]@dataclass class AdamWHparams(OptimizerHparams): """Hyperparameters for the :class:`~torch.optim.AdamW` optimizer. See :class:`~torch.optim.AdamW` for documentation. Args: lr (float, optional): See :class:`~torch.optim.AdamW`. betas (float, optional): See :class:`~torch.optim.AdamW`. eps (float, optional): See :class:`~torch.optim.AdamW`. weight_decay (float, optional): See :class:`~torch.optim.AdamW`. amsgrad (bool, optional): See :class:`~torch.optim.AdamW`. """ lr: float = hp.optional(default=0.001, doc="learning rate") betas: List[float] = hp.optional(default_factory=lambda: [0.9, 0.999], doc="coefficients used for computing running averages of gradient and its square.") eps: float = hp.optional(default=1e-8, doc="term for numerical stability") weight_decay: float = hp.optional(default=1e-2, doc="weight decay (L2 penalty)") amsgrad: bool = hp.optional(default=False, doc="use AMSGrad variant") @property def optimizer_object(cls) -> Type[torch.optim.AdamW]: return torch.optim.AdamW
[docs]@dataclass class DecoupledAdamWHparams(OptimizerHparams): """Hyperparameters for the :class:`~.DecoupledAdamW` optimizer. See :class:`~.DecoupledAdamW` for documentation. Args: lr (float, optional): See :class:`~.DecoupledAdamW`. betas (float, optional): See :class:`~.DecoupledAdamW`. eps (float, optional): See :class:`~.DecoupledAdamW`. weight_decay (float, optional): See :class:`~.DecoupledAdamW`. amsgrad (bool, optional): See :class:`~.DecoupledAdamW`. """ lr: float = hp.optional(default=0.001, doc="learning rate") betas: List[float] = hp.optional(default_factory=lambda: [0.9, 0.999], doc="coefficients used for computing running averages of gradient and its square.") eps: float = hp.optional(default=1e-8, doc="term for numerical stability") weight_decay: float = hp.optional(default=1e-2, doc="weight decay (L2 penalty)") amsgrad: bool = hp.optional(default=False, doc="use AMSGrad variant") @property def optimizer_object(cls) -> Type[DecoupledAdamW]: return DecoupledAdamW
[docs]@dataclass class SGDHparams(OptimizerHparams): """Hyperparameters for the :class:`~torch.optim.SGD` optimizer. See :class:`~torch.optim.SGD` for documentation. Args: lr (float): See :class:`~torch.optim.SGD`. momentum (float, optional): See :class:`~torch.optim.SGD`. weight_decay (float, optional): See :class:`~torch.optim.SGD`. dampening (float, optional): See :class:`~torch.optim.SGD`. nesterov (bool, optional): See :class:`~torch.optim.SGD`. """ lr: float = hp.required(doc="learning rate") momentum: float = hp.optional(default=0.0, doc="momentum factor") weight_decay: float = hp.optional(default=0.0, doc="weight decay (L2 penalty)") dampening: float = hp.optional(default=0.0, doc="dampening for momentum") nesterov: bool = hp.optional(default=False, doc="Nesterov momentum") @property def optimizer_object(cls) -> Type[torch.optim.SGD]: return torch.optim.SGD
[docs]@dataclass class DecoupledSGDWHparams(OptimizerHparams): """Hyperparameters for the :class:`~.DecoupledSGDW` optimizer. See :class:`~.DecoupledSGDW` for documentation. Args: lr (float): See :class:`~.DecoupledSGDW`. momentum (float, optional): See :class:`~.DecoupledSGDW`. weight_decay (float, optional): See :class:`~.DecoupledSGDW`. dampening (float, optional): See :class:`~.DecoupledSGDW`. nesterov (bool, optional): See :class:`~.DecoupledSGDW`. """ lr: float = hp.required(doc="learning rate") momentum: float = hp.optional(default=0.0, doc="momentum factor") weight_decay: float = hp.optional(default=0.0, doc="weight decay (L2 penalty)") dampening: float = hp.optional(default=0.0, doc="dampening for momentum") nesterov: bool = hp.optional(default=False, doc="Nesterov momentum") @property def optimizer_object(cls) -> Type[DecoupledSGDW]: return DecoupledSGDW
[docs]@dataclass class RMSpropHparams(OptimizerHparams): """Hyperparameters for the :class:`~torch.optim.RMSprop` optimizer. See :class:`~torch.optim.RMSprop` for documentation. Args: lr (float): See :class:`~torch.optim.RMSprop`. alpha (float, optional): See :class:`~torch.optim.RMSprop`. eps (float, optional): See :class:`~torch.optim.RMSprop`. momentum (float, optional): See :class:`~torch.optim.RMSprop`. weight_decay (float, optional): See :class:`~torch.optim.RMSprop`. centered (bool, optional): See :class:`~torch.optim.RMSprop`. """ lr: float = hp.required(doc="learning rate") alpha: float = hp.optional(default=0.99, doc="smoothing constant") eps: float = hp.optional(default=1e-8, doc="term for numerical stability") momentum: float = hp.optional(default=0.0, doc="momentum factor") weight_decay: float = hp.optional(default=0.0, doc="weight decay (L2 penalty)") centered: bool = hp.optional( default=False, doc="normalize gradient by an estimation of variance", ) @property def optimizer_object(cls) -> Type[torch.optim.RMSprop]: return torch.optim.RMSprop