# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Threshold stopping callback."""
from typing import Any, Callable, Optional, Union
import torch
from composer.core import Callback, State
from composer.loggers import Logger
[docs]class ThresholdStopper(Callback):
"""Halt training when a metric value reaches a certain threshold.
Example:
.. doctest::
>>> from composer import Evaluator, Trainer
>>> from composer.callbacks.threshold_stopper import ThresholdStopper
>>> # constructing trainer object with this callback
>>> threshold_stopper = ThresholdStopper('MulticlassAccuracy', 'my_evaluator', 0.7)
>>> evaluator = Evaluator(
... dataloader = eval_dataloader,
... label = 'my_evaluator',
... metric_names = ['MulticlassAccuracy']
... )
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=evaluator,
... optimizers=optimizer,
... max_duration="1ep",
... callbacks=[threshold_stopper],
... )
Args:
monitor (str): The name of the metric to monitor.
dataloader_label (str): The label of the dataloader or evaluator associated with the tracked metric. If
monitor is in an Evaluator, the dataloader_label field should be set to the Evaluator's label. If
monitor is a training metric or an ordinary evaluation metric not in an Evaluator, dataloader_label
should be set to 'train' or 'eval' respectively. If dataloader_label is set to 'train', then the
callback will stop training in the middle of the epoch.
threshold (float): The threshold that dictates when to halt training. Whether training stops if the metric
exceeds or falls below the threshold depends on the comparison operator.
comp (Callable[[Any, Any], Any], optional): A comparison operator to measure change of the monitored metric. The comparison
operator will be called ``comp(current_value, prev_best)``. For metrics where the optimal value is low
(error, loss, perplexity), use a less than operator and for metrics like accuracy where the optimal value
is higher, use a greater than operator. Defaults to :func:`torch.less` if loss, error, or perplexity are substrings
of the monitored metric, otherwise defaults to :func:`torch.greater`
stop_on_batch (bool, optional): A bool that indicates whether to stop training in the middle of an epoch if
the training metrics satisfy the threshold comparison. Defaults to False.
"""
def __init__(
self,
monitor: str,
dataloader_label: str,
threshold: float,
*,
comp: Optional[Union[str, Callable[[Any, Any], Any]]] = None,
stop_on_batch: bool = False,
):
self.monitor = monitor
self.threshold = threshold
self.dataloader_label = dataloader_label
self.stop_on_batch = stop_on_batch
if callable(comp):
self.comp_func = comp
if isinstance(comp, str):
if comp.lower() in ('greater', 'gt'):
self.comp_func = torch.greater
elif comp.lower() in ('less', 'lt'):
self.comp_func = torch.less
else:
raise ValueError(
"Unrecognized comp string. Use the strings 'gt', 'greater', 'lt' or 'less' or a callable comparison operator",
)
if comp is None:
if any(substr in monitor.lower() for substr in ['loss', 'error', 'perplexity']):
self.comp_func = torch.less
else:
self.comp_func = torch.greater
def _get_monitored_metric(self, state: State):
if self.dataloader_label == 'train' and state.train_metrics is not None:
if self.monitor in state.train_metrics:
return state.train_metrics[self.monitor].compute()
else:
if self.monitor in state.eval_metrics[self.dataloader_label]:
return state.eval_metrics[self.dataloader_label][self.monitor].compute()
raise ValueError(
f"Couldn't find the metric {self.monitor} with the dataloader label {self.dataloader_label}."
"Check that the dataloader_label is set to 'eval', 'train' or the evaluator name.",
)
def _compare_metric_and_update_state(self, state: State):
metric_val = self._get_monitored_metric(state)
if not torch.is_tensor(metric_val):
metric_val = torch.tensor(metric_val)
if self.comp_func(metric_val, self.threshold):
state.stop_training()
def eval_end(self, state: State, logger: Logger) -> None:
if self.dataloader_label == state.dataloader_label:
# if the monitored metric is an eval metric or in an evaluator
self._compare_metric_and_update_state(state)
def epoch_end(self, state: State, logger: Logger) -> None:
if self.dataloader_label == state.dataloader_label:
# if the monitored metric is not an eval metric, the right logic is run on EPOCH_END
self._compare_metric_and_update_state(state)
def batch_end(self, state: State, logger: Logger) -> None:
if self.stop_on_batch and self.dataloader_label == state.dataloader_label:
self._compare_metric_and_update_state(state)