StreamingDatasetWriter is used to convert a list of samples into binary .mds files that can be read as a StreamingDataset.



Used for writing a StreamingDataset from a list of samples.

class composer.datasets.streaming.writer.StreamingDatasetWriter(dirname, fields, shard_size_limit=16777216)[source]#

Used for writing a StreamingDataset from a list of samples.

Samples are expected to be of type: Dict[str, bytes].

Given each sample, StreamingDatasetWriter only writes out the values for a subset of keys (fields) that are globally shared across the dataset.

StreamingDatasetWriter automatically shards the dataset such that each shard is of size <= shard_size_limit bytes.

To write the dataset:
>>> from composer.datasets.streaming import StreamingDatasetWriter
>>> samples = [
...     {
...         "uid": f"{ix:06}".encode("utf-8"),
...         "data": (3 * ix).to_bytes(4, "big"),
...         "unused": "blah".encode("utf-8"),
...     }
...     for ix in range(100)
... ]
>>> dirname = "remote"
>>> fields = ["uid", "data"]
>>> with StreamingDatasetWriter(dirname=dirname, fields=fields) as writer:
...     writer.write_samples(samples=samples)

To read the dataset:
>>> from composer.datasets.streaming import StreamingDataset
>>> remote = "remote"
>>> local = "local"
>>> decoders = {
...     "uid": lambda uid_bytes: uid_bytes.decode("utf-8"),
...     "data": lambda data_bytes: int.from_bytes(data_bytes, "big"),
... }
>>> dataset = StreamingDataset(remote=remote, local=local, shuffle=False, decoders=decoders)
  • dirname (str) โ€“ Directory to write shards to.

  • fields โ€“ (List[str]): The fields to save for each sample.

  • shard_size_limit (int) โ€“ Maximum shard size in bytes. Default: 1 << 24.


Complete writing the dataset by flushing last samples to a last shard, then write an index file.


Add a sample to the dataset.


sample (Dict[str, bytes]) โ€“ The new sample, whose keys must contain the fields to save (others ignored).

write_samples(samples, use_tqdm=True, total=None)[source]#

Add the samples from the given iterable to the dataset.

  • samples (Iterable[Dict[str, bytes]]) โ€“ The new samples.

  • use_tqdm (bool) โ€“ Whether to display a progress bar. Default: True.

  • total (int, optional) โ€“ Total samples for the progress bar (for when samples is a generator).