Source code for streaming.base.format.xsv.writer

# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

""":class:`XSVWriter` writes samples to `.xsv` files that can be read by :class:`XSVReader`."""

import json
from typing import Any, Optional, Union

import numpy as np

from streaming.base.format.base.writer import SplitWriter
from streaming.base.format.xsv.encodings import is_xsv_encoding, xsv_encode

__all__ = ['XSVWriter', 'CSVWriter', 'TSVWriter']


[docs]class XSVWriter(SplitWriter): r"""Writes a streaming XSV dataset. Args: columns (Dict[str, str]): Sample columns. separator (str): String used to separate columns. newline (str): Newline character inserted between samples. Defaults to ``\\n``. out (str | Tuple[str, str]): Output dataset directory to save shard files. 1. If ``out`` is a local directory, shard files are saved locally. 2. If ``out`` is a remote directory, a local temporary directory is created to cache the shard files and then the shard files are uploaded to a remote location. At the end, the temp directory is deleted once shards are uploaded. 3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in the `local_dir` and also uploaded to a remote location. keep_local (bool): If the dataset is uploaded, whether to keep the local dataset directory or remove it after uploading. Defaults to ``False``. compression (str, optional): Optional compression or compression:level. Defaults to ``None``. hashes (List[str], optional): Optional list of hash algorithms to apply to shard files. Defaults to ``None``. size_limit (Union[int, str], optional): Optional shard size limit, after which point to start a new shard. If None, puts everything in one shard. Can specify bytes human-readable format as well, for example ``"100kb"`` for 100 kilobyte (100*1024) and so on. Defaults to ``1 << 26`` **kwargs (Any): Additional settings for the Writer. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. exist_ok (bool): If the local directory exists and is not empty, whether to overwrite the content or raise an error. `False` raises an error. `True` deletes the content and starts fresh. Defaults to `False`. """ format = 'xsv' def __init__(self, *, columns: dict[str, str], separator: str, newline: str = '\n', out: Union[str, tuple[str, str]], keep_local: bool = False, compression: Optional[str] = None, hashes: Optional[list[str]] = None, size_limit: Optional[Union[int, str]] = 1 << 26, **kwargs: Any) -> None: super().__init__(out=out, keep_local=keep_local, compression=compression, hashes=hashes, size_limit=size_limit, **kwargs) self.columns = columns self.column_names = [] self.column_encodings = [] for name in sorted(columns): encoding = columns[name] assert newline not in name assert separator not in name assert is_xsv_encoding(encoding) self.column_names.append(name) self.column_encodings.append(encoding) self.separator = separator self.newline = newline
[docs] def encode_sample(self, sample: dict[str, Any]) -> bytes: """Encode a sample dict to bytes. Args: sample (Dict[str, Any]): Sample dict. Returns: bytes: Sample encoded as bytes. """ values = [] for name, encoding in zip(self.column_names, self.column_encodings): value = xsv_encode(encoding, sample[name]) assert self.newline not in value assert self.separator not in value values.append(value) text = self.separator.join(values) + self.newline return text.encode('utf-8')
[docs] def get_config(self) -> dict[str, Any]: """Get object describing shard-writing configuration. Returns: Dict[str, Any]: JSON object. """ obj = super().get_config() obj.update({ 'column_names': self.column_names, 'column_encodings': self.column_encodings, 'separator': self.separator, 'newline': self.newline }) return obj
[docs] def encode_split_shard(self) -> tuple[bytes, bytes]: """Encode a split shard out of the cached samples (data, meta files). Returns: Tuple[bytes, bytes]: Data file, meta file. """ header = self.separator.join(self.column_names) + self.newline header = header.encode('utf-8') data = b''.join([header] + self.new_samples) header_offset = len(header) num_samples = np.uint32(len(self.new_samples)) sizes = list(map(len, self.new_samples)) offsets = header_offset + np.array([0] + sizes).cumsum().astype(np.uint32) obj = self.get_config() text = json.dumps(obj, sort_keys=True) meta = num_samples.tobytes() + offsets.tobytes() + text.encode('utf-8') return data, meta
[docs]class CSVWriter(XSVWriter): r"""Writes a streaming CSV dataset. Args: columns (Dict[str, str]): Sample columns. newline (str): Newline character inserted between samples. Defaults to ``\\n``. out (str | Tuple[str, str]): Output dataset directory to save shard files. 1. If ``out`` is a local directory, shard files are saved locally. 2. If ``out`` is a remote directory, a local temporary directory is created to cache the shard files and then the shard files are uploaded to a remote location. At the end, the temp directory is deleted once shards are uploaded. 3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in the `local_dir` and also uploaded to a remote location. keep_local (bool): If the dataset is uploaded, whether to keep the local dataset directory or remove it after uploading. Defaults to ``False``. compression (str, optional): Optional compression or compression:level. Defaults to ``None``. hashes (List[str], optional): Optional list of hash algorithms to apply to shard files. Defaults to ``None``. size_limit (int, optional): Optional shard size limit, after which point to start a new shard. If None, puts everything in one shard. Defaults to ``None``. **kwargs (Any): Additional settings for the Writer. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. """ format = 'csv' separator = ',' def __init__(self, *, columns: dict[str, str], newline: str = '\n', out: Union[str, tuple[str, str]], keep_local: bool = False, compression: Optional[str] = None, hashes: Optional[list[str]] = None, size_limit: Optional[int] = 1 << 26, **kwargs: Any) -> None: super().__init__(columns=columns, separator=self.separator, newline=newline, out=out, keep_local=keep_local, compression=compression, hashes=hashes, size_limit=size_limit, **kwargs)
[docs] def get_config(self) -> dict[str, Any]: """Get object describing shard-writing configuration. Returns: Dict[str, Any]: JSON object. """ obj = super().get_config() obj['format'] = self.format del obj['separator'] return obj
[docs]class TSVWriter(XSVWriter): r"""Writes a streaming TSV dataset. Args: columns (Dict[str, str]): Sample columns. newline (str): Newline character inserted between samples. Defaults to ``\\n``. out (str | Tuple[str, str]): Output dataset directory to save shard files. 1. If ``out`` is a local directory, shard files are saved locally. 2. If ``out`` is a remote directory, a local temporary directory is created to cache the shard files and then the shard files are uploaded to a remote location. At the end, the temp directory is deleted once shards are uploaded. 3. If ``out`` is a tuple of ``(local_dir, remote_dir)``, shard files are saved in the `local_dir` and also uploaded to a remote location. keep_local (bool): If the dataset is uploaded, whether to keep the local dataset directory or remove it after uploading. Defaults to ``False``. compression (str, optional): Optional compression or compression:level. Defaults to ``None``. hashes (List[str], optional): Optional list of hash algorithms to apply to shard files. Defaults to ``None``. size_limit (int, optional): Optional shard size limit, after which point to start a new shard. If None, puts everything in one shard. Defaults to ``None``. **kwargs (Any): Additional settings for the Writer. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. """ format = 'tsv' separator = '\t' def __init__(self, *, columns: dict[str, str], newline: str = '\n', out: Union[str, tuple[str, str]], keep_local: bool = False, compression: Optional[str] = None, hashes: Optional[list[str]] = None, size_limit: Optional[int] = 1 << 26, **kwargs: Any) -> None: super().__init__(columns=columns, separator=self.separator, newline=newline, out=out, keep_local=keep_local, compression=compression, hashes=hashes, size_limit=size_limit, **kwargs)
[docs] def get_config(self) -> dict[str, Any]: """Get object describing shard-writing configuration. Returns: Dict[str, Any]: JSON object. """ obj = super().get_config() obj['format'] = self.format del obj['separator'] return obj