ThresholdStopper#

class composer.callbacks.ThresholdStopper(monitor, dataloader_label, threshold, *, comp=None, stop_on_batch=False)[source]#

Halt training when a metric value reaches a certain threshold.

Example

>>> from composer import Evaluator, Trainer
>>> from composer.callbacks.threshold_stopper import ThresholdStopper
>>> # constructing trainer object with this callback
>>> threshold_stopper = ThresholdStopper('MulticlassAccuracy', 'my_evaluator', 0.7)
>>> evaluator = Evaluator(
...     dataloader = eval_dataloader,
...     label = 'my_evaluator',
...     metric_names = ['MulticlassAccuracy']
... )
>>> trainer = Trainer(
...     model=model,
...     train_dataloader=train_dataloader,
...     eval_dataloader=evaluator,
...     optimizers=optimizer,
...     max_duration="1ep",
...     callbacks=[threshold_stopper],
... )
Parameters
  • monitor (str) โ€“ The name of the metric to monitor.

  • dataloader_label (str) โ€“ The label of the dataloader or evaluator associated with the tracked metric. If monitor is in an Evaluator, the dataloader_label field should be set to the Evaluatorโ€™s label. If monitor is a training metric or an ordinary evaluation metric not in an Evaluator, dataloader_label should be set to โ€˜trainโ€™ or โ€˜evalโ€™ respectively. If dataloader_label is set to โ€˜trainโ€™, then the callback will stop training in the middle of the epoch.

  • threshold (float) โ€“ The threshold that dictates when to halt training. Whether training stops if the metric exceeds or falls below the threshold depends on the comparison operator.

  • comp (Callable[[Any, Any], Any], optional) โ€“ A comparison operator to measure change of the monitored metric. The comparison operator will be called comp(current_value, prev_best). For metrics where the optimal value is low (error, loss, perplexity), use a less than operator and for metrics like accuracy where the optimal value is higher, use a greater than operator. Defaults to torch.less() if loss, error, or perplexity are substrings of the monitored metric, otherwise defaults to torch.greater()

  • stop_on_batch (bool, optional) โ€“ A bool that indicates whether to stop training in the middle of an epoch if the training metrics satisfy the threshold comparison. Defaults to False.