# Copyright 2021 MosaicML. All Rights Reserved.
"""Hyperparameters for callbacks."""
from __future__ import annotations
import abc
import textwrap
from dataclasses import dataclass
from typing import Optional
import yahp as hp
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.callbacks.grad_monitor import GradMonitor
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
from composer.callbacks.speed_monitor import SpeedMonitor
from composer.core.callback import Callback
from composer.core.time import Time
from composer.utils import import_object
__all__ = [
"CallbackHparams",
"GradMonitorHparams",
"MemoryMonitorHparams",
"LRMonitorHparams",
"SpeedMonitorHparams",
"CheckpointSaverHparams",
]
[docs]@dataclass
class CallbackHparams(hp.Hparams, abc.ABC):
"""Base class for Callback hyperparameters."""
[docs] @abc.abstractmethod
def initialize_object(self) -> Callback:
"""Initialize the callback.
Returns:
Callback: An instance of the callback.
"""
pass
[docs]@dataclass
class GradMonitorHparams(CallbackHparams):
""":class:`~.GradMonitor` hyperparamters.
Args:
log_layer_grad_norms (bool, optional):
See :class:`~.GradMonitor` for documentation.
"""
log_layer_grad_norms: bool = hp.optional(
doc="Whether to log gradient norms for individual layers.",
default=False,
)
[docs] def initialize_object(self) -> GradMonitor:
"""Initialize the GradMonitor callback.
Returns:
GradMonitor: An instance of :class:`~.GradMonitor`.
"""
return GradMonitor(log_layer_grad_norms=self.log_layer_grad_norms)
[docs]@dataclass
class MemoryMonitorHparams(CallbackHparams):
""":class:`~.MemoryMonitor` hyperparameters.
There are no parameters as :class:`~.MemoryMonitor` does not take any parameters.
"""
[docs] def initialize_object(self) -> MemoryMonitor:
"""Initialize the MemoryMonitor callback.
Returns:
MemoryMonitor: An instance of :class:`~.MemoryMonitor`.
"""
return MemoryMonitor()
[docs]@dataclass
class LRMonitorHparams(CallbackHparams):
""":class:`~.LRMonitor` hyperparameters.
There are no parameters as :class:`~.LRMonitor` does not take any parameters.
"""
[docs] def initialize_object(self) -> LRMonitor:
"""Initialize the LRMonitor callback.
Returns:
LRMonitor: An instance of :class:`~.LRMonitor`.
"""
return LRMonitor()
[docs]@dataclass
class SpeedMonitorHparams(CallbackHparams):
""":class:`~.SpeedMonitor` hyperparameters.
Args:
window_size (int, optional): See :class:`~.SpeedMonitor` for documentation.
"""
window_size: int = hp.optional(
doc="Number of batchs to use for a rolling average of throughput.",
default=100,
)
[docs] def initialize_object(self) -> SpeedMonitor:
"""Initialize the SpeedMonitor callback.
Returns:
SpeedMonitor: An instance of :class:`~.SpeedMonitor`.
"""
return SpeedMonitor(window_size=self.window_size)
[docs]@dataclass
class CheckpointSaverHparams(CallbackHparams):
""":class:`~.CheckpointSaver` hyperparameters.
Args:
save_folder (str, optional): See :class:`~.CheckpointSaver`.
filename (str, optional): See :class:`~.CheckpointSaver`.
artifact_name (str, optional): See :class:`~.CheckpointSaver`.
latest_filename (str, optional): See :class:`~.CheckpointSaver`.
overwrite (str, optional): See :class:`~.CheckpointSaver`.
weights_only (bool, optional): See :class:`~.CheckpointSaver`.
num_checkpoints_to_keep (int, optional): See :class:`~.CheckpointSaver`.
save_interval (str, optional): Either a :doc:`time-string </trainer/time>` or a path to a function.
If a :doc:`time-string </trainer/time>`, checkpoints will be saved according to this interval.
If a path to a function, it should be of the format ``'path.to.function:function_name'``. The function
should take (:class:`~.State`, :class:`~.Event`) and return a
boolean indicating whether a checkpoint should be saved given the current state and event. The event will
be either :attr:`~composer.core.event.Event.BATCH_CHECKPOINT` or
:attr:`~composer.core.event.Event.EPOCH_CHECKPOINT`.
"""
save_folder: str = hp.optional(doc="Folder where checkpoints will be saved.", default="{run_name}/checkpoints")
filename: str = hp.optional("Checkpoint name format string.", default="ep{epoch}-ba{batch}-rank{rank}")
artifact_name: str = hp.optional("Checkpoint artifact name format string.",
default="{run_name}/checkpoints/ep{epoch}-ba{batch}-rank{rank}")
latest_filename: Optional[str] = hp.optional("Latest checkpoint symlink format string.",
default="latest-rank{rank}")
overwrite: bool = hp.optional("Whether to override existing checkpoints.", default=False)
weights_only: bool = hp.optional("Whether to save only checkpoint weights", default=False)
save_interval: str = hp.optional(textwrap.dedent("""\
Checkpoint interval or path to a `(State, Event) -> bool` function
returning whether a checkpoint should be saved."""),
default="1ep")
num_checkpoints_to_keep: int = hp.optional(
"Number of checkpoints to persist locally. Set to -1 to never delete checkpoints.",
default=-1,
)
def initialize_object(self) -> CheckpointSaver:
try:
save_interval = Time.from_timestring(self.save_interval)
except ValueError:
# assume it is a function path
save_interval = import_object(self.save_interval)
return CheckpointSaver(
folder=self.save_folder,
filename=self.filename,
artifact_name=self.artifact_name,
latest_filename=self.latest_filename,
overwrite=self.overwrite,
save_interval=save_interval,
weights_only=self.weights_only,
num_checkpoints_to_keep=self.num_checkpoints_to_keep,
)