๐ฟ DataLoaders#
DataLoaders are used to pass in training or evaluation data to the
Composer Trainer. There are three different ways of doing so:
- Passing PyTorch - torch.utils.data.DataLoaderobjects directly.
- Providing a - DataSpec, which contains a PyTorch dataloader as well as additional configurations, such as on-device transforms.
- (For validation) Providing - Evaluatorobjects which contain both a dataloader and relevant metrics for validation.
We walk through each of these ways in detail and provide usage examples below.
Passing a PyTorch DataLoader#
Composer dataloaders have type torch.utils.data.DataLoader
(see PyTorch documentation) and can be passed directly to the
Trainer.
from torch.utils.data import DataLoader
from composer import Trainer
train_dataloader = DataLoader(
    training_data,
    batch_size=2048,
    shuffle=True
)
trainer = Trainer(..., train_dataloader=train_dataloader, ...)
Note
The batch_size to the dataloader should be the per-device overall
batch size. For example, if you were using grad_accum=2 a batch_size
of 2048 would mean that each microbatch (one forward/backward pass) would
have a batch size of 1024.
For performance, we highly recommend:
- num_workers > 0: usually set this to the number of CPU cores in your machine divided by the number of GPUs.
- pin_memory = true: Pinned memory can speed up copying memory from a CPU to a GPU. Try to use it everywhere possible because the only drawback is the reduced RAM available to the host.
- persistent_workers = true: Persisting workers will reduce the overhead of creating workers but will use some RAM since these workers have some persistent state.
Note
Samplers are used to specify the order of indices in dataloading. When
using distributed training, it is important to use the Torch
DistributedSampler.
so that each process sees a unique shard of the dataset. If the dataset
is already sharded, then use a SequentialSampler
or RandomSampler.
DataSpec#
Sometimes, the data configuration requires more than just the dataloader. Some example additional configurations include:
- Some transforms should be run on the data after it has been moved onto the correct device (e.g. - GPU).
- Custom batch types would need a - split_batchfunction that tells our trainer how to split the batches into microbatches for gradient accumulation.
- Optionally tracking the number of tokens (or samples) seen during training so far. 
- Providing the length of a dataset when - lenor a similar function isnโt in the dataloaderโs interface.
For these and other potential uses cases, the trainer can also accept the
DataSpec object with these additional settings. For example,
from composer import Trainer
from composer.core import DataSpec
data_spec = DataSpec(
    dataloader=my_train_dataloader,
    num_tokens=193820,
    get_num_tokens_in_batch=lambda batch: batch['text'].shape[0]
)
trainer = Trainer(train_dataloader=data_spec, ...)
Examples of how DataSpec is used for popular datasets can be seen in
our ImageNet and ADE20k files. For reference, the DataSpec arguments
are shown below.
- class composer.core.DataSpec(dataloader, num_samples=None, num_tokens=None, device_transforms=None, split_batch=None, get_num_samples_in_batch=None, get_num_tokens_in_batch=None)[source]
- Specifications for operating and training on data. - An example of constructing a - DataSpecobject with a- device_transformscallable (- NormalizationFn) and then using it with- Trainer:- >>> # In this case, we apply NormalizationFn >>> # Construct DataSpec as shown below to apply this transformation >>> from composer.datasets.utils import NormalizationFn >>> CHANNEL_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255) >>> CHANNEL_STD = (0.229 * 255, 0.224 * 255, 0.225 * 255) >>> device_transform_fn = NormalizationFn(mean=CHANNEL_MEAN, std=CHANNEL_STD) >>> train_dspec = DataSpec(train_dataloader, device_transforms=device_transform_fn) >>> # The same function can be used for eval dataloader as well >>> eval_dspec = DataSpec(eval_dataloader, device_transforms=device_transform_fn) >>> # Use this DataSpec object to construct trainer >>> trainer = Trainer( ... model=model, ... train_dataloader=train_dspec, ... eval_dataloader=eval_dspec, ... optimizers=optimizer, ... max_duration="1ep", ... ) - Parameters
- dataloader (Union[Iterable, DataLoader]) โ The dataloader, which can be any iterable that yields batches. 
- num_samples (int, optional) โ The total number of samples in an epoch, across all ranks. This field is used by the - Timestamp(training progress tracker). If not specified, then- len(dataloader.dataset)is used (if this property is available). Otherwise, the dataset is assumed to be unsized.
- num_tokens (int, optional) โ The total number of tokens in an epoch. This field is used by the - Timestamp(training progress tracker).
- device_transforms ((Batch) -> Batch, optional) โ Function called by the - Trainerto modify the batch once it has been moved onto the device. For example, this function can be used for GPU-based normalization. It can modify the batch in-place, and it should return the modified batch. If not specified, the batch is not modified.
- split_batch ((Batch, int) -> Sequence[Batch], optional) โ Function called by the - Trainerto split a batch (the first parameter) into microbatches of a given size (the second parameter). If the- dataloaderyields batches not of type- torch.Tensor, Mapping, Tuple, or List, then this function must be specified.
- get_num_samples_in_batch ((Batch) -> int, optional) โ - Function that is called by the - Trainerto get the number of samples in the provided batch.- By default, if the batch contains tensors that all have the same 0th dim, then the value of the 0th dim will be returned. If the batch contains tensors where the 0th dim differ, then this function must be specified. 
- get_num_tokens_in_batch ((Batch) -> int, optional) โ - Function that is called by the - Trainerto get the number of tokens in the provided batch.- By default, it returns 0, meaning that number of tokens processed will not be tracked as a part of the training progress tracking. This function must be specified to track the number of tokens processed during training. 
 
 
Validation#
To track training progress, validation datasets can be provided to the
Composer Trainer through the eval_dataloader parameter. If there are
multiple datasets to use for validation/evaluation, each
with their own metrics, Evaluator objects can be used to
pass in multiple dataloaders/datasets to the trainer.
For more information, see Evaluation.
Batch Types#
For custom batch types (not torch.Tensor, List, Tuple, Mapping), implement and provide
the split_batch function to the trainer using DataSpec above. Hereโs an
example function or when the batch from the dataloader is a tuple of two tensors:
def split_batch(self, batch: Batch, num_microbatches: int) -> List[Batch]:
    x, y = batch
    if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
        return list(zip(x.chunk(num_microbatches), y.chunk(num_microbatches)))
Suppose instead the batch had one input image and several target images,
e.g. (Tensor, (Tensor, Tensor, Tensor)). Then the function would be:
def split_batch(self, batch: Batch, num_microbatches: int) -> List[Batch]:
    n = num_microbatches
    x, (y1, y2) = batch
    chunked = (x.chunk(n), (y1.chunk(n), y2.chunk(n)))
    return list(zip(*chunked))