Source code for composer.datasets.dataloader

# Copyright 2021 MosaicML. All Rights Reserved.

"""Common settings across both the training and eval datasets.

These settings are dataset independent.
"""

from __future__ import annotations

import logging
import textwrap
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Optional

import torch
import torch.distributed
import torch.utils.data
import yahp as hp

from composer.core.types import Batch, DataLoader, Dataset

log = logging.getLogger(__name__)

__all__ = ["WrappedDataLoader", "DataLoaderHparams", "unwrap_data_loader"]


[docs]class WrappedDataLoader(DataLoader): """A wrapper around dataloader. Args: dataloader (DataLoader): A wrapped or unwrapped dataloader. """ def __init__(self, dataloader: DataLoader) -> None: if self.is_dataloader_already_wrapped(dataloader): log.debug( textwrap.dedent("""\ The dataloader is already wrapped with %s; it will be wrapped again. If this is unintended behavior, guard the wrapping of the dataloader i.e. with: if not %s.is_dataloader_already_wrapped(dataloader): dataloader = %s(dataloader)"""), type(self).__name__, type(self).__name__, type(self).__name__, ) self.dataset = dataloader.dataset self.batch_size = dataloader.batch_size self.num_workers = dataloader.num_workers self.pin_memory = dataloader.pin_memory self.drop_last = dataloader.drop_last self.timeout = dataloader.timeout self.sampler = dataloader.sampler self.prefetch_factor = dataloader.prefetch_factor self.dataloader = dataloader def __len__(self) -> int: """Get the length of the wrapped dataloader. Returns: int: Length of wrapped dataloader. """ return len(self.dataloader) def __iter__(self) -> Iterator[Batch]: """Get an iterator over the wrapped dataloader. Returns: Iterator: Iterator over wrapped dataloader. """ return iter(self.dataloader) def __bool__(self) -> bool: """Convert to bool. Returns: bool: True. """ return True def __setattr__(self, name: str, value: Any) -> None: """Set attribute, if it is not a reserved keyword. Args: name (str): The attribute name. value (Any): The attribute value. """ if hasattr(self, name) and name in ("dataset", "batch_size", "num_workers", "pin_memory", "drop_last", "timeout", "sampler", "prefetch_factor", "dataloader"): raise RuntimeError(f"Property {name} cannot be set after initialization in a DataLoader") return super().__setattr__(name, value)
[docs] @classmethod def is_dataloader_already_wrapped(cls, dataloader: DataLoader): """Returns whether the ``dataloader`` is wrapped with ``cls``. This helper method checks recursively through all wrappings until the underlying dataloader is reached. Args: dataloader (DataLoader): The dataloader to check Returns: bool: Whether the ``dataloader`` is wrapped recursively with ``cls``. """ if isinstance(dataloader, cls): return True if not isinstance(dataloader, WrappedDataLoader): return False if not isinstance(dataloader.dataloader, WrappedDataLoader): return False return cls.is_dataloader_already_wrapped(dataloader.dataloader)
[docs]def unwrap_data_loader(dataloader: DataLoader) -> DataLoader: """Recursively unwraps a dataloader if it is of type :class:`WrappedDataLoader`. Args: dataloader (DataLoader): The dataloader to unwrap Returns: DataLoader: The underlying dataloader """ if isinstance(dataloader, WrappedDataLoader): return unwrap_data_loader(dataloader.dataloader) return dataloader
[docs]@dataclass class DataLoaderHparams(hp.Hparams): """Hyperparameters to initialize a :class:`torch.utils.data.DataLoader`. Args: num_workers (int, optional): Number of CPU workers to use per device to fetch data. Set to ``0`` to use the main training thread for dataloading. While zero workers can be useful for debugging, it should not be used for performance reasons. Default: ``8``. prefetch_factor (int, optional): Number of samples loaded in advance by each worker. For example, 2 means there will be a total of 2 * num_workers samples prefetched across all workers. If ``num_workers = 0``, then the ``prefetch_factor`` must be left at the default value. Default: ``2``. persistent_workers (bool): Whether to reuse dataloader workers across epochs. If ``num_workers`` is 0, then this field must be ``False``. Default: ``True``. pin_memory (bool, optional): Whether or not to copy Tensors into CUDA pinned memory before returning them. If ``num_workers = 0``, then the ``pin_memory`` must be ``False``. Default: ``True``. timeout (float): Timeout, in seconds, for collecting a batch from workers. Set to ``0`` for no timeout. Default: ``0``. """ num_workers: int = hp.optional(textwrap.dedent("""\ Number of CPU workers to use per device to fetch data. Set to ``0`` to use the main training thread for dataloading. While zero workers can be useful for debugging, it should not be used for performance reasons."""), default=8) prefetch_factor: int = hp.optional(textwrap.dedent("""\ Number of samples loaded in advance by each worker. For example, 2 means there will be a total of 2 * num_workers samples prefetched across all workers. If ``num_workers = 0``, then the ``prefetch_factor`` must be left at the default value."""), default=2) persistent_workers: bool = hp.optional(textwrap.dedent("""\ Whether to reuse dataloader workers across epochs. If ``num_workers`` is 0, then this field must be ``False``"""), default=True) pin_memory: bool = hp.optional(textwrap.dedent("""\ Whether or not to copy Tensors into CUDA pinned memory before returning them. If ``num_workers = 0``, then the ``pin_memory`` must be ``False``."""), default=True) timeout: float = hp.optional( "Timeout, in seconds, for collecting a batch from workers. Set to ``0`` for no timeout.", default=0)
[docs] def initialize_object( self, dataset: Dataset, *, batch_size: int, sampler: Optional[torch.utils.data.Sampler[int]], drop_last: bool, collate_fn: Optional[Callable] = None, worker_init_fn: Optional[Callable] = None, ) -> DataLoader: """Create a dataloader. Args: dataset (Dataset): The dataset. batch_size (int): The per-device batch size. sampler (torch.utils.data.Sampler[int] or None): The sampler to use for the dataloader. drop_last (bool): Whether to drop the last batch if the number of samples is not evenly divisible by the batch size. collate_fn (callable, optional): Custom collate function. Default: ``None``. worker_init_fn (callable, optional): Custom worker init function. Default: ``None``. Returns: DataLoader: The dataloader. """ return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=drop_last, sampler=sampler, collate_fn=collate_fn, worker_init_fn=worker_init_fn, timeout=self.timeout, prefetch_factor=self.prefetch_factor, persistent_workers=self.persistent_workers)