# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Logs metrics to the console and show a progress bar."""
from __future__ import annotations
import os
import sys
from typing import TYPE_CHECKING, Any, Optional, TextIO, Union
import tqdm.auto
import yaml
from composer.core.time import TimeUnit
from composer.loggers.logger import Logger, format_log_data_value
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import dist, is_notebook
if TYPE_CHECKING:
from composer.core import State, Timestamp
__all__ = ['ProgressBarLogger']
_IS_TRAIN_TO_KEYS_TO_LOG = {
True: ['loss/train'],
False: ['eval'],
}
class _ProgressBar:
def __init__(
self,
total: Optional[int],
position: Optional[int],
bar_format: str,
file: TextIO,
metrics: dict[str, Any],
keys_to_log: list[str],
timestamp_key: str,
unit: str = 'it',
) -> None:
self.keys_to_log = keys_to_log
self.metrics = metrics
self.position = position
self.timestamp_key = timestamp_key
self.file = file
is_atty = is_notebook() or os.isatty(self.file.fileno())
self.pbar = tqdm.auto.tqdm(
total=total,
position=position,
bar_format=bar_format,
file=file,
ncols=None if is_atty else 120,
dynamic_ncols=is_atty,
# We set `leave=False` so TQDM does not jump around, but we emulate `leave=True` behavior when closing
# by printing a dummy newline and refreshing to force tqdm to print to a stale line
# But on k8s, we need `leave=True`, as it would otherwise overwrite the pbar in place
# If in a notebook, then always set leave=True, as otherwise jupyter would remote the progress bars
leave=True if is_notebook() else not is_atty,
postfix=metrics,
unit=unit,
)
def log_data(self, data: dict[str, Any]):
formatted_data = {}
for (k, v) in data.items():
# Check if any substring of the key matches the keys to log
if any(key_to_log in k for key_to_log in self.keys_to_log):
formatted_data[k] = format_log_data_value(v)
self.metrics.update(formatted_data)
self.pbar.set_postfix(self.metrics)
def update(self, n=1):
self.pbar.update(n=n)
def update_to_timestamp(self, timestamp: Timestamp):
n = int(getattr(timestamp, self.timestamp_key))
n = n - self.pbar.n
self.update(int(n))
def close(self):
if is_notebook():
# If in a notebook, always refresh before closing, so the
# finished progress is displayed
self.pbar.refresh()
else:
if self.position != 0:
# Force a (potentially hidden) progress bar to re-render itself
# Don't render the dummy pbar (at position 0), since that will clear a real pbar (at position 1)
self.pbar.refresh()
# Create a newline that will not be erased by leave=False. This allows for the finished pbar to be cached in the terminal
# This emulates `leave=True` without progress bar jumping
if not self.file.closed:
print('', file=self.file, flush=True)
self.pbar.close()
def state_dict(self) -> dict[str, Any]:
pbar_state = self.pbar.format_dict
return {
'total': pbar_state['total'],
'position': self.position,
'bar_format': pbar_state['bar_format'],
'metrics': self.metrics,
'keys_to_log': self.keys_to_log,
'n': pbar_state['n'],
'timestamp_key': self.timestamp_key,
}
[docs]class ProgressBarLogger(LoggerDestination):
"""Log metrics to the console and optionally show a progress bar.
.. note::
This logger is automatically instantiated by the trainer via the ``progress_bar``,
and ``console_stream`` options. This logger does not need to be created manually.
`TQDM <https://github.com/tqdm/tqdm>`_ is used to display progress bars.
During training, the progress bar logs the batch and training loss.
During validation, the progress bar logs the batch and validation accuracy.
Example progress bar output::
Epoch 1: 100%|โโโโโโโโโโ| 64/64 [00:01<00:00, 53.17it/s, loss/train=2.3023]
Epoch 1 (val): 100%|โโโโโโโโโโ| 20/20 [00:00<00:00, 100.96it/s, accuracy/val=0.0995]
Args:
stream (str | TextIO, optional): The console stream to use. If a string, it can either be ``'stdout'`` or
``'stderr'``. (default: :attr:`sys.stderr`)
log_traces (bool): Whether to log traces or not. (default: ``False``)
"""
def __init__(
self,
stream: Union[str, TextIO] = sys.stderr,
log_traces: bool = False,
) -> None:
# The dummy pbar is to fix issues when streaming progress bars over k8s, where the progress bar in position 0
# doesn't update until it is finished.
# Need to have a dummy progress bar in position 0, so the "real" progress bars in position 1 doesn't jump around
self.dummy_pbar: Optional[_ProgressBar] = None
self.train_pbar: Optional[_ProgressBar] = None
self.eval_pbar: Optional[_ProgressBar] = None
# set the stream
if isinstance(stream, str):
if stream.lower() == 'stdout':
stream = sys.stdout
elif stream.lower() == 'stderr':
stream = sys.stderr
else:
raise ValueError(f'stream must be one of ("stdout", "stderr", TextIO-like), got {stream}')
self.should_log_traces = log_traces
self.stream = stream
self.state: Optional[State] = None
self.hparams: dict[str, Any] = {}
self.hparams_already_logged_to_console: bool = False
@property
def show_pbar(self) -> bool:
return dist.get_local_rank() == 0
def log_hyperparameters(self, hyperparameters: dict[str, Any]):
# Lazy logging of hyperparameters.
self.hparams.update(hyperparameters)
def _log_hparams_to_console(self):
if dist.get_local_rank() == 0:
self._log_to_console('*' * 30)
self._log_to_console('Config:')
self._log_to_console(yaml.dump(self.hparams))
self._log_to_console('*' * 30)
self.hparams_already_logged_to_console = True
def log_traces(self, traces: dict[str, Any]):
if self.should_log_traces:
for trace_name, trace in traces.items():
trace_str = format_log_data_value(trace)
self._log_to_console(f'[trace]: {trace_name}:' + trace_str + '\n')
def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None:
for metric_name, metric_value in metrics.items():
# Only log metrics and losses to pbar.
if 'metric' in metric_name or 'loss' in metric_name:
self.log_to_pbar(data={metric_name: metric_value})
def log_to_pbar(self, data: dict[str, Any]):
# log to progress bar
current_pbar = self.eval_pbar if self.eval_pbar is not None else self.train_pbar
if current_pbar:
# Logging outside an epoch
current_pbar.log_data(data)
def _log_to_console(self, log_str: str):
"""Logs to the console, avoiding interleaving with a progress bar."""
current_pbar = self.eval_pbar if self.eval_pbar is not None else self.train_pbar
if current_pbar:
# use tqdm.write to avoid interleaving
current_pbar.pbar.write(log_str)
else:
# write directly to self.stream; no active progress bar
print(log_str, file=self.stream, flush=True)
def _build_pbar(self, state: State, is_train: bool) -> _ProgressBar:
"""Builds a pbar.
* If ``state.max_duration.unit`` is :attr:`.TimeUnit.EPOCH`, then a new progress bar will be created for each epoch.
Mid-epoch evaluation progress bars will be labeled with the batch and epoch number.
* Otherwise, one progress bar will be used for all of training. Evaluation progress bars will be labeled
with the time (in units of ``max_duration.unit``) at which evaluation runs.
"""
# Always using position=1 to avoid jumping progress bars
# In jupyter notebooks, no need for the dummy pbar, so use the default position
position = None if is_notebook() else 1
desc = f'{state.dataloader_label:15}'
max_duration_unit = None if state.max_duration is None else state.max_duration.unit
if max_duration_unit == TimeUnit.EPOCH or max_duration_unit is None:
total = int(state.dataloader_len) if state.dataloader_len is not None else None
timestamp_key = 'batch_in_epoch'
unit = TimeUnit.BATCH
n = state.timestamp.epoch.value
if self.train_pbar is None and not is_train:
# epochwise eval results refer to model from previous epoch (n-1)
n = n - 1 if n > 0 else 0
if self.train_pbar is None:
desc += f'Epoch {n:3}'
else:
# For evaluation mid-epoch, show the total batch count
desc += f'Batch {int(state.timestamp.batch):3}'
desc += ': '
else:
if is_train:
assert state.max_duration is not None, 'max_duration should be set if training'
unit = max_duration_unit
total = state.max_duration.value
# pad for the expected length of an eval pbar -- which is 14 characters (see the else logic below)
desc = desc.ljust(len(desc) + 14)
else:
unit = TimeUnit.BATCH
total = int(state.dataloader_len) if state.dataloader_len is not None else None
value = int(state.timestamp.get(max_duration_unit))
# Longest unit name is sample (6 characters)
desc += f'{max_duration_unit.name.capitalize():6} {value:5}: '
timestamp_key = unit.name.lower()
return _ProgressBar(
file=self.stream,
total=total,
position=position,
keys_to_log=_IS_TRAIN_TO_KEYS_TO_LOG[is_train],
# In a notebook, the `bar_format` should not include the {bar}, as otherwise
# it would appear twice.
bar_format=desc + ' {l_bar}' + ('' if is_notebook() else '{bar:25}') + '{r_bar}{bar:-1b}',
unit=unit.value.lower(),
metrics={},
timestamp_key=timestamp_key,
)
def init(self, state: State, logger: Logger) -> None:
del logger # unused
if not is_notebook():
# Notebooks don't need the dummy progress bar; otherwise, it would be visible.
self.dummy_pbar = _ProgressBar(
file=self.stream,
position=0,
total=1,
metrics={},
keys_to_log=[],
bar_format='{bar:-1b}',
timestamp_key='',
)
self.state = state
def fit_start(self, state: State, logger: Logger) -> None:
if not self.hparams_already_logged_to_console:
self._log_hparams_to_console()
def predict_start(self, state: State, logger: Logger) -> None:
if not self.hparams_already_logged_to_console:
self._log_hparams_to_console()
def epoch_start(self, state: State, logger: Logger) -> None:
if self.show_pbar and not self.train_pbar:
self.train_pbar = self._build_pbar(state=state, is_train=True)
def eval_start(self, state: State, logger: Logger) -> None:
if not self.hparams_already_logged_to_console:
self._log_hparams_to_console()
if self.show_pbar:
self.eval_pbar = self._build_pbar(state, is_train=False)
def batch_end(self, state: State, logger: Logger) -> None:
if self.train_pbar:
self.train_pbar.update_to_timestamp(state.timestamp)
def eval_batch_end(self, state: State, logger: Logger) -> None:
if self.eval_pbar:
self.eval_pbar.update_to_timestamp(state.eval_timestamp)
def epoch_end(self, state: State, logger: Logger) -> None:
# Only close progress bars at epoch end if the duration is in epochs, since
# a new pbar will be created for each epoch
# If the duration is in other units, then one progress bar will be used for all of training.
assert state.max_duration is not None, 'max_duration should be set'
if self.train_pbar and state.max_duration.unit == TimeUnit.EPOCH:
self.train_pbar.close()
self.train_pbar = None
def close(self, state: State, logger: Logger) -> None:
del state, logger # unused
# Close any open progress bars
if self.eval_pbar:
self.eval_pbar.close()
self.eval_pbar = None
if self.train_pbar:
self.train_pbar.close()
self.train_pbar = None
if self.dummy_pbar:
self.dummy_pbar.close()
self.dummy_pbar = None
def eval_end(self, state: State, logger: Logger) -> None:
if self.eval_pbar:
self.eval_pbar.close()
self.eval_pbar = None
def state_dict(self) -> dict[str, Any]:
return {
'train_pbar': self.train_pbar.state_dict() if self.train_pbar else None,
'eval_pbar': self.eval_pbar.state_dict() if self.eval_pbar else None,
}
def load_state_dict(self, state: dict[str, Any]) -> None:
if state['train_pbar']:
n = state['train_pbar'].pop('n')
train_pbar = self._ensure_backwards_compatibility(state['train_pbar'])
self.train_pbar = _ProgressBar(file=self.stream, **train_pbar)
self.train_pbar.update(n=n)
if state['eval_pbar']:
n = state['train_pbar'].pop('n')
eval_pbar = self._ensure_backwards_compatibility(state['eval_pbar'])
self.eval_pbar = _ProgressBar(file=self.stream, **eval_pbar)
self.eval_pbar.update(n=n)
def _ensure_backwards_compatibility(self, state: dict[str, Any]) -> dict[str, Any]:
# ensure backwards compatible with mosaicml<=v0.8.0 checkpoints
state.pop('epoch_style', None)
# old checkpoints do not have timestamp_key
if 'timestamp_key' not in state:
if 'unit' not in state:
raise ValueError('Either unit or timestamp_key must be in pbar state of checkpoint.')
unit = state['unit']
assert isinstance(unit, TimeUnit)
state['timestamp_key'] = unit.name.lower()
# new format expects unit as str, not TimeUnit
if 'unit' in state and isinstance(state['unit'], TimeUnit):
state['unit'] = state['unit'].value.lower()
return state