# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""Read and decode sample from shards."""
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Iterator, Optional, Union
from streaming.base.array import Array
from streaming.base.util import bytes_to_int
__all__ = ['FileInfo', 'Reader', 'JointReader', 'SplitReader']
[docs]@dataclass
class FileInfo(object):
"""File validation info.
Args:
basename (str): File basename.
bytes (int): File size in bytes.
hashes (Dict[str, str]): Mapping of hash algorithm to hash value.
"""
basename: str
bytes: int
hashes: dict[str, str]
[docs]class Reader(Array, ABC):
"""Provides random access to the samples of a shard.
Args:
dirname (str): Local dataset directory.
split (str, optional): Which dataset split to use, if any.
compression (str, optional): Optional compression or compression:level.
hashes (List[str]): Optional list of hash algorithms to apply to shard files.
samples (int): Number of samples in this shard.
size_limit (Union[int, str], optional): Optional shard size limit, after which
point to start a new shard. If None, puts everything in one shard. Can
specify bytes in human-readable format as well, for example ``"100kb"``
for 100 kilobyte (100*1024) and so on.
"""
def __init__(
self,
dirname: str,
split: Optional[str],
compression: Optional[str],
hashes: list[str],
samples: int,
size_limit: Optional[Union[int, str]],
) -> None:
if size_limit:
if (isinstance(size_limit, str)):
size_limit = bytes_to_int(size_limit)
if size_limit < 0:
raise ValueError(f'`size_limit` must be greater than zero, instead, ' +
f'found as {size_limit}.')
self.dirname = dirname
self.split = split or ''
self.compression = compression
self.hashes = hashes
self.samples = samples
self.size_limit = size_limit
self.file_pairs = []
[docs] def validate(self, allow_unsafe_types: bool) -> None:
"""Check whether this shard is acceptable to be part of some Stream.
Args:
allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
execution during deserialization, whether to keep going if ``True`` or raise an
error if ``False``.
"""
pass
@property
def size(self):
"""Get the number of samples in this shard.
Returns:
int: Sample count.
"""
return self.samples
def __len__(self) -> int:
"""Get the number of samples in this shard.
Returns:
int: Sample count.
"""
return self.samples
def _evict_raw(self) -> int:
"""Remove all raw files belonging to this shard.
Returns:
int: Bytes evicted from cache.
"""
size = 0
for raw_info, _ in self.file_pairs:
filename = os.path.join(self.dirname, self.split, raw_info.basename)
if os.path.exists(filename):
os.remove(filename)
size += raw_info.bytes
return size
def _evict_zip(self) -> int:
"""Remove all zip files belonging to this shard.
Returns:
int: Bytes evicted from cache.
"""
size = 0
for _, zip_info in self.file_pairs:
if zip_info:
filename = os.path.join(self.dirname, self.split, zip_info.basename)
if os.path.exists(filename):
os.remove(filename)
size += zip_info.bytes
return size
[docs] def evict(self) -> int:
"""Remove all files belonging to this shard.
Returns:
int: Bytes evicted from cache.
"""
return self._evict_raw() + self._evict_zip()
[docs] def set_up_local(self, listing: set[str], safe_keep_zip: bool) -> int:
"""Bring what shard files are present to a consistent state, returning whether present.
Args:
listing (Set[str]): The listing of all files under dirname/[split/]. This is listed
once and then saved because there could potentially be very many shard files.
safe_keep_zip (bool): Whether to keep zip files when decompressing. Possible when
compression was used. Necessary when local is the remote or there is no remote.
Returns:
bool: Whether the shard is present.
"""
# For raw/zip to be considered present, each raw/zip file must be present.
raw_files_present = 0
zip_files_present = 0
for raw_info, zip_info in self.file_pairs:
if raw_info:
filename = os.path.join(self.dirname, self.split, raw_info.basename)
if filename in listing:
raw_files_present += 1
if zip_info:
filename = os.path.join(self.dirname, self.split, zip_info.basename)
if filename in listing:
zip_files_present += 1
# If the shard raw files are partially present, garbage collect the present ones and mark
# the shard raw as not present, in order to achieve consistency.
if not raw_files_present:
has_raw = False
elif raw_files_present < len(self.file_pairs):
has_raw = False
self._evict_raw()
else:
has_raw = True
# Same as the above, but for shard zip files.
if not zip_files_present:
has_zip = False
elif zip_files_present < len(self.file_pairs):
has_zip = False
self._evict_zip()
else:
has_zip = True
# Enumerate cases of raw/zip presence.
if self.compression:
if safe_keep_zip:
if has_raw:
if has_zip:
# Present (normalized).
pass
else:
# Missing: there is no natural way to arrive at this state, so drop raw.
has_raw = False
self._evict_raw()
else:
if has_zip:
# Present: but missing raw, so need to decompress upon use.
pass
else:
# Missing (normalized).
pass
else:
if has_raw:
if has_zip:
# Present: zip is unnecessary, so evict it.
has_zip = False
self._evict_raw()
else:
# Present (normalized).
pass
else:
if has_zip:
# Present: but missing raw, so need to decompress and evict zip upon use.
pass
else:
# Missing (normalized).
pass
else:
if has_zip:
raise ValueError('Shard is invalid: compression was not used, but has a ' +
'compressed form.')
# Get cache usage. Shard is present if either raw or zip are present.
size = 0
if has_raw:
size += self.get_raw_size()
if has_zip:
size += self.get_zip_size() or 0
return size
[docs] def get_raw_size(self) -> int:
"""Get the raw (uncompressed) size of this shard.
Returns:
int: Size in bytes.
"""
size = 0
for info, _ in self.file_pairs:
size += info.bytes
return size
[docs] def get_zip_size(self) -> Optional[int]:
"""Get the zip (compressed) size of this shard, if compression was used.
Returns:
Optional[int]: Size in bytes, or ``None`` if does not exist.
"""
size = 0
for _, info in self.file_pairs:
if info is None:
return None
size += info.bytes
return size
[docs] def get_max_size(self) -> int:
"""Get the full size of this shard.
"Max" in this case means both the raw (decompressed) and zip (compressed) versions are
resident (assuming it has a zip form). This is the maximum disk usage the shard can reach.
When compressed was used, even if keep_zip is ``False``, the zip form must still be
resident at the same time as the raw form during shard decompression.
Returns:
int: Size in bytes.
"""
return self.get_raw_size() + (self.get_zip_size() or 0)
[docs] def get_persistent_size(self, keep_zip: bool) -> int:
"""Get the persistent size of this shard.
"Persistent" in this case means whether both raw and zip are present is subject to
keep_zip. If we are not keeping zip files after decompression, they don't count to the
shard's persistent size on disk.
Args:
keep_zip (bool): Whether to keep zip files after decompressing.
Returns:
int: Size in bytes.
"""
if self.compression:
if keep_zip:
size = self.get_max_size()
else:
size = self.get_raw_size()
else:
size = self.get_raw_size()
return size
[docs] @abstractmethod
def decode_sample(self, data: bytes) -> dict[str, Any]:
"""Decode a sample dict from bytes.
Args:
data (bytes): The sample encoded as bytes.
Returns:
Dict[str, Any]: Sample dict.
"""
raise NotImplementedError
[docs] @abstractmethod
def get_sample_data(self, idx: int) -> bytes:
"""Get the raw sample data at the index.
Args:
idx (int): Sample index.
Returns:
bytes: Sample data.
"""
raise NotImplementedError
[docs] def get_item(self, idx: int) -> dict[str, Any]:
"""Get the sample at the index.
Args:
idx (int): Sample index.
Returns:
Dict[str, Any]: Sample dict.
"""
data = self.get_sample_data(idx)
return self.decode_sample(data)
def __iter__(self) -> Iterator[dict[str, Any]]:
"""Iterate over the samples of this shard.
Returns:
Iterator[Dict[str, Any]]: Iterator over samples.
"""
for i in range(len(self)):
yield self[i]
class JointReader(Reader):
"""Provides random access to the samples of a joint shard.
Args:
dirname (str): Local dataset directory.
split (str, optional): Which dataset split to use, if any.
compression (str, optional): Optional compression or compression:level.
hashes (List[str]): Optional list of hash algorithms to apply to shard files.
raw_data (FileInfo): Uncompressed data file info.
samples (int): Number of samples in this shard.
size_limit (Union[int, str], optional): Optional shard size limit, after which
point to start a new shard. If None, puts everything in one shard.
zip_data (FileInfo, optional): Compressed data file info.
"""
def __init__(
self,
dirname: str,
split: Optional[str],
compression: Optional[str],
hashes: list[str],
raw_data: FileInfo,
samples: int,
size_limit: Optional[Union[int, str]],
zip_data: Optional[FileInfo],
) -> None:
super().__init__(dirname, split, compression, hashes, samples, size_limit)
self.raw_data = raw_data
self.zip_data = zip_data
self.file_pairs.append((raw_data, zip_data))
class SplitReader(Reader):
"""Provides random access to the samples of a split shard.
Args:
dirname (str): Local dataset directory.
split (str, optional): Which dataset split to use, if any.
compression (str, optional): Optional compression or compression:level.
hashes (List[str]): Optional list of hash algorithms to apply to shard files.
raw_data (FileInfo): Uncompressed data file info.
raw_meta (FileInfo): Uncompressed meta file info.
samples (int): Number of samples in this shard.
size_limit (Union[int, str], optional): Optional shard size limit, after which
point to start a new shard. If None, puts everything in one shard.
zip_data (FileInfo, optional): Compressed data file info.
zip_meta (FileInfo, optional): Compressed meta file info.
"""
def __init__(
self,
dirname: str,
split: Optional[str],
compression: Optional[str],
hashes: list[str],
raw_data: FileInfo,
raw_meta: FileInfo,
samples: int,
size_limit: Optional[Union[int, str]],
zip_data: Optional[FileInfo],
zip_meta: Optional[FileInfo],
) -> None:
super().__init__(dirname, split, compression, hashes, samples, size_limit)
self.raw_data = raw_data
self.raw_meta = raw_meta
self.zip_data = zip_data
self.zip_meta = zip_meta
self.file_pairs.append((raw_meta, zip_meta))
self.file_pairs.append((raw_data, zip_data))