Source code for composer.datasets.mnist

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from composer.core import MemoryFormat
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.utils import dist


[docs]def build_mnist_dataloader( datadir: str, global_batch_size: int, is_train: bool = True, download: bool = True, drop_last: bool = True, shuffle: bool = True, **dataloader_kwargs: Any, ) -> DataLoader: """Builds an MNIST dataloader. Args: datadir (str): Path to the data directory global_batch_size (int): Global batch size. is_train (bool): Whether to load the training data or validation data. Default: ``True``. download (bool, optional): Whether to download the dataset, if needed. Default: ``True``. drop_last (bool): Drop remainder samples. Default: ``True``. shuffle (bool): Shuffle the dataset. Default: ``True``. **dataloader_kwargs (Any): Additional settings for the dataloader (e.g. num_workers, etc.) """ if global_batch_size % dist.get_world_size() != 0: raise ValueError( f'global_batch_size ({global_batch_size}) must be divisible by world_size ({dist.get_world_size()}).') batch_size = global_batch_size // dist.get_world_size() transform = transforms.Compose([transforms.ToTensor()]) with dist.run_local_rank_zero_first(): dataset = datasets.MNIST( datadir, train=is_train, download=dist.get_local_rank() == 0 and download, transform=transform, ) sampler = dist.get_sampler(dataset, drop_last=drop_last, shuffle=shuffle) return DataLoader( dataset, batch_size=batch_size, sampler=sampler, drop_last=drop_last, **dataloader_kwargs, )
[docs]def build_synthetic_mnist_dataloader( global_batch_size: int, is_train: bool = True, drop_last: bool = True, shuffle: bool = True, num_unique_samples: int = 100, device: str = 'cpu', memory_format: MemoryFormat = MemoryFormat.CONTIGUOUS_FORMAT, **dataloader_kwargs: Any, ) -> DataLoader: """Builds a synthetic MNIST dataset. Args: global_batch_size (int): Global batch size. is_train (bool): Whether to load the training data or validation data. Default: ``True``. drop_last (bool): Drop remainder samples. Default: ``True``. shuffle (bool): Shuffle the dataset. Default: ``True``. num_unique_samples (int): number of unique samples in synthetic dataset. Default: ``100``. device (str): device with which to load the dataset. Default: ``cpu``. memory_format (:class:`composer.core.MemoryFormat`): memory format of the tensors. Default: ``CONTIGUOUS_FORMAT``. **dataloader_kwargs (Any): Additional settings for the dataloader (e.g. num_workers, etc.) """ if global_batch_size % dist.get_world_size() != 0: raise ValueError( f'global_batch_size ({global_batch_size}) must be divisible by world_size ({dist.get_world_size()}).') batch_size = global_batch_size // dist.get_world_size() dataset = SyntheticBatchPairDataset( total_dataset_size=60_000 if is_train else 10_000, data_shape=[1, 28, 28], num_classes=10, num_unique_samples_to_create=num_unique_samples, device=device, memory_format=memory_format, ) sampler = dist.get_sampler(dataset, drop_last=drop_last, shuffle=shuffle) return DataLoader( dataset, batch_size=batch_size, sampler=sampler, drop_last=drop_last, **dataloader_kwargs, )