# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Outputs profiling data in JSON trace format."""
from __future__ import annotations
import gzip
import json
import os
import pathlib
import queue
import tempfile
import textwrap
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from composer.loggers import Logger
from composer.profiler.json_trace_merger import merge_traces
from composer.profiler.profiler_action import ProfilerAction
from composer.profiler.trace_handler import TraceHandler
from composer.utils import (
FORMAT_NAME_WITH_DIST_AND_TIME_TABLE,
FORMAT_NAME_WITH_DIST_TABLE,
dist,
ensure_folder_is_empty,
format_name_with_dist,
format_name_with_dist_and_time,
)
if TYPE_CHECKING:
from composer.core import State, Timestamp
__all__ = ['JSONTraceHandler']
[docs]class JSONTraceHandler(TraceHandler): # noqa: D101
__doc__ = f"""Records trace events in Chrome JSON trace format.
See `this document <https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview>`_
for more information.
Traces are output to ``output_directory``. Traces can be visualized using the Chrome Trace Viewer.
To view in a Google Chrome browser, navigate to ``chrome://tracing`` and load the JSON trace file.
Args:
folder (str, optional): Format string for the trace file folder. Defaults to ``'{{run_name}}/traces'``.
The following format variables are available:
{textwrap.indent(FORMAT_NAME_WITH_DIST_TABLE, prefix=' ')}
For example, if the ``run_name`` is ``'awesome_training_run'``, and the default ``folder`` of
``'{{run_name}}/traces'`` is used, traces will be stored in ``'awesome_training_run/traces'``.
filename (str, optional): A format string describing how to name trace files.
(default: ``'ep{{epoch}}-ba{{batch}}-rank{{rank}}.json'``)
At the end of each batch where :meth:`~composer.profiler.Profiler.get_action` returns
:attr:`~composer.profiler._profiler_action.ProfilerAction.ACTIVE_AND_SAVE`, trace files are saved
approximately to ``{{folder}}/{{filename.format(...)}}``.
The following format variables are available:
{textwrap.indent(FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, prefix=' ')}
Consider the following scenario, where:
* The :attr:`~.State.run_name` is ``'awesome-training-run'``
* The default ``trace_folder='{{run_name}}/traces'`` is used.
* The default ``name='ep{{epoch}}-ba{{batch}}-rank{{rank}}.json'`` is used.
* The current epoch count is ``1``.
* The current batch count is ``42``.
Each rank (process) will save traces to::
awesome-training-run/traces/ep1-ba42-rank0.json
awesome-training-run/traces/ep1-ba42-rank1.json
awesome-training-run/traces/ep1-ba42-rank2.json
...
remote_file_name (str, optional): Format string for the trace file's remote name.
(default: ``'{{run_name}}/traces/ep{{epoch}}-ba{{batch}}-rank{{rank}}.json'``)
Whenever a trace file is saved, it is also uploaded as a remote file according to this format string.
The same format variables as for ``filename`` are available.
.. seealso:: :doc:`Uploading Files</trainer/file_uploading>` for notes for file uploading.
Leading slashes (``'/'``) will be stripped.
To disable uploading trace files, set this parameter to ``None``.
merged_trace_filename (str, optional): Format string for the merged trace filename.
(default: ``'node{{node_rank}}.json'``)
Each rank writes a separate trace file at the end of each profiling cycle. However, when visualizing
traces, it is generally helpful to merge traces together into a single file. This allows the traces
across all ranks to be shown in a single view. To
The same format variables as for ``folder`` are available. The merged trace file is saved
approximately to ``{{folder}}/{{merged_trace_filename.format(...)}}`` on the local rank zero
process for each node.
If specified (the default), the local rank zero process merges together all traces files from that node,
across all profiling cycles, into a single trace file. The merged trace file is written to the filename
specified by this format string. There will be one merged trace file per node.
To disable merging, set this parameter to ``None``.
.. warning::
Trace merging blocks the training loop. When profiling live training runs, it is recommended to
disable trace merging by setting this parameter to ``None``. Instead, traces should be merged together
in a post-processing step. See :mod:`composer.profiler.json_trace_merger` for additional info.
merged_trace_remote_file_name (str, optional): Format string for the merged trace file's remote file name.
(default: ``'{{run_name}}/traces/merged_trace.json'``)
The same format variables as for ``folder`` are available.
This parameter has no effect if ``merged_trace_filename`` is None.
To disable uploading merged trace files, set this parameter to ``None``.
overwrite (bool, optional): Whether to overwrite existing traces. (default: ``False``)
If ``False``, the :meth:`trace_folder` (as determined by the ``trace_folder`` argument)
must be empty when training starts.
num_traces_to_keep (int, optional): The number of traces to keep locally. The oldest traces
are removed first. Set to ``-1`` to keep all traces locally. (default: ``-1``)
Traces will be removed after they have been uploaded. For example, when this handler
is used in conjunction with the :class:`.RemoteUploaderDownloader`, set this
parameter to ``0`` to immediately delete traces from the local disk after they have been uploaded to
the object store.
This parameter only controls how many traces are kept locally; traces are not deleted from
remote file systems.
Attributes:
saved_traces (List[Tuple[Timestamp, List[pathlib.Path]]]): The trace timestamps and filepaths.
This list contains tuples of the save timestamp and the trace filepaths.
This list will have at most ``save_num_traces_to_keep`` entries. The latest trace
will be at the end.
The index of a filepath in each list corresponds to the global rank of the process that wrote that file.
Each filepath is valid only on the process's (rank's) node.
"""
def __init__(
self,
folder: str = '{run_name}/traces',
filename: str = 'ep{epoch}-ba{batch}-rank{rank}.json',
remote_file_name: Optional[str] = '{run_name}/traces/ep{epoch}-ba{batch}-rank{rank}.json',
merged_trace_filename: Optional[str] = 'merged_trace.json',
merged_trace_remote_file_name: Optional[str] = '{run_name}/traces/merged_trace.json',
*,
overwrite: bool = False,
num_traces_to_keep: int = -1,
):
self.folder = folder
self.overwrite = overwrite
self.filename = filename
self.remote_file_name = remote_file_name
self.merged_trace_filename = merged_trace_filename
self.merged_trace_remote_file_name = merged_trace_remote_file_name
self.saved_traces: List[Tuple[Timestamp, List[pathlib.Path]]] = []
self.num_traces_to_keep = num_traces_to_keep
self._queue: queue.Queue[str] = queue.Queue()
self._is_trace_active = False
self._save_at_batch_end = False
def init(self, state: State, logger: Logger) -> None:
del logger # unused
trace_folder = format_name_with_dist(self.folder, run_name=state.run_name)
os.makedirs(trace_folder, exist_ok=True)
if not self.overwrite:
ensure_folder_is_empty(trace_folder)
# Ensure all ranks checked that the folder is empty before proceeding
# remove any existing merged trace file
if self.merged_trace_filename is not None:
merged_trace_filename = os.path.join(
trace_folder,
format_name_with_dist(self.merged_trace_filename, state.run_name),
)
merged_trace_dirname = os.path.dirname(merged_trace_filename)
if merged_trace_dirname:
if os.path.exists(merged_trace_filename):
os.remove(merged_trace_filename)
dist.barrier()
def batch_start(self, state: State, logger: Logger) -> None:
del logger # unusued
if state.profiler is None:
raise RuntimeError((
'The Composer Profiler was not enabled, which is required to use the '
f'{type(self).__name__}. To enable, set the `prof_schedule` argument of the Trainer.'
))
if state.profiler.schedule(state) != ProfilerAction.SKIP and not self._is_trace_active:
# Starting a new profiling cycle
wall_clock_ns = time.time_ns()
self._record_event(
name='process_name',
ph='M', # metadata
wall_clock_ns=wall_clock_ns,
tid=os.getpid(),
pid=dist.get_global_rank(),
args={'name': f'Rank {dist.get_global_rank()} training loop process'},
)
self._record_event(
name='thread_name',
ph='M', # metadata
wall_clock_ns=wall_clock_ns,
tid=os.getpid(),
pid=dist.get_global_rank(),
args={'name': f'Training Loop'},
)
self._record_event(
name='thread_sort_index',
ph='M', # metadata
wall_clock_ns=wall_clock_ns,
tid=os.getpid(),
pid=dist.get_global_rank(),
args={'sort_index': 0},
) # training loop thread should be first
self._record_event(
name='global_rank',
ph='M', # metadata
wall_clock_ns=wall_clock_ns,
tid=os.getpid(),
pid=dist.get_global_rank(),
args={'value': dist.get_global_rank()},
)
self._record_event(
name='process_sort_index',
ph='M', # metadata
wall_clock_ns=wall_clock_ns,
tid=os.getpid(),
pid=dist.get_global_rank(),
args={'sort_index': dist.get_global_rank()},
) # sort index for processes should be the global rank
# Synchronize the clocks
# Each rank will record a timestamp at approxmately the same real world time
clock_sync_a = time.time_ns()
dist.barrier() # syncronize all ranks
clock_sync_time_ns = time.time_ns()
dist.barrier() # another barrier to bound the error
clock_sync_b = time.time_ns()
clock_sync_error_bound = clock_sync_b - clock_sync_a
self._record_event(
name='clock_sync_timestamp_us',
ph='M', # metadata
wall_clock_ns=wall_clock_ns,
tid=os.getpid(),
pid=dist.get_global_rank(),
args={'value': clock_sync_time_ns // 1000},
)
self._record_event(
name='clock_sync_error_bound',
ph='M', # metadata
wall_clock_ns=wall_clock_ns,
tid=os.getpid(),
pid=dist.get_global_rank(),
args={'value': clock_sync_error_bound // 1000},
)
self._is_trace_active = True
if state.profiler.schedule(state) == ProfilerAction.ACTIVE_AND_SAVE:
self._save_at_batch_end = True
def batch_end(self, state: State, logger: Logger) -> None:
assert state.profiler is not None
timestamp = state.timestamp
trace_folder = format_name_with_dist(self.folder, run_name=state.run_name)
if self._save_at_batch_end:
# no longer active, but was previously active.
# Epty the queue and save the trace file
trace_filename = os.path.join(
trace_folder,
format_name_with_dist_and_time(self.filename, state.run_name, timestamp),
)
trace_dirname = os.path.dirname(trace_filename)
if trace_dirname:
os.makedirs(trace_dirname, exist_ok=True)
with open(trace_filename, 'w+') as f:
is_first_line = True
f.write('[\n')
while True:
try:
s = self._queue.get_nowait()
except queue.Empty:
break
if not is_first_line:
s = ',\n' + s
is_first_line = False
f.write(s)
f.write('\n]\n')
if self.remote_file_name is not None:
remote_file_name = format_name_with_dist_and_time(self.remote_file_name, state.run_name, timestamp)
logger.upload_file(
remote_file_name=remote_file_name,
file_path=trace_filename,
overwrite=self.overwrite,
)
# Gather the filenames
trace_files = [pathlib.Path(x) for x in dist.all_gather_object(trace_filename)]
self.saved_traces.append((timestamp, trace_files))
# Ensure that all traces have been saved.
dist.barrier()
if self.merged_trace_filename is not None and dist.get_local_rank() == 0:
# Merge together all traces from the node into one file
start_rank = dist.get_global_rank()
end_rank = dist.get_global_rank() + dist.get_local_world_size()
trace_files_to_merge = trace_files[start_rank:end_rank]
merged_trace_filename = os.path.join(
trace_folder,
format_name_with_dist(
self.merged_trace_filename,
state.run_name,
),
)
merged_trace_dirname = os.path.dirname(merged_trace_filename)
if merged_trace_dirname:
os.makedirs(merged_trace_dirname, exist_ok=True)
if os.path.exists(merged_trace_filename):
# Include the existing merged trace in the new trace
with tempfile.NamedTemporaryFile('x+', delete=False) as f:
merge_traces(f.name, merged_trace_filename, *trace_files_to_merge)
os.rename(f.name, merged_trace_filename)
else:
# Write the trace directly
merge_traces(merged_trace_filename, *trace_files_to_merge)
if self.merged_trace_remote_file_name is not None:
merged_trace_remote_file_name = format_name_with_dist(
self.merged_trace_remote_file_name,
state.run_name,
)
logger.upload_file(
remote_file_name=merged_trace_remote_file_name,
file_path=merged_trace_remote_file_name,
overwrite=True,
)
# delete old trace files
if self.num_traces_to_keep >= 0:
while len(self.saved_traces) > self.num_traces_to_keep:
timestamp, checkpoint_filepaths = self.saved_traces[0]
if dist.get_global_rank() < len(checkpoint_filepaths):
# Remove this rank's trace
os.remove(checkpoint_filepaths[dist.get_global_rank()])
del self.saved_traces[0]
self._is_trace_active = False
self._save_at_batch_end = False
def process_duration_event(
self,
name: str,
categories: Union[List[str], Tuple[str, ...]],
is_start: bool,
timestamp: Timestamp,
wall_clock_time_ns: int,
) -> None:
ph = 'B' if is_start else 'E'
args = {}
args['epoch'] = timestamp.epoch.value
args['batch'] = timestamp.batch.value
self._record_event(
name=name,
categories=','.join(categories),
ph=ph,
wall_clock_ns=wall_clock_time_ns,
pid=dist.get_global_rank(),
args=args,
tid=os.getpid(),
)
def process_instant_event(
self,
name: str,
categories: Union[List[str], Tuple[str, ...]],
timestamp: Timestamp,
wall_clock_time_ns: int,
) -> None:
args = {}
args['epoch'] = timestamp.epoch.value
args['batch'] = timestamp.batch.value
self._record_event(
name=name,
categories=','.join(categories),
ph='i',
wall_clock_ns=wall_clock_time_ns,
args=args,
pid=dist.get_global_rank(),
tid=os.getpid(),
s='p', # mark instant event for at process level
)
def process_counter_event(
self,
name: str,
categories: Union[List[str], Tuple[str, ...]],
timestamp: Timestamp,
wall_clock_time_ns: int,
values: Dict[str, Union[int, float]],
) -> None:
self._record_event(
name=name,
categories=','.join(categories),
ph='C', # counter event
wall_clock_ns=wall_clock_time_ns,
pid=dist.get_global_rank(),
tid=os.getpid(),
args=values,
)
def _record_event(self, name: str, ph: str, wall_clock_ns: int, pid: int, tid: int, categories: str = '', **kwargs):
"""Helper function to record an event in the trace.
Args:
name (str): Event name
categories (str): Comma-seperated string of event categories
ph (str): Event type. Should be one of the following
Duration Events: ``B`` (begin), ``E`` (end)
Complete Events: ``X``
Instant Events: ``i``
Counter Events: ``C``
Async Events: ``b`` (nestable start), ``n`` (nestable instant), ``e`` (nestable end)
Flow events: ``s`` (start), ``t`` (step), ``f`` (end)
Sample events: ``P``
Object Events ``N`` (created), ``O`` (snapshot), ``D`` (destroyed)
Metadata Events: ``M``
Memory Dump Events: ``V`` (global), ``v`` (process)
Mark Events: ``R``
Clock Sync Events ``c``
wall_clock_ns (int): Wall clock time, in nanoseconds.
tid (int): :meth:`threading.get_ident` value for the event
pid (int): :meth:`os.get_pid` value for the event
kwargs: Any extra info to record with the event, such as event specific fields.
"""
data = {
'name': name,
'cat': categories,
'ph': ph,
'ts': wall_clock_ns // 1000, # tracing clock timestamp, in microseconds
'pid': pid,
'tid': tid,
**kwargs,
}
entry = json.dumps(data, indent=None)
self._queue.put_nowait(entry)
def process_chrome_json_trace_file(self, filepath: pathlib.Path) -> None:
with (gzip.open(filepath, 'rt') if str(filepath).endswith('.gz') else open(filepath, 'r')) as f:
# It may be an incomplete trace file that is missing the closing ] bracket, as is permitted
# in the chrome json format spec
trace_data_str = f.read().strip()
if trace_data_str.startswith('[') and not trace_data_str.endswith(']'):
trace_data_str += ']'
trace_data = json.loads(trace_data_str)
if isinstance(trace_data, dict):
event_list = trace_data['traceEvents']
else:
event_list = trace_data
if not isinstance(event_list, list):
raise TypeError('A trace file should either be a dict or a list')
for entry in event_list:
entry['pid'] = dist.get_global_rank() # override the PID to the global rank
entry_s = json.dumps(entry, indent=None)
self._queue.put_nowait(entry_s)