Source code for composer.callbacks.speed_monitor

# Copyright 2021 MosaicML. All Rights Reserved.

from __future__ import annotations

import time
from collections import deque
from typing import Deque, Optional

from composer import Logger, State
from composer.callbacks.callback_hparams import SpeedMonitorHparams
from composer.core.callback import RankZeroCallback
from composer.core.types import StateDict


[docs]class SpeedMonitor(RankZeroCallback): """Logs the training throughput. It logs: * A rolling average (over the ``window_size`` most recent batches) of the number of samples processed per second to the ``throughput/step`` key. * The number of samples processed per second, averaged over an entire epoch, to the ``throughput/epoch`` key. * The total elapsed training time to the ``wall_clock_train`` key. Args: window_size (int): Number of batchs to use for a rolling average of throughput. """ def __init__(self, window_size: int): super().__init__() self.train_examples_per_epoch = 0 self.wall_clock_train = 0.0 self.epoch_start_time = 0.0 self.batch_end_times: Deque[float] = deque(maxlen=window_size + 1) # rolling list of batch end times self.batch_num_samples: Deque[int] = deque(maxlen=window_size) # rolling list of num samples in batch. self.hparams = SpeedMonitorHparams(window_size=window_size) self.loaded_state: Optional[StateDict] = None
[docs] def state_dict(self) -> StateDict: current_time = time.time() return { "train_examples_per_epoch": self.train_examples_per_epoch, "wall_clock_train": self.wall_clock_train, "epoch_duration": current_time - self.epoch_start_time, "batch_durations": [current_time - x for x in self.batch_end_times], "batch_num_samples": self.batch_num_samples, }
[docs] def load_state_dict(self, state: StateDict) -> None: self.loaded_state = state
def _load_state(self) -> None: current_time = time.time() if self.loaded_state is not None: self.train_examples_per_epoch = self.loaded_state["train_examples_per_epoch"] self.wall_clock_train = self.loaded_state["wall_clock_train"] self.epoch_start_time = current_time - self.loaded_state["epoch_duration"] self.batch_end_times = deque([current_time - x for x in self.loaded_state["batch_durations"]], maxlen=self.hparams.window_size + 1) self.batch_num_samples = self.loaded_state["batch_num_samples"] self.loaded_state = None
[docs] def batch_start(self, state: State, logger: Logger) -> None: del state, logger # unused self._load_state()
[docs] def epoch_start(self, state: State, logger: Logger): del state, logger # unused self._load_state() self.epoch_start_time = time.time() self.batch_end_times.clear() self.batch_num_samples.clear() self.train_examples_per_epoch = 0
[docs] def batch_end(self, state: State, logger: Logger): self.batch_end_times.append(time.time()) batch_num_samples = 0 batch_num_samples += state.last_batch_size # TODO this is a hack around not having proper syncing / reduction available in callbacks # Ideally, callbacks would have a way of reducing tensors. # It assumes that each process has equal batch sizing # For the speed monitor, we might be able to use the static step converter with num_samples batch_num_samples *= state.world_size self.batch_num_samples.append(batch_num_samples) self.train_examples_per_epoch += batch_num_samples if len(self.batch_end_times) == self.hparams.window_size + 1: throughput = sum(self.batch_num_samples) / (self.batch_end_times[-1] - self.batch_end_times[0]) logger.metric_batch({'throughput/step': throughput})
[docs] def epoch_end(self, state: State, logger: Logger): del state # unused epoch_time = time.time() - self.epoch_start_time self.wall_clock_train += epoch_time logger.metric_epoch({ "wall_clock_train": self.wall_clock_train, }) logger.metric_epoch({ "throughput/epoch": self.train_examples_per_epoch / epoch_time, })