# Copyright 2021 MosaicML. All Rights Reserved.
# To keep the typing organized for this file, see iter_helpers.pyi
# All typing annotations are in there
# All methods signatures must be defined in there.
"""Utilities for iterating over collections."""
import contextlib
from collections.abc import Sequence
[docs]def map_collection(collection, map_fn):
"""Apply ``map_fn`` on each element in ``collection``.
* If ``collection`` is a tuple or list of elements, ``map_fn`` is applied on each element,
and a tuple or list, respectively, containing mapped values is returned.
* If ``collection`` is a dictionary, ``map_fn`` is applied on each value, and a dictionary
containing the mapped values is returned.
* If ``collection`` is ``None``, ``None`` is returned.
* If ``collection`` is a single element, the result of applying ``map_fn`` on it is returned.
Args:
collection: The element, or a tuple of elements.
map_fn: A function to invoke on each element.
Returns:
Collection: The result of applying ``map_fn`` on each element of ``collection``.
The type of ``collection`` is preserved.
"""
if collection is None:
return None
if isinstance(collection, (tuple, list)):
return type(collection)(map_fn(x) for x in collection)
if isinstance(collection, dict):
return {k: map_fn(v) for k, v in collection.items()}
return map_fn(collection)
[docs]def ensure_tuple(x):
"""Converts ``x`` into a tuple.
* If ``x`` is ``None``, then ``tuple()`` is returned.
* If ``x`` is a tuple, then ``x`` is returned as-is.
* If ``x`` is a list, then ``tuple(x)`` is returned.
* If ``x`` is a dict, then ``tuple(v for v in x.values())`` is returned.
Otherwise, a single element tuple of ``(x,)`` is returned.
Args:
x (Any): The input to convert into a tuple.
Returns:
tuple: A tuple of ``x``.
"""
if x is None:
return tuple()
if isinstance(x, (str, bytes, bytearray)):
return (x,)
if isinstance(x, Sequence):
return tuple(x)
if isinstance(x, dict):
return tuple(x.values())
return (x,)
[docs]def iterate_with_pbar(iterator, progress_bar=None):
"""Iterate over a batch iterator and update a :class:`tqdm.tqdm` progress bar by the batch size on each step.
This function iterates over ``iterator``, which is expected to yield batches of elements.
On each step, the batch is yielded back to the caller, and the ``progress_bar`` is updated by the
**length** of each batch.
.. note::
It is expected that the ``progress_bar = tqdm.tqdm(total=sum(len(x) for x in iterator))``.
Args:
iterator (Iterator[TSized]): An iterator that yields batches of elements.
progress_bar (Optional[tqdm.tqdm], optional): A :class:`tqdm.tqdm` progress bar.
If ``None`` (the default), then this function simply yields from ``iterator``.
Yields:
Iterator[TSized]: The elements of ``iterator``.
"""
with progress_bar if progress_bar is not None else contextlib.nullcontext(None) as pb:
for x in iterator:
yield x
if pb is not None:
pb.update(len(x))