Elastic Determinism#

Deterministic and reproducible training across varying numbers of GPUs is essential for resizing workloads, debugging distributed training jobs, and more. Streaming is built to provide elastically deterministic training and resumption. For example, a training run on 24 GPUs can be stopped, resumed on 16 GPUs, and later, finished on 48 GPUs, all with the same loss curve and global batch size. Here’s an example of completely deterministic loss curves as the number of GPUs increases from 8 to 64:

Elastic Determinism

When combining elastic determinism with elastically sharded checkpoints, as our Composer training library does, distributed training becomes easier and much more flexible. See here for more information on Composer’s checkpointing.


For elastic determinism, Streaming merely requires that your global batch size stays constant over the course of the training job, and is also divisible by all the numbers of GPUs you wish to run on. For example, with a global batch size of 18, you can train deterministically on 1, 2, 3, 6, 9, or 18 GPUs, but not on 7, since 18 samples cannot be evenly split among GPUs.

Streaming uses the num_canonical_nodes parameter, which controls the number of buckets into which samples are partitioned, to ensure that the global sample order remains elastically deterministic. To retain determinism between runs, set num_canonical_nodes to the same value. The num_canonical_nodes defaults to the number of physical nodes of the first run.

For example, if Run 1 was trained on 32 GPUs, where each physical node had 8 GPUs, then the total number of physical nodes is 4, and num_canonical_nodes defaults to 4. If Run 2 is required to have the same loss curve as Run 1, explicitly set num_canonical_nodes to 4, and remember to set batch_size accordingly:

# Dataset for Run 1 does not specify `num_canonical_nodes`. Assuming that each physical node has 8 GPUs,
# and Run 1 is launched on 32 GPUs, `num_canonical_nodes` is set to the number of physical nodes, 4.
run_1_32_gpu_dataset = StreamingDataset(
    remote = 'oci://some_remote_path/dataset',
    local = 'tmp/local/cache',
    batch_size = 4,    # This is the per-device batch size. Global batch size is 32 gpus * 4 samples/gpu = 128 samples

# To make Run 2 have the same loss curve as Run 1, explicitly set `num_canonical_nodes` to 4.
# Assuming Run 2 is launched on 8 GPUs, the `batch_size` (per-device) must increase by a factor of 4
# so that the global batch size stays the same (128 samples).
run_2_8_gpu_dataset = StreamingDataset(
    remote = 'oci://some_remote_path/dataset',
    local = 'tmp/local/cache',
    num_canonical_nodes = 4,    # Explicitly set to the same as Run 1 for deterministic training
    batch_size = 16,    # This is the per-device batch size. Global batch size is 8 gpus * 16 samples/gpu = 128 samples

See this section for more information on how num_canonical_nodes is used.