DataLoaders are used to pass in training or evaluation data to the
Trainer. There are three different ways of doing so:
DataSpec, which contains a PyTorch dataloader as well as additional configurations, such as on-device transforms.
(For validation) Providing
Evaluatorobjects which contain both a dataloader and relevant metrics for validation.
We walk through each of these ways in detail and provide usage examples below.
Passing a PyTorch DataLoader#
from torch.utils.data import DataLoader from composer import Trainer train_dataloader = DataLoader( training_data, batch_size=2048, shuffle=True ) trainer = Trainer(..., train_dataloader=train_dataloader, ...)
batch_size to the dataloader should be the per-device overall
batch size. For example, if you were using
grad_accum=2 a batch_size
2048 would mean that each microbatch (one forward/backward pass) would
have a batch size of
For performance, we highly recommend:
num_workers > 0: usually set this to the number of CPU cores in your machine divided by the number of GPUs.
pin_memory = true: Pinned memory can speed up copying memory from a CPU to a GPU. Try to use it everywhere possible because the only drawback is the reduced RAM available to the host.
persistent_workers = true: Persisting workers will reduce the overhead of creating workers but will use some RAM since these workers have some persistent state.
Sometimes, the data configuration requires more than just the dataloader. Some example additional configurations include:
Some transforms should be run on the data after it has been moved onto the correct device (e.g.
Custom batch types would need a
split_batchfunction that tells our trainer how to split the batches into microbatches for gradient accumulation.
Optionally tracking the number of tokens (or samples) seen during training so far.
Providing the length of a dataset when
lenor a similar function isn’t in the dataloader’s interface.
For these and other potential uses cases, the trainer can also accept the
DataSpec object with these additional settings. For example,
from composer import Trainer from composer.core import DataSpec data_spec = DataSpec( dataloader=my_train_dataloader, num_tokens=193820, get_num_tokens_in_batch=lambda batch: batch['text'].shape ) trainer = Trainer(train_dataloader=data_spec, ...)
- class composer.core.DataSpec(dataloader, num_samples=None, num_tokens=None, device_transforms=None, split_batch=None, get_num_samples_in_batch=None, get_num_tokens_in_batch=None)
Specifications for operating and training on data.
>>> # In this case, we apply NormalizationFn >>> # Construct DataSpec as shown below to apply this transformation >>> from composer.datasets.utils import NormalizationFn >>> CHANNEL_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255) >>> CHANNEL_STD = (0.229 * 255, 0.224 * 255, 0.225 * 255) >>> device_transform_fn = NormalizationFn(mean=CHANNEL_MEAN, std=CHANNEL_STD) >>> train_dspec = DataSpec(train_dataloader, device_transforms=device_transform_fn) >>> # The same function can be used for eval dataloader as well >>> eval_dspec = DataSpec(eval_dataloader, device_transforms=device_transform_fn) >>> # Use this DataSpec object to construct trainer >>> trainer = Trainer( ... model=model, ... train_dataloader=train_dspec, ... eval_dataloader=eval_dspec, ... optimizers=optimizer, ... max_duration="1ep", ... )
dataloader (Iterable) – The dataloader, which can be any iterable that yields batches.
num_samples (int, optional) – The total number of samples in an epoch, across all ranks. This field is used by the
Timestamp(training progress tracker). If not specified, then
len(dataloader.dataset)is used (if this property is available). Otherwise, the dataset is assumed to be unsized.
device_transforms ((Batch) -> Batch, optional) – Function called by the
Trainerto modify the batch once it has been moved onto the device. For example, this function can be used for GPU-based normalization. It can modify the batch in-place, and it should return the modified batch. If not specified, the batch is not modified.
split_batch ((Batch, int) -> Sequence[Batch], optional) – Function called by the
Trainerto split a batch (the first parameter) into the number of microbatches specified (the second parameter). If the
dataloaderyields batches not of type
torch.Tensor, Mapping, Tuple, or List, then this function must be specified.
get_num_samples_in_batch ((Batch) -> int, optional) –
Function that is called by the
Trainerto get the number of samples in the provided batch.
By default, if the batch contains tensors that all have the same 0th dim, then the value of the 0th dim will be returned. If the batch contains tensors where the 0th dim differ, then this function must be specified.
get_num_tokens_in_batch ((Batch) -> int, optional) –
Function that is called by the
Trainerto get the number of tokens in the provided batch.
By default, it returns 0, meaning that number of tokens processed will not be tracked as a part of the training progress tracking. This function must be specified to track the number of tokens processed during training.
To track training progress, validation datasets can be provided to the
Composer Trainer through the
eval_dataloader parameter. If there are
multiple datasets to use for validation/evaluation, each
with their own metrics,
Evaluator objects can be used to
pass in multiple dataloaders/datasets to the trainer.
For more information, see Evaluation.
For custom batch types (not torch.Tensor, List, Tuple, Mapping), implement and provide
split_batch function to the trainer using
DataSpec above. Here’s an
example function or when the batch from the dataloader is a tuple of two tensors:
def split_batch(self, batch: Batch, num_microbatches: int) -> List[Batch]: x, y = batch if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): return list(zip(x.chunk(num_microbatches), y.chunk(num_microbatches)))
Suppose instead the batch had one input image and several target images,
(Tensor, (Tensor, Tensor, Tensor)). Then the function would be:
def split_batch(self, batch: Batch, num_microbatches: int) -> List[Batch]: n = num_microbatches x, (y1, y2) = batch chunked = (x.chunk(n), (y1.chunk(n), y2.chunk(n))) return list(zip(*chunked))