Source code for composer.datasets.evaluator

# Copyright 2021 MosaicML. All Rights Reserved.

"""Specifies an instance of an :class:`~.evaluator.Evaluator`, which wraps a dataloader to include metrics that apply to
a specific dataset."""

from __future__ import annotations

import copy
import logging
import textwrap
from dataclasses import dataclass
from typing import List, Optional

import yahp as hp
from torchmetrics import Metric, MetricCollection

from composer.core.evaluator import Evaluator
from composer.datasets import DataLoaderHparams
from composer.datasets.dataset_registry import get_dataset_registry
from composer.datasets.hparams import DatasetHparams
from composer.models.base import ComposerModel

log = logging.getLogger(__name__)

__all__ = ["EvaluatorHparams"]


[docs]@dataclass class EvaluatorHparams(hp.Hparams): """Params for the :class:`~.evaluator.Evaluator`. Also see the documentation for the :class:`~.evaluator.Evaluator`. Args: label (str): Name of the Evaluator. Used for logging/reporting metrics. eval_dataset (DatasetHparams): Evaluation dataset. metrics (list, optional): List of strings of names of the metrics for the evaluator. Can be a :class:`torchmetrics.Metric` name or the class name of a metric returned by :meth:`~.ComposerModel.metrics` If ``None``, uses all metrics in the model. Default: ``None``. """ hparams_registry = { # type: ignore "eval_dataset": get_dataset_registry(), } label: str = hp.required(doc="Name of the Evaluator object. Used for logging/reporting metrics") eval_dataset: DatasetHparams = hp.required(doc="Evaluator dataset for the Evaluator") metric_names: Optional[List[str]] = hp.optional( doc=textwrap.dedent("""Name of the metrics for the evaluator. Can be a torchmetrics metric name or the class name of a metric returned by model.metrics(). If None (the default), uses all metrics in the model"""), default=None)
[docs] def initialize_object(self, model: ComposerModel, batch_size: int, dataloader_hparams: DataLoaderHparams): """Initialize an :class:`~.evaluator.Evaluator` If the Evaluator ``metric_names`` is empty or None is provided, the function returns a copy of all the model's default evaluation metrics. Args: model (ComposerModel): The model, which is used to retrieve metric names. batch_size (int): The device batch size to use for the evaluation dataset. dataloader_hparams (DataLoaderHparams): The hparams to use to construct a dataloader for the evaluation dataset. Returns: Evaluator: The evaluator. """ dataloader = self.eval_dataset.initialize_object(batch_size=batch_size, dataloader_hparams=dataloader_hparams) # Get and copy all the model's associated evaluation metrics model_metrics = model.metrics(train=False) if isinstance(model_metrics, Metric): # Forcing metrics to be a MetricCollection simplifies logging results model_metrics = MetricCollection([model_metrics]) # Use all the metrics from the model if no metric_names are specified if self.metric_names is None: evaluator_metrics = copy.deepcopy(model_metrics) else: evaluator_metrics = MetricCollection([]) for metric_name in self.metric_names: try: metric = model_metrics[metric_name] except KeyError as e: raise RuntimeError( textwrap.dedent(f"""No metric found with the name {metric_name}. Check if this" "metric is compatible/listed in your model metrics.""")) from e assert isinstance(metric, Metric), "all values of a MetricCollection.__getitem__ should be a metric" evaluator_metrics.add_metrics(copy.deepcopy(metric)) if len(evaluator_metrics) == 0: raise RuntimeError( textwrap.dedent(f"""No metrics compatible with your model were added to this evaluator. Check that the metrics you specified are compatible/listed in your model.""")) return Evaluator( label=self.label, dataloader=dataloader, metrics=evaluator_metrics, )