Source code for composer.callbacks.nan_monitor
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Callback for catching loss NaNs."""
from typing import Dict, Sequence
import torch
from composer import Callback, Logger, State
__all__ = ['NaNMonitor']
[docs]class NaNMonitor(Callback):
"""Catches NaNs in the loss and raises an error if one is found."""
[docs] def after_loss(self, state: State, logger: Logger):
"""Check if loss is NaN and raise an error if so."""
if isinstance(state.loss, torch.Tensor):
if torch.isnan(state.loss).any():
raise RuntimeError('Train loss contains a NaN.')
elif isinstance(state.loss, Sequence):
for loss in state.loss:
if torch.isnan(loss).any():
raise RuntimeError('Train loss contains a NaN.')
elif isinstance(state.loss, Dict):
for k, v in state.loss.items():
if torch.isnan(v).any():
raise RuntimeError(f'Train loss {k} contains a NaN.')
else:
raise TypeError(f'Loss is of type {type(state.loss)}, but should be a tensor or a sequence of tensors')