Source code for composer.datasets.imagenet

# Copyright 2021 MosaicML. All Rights Reserved.

"""ImageNet classfication dataset.

The most widely used dataset for Image Classification algorithms. Please refer to the `ImageNet 2012 Classification
Dataset <http://image-net.org/>`_ for more details. Also includes streaming dataset versions based on the `WebDatasets
<https://github.com/webdataset/webdataset>`_.
"""

import os
import textwrap
from dataclasses import dataclass
from typing import List

import numpy as np
import torch
import torch.utils.data
import yahp as hp
from torchvision import transforms
from torchvision.datasets import ImageFolder

from composer.core import DataSpec
from composer.core.types import DataLoader
from composer.datasets.dataloader import DataLoaderHparams
from composer.datasets.ffcv_utils import ffcv_monkey_patches, write_ffcv_dataset
from composer.datasets.hparams import DatasetHparams, SyntheticHparamsMixin, WebDatasetHparams
from composer.datasets.synthetic import SyntheticBatchPairDataset
from composer.datasets.utils import NormalizationFn, pil_image_collate
from composer.utils import dist

# ImageNet normalization values from torchvision: https://pytorch.org/vision/stable/models.html
IMAGENET_CHANNEL_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255)
IMAGENET_CHANNEL_STD = (0.229 * 255, 0.224 * 255, 0.225 * 255)

__all__ = ["ImagenetDatasetHparams", "Imagenet1kWebDatasetHparams", "TinyImagenet200WebDatasetHparams"]


