Source code for streaming.base.shuffle
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""Shuffle epochs of samples from different shards across worker partitions."""
import numpy as np
from numpy.typing import NDArray
from streaming.base.shuffle.naive import get_shuffle_naive
from streaming.base.shuffle.py1b import get_shuffle_py1b
from streaming.base.shuffle.py1br import get_shuffle_py1br
from streaming.base.shuffle.py1e import get_shuffle_py1e
from streaming.base.shuffle.py1s import get_shuffle_py1s
from streaming.base.shuffle.py2s import get_shuffle_py2s
algos = {
'py1b': get_shuffle_py1b,
'py1br': get_shuffle_py1br,
'py1e': get_shuffle_py1e,
'py1s': get_shuffle_py1s,
'py2s': get_shuffle_py2s,
'naive': get_shuffle_naive,
}
[docs]def get_shuffle(algo: str,
shard_sizes: NDArray[np.int64],
num_canonical_nodes: int,
seed: int,
epoch: int,
block_size: int = 1 << 18) -> NDArray[np.int64]:
"""Get the shuffled global ordering of samples for an epoch.
The assignment of shards to nodes is fixed across epochs, but each grouping of shards is
processed concurrently in a different order by each node's workers each epoch.
Args:
algo (str): Which shuffling algorithm to use.
shard_sizes (NDArray[np.int64]): Number of samples contained in each shard, in order.
num_canonical_nodes (int): Number of canonical nodes.
seed (int): Base random seed, which is held constant over an entire training run.
epoch (int): Current epoch, which is added to the seed to get a different deterministic
shuffle each epoch.
block_size (int): Unit of shuffle. Defaults to ``1 << 18``.
Returns:
NDArray[np.int64]: 1:1 mapping of sample ID to shuffled sample ID.
"""
get = algos[algo]
return get(shard_sizes, num_canonical_nodes, seed, epoch, block_size)