Source code for composer.callbacks.lr_monitor

# Copyright 2021 MosaicML. All Rights Reserved.

from composer.core import Callback, Logger, State
from composer.utils import ensure_tuple


[docs]class LRMonitor(Callback): """Logs the learning rate. Learning rates for each optimizer is logged on each batch under the ``lr-{OPTIMIZER_NAME}/group{GROUP_NUMBER}`` key. """ def __init__(self) -> None: super().__init__()
[docs] def batch_end(self, state: State, logger: Logger): assert state.optimizers is not None, "optimizers must be defined" for optimizer in ensure_tuple(state.optimizers): lrs = [group['lr'] for group in optimizer.param_groups] name = optimizer.__class__.__name__ for lr in lrs: for idx, lr in enumerate(lrs): logger.metric_batch({f'lr-{name}/group{idx}': lr})