Source code for streaming.base.local
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""A non-streaming pytorch map Dataset."""
import json
import os
from typing import Any, Optional
import numpy as np
from torch.utils.data import Dataset
from streaming.base.array import Array
from streaming.base.format import get_index_basename, reader_from_json
from streaming.base.spanner import Spanner
__all__ = ['LocalDataset']
[docs]class LocalDataset(Array, Dataset):
"""A streaming dataset whose shards reside locally as a pytorch Dataset.
Args:
local (str): Local dataset directory where shards are cached by split.
split (str, optional): Which dataset split to use, if any. Defaults to ``None``.
"""
def __init__(self, local: str, split: Optional[str] = None):
split = split or ''
self.local = local
self.split = split
filename = os.path.join(local, split, get_index_basename()) # pyright: ignore
obj = json.load(open(filename))
if obj['version'] != 2:
raise ValueError(f'Unsupported streaming data version: {obj["version"]}. ' +
f'Expected version 2.')
self.shards = []
for info in obj['shards']:
shard = reader_from_json(local, split, info)
self.shards.append(shard)
self.num_samples = sum([shard.samples for shard in self.shards])
shard_sizes = np.array([x.samples for x in self.shards])
self.spanner = Spanner(shard_sizes)
def __len__(self) -> int:
"""Get the length as a PyTorch Dataset.
Returns:
int: Dataset length.
"""
return self.num_samples
@property
def size(self) -> int:
"""Get the size of the dataset in samples.
Returns:
int: Number of samples.
"""
return self.num_samples
[docs] def get_item(self, sample_id: int) -> dict[str, Any]:
"""Get sample by global sample ID.
Args:
sample_id (int): Sample ID.
Returns:
Dict[str, Any]: Column name with sample data.
"""
shard_id, index_in_shard = self.spanner[sample_id]
shard = self.shards[shard_id]
return shard[index_in_shard]