get_sampler#
- composer.utils.dist.get_sampler(dataset, *, drop_last=False, shuffle=False, num_replicas=None, rank=None)[source]#
Constructs a
DistributedSampler
for a dataset.The
DistributedSampler
assumes that each rank has a complete copy of the dataset. It ensures that each rank sees a unique shard for each epoch containinglen(dataset) / get_world_size()
samples.Note
If the
dataset
is already sharded by rank, use aSequentialSampler
orRandomSampler
.- Parameters
dataset (Dataset) โ The dataset.
drop_last (bool) โ Whether to trop the last batch.
shuffle (bool) โ Whether to shuffle the dataset.
num_replicas (int, optional) โ The number of replicas. If
None
, defaults to the world size.rank (int, optional) โ The rank. If
None
, defaults to the global rank.
- Returns
torch.utils.data.distributed.DistributedSampler โ The sampler.