StreamingDataset#

class streaming.StreamingDataset(*, streams=None, remote=None, local=None, split=None, download_retry=2, download_timeout=60, validate_hash=None, keep_zip=False, epoch_size=None, predownload=None, cache_limit=None, sampling_method='balanced', sampling_granularity=1, partition_algo='relaxed', num_canonical_nodes=None, batch_size=None, shuffle=False, shuffle_algo='py1e', shuffle_seed=9176, shuffle_block_size=None, batching_method='random', allow_unsafe_types=False, replication=None)[source]#

A mid-epoch-resumable streaming/caching pytorch IterableDataset.

Features elastically deterministic shuffling, which enables fast mid-epoch resumption.

Checkpoints are represented in JSON as follows:

{
    "epoch" :"int",
    "sample_in_epoch": "int",
    "shuffle_seed": "int",
    "num_canonical_nodes": "int"
}

StreamingDataset init takes two kinds of arguments:

  • What to iterate:

    • One or more streams (you must provide either streams or remote/local):

      • streams

      • remote

      • local

    • Knobs to control streaming behavior, which, if multiple streams are provided, become defaults applied to each of them:

      • split

      • download_retry

      • download_timeout

      • validate_hash

      • keep_zip

    • Absolute dataset size, if streams were weighted relatively:

      • epoch_size

  • How to iterate:

    • Shard lifecycle:

      • predownload

      • cache_limit

    • Sampling:

      • sampling_method

      • sampling_granularity

    • Determinism:

      • partition_algo

      • num_canonical_nodes

      • batch_size

    • Shuffling:

      • shuffle

      • shuffle_algo

      • shuffle_seed

      • shuffle_block_size

    • Batching:

      • batching_method

