Source code for composer.datasets.hparams

# Copyright 2021 MosaicML. All Rights Reserved.

"""Dataset Hyperparameter classes."""

from __future__ import annotations

import abc
import textwrap
from dataclasses import dataclass
from typing import Optional, Union

try:
    import custom_inherit
except ImportError:
    # if custom_inherit is not installed, then the docstrings will be incomplete. That's fine.
    metaclass = abc.ABCMeta
else:
    metaclass = custom_inherit.DocInheritMeta(style="google_with_merge", abstract_base_class=True)

import yahp as hp

from composer.core import DataSpec
from composer.core.types import DataLoader, MemoryFormat
from composer.datasets.dataloader import DataLoaderHparams

__all__ = ["SyntheticHparamsMixin", "DatasetHparams"]


[docs]@dataclass class SyntheticHparamsMixin(hp.Hparams, abc.ABC): """Synthetic dataset parameter mixin for :class:`DatasetHparams`. Args: use_synthetic (bool, optional): Whether to use synthetic data. Default: ``False``. synthetic_num_unique_samples (int, optional): The number of unique samples to allocate memory for. Ignored if :attr:`use_synthetic` is ``False``. Default: ``100``. synthetic_device (str, optional): The device to store the sample pool on. Set to ``'cuda'`` to store samples on the GPU and eliminate PCI-e bandwidth with the dataloader. Set to ``'cpu'`` to move data between host memory and the device on every batch. Ignored if :attr:`use_synthetic` is ``False``. Default: ``'cpu'``. synthetic_memory_format: The :class:`~.core.types.MemoryFormat` to use. Ignored if :attr:`use_synthetic` is ``False``. Default: ``'CONTIGUOUS_FORMAT'``. """ use_synthetic: bool = hp.optional("Whether to use synthetic data. Defaults to False.", default=False) synthetic_num_unique_samples: int = hp.optional("The number of unique samples to allocate memory for.", default=100) synthetic_device: str = hp.optional("Device to store the sample pool. Should be `cuda` or `cpu`. Defauls to `cpu`.", default="cpu") synthetic_memory_format: MemoryFormat = hp.optional("Memory format. Defaults to contiguous format.", default=MemoryFormat.CONTIGUOUS_FORMAT)
[docs]@dataclass class DatasetHparams(hp.Hparams, abc.ABC, metaclass=metaclass): """Abstract base class for hyperparameters to initialize a dataset. Args: datadir (str): The path to the data directory. is_train (bool): Whether to load the training data or validation data. Default: ``True``. drop_last (bool): If the number of samples is not divisible by the batch size, whether to drop the last batch or pad the last batch with zeros. Default: ``True``. shuffle (bool): Whether to shuffle the dataset. Default: ``True``. """ is_train: bool = hp.optional("Whether to load the training data (the default) or validation data.", default=True) drop_last: bool = hp.optional(textwrap.dedent("""\ If the number of samples is not divisible by the batch size, whether to drop the last batch (the default) or pad the last batch with zeros."""), default=True) shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch. Defaults to True.", default=True) datadir: Optional[str] = hp.optional("The path to the data directory", default=None)
[docs] @abc.abstractmethod def initialize_object(self, batch_size: int, dataloader_hparams: DataLoaderHparams) -> Union[DataLoader, DataSpec]: """Creates a :class:`~.core.types.DataLoader` or :class:`~.core.data_spec.DataSpec` for this dataset. Args: batch_size (int): The size of the batch the dataloader should yield. This batch size is device-specific and already incorporates the world size. dataloader_hparams (DataLoaderHparams): The dataset-independent hparams for the dataloader. Returns: DataLoader or DataSpec: The :class:`~core.types.DataLoader`, or if the dataloader yields batches of custom types, a :class:`~core.data_spec.DataSpec`. """ pass
[docs]@dataclass class WebDatasetHparams(DatasetHparams, abc.ABC, metaclass=metaclass): """Abstract base class for hyperparameters to initialize a webdataset. Args: webdataset_cache_dir (str): WebDataset cache directory. webdataset_cache_verbose (str): WebDataset cache verbosity. """ webdataset_cache_dir: str = hp.optional('WebDataset cache directory', default='/tmp/webdataset_cache/') webdataset_cache_verbose: bool = hp.optional('WebDataset cache verbosity', default=False) shuffle_buffer: int = hp.optional('WebDataset shuffle buffer size per worker', default=256)