# Copyright 2021 MosaicML. All Rights Reserved.
"""Reference for common types used throughout the composer library.
Attributes:
Batch (BatchPair | BatchDict | torch.Tensor): Union type covering the most common representations of batches.
A batch of data can be represented in several formats, depending on the application.
BatchPair (Sequence[Union[torch.Tensor, Sequence[torch.Tensor]]]): Commonly used in computer vision tasks.
The object is assumed to contain exactly two elements, where the first represents inputs
and the second represents targets.
BatchDict (Dict[str, Tensor]): Commonly used in natural language processing tasks.
PyTorchScheduler (torch.optim.lr_scheduler._LRScheduler): Alias for base class of learning rate schedulers such
as :class:`torch.optim.lr_scheduler.ConstantLR`.
JSON (str | float | int | None | List['JSON'] | Dict[str, 'JSON']): JSON Data.
Dataset (torch.utils.data.Dataset[Batch]): Alias for :class:`torch.utils.data.Dataset`.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Union
import torch
import torch.utils.data
from composer.utils.string_enum import StringEnum
try:
from typing import Protocol
except ImportError:
Protocol = object # Protocol is not available in python 3.7
if TYPE_CHECKING:
from typing import Protocol
__all__ = [
"Batch", "BatchPair", "BatchDict", "PyTorchScheduler", "JSON", "MemoryFormat", "as_batch_dict", "as_batch_pair",
"DataLoader", "BreakEpochException"
]
# For BatchPar, if it is a list, then it should always be of length 2.
# Pytorch's default collate_fn returns a list even when the dataset returns a tuple.
BatchPair = Sequence[Union[torch.Tensor, Sequence[torch.Tensor]]]
BatchDict = Dict[str, torch.Tensor]
Batch = Union[BatchPair, BatchDict, torch.Tensor]
Dataset = torch.utils.data.Dataset[Batch]
PyTorchScheduler = torch.optim.lr_scheduler._LRScheduler
JSON = Union[str, float, int, None, List['JSON'], Dict[str, 'JSON']]
[docs]def as_batch_dict(batch: Batch) -> BatchDict:
"""Casts a :class:`Batch` as a :class:`BatchDict`.
Args:
batch (Batch): A batch.
Raises:
TypeError: If the ``batch`` is not a :class:`BatchDict`.
Returns:
BatchDict: The batch, represented as a :class:`BatchDict`.
"""
if not isinstance(batch, dict):
raise TypeError(f'batch_dict requires batch of type dict, got {type(batch)}')
return batch
[docs]def as_batch_pair(batch: Batch) -> BatchPair:
"""Casts a :class:`Batch` as a :class:`BatchPair`.
Args:
batch (Batch): A batch.
Returns:
BatchPair: The batch, represented as a :class:`BatchPair`.
Raises:
TypeError: If the batch is not a :class:`BatchPair`.
"""
if not isinstance(batch, (tuple, list)):
raise TypeError(f'batch_pair required batch to be a tuple or list, got {type(batch)}')
if not len(batch) == 2:
raise TypeError(f'batch has length {len(batch)}, expected length 2')
return batch
[docs]class BreakEpochException(Exception):
"""Raising this exception will immediately end the current epoch.
If you're wondering whether you should use this, the answer is no.
"""
pass
[docs]class DataLoader(Protocol):
"""Protocol for custom DataLoaders compatible with
:class:`torch.utils.data.DataLoader`.
Attributes:
dataset (Dataset): Dataset from which to load the data.
batch_size (int, optional): How many samples per batch to load for a
single device (default: ``1``).
num_workers (int): How many subprocesses to use for data loading.
``0`` means that the data will be loaded in the main process.
pin_memory (bool): If ``True``, the data loader will copy Tensors
into CUDA pinned memory before returning them.
drop_last (bool): If ``len(dataset)`` is not evenly
divisible by :attr:`batch_size`, whether the last batch is
dropped (if True) or truncated (if False).
timeout (float): The timeout for collecting a batch from workers.
sampler (torch.utils.data.Sampler[int]): The dataloader sampler.
prefetch_factor (int): Number of samples loaded in advance by each
worker. ``2`` means there will be a total of
2 * :attr:`num_workers` samples prefetched across all workers.
"""
dataset: Dataset
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
sampler: torch.utils.data.Sampler[int]
prefetch_factor: int
def __iter__(self) -> Iterator[Batch]:
"""Iterates over the dataset.
Yields:
Iterator[Batch]: An iterator over batches.
"""
...
def __len__(self) -> int:
"""Returns the number of batches in an epoch.
Raises:
NotImplementedError: Raised if the dataset has unknown length.
Returns:
int: Number of batches in an epoch.
"""
...