Parameters
  • streams (Sequence[Stream], optional) – One or more streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either streams or remote/local. Defaults to None.

  • remote (str, optional) – Remote path or directory to download the dataset from. If None, its data must exist locally. StreamingDataset uses either streams or remote/local. Defaults to None.

  • local (str, optional) – Local working directory to download shards to. This is where shards are cached while they are being used. Uses a temp directory if not set. StreamingDataset uses either streams or remote/local. Defaults to None.

  • split (str, optional) – Which dataset split to use, if any. If provided, we stream from/to the split subdirs of remote and local. Defaults to None.

  • download_retry (int) – Number of download re-attempts before giving up. Defaults to 2.

  • download_timeout (float) – Number of seconds to wait for a shard to download before raising an exception. Defaults to 60.

  • validate_hash (str, optional) – Optional hash or checksum algorithm to use to validate shards. Defaults to None.

  • keep_zip (bool) – Whether to keep or delete the compressed form when decompressing downloaded shards. If False, keep iff remote is local or no remote. Defaults to False.

  • epoch_size (Union[int, str], optional) – Number of samples to draw per epoch balanced across all streams. If None, takes its value from the total number of underlying samples. Provide this field if you are weighting streams relatively to target a larger or smaller epoch size. Defaults to None. Can also take in human-readable number abbreviations (e.g., "100k", "64M", "77b", etc). Defaults to None.

  • predownload (int, optional) – Target number of samples to download per worker in advance of current sample. Workers will attempt to download ahead by this many samples during, but not before, training. Recommendation is to provide a value greater than per device batch size to ensure at-least per device batch size number of samples cached locally. If None, its value is set to 8 * batch_size. Defaults to None.

  • cache_limit (Union[int, str], optional) – Maximum size in bytes of this StreamingDataset’s shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to None to disable shard eviction. Supports integer bytes as well as string human-readable bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None.

  • sampling_method (str) – Which sampling method to use, either balanced or fixed. Defaults to balanced.

  • sampling_granularity (int) – When picking samples for a stream’s final partial repeat, how many samples to pick from the same shard at a time (1 for evenly balanced across shards, 1000 to pick 1000 samples from the same shard at a time, etc). Defaults to 1.

  • partition_algo (str) – Which partitioning algorithm to use. Defaults to relaxed.

  • num_canonical_nodes (int, optional) –

    Canonical number of nodes for shuffling with resumption. The sample space is divided evenly according to the number of canonical nodes. The higher the value, the more independent non-overlapping paths the StreamingDataset replicas take through the shards per model replica (increasing data source diversity). If None, this is interpreted as 64 times the number of physical nodes of the initial run if shuffle_algo is py1s or py2s, and simply the number of physical nodes of the initial run otherwise. Defaults to None.

    Note

    For sequential sample ordering, set shuffle to False and num_canonical_nodes to the number of physical nodes of the initial run.

  • batch_size (int, optional) – Per-device batch size, the same as what is passed to the DataLoader. This affects how the dataset is partitioned over the workers and is necessary for deterministic resumption and optimal performance. Defaults to None.

  • shuffle (bool) – Whether to iterate over the samples in randomized order. Defaults to False.

  • shuffle_algo (str) – Which shuffling algorithm to use. Defaults to py1e.

  • shuffle_seed (int) – Seed for deterministic data shuffling. Defaults to 9176.

  • shuffle_block_size (int, optional) – Unit of shuffle. A canonical node’s samples are split into blocks of this size, and samples within each block are shuffled. If None, its value is calculated as max(4_000_000 // num_canonical_nodes), 1 << 18). Defaults to None.

  • batching_method (str) – Which batching method to use, either random, stratified, per_stream, or device_per_stream. Defaults to random.

  • allow_unsafe_types (bool) – If a shard contains Pickle, which allows arbitrary code execution during deserialization, whether to keep going if True or raise an error if False. Defaults to False.

  • replication (int, optional) – Determines how many consecutive devices will receive the same samples. Useful for training with tensor or sequence parallelism, where multiple devices need to see the same partition of the dataset. Defaults to None.

property cache_usage#

Get the cache usage.

Returns

int – Cache usage in bytes.

evict_coldest_shard()[source]#

Evict the coldest (i.e., least recently accessed) shard.

This method is multithread/multiprocess-safe.

evict_shard(shard_id)[source]#

Evict the given shard.

This method is multithread/multiprocess-safe.

Parameters

shard_id (int) – Shard to evict.

get_item(sample_id, retry=7)[source]#

Get sample by global index, blocking to download its shard if not present.

Parameters
  • sample_id (int) – Sample index.

  • retry (int) – Maximum number of times to download its shard before giving up. In the edge case of a shard being evicted before sample access, you will have to re-download it. Defaults to 7.

Returns

Dict[str, Any] – Mapping of column name to column data.

load_state_dict(obj)[source]#

Load a dict containing training state (called from non-worker process).

This is called on each copy of the dataset when resuming.

We just save the state to shared memory for workers to pick up when __iter__ is next called. We use shm because changes to this copy of the dataset wouldn’t be picked up by persistent workers.

Parameters

obj (Dict[str, Any]) – The state.

property next_epoch#

Get the next epoch.

Returns

int – Next epoch.

on_exception(future)[source]#

Raise an exception to the caller if an exception was generated by a thread.

Also, set the thread event to let the other threads know about the exception.

Parameters

future (Future) – The status of the task.

Raises

Exception – re-raises the exception.

prepare_shard(shard_id, blocking=True)[source]#

Download a shard, either waiting or skipping if in progress by another worker.

This method is multithread/multiprocess-safe.

If cache limit is enabled, this method may delete one or more other shards to make space for this download.

Parameters
  • shard_id (int) – Shard to download.

  • blocking (bool) – Whether to wait or skip if the shard is currently being downloaded by someone else.

resample_streams(epoch, stream_id=None)[source]#

Perform the up/down-sampling needed to generate the weighted epoch.

Parameters
  • epoch (int) – What epoch this is for. Used in seeding the sampling RNG.

  • stream_id (Optional[int]) – Which stream to resample. If None, resample all streams. Defaults to None.

Returns

Tuple[NDArray[np.int64], NDArray[np.int64]] – Sampled shard sizes and sample mapping.

property size#

Get the size of the dataset in samples.

Returns

int – Number of samples.

state_dict(num_samples, from_beginning)[source]#

Get a dict containing training state (called from non-worker process).

This is called on rank zero.

Our stock StreamingDataLoader counts samples from start of training (from_beginning=false). However, if you are always counting from the start of the epoch, set from_beginning=true.

Parameters
  • num_samples (int) – The number of samples processed so far in the current epoch.

  • from_beginning (int) – Whether we are counting samples from the start of this epoch, or the start of just this potentially resumed training run this epoch.

Returns

Dict[str, Any] – The state.