# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Composer Profiler."""
from __future__ import annotations
import logging
import pathlib
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union
from composer.core import Callback
from composer.loggers import Logger
from composer.profiler.json_trace_handler import JSONTraceHandler
from composer.profiler.marker import Marker
from composer.profiler.profiler_action import ProfilerAction
from composer.profiler.system_profiler import SystemProfiler
from composer.profiler.torch_profiler import TorchProfiler
from composer.profiler.trace_handler import TraceHandler
from composer.utils import ensure_tuple, parse_uri
if TYPE_CHECKING:
from composer.core import State
__all__ = ['Profiler']
log = logging.getLogger(__name__)
[docs]class Profiler(Callback):
"""Composer Profiler.
See the :doc:`Profiling Guide </trainer/performance_tutorials/profiling>` for additional information.
Args:
schedule ((State) -> ProfilerAction): The profiling scheduling function.
It takes the training state and returns a :class:`.ProfilerAction`.
For convenience, Composer includes a :meth:`~composer.profiler.cyclic_schedule.cyclic_schedule` helper.
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
original_profiler_init = Profiler.__init__
def new_profiler_init(self, dummy_ellipsis=None, **kwargs):
if 'trace_handlers' not in kwargs:
kwargs['trace_handlers'] = []
kwargs['torch_prof_memory_filename'] = None
original_profiler_init(self, **kwargs)
Profiler.__init__ = new_profiler_init
.. testcode::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(
...,
schedule=cyclic_schedule(
skip_first=1,
wait=0,
warmup=1,
active=4,
repeat=1,
),
torch_prof_memory_filename=None,
)
trace_handlers (TraceHandler | Sequence[TraceHandler]): Trace handlers which record and
save profiling data to traces. Additionally supports full object store paths.
sys_prof_cpu (bool, optional): Whether to record cpu statistics. (default: ``True``).
sys_prof_memory (bool, optional): Whether to record memory statistics. (default: ``False``).
sys_prof_disk (bool, optional): Whether to record disk statistics. (default: ``False``).
sys_prof_net (bool, optional): Whether to record network statistics. (default: ``False``).
sys_prof_stats_thread_interval_seconds (float, optional): Interval to record stats, in seconds.
(default: ``0.5``).
torch_prof_folder (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
Additionally supports full object store paths e.g: s3://bucket/path/to/file.
torch_prof_memory_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_memory_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
Additionally supports full object store paths e.g: s3://bucket/path/to/file.
torch_prof_overwrite (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_use_gzip (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_record_shapes (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_profile_memory (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_with_stack (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_with_flops (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_num_traces_to_keep (int, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
"""
def __init__(
self,
schedule: Callable[[State], ProfilerAction],
trace_handlers: list[TraceHandler],
sys_prof_cpu: bool = True,
sys_prof_memory: bool = False,
sys_prof_disk: bool = False,
sys_prof_net: bool = False,
sys_prof_stats_thread_interval_seconds: float = 0.5,
torch_prof_folder: str = '{run_name}/torch_traces',
torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json',
torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
torch_prof_memory_filename: Optional[str] = None,
torch_prof_memory_remote_file_name: Optional[str] = (
'{run_name}/torch_memory_traces/'
'rank{rank}.{batch}.pt.memory_trace.html'
),
torch_prof_overwrite: bool = False,
torch_prof_use_gzip: bool = False,
torch_prof_record_shapes: bool = False,
torch_prof_profile_memory: bool = True,
torch_prof_with_stack: bool = False,
torch_prof_with_flops: bool = True,
torch_prof_num_traces_to_keep: int = -1,
) -> None:
self._names_to_markers: dict[str, Marker] = {}
self._trace_handlers = list(ensure_tuple(trace_handlers))
self.schedule = schedule
self.state = None
self._callbacks: list[Callback] = []
# Used to count skip_first starting from resumption timestamp
self.resumption_batch_idx: int = 0
self.remote_filenames: list[str] = []
# First, add each remote file name to self.remote_filenames to create RemoteUploaderDownloader logger in trainer. [s3://bucket/path/to/file]
# Then modify remote file name to be a local path to pass into torch_profiler and system_profiler. e.g: path/to/file
if torch_prof_remote_file_name:
self.remote_filenames.append(torch_prof_remote_file_name)
_, _, torch_prof_remote_file_name = parse_uri(torch_prof_remote_file_name)
if torch_prof_memory_remote_file_name:
self.remote_filenames.append(torch_prof_memory_remote_file_name)
_, _, torch_prof_memory_remote_file_name = parse_uri(torch_prof_memory_remote_file_name)
for handler in self._trace_handlers:
if isinstance(handler, JSONTraceHandler):
if handler.remote_file_name:
self.remote_filenames.append(handler.remote_file_name)
_, _, handler.remote_file_name = parse_uri(handler.remote_file_name)
if handler.merged_trace_remote_file_name:
self.remote_filenames.append(handler.merged_trace_remote_file_name)
_, _, handler.merged_trace_remote_file_name = parse_uri(handler.merged_trace_remote_file_name)
if sys_prof_cpu or sys_prof_memory or sys_prof_disk or sys_prof_net:
self._callbacks.append(
SystemProfiler(
profile_cpu=sys_prof_cpu,
profile_memory=sys_prof_memory,
profile_disk=sys_prof_disk,
profile_net=sys_prof_net,
stats_thread_interval_seconds=sys_prof_stats_thread_interval_seconds,
),
)
if torch_prof_memory_filename is not None:
if not (torch_prof_with_stack and torch_prof_record_shapes and torch_prof_profile_memory):
raise ValueError(
f'torch_prof_memory_filename is set. Generating the memory timeline graph requires all the three flags torch_prof_with_stack, torch_prof_record_shapes, and torch_prof_profile_memory to be true. Got torch_prof_with_stack={torch_prof_with_stack}, torch_prof_record_shapes={torch_prof_record_shapes}, torch_prof_profile_memory={torch_prof_profile_memory}',
)
log.info(
f'Memory profiling is enabled and uses {torch_prof_memory_filename} as the filename to generate the memory timeline graph. To disable the memory timeline graph generation, explicitly set torch_prof_memory_filename to None.',
)
else:
log.info(f'torch_prof_memory_filename is set to None. Memory timeline will not be be generated.')
if torch_prof_record_shapes or torch_prof_profile_memory or torch_prof_with_stack or torch_prof_with_flops:
self._callbacks.append(
TorchProfiler(
filename=torch_prof_filename,
folder=torch_prof_folder,
remote_file_name=torch_prof_remote_file_name,
memory_filename=torch_prof_memory_filename,
memory_remote_file_name=torch_prof_memory_remote_file_name,
num_traces_to_keep=torch_prof_num_traces_to_keep,
overwrite=torch_prof_overwrite,
record_shapes=torch_prof_record_shapes,
profile_memory=torch_prof_profile_memory,
use_gzip=torch_prof_use_gzip,
with_stack=torch_prof_with_stack,
with_flops=torch_prof_with_flops,
),
)
[docs] def bind_to_state(
self,
state: State,
):
"""Bind the profiler to the ``state``.
.. note::
The :class:`.Trainer` automatically invokes this method.
Args:
state (State): The training state.
"""
self.state = state
self.state.callbacks.append(self)
self.state.callbacks.extend(self._callbacks)
self.state.callbacks.extend(self._trace_handlers)
@property
def trace_handlers(self):
"""Profiler trace handlers."""
return self._trace_handlers
@trace_handlers.setter
def trace_handlers(self, trace_handlers: Union[TraceHandler, Sequence[TraceHandler]]):
"""Profiler trace handlers."""
self._trace_handlers[:] = ensure_tuple(trace_handlers)
[docs] def record_chrome_json_trace_file(self, filepath: Union[str, pathlib.Path]):
"""Record trace events in Chrome JSON format in the trace handlers.
See `this document <https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview>`_
for more information about Chrome JSON format.
.. note::
For custom profiling, it is recommended to use :meth:`marker` instead of manually creating a Chrome JSON
trace file. By default, the Composer Profiler will automatically saving :class:`.Marker` events in Chrome
JSON format.
This method exists for external profilers that natively record events in Chrome JSON format (such as the
:class:`~composer.profiler.torch_profiler.TorchProfiler`). These profilers can use this method to route
their profiling traces to the Composer profiler :attr:`~trace_handlers` so events from both the Composer
Profiler and external profilers are recorded in the same trace file.
"""
for recorder in self.trace_handlers:
recorder.process_chrome_json_trace_file(pathlib.Path(filepath))
[docs] def marker(
self,
name: str,
actions: Sequence[ProfilerAction] = (
ProfilerAction.WARMUP,
ProfilerAction.ACTIVE,
ProfilerAction.ACTIVE_AND_SAVE,
),
record_instant_on_start: bool = False,
record_instant_on_finish: bool = False,
categories: Union[list[str], tuple[str, ...]] = (),
) -> Marker:
"""Create and get an instance of a :class:`.Marker`.
If a :class:`.Marker` with the specified ``name`` does not already exist, it will be created.
Otherwise, the existing instance will be returned.
.. note::
:meth:`.Profiler.marker()` should be used to construct markers. :class:`.Marker` **should not** be
instantiated directly by the user.
For example:
.. testsetup:: composer.profiler.profiler.Profiler.marker
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
state.profiler = profiler
.. doctest:: composer.profiler.profiler.Profiler.marker
>>> marker = profiler.marker("foo")
>>> marker
<composer.profiler.marker.Marker object at ...>
Please see :meth:`.Marker.start()` and :meth:`.Marker.finish()` for usage on creating markers to measure duration events,
:meth:`.Marker.instant()` for usage on creating markers to mark instant events and :meth:`.Marker.counter()` for usage on
creating markers for counting.
Args:
name (str): The name for the :class:`.Marker`.
actions (Sequence[ProfilerAction], optional): :class:`.ProfilerAction` states to record on.
Defaults to (:attr:`~.ProfilerAction.WARMUP`, :attr:`~.ProfilerAction.ACTIVE`,
:attr:`~.ProfilerAction.ACTIVE_AND_SAVE`).
record_instant_on_start (bool, optional): Whether to record an instant event whenever the marker is started.
Defaults to ``False``.
record_instant_on_finish (bool, optional): Whether to record an instant event whenever the marker is finished.
Defaults to ``False``.
categories (Union[list[str], tuple[str, ...]], optional): Categories for this marker. Defaults to ``None``.
Returns:
Marker: Marker instance.
"""
if self.state is None:
raise RuntimeError('Profiler.bind_to_state() must be invoked before the Profiler can be used.')
if name not in self._names_to_markers:
def should_record(state: State) -> bool:
return self.schedule(state) in actions
self._names_to_markers[name] = Marker(
state=self.state,
trace_handlers=self.trace_handlers,
name=name,
should_record=should_record,
record_instant_on_start=record_instant_on_start,
record_instant_on_finish=record_instant_on_finish,
categories=categories,
)
self._names_to_markers[name].categories = categories
return self._names_to_markers[name]
def after_load(self, state: State, logger: Logger) -> None:
del logger
self.resumption_batch_idx = int(state.timestamp.batch_in_epoch)