Mixing Datasets#

Training a model often requires combining data from multiple different sources. Streaming makes combining these data sources, or streams, easy and configurable. See the main concepts page for a high-level view of distributed training with multiple streams.

Using Streams#

A stream is a data source, as a collection of shard files (or set of subdirectories containing shard files). Shard files can optionally be compressed. Streams are represented by the streaming.Stream object. Similar to streaming.StreamingDataset itself, a Stream object can take in remote and local paths – see here for an example.

It is possible, though not recommended, for streams to have different schemas.

Configuring the data mix#

The proportion, repeat, or choose arguments to Stream are used to configure different dataset mixing schemes. Only one of them may be set at a time, and all streams must use the same mixing scheme (e.g., Stream A with proportion and Stream B with choose are incompatible).

  • proportion: Specifies how to sample this Stream relative to other Streams.

  • repeat: Specifies the degree to which a Stream is upsampled or downsampled.

  • choose: Specifies the number of samples to choose from a Stream.

Let’s look at some examples of dataset mixing in action.

Using proportion for relative weighting#

As an example, let’s say we have Stream A with 100 samples and Stream B with 200 samples. The epoch_size, if not set, will default to the total number of unique samples – in this case, 300. To configure our training dataset to be 25% from Stream A and 75% from Stream B, we simply set proportion to those values:

stream_A = Stream(
    remote = 's3://stream_A_remote',
    local = '/tmp/stream_A',
    proportion = 0.25,
stream_B = Stream(
    remote = 's3://stream_B_remote',
    local = '/tmp/stream_B',
    proportion = 0.75,
dataset = StreamingDataset(
    streams = [stream_A, stream_B],

Since epoch_size has not been specified, the epoch will be 300 samples long, of which 75 samples will come from Stream A, and 225 from Stream B. Equivalently, we could have also set proportion to 2 for Stream A and 6 for Stream B for the same weighting – StreamingDataset will normalize the proportion weights.

If epoch_size is explicitly set, then proportions will apply to that value instead. For example, if epoch_size was passed as 400 to StreamingDataset, as below, and proportions stayed the same, then in each epoch, 100 samples would be from Stream A and 300 would be from Stream B.

dataset = StreamingDataset(
    epoch_size = 400,
    streams = [stream_A, stream_B],    # With proportions A: 0.25 and B: 0.75.

For multi-epoch training, to control how samples are chosen between epochs, see the inter-epoch sampling section.

Using repeat for absolute weighting#

It can be useful to specify how many times to upsample or downsample a Stream – the repeat argument fulfills this use case. For example, to see every sample from Stream A 3 times per epoch, simply set repeat to 3:

stream_A = Stream(
    remote = 's3://stream_A_remote',
    local = '/tmp/stream_A',
    repeat = 3,

To downsample a stream, meaning that only a fraction of the total samples from that stream are seen every epoch, set repeat to less than 1. For example, to see only a quarter of the samples from Stream A per epoch, set repeat to 0.25.

Using choose for absolute weighting#

Specifying the absolute number of samples to choose from a Stream can also be useful when mixing datasets. Use the choose argument to indicate the number of samples to take from a stream per epoch. For example, to see exactly 250 samples from Stream A per epoch, set choose to 250:

stream_A = Stream(
    remote = 's3://stream_A_remote',
    local = '/tmp/stream_A',
    choose = 250,

Batching Methods#

Controlling how a global batch is constructed is a requirement for some training runs. StreamingDataset’s batching_method argument takes in three different options to configure the composition of each global batch:

  • 'random': (default) Global batches respect dataset mixing in expectation. Stream proportions can vary somewhat between batches.

  • 'stratified': Every global batch respects dataset mixing exactly. Can help mitigate loss spikes and divergence by making sure stream proportions hold for every batch.

  • 'per_stream': Each global batch contains samples from only one stream at a time. Particularly useful when your streams contain data of different tensor shapes/sizes, so that each batch can contain samples of the same shape/size.

  • 'device_per_stream': Each device batch contains samples from only one stream at a time. Like per_stream, this is useful when mixing data from streams with different tensor shapes/sizes. Crucially, by mixing at the device batch level, each global batch can still contain samples from multiple streams.

As an example, suppose we have Stream A (green) and Stream B (blue), each making up half of our total dataset. Applying each of the batching methods would make global batches look like this:

Batching Methods

Each bar represents a single global batch. We see that random batching can have some variance in stream composition, while stratified batching keeps composition exact, and per_stream batching constructs each batch with a single stream.