[docs]@dataclass class ImagenetDatasetHparams(DatasetHparams, SyntheticHparamsMixin): """Defines an instance of the ImageNet dataset for image classification. Args: resize_size (int, optional): The resize size to use. Use ``-1`` to not resize. Default: ``-1``. crop size (int): The crop size to use. Default: ``224``. use_ffcv (bool): Whether to use FFCV dataloaders. Default: ``False``. ffcv_dir (str): A directory containing train/val <file>.ffcv files. If these files don't exist and ``ffcv_write_dataset`` is ``True``, train/val <file>.ffcv files will be created in this dir. Default: ``"/tmp"``. ffcv_dest_train (str): <file>.ffcv file that has training samples. Default: ``"train.ffcv"``. ffcv_dest_val (str): <file>.ffcv file that has validation samples. Default: ``"val.ffcv"``. ffcv_write_dataset (std): Whether to create dataset in FFCV format (<file>.ffcv) if it doesn't exist. Default: ``False``. """ resize_size: int = hp.optional("resize size. Set to -1 to not resize", default=-1) crop_size: int = hp.optional("crop size", default=224) use_ffcv: bool = hp.optional("whether to use ffcv for faster dataloading", default=False) ffcv_dir: str = hp.optional( "A directory containing train/val <file>.ffcv files. If these files don't exist and ffcv_write_dataset is true, train/val <file>.ffcv files will be created in this dir.", default="/tmp") ffcv_dest_train: str = hp.optional("<file>.ffcv file that has training samples", default="train.ffcv") ffcv_dest_val: str = hp.optional("<file>.ffcv file that has validation samples", default="val.ffcv") ffcv_write_dataset: bool = hp.optional("Whether to create dataset in FFCV format (<file>.ffcv) if it doesn't exist", default=False)
[docs] def initialize_object(self, batch_size: int, dataloader_hparams: DataLoaderHparams) -> DataSpec: if self.use_synthetic: total_dataset_size = 1_281_167 if self.is_train else 50_000 dataset = SyntheticBatchPairDataset( total_dataset_size=total_dataset_size, data_shape=[3, self.crop_size, self.crop_size], num_classes=1000, num_unique_samples_to_create=self.synthetic_num_unique_samples, device=self.synthetic_device, memory_format=self.synthetic_memory_format, ) collate_fn = None device_transform_fn = None elif self.use_ffcv: try: import ffcv # type: ignore from ffcv.fields.decoders import RandomResizedCropRGBImageDecoder # type: ignore from ffcv.fields.decoders import CenterCropRGBImageDecoder, IntDecoder # type: ignore from ffcv.pipeline.operation import Operation # type: ignore except ImportError: raise ImportError( textwrap.dedent("""\ Composer was installed without ffcv support. To use ffcv with Composer, please install ffcv in your environment.""")) if self.is_train: dataset_file = self.ffcv_dest_train split = "train" else: dataset_file = self.ffcv_dest_val split = "val" dataset_file = self.ffcv_dest_train if self.is_train else self.ffcv_dest_val dataset_filepath = os.path.join(self.ffcv_dir, dataset_file) # always create if ffcv_write_dataset is true if self.ffcv_write_dataset: if dist.get_local_rank() == 0: if self.datadir is None: raise ValueError( "datadir is required if use_synthetic is False and ffcv_write_dataset is True.") ds = ImageFolder(os.path.join(self.datadir, split)) write_ffcv_dataset(dataset=ds, write_path=dataset_filepath, max_resolution=500, num_workers=dataloader_hparams.num_workers, compress_probability=0.50, jpeg_quality=90) # Wait for the local rank 0 to be done creating the dataset in ffcv format. dist.barrier() this_device = torch.device(f'cuda:{dist.get_local_rank()}') label_pipeline: List[Operation] = [ IntDecoder(), ffcv.transforms.ToTensor(), ffcv.transforms.Squeeze(), ffcv.transforms.ToDevice(this_device, non_blocking=True) ] image_pipeline: List[Operation] = [] if self.is_train: image_pipeline.extend([ RandomResizedCropRGBImageDecoder((self.crop_size, self.crop_size)), ffcv.transforms.RandomHorizontalFlip() ]) dtype = np.float16 else: image_pipeline.extend([CenterCropRGBImageDecoder((self.crop_size, self.crop_size), ratio=224 / 256)]) dtype = np.float32 # Common transforms for train and test image_pipeline.extend([ ffcv.transforms.ToTensor(), ffcv.transforms.ToDevice(this_device, non_blocking=True), ffcv.transforms.ToTorchImage(), ffcv.transforms.NormalizeImage(np.array(IMAGENET_CHANNEL_MEAN), np.array(IMAGENET_CHANNEL_STD), dtype), ]) is_distributed = dist.get_world_size() > 1 ffcv_monkey_patches() ordering = ffcv.loader.OrderOption.RANDOM if self.is_train else ffcv.loader.OrderOption.SEQUENTIAL return ffcv.Loader( dataset_filepath, batch_size=batch_size, num_workers=dataloader_hparams.num_workers, order=ordering, distributed=is_distributed, pipelines={ 'image': image_pipeline, 'label': label_pipeline }, batches_ahead=dataloader_hparams.prefetch_factor, drop_last=self.drop_last, ) else: if self.is_train: # include fixed-size resize before RandomResizedCrop in training only # if requested (by specifying a size > 0) train_resize_size = self.resize_size train_transforms: List[torch.nn.Module] = [] if train_resize_size > 0: train_transforms.append(transforms.Resize(train_resize_size)) # always include RandomResizedCrop and RandomHorizontalFlip train_transforms += [ transforms.RandomResizedCrop(self.crop_size, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)), transforms.RandomHorizontalFlip() ] transformation = transforms.Compose(train_transforms) split = "train" else: transformation = transforms.Compose([ transforms.Resize(self.resize_size), transforms.CenterCrop(self.crop_size), ]) split = "val" device_transform_fn = NormalizationFn(mean=IMAGENET_CHANNEL_MEAN, std=IMAGENET_CHANNEL_STD) collate_fn = pil_image_collate if self.datadir is None: raise ValueError("datadir must be specified if self.synthetic is False") dataset = ImageFolder(os.path.join(self.datadir, split), transformation) sampler = dist.get_sampler(dataset, drop_last=self.drop_last, shuffle=self.shuffle) return DataSpec(dataloader=dataloader_hparams.initialize_object( dataset=dataset, batch_size=batch_size, sampler=sampler, drop_last=self.drop_last, collate_fn=collate_fn, ), device_transforms=device_transform_fn)
[docs]@dataclass class TinyImagenet200WebDatasetHparams(WebDatasetHparams): """Defines an instance of the TinyImagenet-200 WebDataset for image classification. Args: remote (str): S3 bucket or root directory where dataset is stored. Default: ``'s3://mosaicml-internal-dataset-tinyimagenet200'``. name (str): Key used to determine where dataset is cached on local filesystem. Default: ``'tinyimagenet200'``. n_train_samples (int): Number of training samples. Default: ``100000``. n_val_samples (int): Number of validation samples. Default: ``10000``. height (int): Sample image height in pixels. Default: ``64``. width (int): Sample image width in pixels. Default: ``64``. n_classes (int): Number of output classes. Default: ``200``. channel_means (list of float): Channel means for normalization. Default: ``(0.485, 0.456, 0.406)``. channel_stds (list of float): Channel stds for normalization. Default: ``(0.229, 0.224, 0.225)``. """ remote: str = hp.optional('WebDataset S3 bucket name', default='s3://mosaicml-internal-dataset-tinyimagenet200') name: str = hp.optional('WebDataset local cache name', default='tinyimagenet200') n_train_samples: int = hp.optional('Number of samples in training split', default=100_000) n_val_samples: int = hp.optional('Number of samples in validation split', default=10_000) height: int = hp.optional('Image height', default=64) width: int = hp.optional('Image width', default=64) n_classes: int = hp.optional('Number of output classes', default=200) channel_means: List[float] = hp.optional('Mean per image channel', default=(0.485, 0.456, 0.406)) channel_stds: List[float] = hp.optional('Std per image channel', default=(0.229, 0.224, 0.225))
[docs] def initialize_object(self, batch_size: int, dataloader_hparams: DataLoaderHparams) -> DataLoader: from composer.datasets.webdataset_utils import load_webdataset if self.is_train: split = 'train' transform = transforms.Compose([ transforms.RandomCrop((self.height, self.width), (self.height // 8, self.width // 8)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(self.channel_means, self.channel_stds), ]) else: split = 'val' transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(self.channel_means, self.channel_stds), ]) preprocess = lambda dataset: dataset.decode('pil').map_dict(jpg=transform).to_tuple('jpg', 'cls') dataset = load_webdataset(self.remote, self.name, split, self.webdataset_cache_dir, self.webdataset_cache_verbose, self.shuffle, self.shuffle_buffer, preprocess, dist.get_world_size(), dataloader_hparams.num_workers, batch_size, self.drop_last) return dataloader_hparams.initialize_object(dataset, batch_size=batch_size, sampler=None, drop_last=self.drop_last)
[docs]@dataclass class Imagenet1kWebDatasetHparams(WebDatasetHparams): """Defines an instance of the ImageNet-1k WebDataset for image classification. Args: remote (str): S3 bucket or root directory where dataset is stored. Default: ``'s3://mosaicml-internal-dataset-imagenet1k'``. name (str): Key used to determine where dataset is cached on local filesystem. Default: ``'imagenet1k'``. resize_size (int, optional): The resize size to use. Use -1 to not resize. Default: ``-1``. crop size (int): The crop size to use. Default: ``224``. """ remote: str = hp.optional('WebDataset S3 bucket name', default='s3://mosaicml-internal-dataset-imagenet1k') name: str = hp.optional('WebDataset local cache name', default='imagenet1k') resize_size: int = hp.optional("resize size. Set to -1 to not resize", default=-1) crop_size: int = hp.optional("crop size", default=224)
[docs] def initialize_object(self, batch_size: int, dataloader_hparams: DataLoaderHparams) -> DataSpec: from composer.datasets.webdataset_utils import load_webdataset if self.is_train: # include fixed-size resize before RandomResizedCrop in training only # if requested (by specifying a size > 0) train_resize_size = self.resize_size train_transforms: List[torch.nn.Module] = [] if train_resize_size > 0: train_transforms.append(transforms.Resize(train_resize_size)) # always include RandomResizedCrop and RandomHorizontalFlip train_transforms += [ transforms.RandomResizedCrop(self.crop_size, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)), transforms.RandomHorizontalFlip() ] transform = transforms.Compose(train_transforms) else: transform = transforms.Compose([ transforms.Resize(self.resize_size), transforms.CenterCrop(self.crop_size), ]) split = 'train' if self.is_train else 'val' preprocess = lambda dataset: dataset.decode('pil').map_dict(jpg=transform).to_tuple('jpg', 'cls') dataset = load_webdataset(self.remote, self.name, split, self.webdataset_cache_dir, self.webdataset_cache_verbose, self.shuffle, self.shuffle_buffer, preprocess, dist.get_world_size(), dataloader_hparams.num_workers, batch_size, self.drop_last) collate_fn = pil_image_collate device_transform_fn = NormalizationFn(mean=IMAGENET_CHANNEL_MEAN, std=IMAGENET_CHANNEL_STD) return DataSpec(dataloader=dataloader_hparams.initialize_object( dataset=dataset, batch_size=batch_size, sampler=None, drop_last=self.drop_last, collate_fn=collate_fn, ), device_transforms=device_transform_fn)