Source code for composer.utils.iter_helpers

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

# To keep the typing organized for this file, see iter_helpers.pyi
# All typing annotations are in there
# All methods signatures must be defined in there.

"""Utilities for iterating over collections."""
from __future__ import annotations

import collections.abc
import io
from typing import Any


[docs]def map_collection(collection, map_fn) -> Any: """Applies ``map_fn`` on each element in ``collection``. * If ``collection`` is a tuple or list of elements, ``map_fn`` is applied on each element, and a tuple or list, respectively, containing mapped values is returned. * If ``collection`` is a dictionary, ``map_fn`` is applied on each value, and a dictionary containing the mapped values is returned. * If ``collection`` is ``None``, ``None`` is returned. * If ``collection`` is a single element, the result of applying ``map_fn`` on it is returned. Args: collection: The element, or a tuple of elements. map_fn: A function to invoke on each element. Returns: Collection: The result of applying ``map_fn`` on each element of ``collection``. The type of ``collection`` is preserved. """ if collection is None: return None if isinstance(collection, (tuple, list)): return type(collection)(map_fn(x) for x in collection) if isinstance(collection, dict): return {k: map_fn(v) for k, v in collection.items()} return map_fn(collection)
[docs]def ensure_tuple(x) -> tuple[Any, ...]: """Converts ``x`` into a tuple. * If ``x`` is ``None``, then ``tuple()`` is returned. * If ``x`` is a tuple, then ``x`` is returned as-is. * If ``x`` is a list, then ``tuple(x)`` is returned. * If ``x`` is a dict, then ``tuple(v for v in x.values())`` is returned. Otherwise, a single element tuple of ``(x,)`` is returned. Args: x (Any): The input to convert into a tuple. Returns: tuple: A tuple of ``x``. """ if x is None: return () if isinstance(x, (str, bytes, bytearray)): return (x,) if isinstance(x, collections.abc.Sequence): return tuple(x) if isinstance(x, dict): return tuple(x.values()) return (x,)
[docs]class IteratorFileStream(io.RawIOBase): """Class used to convert iterator of bytes into a file-like binary stream object. Original implementation found `here <https://stackoverflow.com/questions/6657820/how-to-convert-an-iterable-to-a-stream/20260030#20260030>`_. .. note A usage example ``f = io.BufferedReader(IteratorFileStream(iterator), buffer_size=buffer_size)`` Args: iterator: An iterator over bytes objects """ def __init__(self, iterator): self.leftover = None self.iterator = iterator def readinto(self, b): try: l = len(b) # max bytes to read if self.leftover: chunk = self.leftover else: chunk = next(self.iterator) output, self.leftover = chunk[:l], chunk[l:] b[:len(output)] = output return len(output) except StopIteration: return 0 #EOF def readable(self): return True
def iterate_with_callback(iterator, total_len, callback=None): """Invoke ``callback`` after each chunk is yielded from ``iterator``. Args: iterator (Iterator): The iterator, which should yield chunks of data. total_len (int): The total length of the iterator. callback (Callable[[int, int], None], optional): The callback to invoke after each chunk of data is yielded back to the caller. Defaults to None, for no callback. It is called with the cumulative size of all chunks yielded thus far and the ``total_len``. """ current_len = 0 if callback is not None: # Call the callback for any initialization callback(current_len, total_len) for chunk in iterator: current_len += len(chunk) yield chunk if callback is not None: callback(current_len, total_len)