Source code for composer.callbacks.runtime_estimator

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Estimate total time of training."""
from __future__ import annotations

import logging
import time
import warnings
from typing import Optional

from composer.core import Callback, State, TimeUnit
from composer.loggers import Logger

log = logging.getLogger(__name__)

__all__ = ['RuntimeEstimator']


[docs]class RuntimeEstimator(Callback): """Estimates total training time. The training time is computed by taking the time elapsed for the current duration and multiplying out to the full extended length of the training run. This callback provides a best attempt estimate. This estimate may be inaccurate if throughput changes through training or other significant changes are made to the model or dataloader. Example: .. doctest:: >>> from composer import Trainer >>> from composer.callbacks import RuntimeEstimator >>> # constructing trainer object with this callback >>> trainer = Trainer( ... model=model, ... train_dataloader=train_dataloader, ... eval_dataloader=eval_dataloader, ... optimizers=optimizer, ... max_duration='1ep', ... callbacks=[RuntimeEstimator()], ... ) The runtime estimate is logged by the :class:`.Logger` to the following key as described below. +-----------------------------------+----------------------------------------------------------------+ | Key | Logged data | +===================================+================================================================+ | `time/remaining_estimate` | Estimated time to completion | +-----------------------------------+----------------------------------------------------------------+ | `time/remaining_estimate_unit` | Unit of time specified by user (seconds, minutes, hours, days) | +-----------------------------------+----------------------------------------------------------------+ Args: skip_batches (int, optional): Number of batches to skip before starting clock to estimate remaining time. Typically, the first few batches are slower due to dataloader, cache warming, and other reasons. Defaults to 1. time_unit (str, optional): Time unit to use for `time` logging. Can be one of 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'. """ def __init__(self, skip_batches: int = 1, time_unit: str = 'hours') -> None: self._enabled = True self.batches_left_to_skip = skip_batches self.start_time = None self.start_dur = None self.train_dataloader_len = None self.time_unit = time_unit self.divider = 1 if time_unit == 'seconds': self.divider = 1 elif time_unit == 'minutes': self.divider = 60 elif time_unit == 'hours': self.divider = 60 * 60 elif time_unit == 'days': self.divider = 60 * 60 * 24 else: raise ValueError( f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".', ) # Keep track of time spent evaluating self.total_eval_wct = 0.0 self.eval_wct_per_label: dict[str, list[float]] = {} # How often eval is called as fraction of total training time self.eval_frequency_per_label: dict[str, float] = {} self.last_elapsed_fraction: float = 0.0 def _get_elapsed_duration(self, state: State) -> Optional[float]: """Get the elapsed duration. Unlike `state.get_elapsed_duration`, this method computes fractional progress in an epoch provided at least 1 epoch has passed by recording how many batches were in each epoch. """ if state.max_duration is None: return None if state.max_duration.unit == TimeUnit('ep'): if state.timestamp.epoch.value >= 1: batches_per_epoch = ( state.timestamp.batch - state.timestamp.batch_in_epoch ).value / state.timestamp.epoch.value return state.timestamp.get('ba').value / (state.max_duration.value * batches_per_epoch) elif self.train_dataloader_len is not None: return state.timestamp.get('ba').value / (state.max_duration.value * self.train_dataloader_len) elapsed_dur = state.get_elapsed_duration() if elapsed_dur is not None: return elapsed_dur.value return None def init(self, state: State, logger: Logger) -> None: if self._enabled: logger.log_hyperparameters({'time/remaining_estimate_unit': self.time_unit}) def batch_start(self, state: State, logger: Logger) -> None: if self._enabled and self.start_time is None and self.batches_left_to_skip == 0: self.start_time = time.time() self.start_dur = self._get_elapsed_duration(state) if self.start_dur is None: warnings.warn('`max_duration` is not set. Cannot estimate remaining time.') self._enabled = False # Cache train dataloader len if specified for `_get_elapsed_duration` if state.dataloader_len is not None: self.train_dataloader_len = state.dataloader_len.value def batch_end(self, state: State, logger: Logger) -> None: if not self._enabled: return if self.batches_left_to_skip > 0: self.batches_left_to_skip -= 1 return elapsed_dur = self._get_elapsed_duration(state) assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start if enabled' assert self.start_dur is not None assert self.start_time is not None if elapsed_dur > self.start_dur: elapsed_time = time.time() - self.start_time elapsed_time -= self.total_eval_wct # Subtract time spent evaluating rate = elapsed_time / (elapsed_dur - self.start_dur) remaining_time = rate * (1 - elapsed_dur) # Add remaining time from each evaluator using known frequencies. We explicitly compute # frequency instead of using time interpolation to avoid saw tooth pattern in estimates for dataloader_label, eval_wcts in self.eval_wct_per_label.items(): # Discard first eval_wct if possible as it is often slower due to dataset downloading eval_wct_avg = None num_evals_finished = len(eval_wcts) if num_evals_finished > 1: eval_wct_avg = sum(eval_wcts[1:]) / (num_evals_finished - 1) else: eval_wct_avg = sum(eval_wcts) / num_evals_finished eval_rate = self.eval_frequency_per_label[dataloader_label] num_total_evals = 1 / eval_rate * (1 - self.start_dur) remaining_calls = num_total_evals - num_evals_finished remaining_time += eval_wct_avg * remaining_calls logger.log_metrics({'time/remaining_estimate': remaining_time / self.divider}) def eval_end(self, state: State, logger: Logger) -> None: # If eval is called before training starts, ignore it if not self._enabled or self.start_time is None: return self.total_eval_wct += state.eval_timestamp.total_wct.total_seconds() # state.dataloader_label should always be non-None unless user explicitly sets evaluator # label to None, ignoring type hints assert state.dataloader_label is not None, 'evaluator label must not be None' if state.dataloader_label not in self.eval_wct_per_label: self.eval_wct_per_label[state.dataloader_label] = [] self.eval_wct_per_label[state.dataloader_label].append(state.eval_timestamp.total_wct.total_seconds()) elapsed_dur = self._get_elapsed_duration(state) assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start if enabled' assert self.start_dur is not None, 'start_dur is set on batch_start if enabled' elapsed_fraction = elapsed_dur - self.start_dur num_evals_finished = len(self.eval_wct_per_label[state.dataloader_label]) self.eval_frequency_per_label[state.dataloader_label] = elapsed_fraction / num_evals_finished