Source code for composer.datasets.ffcv_utils

import json
import logging
from typing import Optional

import numpy as np

from composer.core.types import Dataset
from composer.datasets.webdataset_utils import init_webdataset_meta
from composer.utils import MissingConditionalImportError

try:
    import ffcv  # type: ignore
    ffcv_installed = True
except ImportError:
    ffcv_installed = False

log = logging.getLogger(__name__)

__all__ = ["write_ffcv_dataset", "ffcv_monkey_patches"]


def _require_ffcv():
    if not ffcv_installed:
        raise MissingConditionalImportError(extra_deps_group="ffcv", conda_package="ffcv")


def ffcv_monkey_patches():
    _require_ffcv()

    # ffcv's __len__ function is expensive as it always calls self.next_traversal_order which does shuffling.
    # Composer calls len(dataloader) function in training loop for every batch and thus len function causes 2x slowdown.
    # ffcv's __len__ is fixed in 1.0.0 branch but for another reason (https://github.com/libffcv/ffcv/issues/163).
    def new_len(self):
        if not hasattr(self, "init_traversal_order"):
            self.init_traversal_order = self.next_traversal_order()
        if self.drop_last:
            return len(self.init_traversal_order) // self.batch_size
        else:
            return int(np.ceil(len(self.init_traversal_order) / self.batch_size))

    ffcv.loader.loader.Loader.__len__ = new_len


[docs]def write_ffcv_dataset(dataset: Optional[Dataset] = None, remote: Optional[str] = None, write_path: str = "/tmp/dataset.ffcv", max_resolution: Optional[int] = None, num_workers: int = 16, write_mode: str = 'raw', compress_probability: float = 0.50, jpeg_quality: float = 90, chunk_size: int = 100): """Converts PyTorch ``dataset`` or webdataset at ``remote`` into FFCV format at filepath ``write_path``. Args: dataset (Iterable[Sample]): A PyTorch dataset. Default: ``None``. remote (str): A remote path for webdataset. Default: ``None``. write_path (str): Write results to this file. Default: ``"/tmp/dataset.ffcv"``. max_resolution (int): Limit resolution if provided. Default: ``None``. num_workers (int): Numbers of workers to use. Default: ``16``. write_mode (str): Write mode for the dataset. Default: ``'raw'``. compress_probability (float): Probability with which image is JPEG-compressed. Default: ``0.5``. jpeg_quality (float): Quality to use for jpeg compression. Default: ``90``. chunk_size (int): Size of chunks processed by each worker during conversion. Default: ``100``. """ _require_ffcv() if dataset is None and remote is None: raise ValueError("At least one of dataset or remote should not be None.") log.info(f"Writing dataset in FFCV <file>.ffcv format to {write_path}.") writer = ffcv.writer.DatasetWriter(write_path, { 'image': ffcv.fields.RGBImageField(write_mode=write_mode, max_resolution=max_resolution, compress_probability=compress_probability, jpeg_quality=jpeg_quality), 'label': ffcv.fields.IntField() }, num_workers=num_workers) if dataset: writer.from_indexed_dataset(dataset, chunksize=chunk_size) elif remote is not None: pipeline = lambda dataset: dataset.decode('pil').to_tuple('jpg', 'cls') text = init_webdataset_meta(remote) metadata = json.loads(text) num_shards = metadata['n_shards'] if metadata['n_leftover'] > 0: num_shards = num_shards + 1 if remote.startswith('s3://'): urls = [f'pipe: aws s3 cp {remote}/{idx:05d}.tar -' for idx in range(num_shards)] else: urls = [f'{remote}/{idx:05d}.tar' for idx in range(num_shards)] lengths = np.repeat(metadata['samples_per_shard'], metadata['n_shards']) lengths = np.insert(lengths, 0, 0) # we don't have n_leftover in the lengths array so we don't need to # remove it from offsets. offsets = np.cumsum(lengths) todos = zip(urls, offsets) total_len = metadata['samples_per_shard'] * metadata['n_shards'] + metadata['n_leftover'] # We call this internal API instead of writer.from_webdataset because with writer.from_webdataset # FFCV downloads the whole dataset twice. writer._write_common(total_len, todos, ffcv.writer.worker_job_webdataset, (pipeline,))