Source code for composer.profiler.dataloader_profiler

# Copyright 2021 MosaicML. All Rights Reserved.

"""Profiler to measure the time it takes the data loader to return a batch."""

from __future__ import annotations

from typing import TYPE_CHECKING, Iterator, Optional

from composer.core.callback import Callback
from composer.datasets.dataloader import WrappedDataLoader

if TYPE_CHECKING:
    from composer.core.state import State
    from composer.core.types import Batch, DataLoader
    from composer.loggers import Logger
    from composer.profiler import Profiler

__all__ = ["DataLoaderProfiler"]


class _ProfiledDataLoader(WrappedDataLoader):
    """Wraps a dataloader to record the duration it takes to yield a batch. This class should not be instantiated
    directly.

    Args:
        profiler (Profiler): The profiler instance.
        dataloader (DataLoader): The dataloader to profile.
        name (str): The name for the dataloader.
    """

    def __init__(self, profiler: Profiler, dataloader: DataLoader, name: str) -> None:
        super().__init__(dataloader)
        self._marker = profiler.marker(f"dataloader/{name}", categories=["dataloader"])
        self._iterator: Optional[Iterator[Batch]] = None

    def __iter__(self) -> _ProfiledDataLoader:
        self._iterator = iter(self.dataloader)
        return self

    def __next__(self) -> Batch:
        assert self._iterator is not None
        self._marker.start()
        try:
            return next(self._iterator)
        finally:
            self._marker.finish()


[docs]class DataLoaderProfiler(Callback): """Profile a DataLoader. This callback measures the latency it takes for the DataLoader to yield a batch. .. note:: The Composer :class:`~composer.trainer.trainer.Trainer` automatically creates an instance of this :class:`.DataLoaderProfiler` callback whenever the profiler is enabled. When using the Composer :class:`~composer.trainer.trainer.Trainer`, one does not need to directly create an instance of this :class:`.DataLoaderProfiler` callback. """ def fit_start(self, state: State, logger: Logger): del logger # unused if state.profiler is None: raise RuntimeError(("The Composer Profiler was not enabled, which is required to use the " f"{type(self).__name__}. To enable, set the `prof_schedule` argument of the Trainer.")) if not _ProfiledDataLoader.is_dataloader_already_wrapped(state.train_dataloader): state.train_dataloader = _ProfiledDataLoader(state.profiler, state.train_dataloader, "train") for evaluator in state.evaluators: if not _ProfiledDataLoader.is_dataloader_already_wrapped(evaluator.dataloader.dataloader): evaluator.dataloader.dataloader = _ProfiledDataLoader(state.profiler, evaluator.dataloader.dataloader, evaluator.label)