Source code for composer.callbacks.early_stopper

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Early stopping callback."""

from __future__ import annotations

import logging
from typing import Any, Callable, Optional, Union

import torch

from composer.core import Callback, State, Time, TimeUnit
from composer.loggers import Logger

log = logging.getLogger(__name__)

__all__ = ['EarlyStopper']


[docs]class EarlyStopper(Callback): """Track a metric and halt training if it does not improve within a given interval. Example: .. doctest:: >>> from composer import Evaluator, Trainer >>> from composer.callbacks.early_stopper import EarlyStopper >>> # constructing trainer object with this callback >>> early_stopper = EarlyStopper('MulticlassAccuracy', 'my_evaluator', patience=1) >>> evaluator = Evaluator( ... dataloader = eval_dataloader, ... label = 'my_evaluator', ... metric_names = ['MulticlassAccuracy'] ... ) >>> trainer = Trainer( ... model=model, ... train_dataloader=train_dataloader, ... eval_dataloader=evaluator, ... optimizers=optimizer, ... max_duration="1ep", ... callbacks=[early_stopper], ... ) Args: monitor (str): The name of the metric to monitor. dataloader_label (str): The label of the dataloader or evaluator associated with the tracked metric. If ``monitor`` is in an :class:`.Evaluator`, the ``dataloader_label`` field should be set to the label of the :class:`.Evaluator`. If monitor is a training metric or an ordinary evaluation metric not in an :class:`.Evaluator`, the ``dataloader_label`` should be set to the dataloader label, which defaults to ``'train'`` or ``'eval'``, respectively. comp (str | (Any, Any) -> Any, optional): A comparison operator to measure change of the monitored metric. The comparison operator will be called ``comp(current_value, prev_best)``. For metrics where the optimal value is low (error, loss, perplexity), use a less than operator, and for metrics like accuracy where the optimal value is higher, use a greater than operator. Defaults to :func:`torch.less` if loss, error, or perplexity are substrings of the monitored metric, otherwise defaults to :func:`torch.greater`. min_delta (float, optional): An optional float that requires a new value to exceed the best value by at least that amount. Default: ``0.0``. patience (Time | int | str, optional): The interval of time the monitored metric can not improve without stopping training. Default: 1 epoch. If patience is an integer, it is interpreted as the number of epochs. """ def __init__( self, monitor: str, dataloader_label: str, comp: Optional[Union[str, Callable[[Any, Any], Any]]] = None, min_delta: float = 0.0, patience: Union[int, str, Time] = 1, ): self.monitor = monitor self.dataloader_label = dataloader_label self.min_delta = abs(min_delta) if callable(comp): self.comp_func = comp if isinstance(comp, str): if comp.lower() in ('greater', 'gt'): self.comp_func = torch.greater elif comp.lower() in ('less', 'lt'): self.comp_func = torch.less else: raise ValueError( "Unrecognized comp string. Use the strings 'gt', 'greater', 'lt' or 'less' or a callable comparison operator", ) if comp is None: if any(substr in monitor.lower() for substr in ['loss', 'error', 'perplexity']): self.comp_func = torch.less else: self.comp_func = torch.greater self.best = None self.best_occurred = None if isinstance(patience, str): self.patience = Time.from_timestring(patience) elif isinstance(patience, int): self.patience = Time(patience, TimeUnit.EPOCH) else: self.patience = patience if self.patience.unit not in (TimeUnit.EPOCH, TimeUnit.BATCH): raise ValueError('If `patience` is an instance of Time, it must have units of EPOCH or BATCH.') def _get_monitored_metric(self, state: State): if self.dataloader_label == 'train' and state.train_metrics is not None: if self.monitor in state.train_metrics: return state.train_metrics[self.monitor].compute() else: if self.monitor in state.eval_metrics[self.dataloader_label]: return state.eval_metrics[self.dataloader_label][self.monitor].compute() raise ValueError( f"Couldn't find the metric {self.monitor} with the dataloader label {self.dataloader_label}." "Check that the dataloader_label is set to 'eval', 'train' or the evaluator name.", ) def _update_stopper_state(self, state: State): metric_val = self._get_monitored_metric(state) if not torch.is_tensor(metric_val): metric_val = torch.tensor(metric_val) if self.best is None: self.best = metric_val self.best_occurred = state.timestamp elif self.comp_func(metric_val, self.best) and torch.abs(metric_val - self.best) > self.min_delta: self.best = metric_val self.best_occurred = state.timestamp assert self.best_occurred is not None if self.patience.unit == TimeUnit.EPOCH: if state.timestamp.epoch - self.best_occurred.epoch > self.patience: state.stop_training() elif self.patience.unit == TimeUnit.BATCH: if state.timestamp.batch - self.best_occurred.batch > self.patience: state.stop_training() else: raise ValueError(f'The units of `patience` should be EPOCH or BATCH.') def eval_end(self, state: State, logger: Logger) -> None: if self.dataloader_label == state.dataloader_label: # if the monitored metric is an eval metric or in an evaluator self._update_stopper_state(state) def epoch_end(self, state: State, logger: Logger) -> None: if self.dataloader_label == state.dataloader_label: # if the monitored metric is not an eval metric, the right logic is run on EPOCH_END self._update_stopper_state(state) def batch_end(self, state: State, logger: Logger) -> None: if self.patience.unit == TimeUnit.BATCH and self.dataloader_label == state.dataloader_label: self._update_stopper_state(state)