importbisectimportitertoolsimportmathimportwarningsfromtypingimport(cast,Dict,Generic,Iterable,List,Optional,Sequence,Tuple,TypeVar,Union,)# No 'default_generator' in torch/__init__.pyifromtorchimportdefault_generator,randpermfrom...importGenerator,Tensor__all__=["Dataset","IterableDataset","TensorDataset","StackDataset","ConcatDataset","ChainDataset","Subset","random_split",]T_co=TypeVar("T_co",covariant=True)T=TypeVar("T")T_dict=Dict[str,T_co]T_tuple=Tuple[T_co,...]T_stack=TypeVar("T_stack",T_tuple,T_dict)
[docs]classDataset(Generic[T_co]):r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """def__getitem__(self,index)->T_co:raiseNotImplementedError("Subclasses of Dataset should implement __getitem__.")# def __getitems__(self, indices: List) -> List[T_co]:# Not implemented to prevent false-positives in fetcher check in# torch.utils.data._utils.fetch._MapDatasetFetcherdef__add__(self,other:"Dataset[T_co]")->"ConcatDataset[T_co]":returnConcatDataset([self,other])
# No `def __len__(self)` default?# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]# in pytorch/torch/utils/data/sampler.pyclassIterableDataset(Dataset[T_co],Iterable[T_co]):r"""An iterable Dataset. All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream. All subclasses should overwrite :meth:`__iter__`, which would return an iterator of samples in this dataset. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. When :attr:`num_workers > 0`, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's :attr:`worker_init_fn` option to modify each copy's behavior. Example 1: splitting workload across all workers in :meth:`__iter__`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> # xdoctest: +SKIP("Fails on MacOS12") >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # xdoctest: +REQUIRES(POSIX) >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6] """def__add__(self,other:Dataset[T_co]):returnChainDataset([self,other])# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]classTensorDataset(Dataset[Tuple[Tensor,...]]):r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """tensors:Tuple[Tensor,...]def__init__(self,*tensors:Tensor)->None:assertall(tensors[0].size(0)==tensor.size(0)fortensorintensors),"Size mismatch between tensors"self.tensors=tensorsdef__getitem__(self,index):returntuple(tensor[index]fortensorinself.tensors)def__len__(self):returnself.tensors[0].size(0)classStackDataset(Dataset[T_stack]):r"""Dataset as a stacking of multiple datasets. This class is useful to assemble different parts of complex input data, given as datasets. Example: >>> # xdoctest: +SKIP >>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} Args: *args (Dataset): Datasets for stacking returned as tuple. **kwargs (Dataset): Datasets for stacking returned as dict. """datasets:Union[tuple,dict]def__init__(self,*args:Dataset[T_co],**kwargs:Dataset[T_co])->None:ifargs:ifkwargs:raiseValueError("Supported either ``tuple``- (via ``args``) or""``dict``- (via ``kwargs``) like input/output, but both types are given.")self._length=len(args[0])# type: ignore[arg-type]ifany(self._length!=len(dataset)fordatasetinargs):# type: ignore[arg-type]raiseValueError("Size mismatch between datasets")self.datasets=argselifkwargs:tmp=list(kwargs.values())self._length=len(tmp[0])# type: ignore[arg-type]ifany(self._length!=len(dataset)fordatasetintmp):# type: ignore[arg-type]raiseValueError("Size mismatch between datasets")self.datasets=kwargselse:raiseValueError("At least one dataset should be passed")def__getitem__(self,index):ifisinstance(self.datasets,dict):return{k:dataset[index]fork,datasetinself.datasets.items()}returntuple(dataset[index]fordatasetinself.datasets)def__getitems__(self,indices:list):# add batched sampling support when parent datasets supports it.ifisinstance(self.datasets,dict):dict_batch:List[T_dict]=[{}for_inindices]fork,datasetinself.datasets.items():ifcallable(getattr(dataset,"__getitems__",None)):items=dataset.__getitems__(indices)# type: ignore[attr-defined]iflen(items)!=len(indices):raiseValueError("Nested dataset's output size mismatch."f" Expected {len(indices)}, got {len(items)}")fordata,d_sampleinzip(items,dict_batch):d_sample[k]=dataelse:foridx,d_sampleinzip(indices,dict_batch):d_sample[k]=dataset[idx]returndict_batch# tuple datalist_batch:List[list]=[[]for_inindices]fordatasetinself.datasets:ifcallable(getattr(dataset,"__getitems__",None)):items=dataset.__getitems__(indices)# type: ignore[attr-defined]iflen(items)!=len(indices):raiseValueError("Nested dataset's output size mismatch."f" Expected {len(indices)}, got {len(items)}")fordata,t_sampleinzip(items,list_batch):t_sample.append(data)else:foridx,t_sampleinzip(indices,list_batch):t_sample.append(dataset[idx])tuple_batch:List[T_tuple]=[tuple(sample)forsampleinlist_batch]returntuple_batchdef__len__(self):returnself._lengthclassConcatDataset(Dataset[T_co]):r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """datasets:List[Dataset[T_co]]cumulative_sizes:List[int]@staticmethoddefcumsum(sequence):r,s=[],0foreinsequence:l=len(e)r.append(l+s)s+=lreturnrdef__init__(self,datasets:Iterable[Dataset])->None:super().__init__()self.datasets=list(datasets)assertlen(self.datasets)>0,"datasets should not be an empty iterable"# type: ignore[arg-type]fordinself.datasets:assertnotisinstance(d,IterableDataset),"ConcatDataset does not support IterableDataset"self.cumulative_sizes=self.cumsum(self.datasets)def__len__(self):returnself.cumulative_sizes[-1]def__getitem__(self,idx):ifidx<0:if-idx>len(self):raiseValueError("absolute value of index should not exceed dataset length")idx=len(self)+idxdataset_idx=bisect.bisect_right(self.cumulative_sizes,idx)ifdataset_idx==0:sample_idx=idxelse:sample_idx=idx-self.cumulative_sizes[dataset_idx-1]returnself.datasets[dataset_idx][sample_idx]@propertydefcummulative_sizes(self):warnings.warn("cummulative_sizes attribute is renamed to ""cumulative_sizes",DeprecationWarning,stacklevel=2,)returnself.cumulative_sizesclassChainDataset(IterableDataset):r"""Dataset for chaining multiple :class:`IterableDataset` s. This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """def__init__(self,datasets:Iterable[Dataset])->None:super().__init__()self.datasets=datasetsdef__iter__(self):fordinself.datasets:assertisinstance(d,IterableDataset),"ChainDataset only supports IterableDataset"yield fromddef__len__(self):total=0fordinself.datasets:assertisinstance(d,IterableDataset),"ChainDataset only supports IterableDataset"total+=len(d)# type: ignore[arg-type]returntotalclassSubset(Dataset[T_co]):r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """dataset:Dataset[T_co]indices:Sequence[int]def__init__(self,dataset:Dataset[T_co],indices:Sequence[int])->None:self.dataset=datasetself.indices=indicesdef__getitem__(self,idx):ifisinstance(idx,list):returnself.dataset[[self.indices[i]foriinidx]]returnself.dataset[self.indices[idx]]def__getitems__(self,indices:List[int])->List[T_co]:# add batched sampling support when parent dataset supports it.# see torch.utils.data._utils.fetch._MapDatasetFetcherifcallable(getattr(self.dataset,"__getitems__",None)):returnself.dataset.__getitems__([self.indices[idx]foridxinindices])# type: ignore[attr-defined]else:return[self.dataset[self.indices[idx]]foridxinindices]def__len__(self):returnlen(self.indices)defrandom_split(dataset:Dataset[T],lengths:Sequence[Union[int,float]],generator:Optional[Generator]=default_generator,)->List[Subset[T]]:r""" Randomly split a dataset into non-overlapping new datasets of given lengths. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided. After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left. Optionally fix the generator for reproducible results, e.g.: Example: >>> # xdoctest: +SKIP >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths or fractions of splits to be produced generator (Generator): Generator used for the random permutation. """ifmath.isclose(sum(lengths),1)andsum(lengths)<=1:subset_lengths:List[int]=[]fori,fracinenumerate(lengths):iffrac<0orfrac>1:raiseValueError(f"Fraction at index {i} is not between 0 and 1")n_items_in_split=int(math.floor(len(dataset)*frac)# type: ignore[arg-type])subset_lengths.append(n_items_in_split)remainder=len(dataset)-sum(subset_lengths)# type: ignore[arg-type]# add 1 to all the lengths in round-robin fashion until the remainder is 0foriinrange(remainder):idx_to_add_at=i%len(subset_lengths)subset_lengths[idx_to_add_at]+=1lengths=subset_lengthsfori,lengthinenumerate(lengths):iflength==0:warnings.warn(f"Length of split at index {i} is 0. "f"This might result in an empty dataset.")# Cannot verify that dataset is Sizedifsum(lengths)!=len(dataset):# type: ignore[arg-type]raiseValueError("Sum of input lengths does not equal the length of the input dataset!")indices=randperm(sum(lengths),generator=generator).tolist()# type: ignore[arg-type, call-overload]lengths=cast(Sequence[int],lengths)return[Subset(dataset,indices[offset-length:offset])foroffset,lengthinzip(itertools.accumulate(lengths),lengths)]