# Copyright 2021 MosaicML. All Rights Reserved.
"""Profiler to collect :mod:`torch` performance metrics during training."""
from __future__ import annotations
import json
import os
import textwrap
from typing import Optional, OrderedDict
import torch.profiler
from torch.profiler.profiler import ProfilerAction as TorchProfilerAction
from composer.core.callback import Callback
from composer.core.state import State
from composer.loggers import Logger
from composer.loggers.logger import LogLevel
from composer.profiler.profiler_action import ProfilerAction
from composer.utils import dist
from composer.utils.file_helpers import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE,
ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time)
__all__ = ["TorchProfiler"]
[docs]class TorchProfiler(Callback):
__doc__ = f"""Profile the execution using the :class:`PyTorch Profiler <torch.profiler.profile>`.
Profiling results are stored in TensorBoard format in the directory specified by ``folder``.
.. note::
The Composer :class:`~composer.trainer.trainer.Trainer` automatically creates an instance of this
:class:`.TorchProfiler` callback whenever any of the PyTorch Profiler arguments
(``torch_prof_record_shapes``, ``torch_prof_profile_memory``, ``torch_prof_with_stack``, or
``torch_prof_with_flops``) are enabled.
When using the Composer :class:`~composer.trainer.trainer.Trainer`, one does not need to directly create an
instance of this :class:`.TorchProfiler` callback.
To view profiling results, run::
pip install tensorbaord torch_tb_profiler
tensorboard --logdir path/to/torch/trace_folder
.. note::
See :doc:`profiler` for additional usage details on the :class:`torch.profiler.profile`.
.. note::
Enabling shape and stack tracing results in additional overhead.
When ``record_shapes=True`` is specified, the profiler will temporarily hold references to tensors which
may prevent certain optimizations that depend on the reference count and can introduce extra tensor copies.
Args:
folder (str, optional): Format string for the folder containing the Torch Profiler trace files.
Defaults to ``'{{run_name}}/torch_traces'``.
The following format variables are available:
{textwrap.indent(FORMAT_NAME_WITH_DIST_TABLE, prefix=' ')}
For example, if the ``run_name`` is ``'awesome_training_run'``, and the default ``folder`` of
``'{{run_name}}/torch_traces'`` is used, Torch Profiler traces will be stored in
``'awesome_training_run/torch_traces'``.
filename (str, optional): A format string describing how to name Torch Profiler trace files.
Defaults to ``'rank{{rank}}.{{batch}}.pt.trace.json'``.
At the end of each batch where :meth:`~composer.profiler.Profiler.get_action` returns
:attr:`~composer.profiler._profiler_action.ProfilerAction.ACTIVE_AND_SAVE`, trace files are saved
approximately to ``{{folder.format(...)}}/{{filename.format(...)}}``.
The following format variables are available:
{textwrap.indent(FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, prefix=' ')}
Consider the following scenario, where:
* The :attr:`~.Logger.run_name` is ``'awesome-training-run'``.
* The default ``trace_folder='{{run_name}}/torch_traces'`` is used.
* The default ``name='rank{{rank}}.{{batch}}.pt.trace.json'`` is used.
* The current epoch count is ``1``.
* The current batch count is ``42``.
Each rank (process) will save traces to::
awesome-training-run/torch_traces/ep1-ba42-rank0.json
awesome-training-run/torch_traces/ep1-ba42-rank1.json
awesome-training-run/torch_traces/ep1-ba42-rank2.json
...
artifact_name (str, optional): Format string for a Torch Profiler trace file's artifact name.
Defaults to ``'{{run_name}}/torch_traces/rank{{rank}}.{{batch}}.pt.trace.json'``.
Whenever a trace file is saved, it is also logged as a file artifact according to this format string.
The same format variables as for ``filename`` are available.
.. seealso:: :meth:`~composer.loggers.logger.Logger.file_artifact` for file artifact logging.
Leading slashes (``'/'``) will be stripped.
To disable logging trace files as file artifacts, set this parameter to ``None``.
overwrite (bool, optional): Whether to override existing Torch Profiler traces. Defaults to False.
If False, then the trace folder as determined by ``folder`` must be empty.
use_gzip (bool, optional): Whether to use gzip for the trace. Defaults to False.
If True, ``'.gz'`` will be appended ``filename`` and ``artifact_name``
(if they do not already end in ``'.gz'``).
record_shapes (bool, optional): Whether to record tensor shapes. Defaults to False.
profile_memory (bool, optional): Whether to profile memory. Defaults to True.
with_stack (bool, optional): Whether to record stack info. Defaults to False.
with_flops (bool, optional): Whether to estimate flops for operators. Defaults to True.
num_traces_to_keep (int, optional): The number of trace files to keep locally. Defaults to -1.
If set to -1, then all traces files are kept locally.
After a trace has been saved and logged as a file artifact, the oldest traces are removed until
``num_traces_to_keep`` traces remain. This parameter only controls how many traces are kept locally;
traces are not deleted from artifact stores.
It can be useful to set this parameter to ``0`` when using an artifact logger such as the
:class:`~composer.loggers.object_store_logger.ObjectStoreLogger`. This combination will minimize local
disk usage by deleting trace files immediately after they have been uploaded to the object store.
Attributes:
saved_traces (List[Tuple[Timestamp, List[pathlib.Path]]]): The trace timestamps and filepaths.
This list contains tuples of the save timestamp and the trace filepaths.
This list will have at most ``num_traces_to_keep`` entries. The latest trace
will be at the end.
The index of a filepath in each list corresponds to the global rank of the process that wrote that file.
Each filepath is valid only on the process's (rank's) node.
"""
def __init__(
self,
folder: str = '{run_name}/torch_traces',
filename: str = 'rank{rank}.{batch}.pt.trace.json',
artifact_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
*,
overwrite: bool = False,
use_gzip: bool = False,
record_shapes: bool = False,
profile_memory: bool = True,
with_stack: bool = False,
with_flops: bool = True,
num_traces_to_keep: int = -1,
) -> None:
self.overwrite = overwrite
self.folder = folder
if use_gzip and not filename.endswith('.gz'):
filename += ".gz"
self.filename = filename
if use_gzip and artifact_name is not None and not artifact_name.endswith('.gz'):
artifact_name += ".gz"
self.artifact_name = artifact_name
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.with_flops = with_flops
self.num_traces_to_keep = num_traces_to_keep
self.saved_traces = OrderedDict()
self.profiler: Optional[torch.profiler.profile] = None
def init(self, state: State, logger: Logger) -> None:
if state.profiler is None:
raise RuntimeError(("The Composer Profiler was not enabled, which is required to use the "
f"{type(self).__name__}. To enable, set the `prof_schedule` argument of the Trainer."))
folder_name = format_name_with_dist(self.folder, logger.run_name)
os.makedirs(folder_name, exist_ok=True)
if not self.overwrite:
ensure_folder_is_empty(folder_name)
dist.barrier()
def scheduler_fn(torch_profiler_step: int) -> TorchProfilerAction:
del torch_profiler_step # the torch profiler step is unused. Using the composer timer instead.
assert state.profiler is not None
composer_profiler_action = state.profiler.schedule(state)
if composer_profiler_action == ProfilerAction.ACTIVE_AND_SAVE:
return TorchProfilerAction.RECORD_AND_SAVE
if composer_profiler_action == ProfilerAction.ACTIVE:
return TorchProfilerAction.RECORD
if composer_profiler_action == ProfilerAction.WARMUP:
return TorchProfilerAction.WARMUP
assert composer_profiler_action == ProfilerAction.SKIP, f"unexpected action: {composer_profiler_action}"
return TorchProfilerAction.NONE
def handler_fn(prof: torch.profiler.profiler.profile):
assert state.profiler is not None
timestamp = state.timer.get_timestamp()
trace_file_name = os.path.join(
folder_name,
format_name_with_dist_and_time(self.filename, run_name=logger.run_name, timestamp=timestamp),
)
trace_file_dirname = os.path.dirname(trace_file_name)
if trace_file_dirname:
os.makedirs(trace_file_dirname, exist_ok=True)
prof.export_chrome_trace(trace_file_name)
state.profiler.record_chrome_json_trace_file(trace_file_name)
if self.artifact_name is not None:
trace_artifact_name = format_name_with_dist_and_time(self.artifact_name,
run_name=logger.run_name,
timestamp=timestamp)
trace_artifact_name = trace_artifact_name.lstrip('/')
logger.file_artifact(LogLevel.BATCH,
artifact_name=trace_artifact_name,
file_path=trace_file_name,
overwrite=self.overwrite)
if self.num_traces_to_keep >= 0:
while len(self.saved_traces) > self.num_traces_to_keep:
# self.saved_traces is an ordered dict, so the zeroth item will be the oldest checkpoint
timestamp, filepaths = next(iter(self.saved_traces.items()))
if dist.get_global_rank() < len(filepaths):
# Remove this rank's checkpoint
os.remove(filepaths[dist.get_global_rank()])
del self.saved_traces[timestamp]
self.profiler = torch.profiler.profile(
schedule=scheduler_fn,
on_trace_ready=handler_fn,
record_shapes=self.record_shapes,
profile_memory=self.profile_memory,
with_stack=self.with_stack,
with_flops=self.with_flops,
)
self.profiler.__enter__()
def batch_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
assert self.profiler is not None
self.profiler.add_metadata_json("global_rank", json.dumps(dist.get_global_rank()))
self.profiler.step()
def batch_start(self, state: State, logger: Logger) -> None:
del state # unused
assert self.profiler is not None
logger.data_batch({"profiler/state": self.profiler.current_action.name})
def close(self, state: State, logger: Logger) -> None:
del state, logger # unused
if self.profiler is not None:
self.profiler.__exit__(None, None, None)