Source code for composer.datasets.mnist

# Copyright 2021 MosaicML. All Rights Reserved.

from dataclasses import dataclass

import yahp as hp
from torchvision import datasets, transforms

from composer.datasets.hparams import DataloaderSpec, DatasetHparams


[docs]@dataclass class MNISTDatasetHparams(DatasetHparams): """Defines an instance of the MNIST dataset for image classification. Parameters: is_train (bool): Whether to load the training or validation dataset. datadir (str): Data directory to use. download (bool): Whether to download the dataset, if needed. drop_last (bool): Whether to drop the last samples for the last batch. shuffle (bool): Whether to shuffle the dataset for each epoch. """ is_train: bool = hp.required("whether to load the training or validation dataset") datadir: str = hp.required("data directory") download: bool = hp.required("whether to download the dataset, if needed") drop_last: bool = hp.optional("Whether to drop the last samples for the last batch", default=True) shuffle: bool = hp.optional("Whether to shuffle the dataset for each epoch", default=True)
[docs] def initialize_object(self) -> DataloaderSpec: transform = transforms.Compose([transforms.ToTensor()]) dataset = datasets.MNIST( self.datadir, train=self.is_train, download=self.download, transform=transform, ) return DataloaderSpec( dataset=dataset, drop_last=self.drop_last, shuffle=self.shuffle, )