streaming.base.partition.get_partitions_relaxed(num_samples, num_canonical_nodes, num_physical_nodes, ranks_per_node, workers_per_rank, batch_size=None, drop_first=0, initial_physical_nodes=None)[source]#

Partition the given number of samples to nodes, ranks, and workers.

Either canonical or physical nodes must be evenly divisible by the other when partitioning over the initial number of physical nodes. For partitions during resumption, the only constraint is that the global batch size, which remains constant during training, must be evenly divisible by the total number of devices, which is num_physical_nodes * ranks_per_node.

It is suggested to set num_canonical_nodes higher than your expected number of physical nodes, because scaling your number of nodes below that level may result in more shards being used across node boundaries due to preserving the same global sample order.

  • num_samples (int) – Dataset size.

  • num_canonical_nodes (int) – Number of canonical nodes.

  • num_physical_nodes (int) – Number of physical nodes.

  • ranks_per_node (int) – Number of ranks per node.

  • workers_per_rank (int) – Number of worker partitions per rank.

  • batch_size (int, optional) – Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to None.

  • drop_first (int) – Number of samples seen already, which are dropped. Defaults to 0.

  • initial_physical_nodes (int, optional) – Number of physical nodes at the start of training. Defaults to None.


NDArray[np.int64] – Partitions of shape (physical nodes, ranks per node, workers per rank, batches per worker, batch size).