Source code for streaming.base.shuffle.py1s

# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Shuffling algorithm that shuffles intra-shard in one place.

This algorithm is roughly twice as fast as algorithm ``py2s``, and ever so slightly biased.

Bias in this case merely refers to how we assign samples when we split shards at canonical node
boundaries, which is non-random in this algorithm. In practice, we found this does not matter to
convergence, while making us faster.
"""

import numpy as np
from numpy.typing import NDArray


def divide_spans(spans: list[tuple[int, int]], num_samples: int, num_parts: int) -> \
        tuple[list[tuple[int, int]], list[tuple[int, int]]]:
    """Divide the spans into discrete, equal sized partitions.

    Don't use ``spans`` after this, as it is modified in-place for performance reasons.

    Args:
        spans (List[Tuple[int, int]]): List of spans to partition.
        num_samples (int): Total number of samples across all spans.
        num_parts (int): Number of groupings to divide spans into.

    Returns:
        Tuple[List[Tuple, int, int]], List[Tuple[int, int]]]: Spans and super spans.
    """
    begin_part = 0
    span_index = 0
    samples_so_far = 0

    out_spans = []
    super_spans = []

    for part in range(num_parts):
        # note that the size of a part (canonical node) is num_samples // num_parts.
        part_end = num_samples * (part + 1) // num_parts

        # loop over spans until we've filled up our part (canonical node) completely
        while True:
            if span_index == len(spans):
                break

            # input spans are the shard spans. these can be unequally sized and may cross
            # part (canonical node) boundaries.
            span = spans[span_index]
            # spans are (begin, end excl)
            samples_this_span = span[1] - span[0]
            # check if the shard span contains more samples than the part (canonical node) can fit
            if part_end < samples_so_far + samples_this_span:
                # if there is space left in the part, split the span
                if samples_so_far < part_end:
                    split = part_end - samples_so_far
                    # create a span, filling up with as many samples as possible from shard span
                    new_span = span[0], span[0] + split
                    out_spans.append(new_span)
                    # modify the old shard span to reflect that it's been split
                    spans[span_index] = span[0] + split, span[1]
                    samples_so_far += split
                break

            out_spans.append(span)
            span_index += 1
            samples_so_far += samples_this_span

        # super spans are tell us which new spans belong to each part (canonical node)
        # as a tuple of (begin span index, end span index excl)
        super_span = begin_part, len(out_spans)
        super_spans.append(super_span)
        begin_part = len(out_spans)

    return out_spans, super_spans


[docs]def get_shuffle_py1s(shard_sizes: NDArray[np.int64], num_canonical_nodes: int, seed: int, epoch: int, block_size: int = 1 << 18) -> NDArray[np.int64]: """Get the shuffled global ordering of samples for an epoch. The assignment of shards to nodes is fixed across epochs, but each grouping of shards is processed concurrently in a different order by each node's workers each epoch. Args: shard_sizes (NDArray[np.int64]): Number of samples contained in each shard, in order. num_canonical_nodes (int): Number of canonical nodes. seed (int): Base random seed, which is held constant over an entire training run. epoch (int): Current epoch, which is added to the seed to get a different deterministic shuffle each epoch. block_size (int): Unit of shuffle (ignored, because we shuffle on the basis of shards). Defaults to ``1 << 18``. Returns: NDArray[np.int64]: 1:1 mapping of sample ID to shuffled sample ID. """ # Create each shard's sample ID span (begin, end excl). spans = [] num_samples = 0 for shard_size in shard_sizes: span = num_samples, num_samples + shard_size spans.append(span) num_samples += shard_size # Generate the initial ordering of shards, which is fixed over an entire training run. run_rng = np.random.default_rng(seed) run_rng.shuffle(spans) # Break the shard spans at canonical node boundaries. spans, super_spans = divide_spans(spans, num_samples, num_canonical_nodes) # Shuffle the span ordering within each canonical node uniquely to this epoch. epoch_rng = np.random.default_rng(seed + epoch) for begin, end in super_spans: part = spans[begin:end] epoch_rng.shuffle(part) # pyright: ignore spans[begin:end] = part # Populate the global sample ID mapping, shuffling within each span. ids = np.empty(num_samples, np.int64) offset = 0 for begin, end in spans: span_size = end - begin ids[offset:offset + span_size] = np.arange(begin, end) epoch_rng.shuffle(ids[offset:offset + span_size]) offset += span_size return ids