get_sampler#

composer.utils.dist.get_sampler(dataset, *, drop_last=False, shuffle=False, num_replicas=None, rank=None)[source]#

Constructs a DistributedSampler for a dataset.

The DistributedSampler assumes that each rank has a complete copy of the dataset. It ensures that each rank sees a unique shard for each epoch containing len(dataset) / get_world_size() samples.

Note

If the dataset is already sharded by rank, use a SequentialSampler or RandomSampler.

Parameters
  • dataset (Dataset) โ€“ The dataset.

  • drop_last (bool) โ€“ Whether to trop the last batch.

  • shuffle (bool) โ€“ Whether to shuffle the dataset.

  • num_replicas (int, optional) โ€“ The number of replicas. If None, defaults to the world size.

  • rank (int, optional) โ€“ The rank. If None, defaults to the global rank.

Returns

torch.utils.data.distributed.DistributedSampler โ€“ The sampler.