Source code for composer.utils.retrying

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Retry helper."""

from __future__ import annotations

import functools
import inspect
import logging
import random
import time
from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast, overload

TCallable = TypeVar('TCallable', bound=Callable)

__all__ = ['retry']

log = logging.getLogger(__name__)

def retry(
    exc_class: Union[Type[Exception], Sequence[Type[Exception]]] = ...,
    num_attempts: int = ...,
    initial_backoff: float = ...,
    max_jitter: float = ...,
) -> Callable[[TCallable], TCallable]:

def retry(exc_class: TCallable) -> TCallable:
    # Use the decorator without parenthesis

# error: Type "(TCallable@retry) -> TCallable@retry" cannot be assigned to type "(func: Never) -> Never"
[docs]def retry( # type: ignore exc_class: Union[TCallable, Type[Exception], Sequence[Type[Exception]]] = Exception, num_attempts: int = 3, initial_backoff: float = 1.0, max_jitter: float = 0.5, ): """Decorator to retry a function with backoff and jitter. Attempts are spaced out with ``initial_backoff + 2**num_attempts + random.random() * max_jitter`` seconds. Optionally, the decorated function can specify `retry_index` as an argument to receive the current attempt number. Example: .. testcode:: from composer.utils import retry @retry(RuntimeError, num_attempts=3, initial_backoff=0.1) def flaky_function(retry_index: int): if retry_index < 2: raise RuntimeError("Called too soon!") return "Third time's a charm." print(flaky_function()) .. testoutput:: Third time's a charm. Args: exc_class (Type[Exception] | Sequence[Type[Exception]]], optional): The exception class or classes to retry. Defaults to Exception. num_attempts (int, optional): The total number of attempts to make. Defaults to 3. initial_backoff (float, optional): The initial backoff, in seconds. Defaults to 1.0. max_jitter (float, optional): The maximum amount of random jitter to add. Defaults to 0.5. Increasing the ``max_jitter`` can help prevent overloading a resource when multiple processes in parallel are calling the same underlying function. """ if num_attempts < 1: raise ValueError('num_attempts must be at leats 1') def wrapped_func(func: TCallable) -> TCallable: @functools.wraps(func) def new_func(*args: Any, **kwargs: Any): retry_index_param = 'retry_index' i = 0 while True: try: if retry_index_param in inspect.signature(func).parameters: kwargs[retry_index_param] = i return func(*args, **kwargs) except exc_class as e: log.debug(f'Attempt {i} failed. Exception type: {type(e)}, message: {str(e)}.') if i + 1 == num_attempts: raise e else: time.sleep(initial_backoff * 2**i + random.random() * max_jitter) i += 1 return cast(TCallable, new_func) if not isinstance( exc_class,, ) and not (isinstance(exc_class, type) and issubclass(exc_class, Exception)): # Using the decorator without (), like @retry_with_backoff func = cast(TCallable, exc_class) exc_class = Exception return wrapped_func(func) return wrapped_func