# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
""":class:`MDSWriter` writes samples to ``.mds`` files that can be read by :class:`MDSReader`."""
import json
from typing import Any, Optional, Union
import numpy as np
from streaming.base.format.base.writer import JointWriter
from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings,
is_mds_encoding, mds_encode)
__all__ = ['MDSWriter']
[docs]class MDSWriter(JointWriter):
"""Writes a streaming MDS dataset.
Args:
columns (Dict[str, str]): Sample columns.
out (str | Tuple[str, str]): Output dataset directory to save shard files.
1. If ``out`` is a local directory, shard files are saved locally.
2. If ``out`` is a remote directory, a local temporary directory is created to
cache the shard files and then the shard files are uploaded to a remote
location. At the end, the temp directory is deleted once shards are uploaded.
3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in the
`local_dir` and also uploaded to a remote location.
keep_local (bool): If the dataset is uploaded, whether to keep the local dataset directory
or remove it after uploading. Defaults to ``False``.
compression (str, optional): Optional compression or compression:level. Defaults to
``None``.
hashes (List[str], optional): Optional list of hash algorithms to apply to shard files.
Defaults to ``None``.
size_limit (Union[int, str], optional): Optional shard size limit, after which point to
start a new shard. If ``None``, puts everything in one shard. Can specify bytes
human-readable format as well, for example ``"100kb"`` for 100 kilobyte
(100*1024) and so on. Defaults to ``1 << 26``.
**kwargs (Any): Additional settings for the Writer.
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
exist_ok (bool): If the local directory exists and is not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` deletes the
content and starts fresh. Defaults to `False`.
"""
format = 'mds'
extra_bytes_per_sample = 4
def __init__(self,
*,
columns: dict[str, str],
out: Union[str, tuple[str, str]],
keep_local: bool = False,
compression: Optional[str] = None,
hashes: Optional[list[str]] = None,
size_limit: Optional[Union[int, str]] = 1 << 26,
**kwargs: Any) -> None:
super().__init__(out=out,
keep_local=keep_local,
compression=compression,
hashes=hashes,
size_limit=size_limit,
extra_bytes_per_sample=self.extra_bytes_per_sample,
**kwargs)
self.columns = columns
self.column_names = []
self.column_encodings = []
self.column_sizes = []
for name in sorted(columns):
encoding = columns[name]
if not is_mds_encoding(encoding):
raise TypeError(f'MDSWriter passed column `{name}` with encoding `{encoding}` ' +
f'is unsupported. Supported encodings are {get_mds_encodings()}')
size = get_mds_encoded_size(encoding)
self.column_names.append(name)
self.column_encodings.append(encoding)
self.column_sizes.append(size)
obj = self.get_config()
text = json.dumps(obj, sort_keys=True)
self.config_data = text.encode('utf-8')
self.extra_bytes_per_shard = 4 + 4 + len(self.config_data)
self._reset_cache()
[docs] def encode_sample(self, sample: dict[str, Any]) -> bytes:
"""Encode a sample dict to bytes.
Args:
sample (Dict[str, Any]): Sample dict.
Returns:
bytes: Sample encoded as bytes.
"""
sizes = []
data = []
for key, encoding, size in zip(self.column_names, self.column_encodings,
self.column_sizes):
value = sample[key]
datum = mds_encode(encoding, value)
if size is None:
size = len(datum)
sizes.append(size)
else:
if size != len(datum):
raise KeyError(f'Unexpected data size; was this data typed with the correct ' +
f'encoding ({encoding})?')
data.append(datum)
head = np.array(sizes, np.uint32).tobytes()
body = b''.join(data)
return head + body
[docs] def get_config(self) -> dict[str, Any]:
"""Get object describing shard-writing configuration.
Returns:
Dict[str, Any]: JSON object.
"""
obj = super().get_config()
obj.update({
'column_names': self.column_names,
'column_encodings': self.column_encodings,
'column_sizes': self.column_sizes
})
return obj
[docs] def encode_joint_shard(self) -> bytes:
"""Encode a joint shard out of the cached samples (single file).
Returns:
bytes: File data.
"""
num_samples = np.uint32(len(self.new_samples))
sizes = list(map(len, self.new_samples))
offsets = np.array([0] + sizes).cumsum().astype(np.uint32)
offsets += len(num_samples.tobytes()) + len(offsets.tobytes()) + len(self.config_data)
sample_data = b''.join(self.new_samples)
return num_samples.tobytes() + offsets.tobytes() + self.config_data + sample_data