StreamingDataset#
- class streaming.StreamingDataset(*, streams=None, remote=None, local=None, split=None, download_retry=2, download_timeout=60.0, validate_hash=None, keep_zip=False, epoch_size=None, predownload=None, cache_limit=None, sampling_method='balanced', sampling_granularity=1, partition_algo='relaxed', num_canonical_nodes=None, batch_size=None, shuffle=False, shuffle_algo='py1e', shuffle_seed=9176, shuffle_block_size=None, batching_method='random', allow_unsafe_types=False, replication=None)[source]#
A mid-epoch-resumable streaming/caching pytorch IterableDataset.
Features elastically deterministic shuffling, which enables fast mid-epoch resumption.
Checkpoints are represented in JSON as follows:
{ "epoch" :"int", "sample_in_epoch": "int", "shuffle_seed": "int", "num_canonical_nodes": "int" }
StreamingDataset init takes two kinds of arguments:
What to iterate:
One or more streams (you must provide either
streams
orremote
/local
):streams
remote
local
Knobs to control streaming behavior, which, if multiple streams are provided, become defaults applied to each of them:
split
download_retry
download_timeout
validate_hash
keep_zip
Absolute dataset size, if streams were weighted relatively:
epoch_size
How to iterate:
Shard lifecycle:
predownload
cache_limit
Sampling:
sampling_method
sampling_granularity
Determinism:
partition_algo
num_canonical_nodes
batch_size
Shuffling:
shuffle
shuffle_algo
shuffle_seed
shuffle_block_size
Batching:
batching_method
- Parameters
streams (Sequence[Stream], optional) β One or more streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either
streams
orremote
/local
. Defaults toNone
.remote (str, optional) β Remote path or directory to download the dataset from. If
None
, its data must exist locally. StreamingDataset uses eitherstreams
orremote
/local
. Defaults toNone
.local (str, optional) β Local working directory to download shards to. This is where shards are cached while they are being used. Uses a temp directory if not set. StreamingDataset uses either
streams
orremote
/local
. Defaults toNone
.split (str, optional) β Which dataset split to use, if any. If provided, we stream from/to the
split
subdirs ofremote
andlocal
. Defaults toNone
.download_retry (int) β Number of download re-attempts before giving up. Defaults to
2
.download_timeout (float) β Number of seconds to wait for a shard to download before raising an exception. Defaults to
60
.validate_hash (str, optional) β Optional hash or checksum algorithm to use to validate shards. Defaults to
None
.keep_zip (bool) β Whether to keep or delete the compressed form when decompressing downloaded shards. If
False
, keep iff remote is local or no remote. Defaults toFalse
.epoch_size (Union[int, str], optional) β Number of samples to draw per epoch balanced across all streams. If
None
, takes its value from the total number of underlying samples. Provide this field if you are weighting streams relatively to target a larger or smaller epoch size. Defaults toNone
. Can also take in human-readable number abbreviations (e.g.,"100k"
,"64M"
,"77b"
, etc). Defaults toNone
.predownload (int, optional) β Target number of samples to download per worker in advance of current sample. Workers will attempt to download ahead by this many samples during, but not before, training. Recommendation is to provide a value greater than per device batch size to ensure at-least per device batch size number of samples cached locally. If
None
, its value is set to8 * batch_size
. Defaults toNone
.cache_limit (Union[int, str], optional) β Maximum size in bytes of this StreamingDatasetβs shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to
None
to disable shard eviction. Supports integer bytes as well as string human-readable bytes (e.g.,100b
,64kb
,77mb
, and so on). Defaults toNone
.sampling_method (str) β Which sampling method to use, either
balanced
orfixed
. Defaults tobalanced
.sampling_granularity (int) β When picking samples for a streamβs final partial repeat, how many samples to pick from the same shard at a time (
1
for evenly balanced across shards,1000
to pick 1000 samples from the same shard at a time, etc). Defaults to1
.partition_algo (str) β Which partitioning algorithm to use. Defaults to
relaxed
.num_canonical_nodes (int, optional) β
Canonical number of nodes for shuffling with resumption. The sample space is divided evenly according to the number of canonical nodes. The higher the value, the more independent non-overlapping paths the StreamingDataset replicas take through the shards per model replica (increasing data source diversity). If
None
, this is interpreted as 64 times the number of physical nodes of the initial run ifshuffle_algo
ispy1s
orpy2s
, and simply the number of physical nodes of the initial run otherwise. Defaults toNone
.Note
For sequential sample ordering, set
shuffle
toFalse
andnum_canonical_nodes
to the number of physical nodes of the initial run.batch_size (int, optional) β Per-device batch size, the same as what is passed to the DataLoader. This affects how the dataset is partitioned over the workers and is necessary for deterministic resumption and optimal performance. Defaults to
None
.shuffle (bool) β Whether to iterate over the samples in randomized order. Defaults to
False
.shuffle_algo (str) β Which shuffling algorithm to use. Defaults to
py1e
.shuffle_seed (int) β Seed for deterministic data shuffling. Defaults to
9176
.shuffle_block_size (int, optional) β Unit of shuffle. A canonical nodeβs samples are split into blocks of this size, and samples within each block are shuffled. If
None
, its value is calculated asmax(4_000_000 // num_canonical_nodes), 1 << 18)
. Defaults toNone
.batching_method (str) β Which batching method to use, either
random
,stratified
,per_stream
, ordevice_per_stream
. Defaults torandom
.allow_unsafe_types (bool) β If a shard contains Pickle, which allows arbitrary code execution during deserialization, whether to keep going if
True
or raise an error ifFalse
. Defaults toFalse
.replication (int, optional) β Determines how many consecutive devices will receive the same samples. Useful for training with tensor or sequence parallelism, where multiple devices need to see the same partition of the dataset. Defaults to
None
.
- property cache_usage#
Get the cache usage.
- Returns
int β Cache usage in bytes.
- evict_coldest_shard()[source]#
Evict the coldest (i.e., least recently accessed) shard.
This method is multithread/multiprocess-safe.
- evict_shard(shard_id)[source]#
Evict the given shard.
This method is multithread/multiprocess-safe.
- Parameters
shard_id (int) β Shard to evict.
- get_item(sample_id, retry=7)[source]#
Get sample by global index, blocking to download its shard if not present.
- Parameters
- Returns
Dict[str, Any] β Mapping of column name to column data.
- load_state_dict(obj)[source]#
Load a dict containing training state (called from non-worker process).
This is called on each copy of the dataset when resuming.
We just save the state to shared memory for workers to pick up when __iter__ is next called. We use shm because changes to this copy of the dataset wouldnβt be picked up by persistent workers.
- Parameters
obj (Dict[str, Any]) β The state.
- property next_epoch#
Get the next epoch.
- Returns
int β Next epoch.
- on_exception(future)[source]#
Raise an exception to the caller if an exception was generated by a thread.
Also, set the thread event to let the other threads know about the exception.
- Parameters
future (Future) β The status of the task.
- Raises
Exception β re-raises the exception.
- prepare_shard(shard_id, blocking=True)[source]#
Download a shard, either waiting or skipping if in progress by another worker.
This method is multithread/multiprocess-safe.
If cache limit is enabled, this method may delete one or more other shards to make space for this download.
- resample_streams(epoch, stream_id=None)[source]#
Perform the up/down-sampling needed to generate the weighted epoch.
- property size#
Get the size of the dataset in samples.
- Returns
int β Number of samples.
- state_dict(num_samples, from_beginning)[source]#
Get a dict containing training state (called from non-worker process).
This is called on rank zero.
Our stock StreamingDataLoader counts samples from start of training (from_beginning=false). However, if you are always counting from the start of the epoch, set from_beginning=true.