Source code for composer.core.evaluator

# Copyright 2021 MosaicML. All Rights Reserved.

"""A wrapper for a dataloader to include metrics that apply to a specific dataset."""

from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Union

from torchmetrics import Metric, MetricCollection

from composer.core.data_spec import DataSpec as DataSpec

if TYPE_CHECKING:
    from composer.core.types import DataLoader

__all__ = ["Evaluator"]


[docs]class Evaluator: """A wrapper for a dataloader to include metrics that apply to a specific dataset. For example, :class:`~.nlp_metrics.CrossEntropyLoss` metric for NLP models. .. doctest:: >>> from torchmetrics.classification.accuracy import Accuracy >>> eval_evaluator = Evaluator(label="myEvaluator", dataloader=eval_dataloader, metrics=Accuracy()) >>> trainer = Trainer( ... model=model, ... train_dataloader=train_dataloader, ... eval_dataloader=eval_evaluator, ... optimizers=optimizer, ... max_duration="1ep", ... ) .. testcleanup:: trainer.engine.close() Args: label (str): Name of the Evaluator dataloader (Union[DataSpec, DataLoader]): DataLoader/DataSpec for evaluation data metrics (Metric | MetricCollection): :class:`torchmetrics.Metric` to log. ``metrics`` will be deep-copied to ensure that each evaluator updates only its ``metrics``. """ def __init__(self, *, label: str, dataloader: Union[DataSpec, DataLoader], metrics: Union[Metric, MetricCollection]): self.label = label if isinstance(dataloader, DataSpec): self.dataloader = dataloader else: self.dataloader = DataSpec(dataloader) # Forcing metrics to be a MetricCollection simplifies logging results metrics = copy.deepcopy(metrics) if isinstance(metrics, Metric): self.metrics = MetricCollection([metrics]) else: self.metrics = metrics