Source code for composer.profiler.json_trace_merger

# Copyright 2021 MosaicML. All Rights Reserved.

"""Merge trace files together.

To run:

.. code-block::

    python -m composer.profiler.json_trace_merger -o merged_trace_output.json path/to/input_file_1.json path/to/input_file_2.json ...

To view the traces, open a Google Chrome browser window, navigate to ``chrome://tracing`` and load the ``merged_trace_output.json``
to visualize the trace.
"""
import argparse
import json
import pathlib
from typing import Dict, List, Tuple, Union

__all__ = ["merge_traces"]


def _load_trace(file: Union[str, pathlib.Path]) -> Union[Dict, List]:
    with open(file, "r") as f:
        trace_str = f.read().strip()
        if trace_str.startswith("["):
            if trace_str.endswith("}"):
                trace_str += "]"
            if trace_str.endswith(","):
                trace_str = trace_str[-1] + "]"
        return json.loads(trace_str)


def _get_global_rank_from_file(file: Union[str, pathlib.Path]) -> int:
    trace_json = _load_trace(file)
    if isinstance(trace_json, list):
        for event in trace_json:
            if event["ph"] == "M" and event["name"] == "global_rank":
                return event["args"]["value"]
    else:
        assert isinstance(trace_json, dict)
        return trace_json["global_rank"]
    raise RuntimeError("global rank not found in file")


def _get_rank_to_clock_syncs(trace_files: Tuple[Union[str, pathlib.Path], ...]) -> Dict[int, int]:
    rank_to_clock_sync: Dict[int, int] = {}
    for filename in trace_files:
        rank = _get_global_rank_from_file(filename)
        trace_json = _load_trace(filename)
        if isinstance(trace_json, list):
            for event in trace_json:
                if event["ph"] == "M" and event["name"] == "clock_sync_timestamp_us":
                    clock_sync = event["args"]["value"]
                    rank_to_clock_sync[rank] = clock_sync
                    break
        else:
            assert isinstance(trace_json, dict)
            if trace_json.get("clock_sync_timestamp_us") is not None:
                rank_to_clock_sync[rank] = trace_json["clock_sync_timestamp_us"]

    return rank_to_clock_sync


[docs]def merge_traces(output_file: Union[str, pathlib.Path], *trace_files: Union[str, pathlib.Path]): """Merge profiler output JSON trace files together. This function will update the trace events such that: - The ``pid`` will be set to the global rank. - The ``ts`` is syncronized with that of the rank 0 process. - The backward pass process appears below the forward process. Args: output_file (str | pathlib.Path): The file to write the merged trace to trace_files (str | pathlib.Path): Variable number of trace files to merge together """ ranks_to_clock_sync = _get_rank_to_clock_syncs(trace_files) rank_to_backwards_thread = {} rank_to_seen_threads = {rank: set() for rank in ranks_to_clock_sync.keys()} rank_zero_clock_sync = ranks_to_clock_sync[0] with open(output_file, "w+") as output_f: is_first_line = True output_f.write("[") for trace_filename in trace_files: rank = _get_global_rank_from_file(trace_filename) clock_sync_diff = rank_zero_clock_sync - ranks_to_clock_sync[rank] with open(trace_filename, "r") as trace_f: trace_data = json.load(trace_f) if isinstance(trace_data, list): trace_list = trace_data else: assert isinstance(trace_data, dict) trace_list = trace_data["traceEvents"] for event in trace_list: if "pid" not in event: # we need the pid to merge continue if "tid" not in event: continue if "PyTorch Profiler" in str(event["tid"]): # skip this line; it pollutes the UI continue if "ts" in event: event["ts"] = event["ts"] + clock_sync_diff event["pid"] = rank if event["tid"] not in rank_to_seen_threads[rank]: # By default, make all threads display last # The training loop thread later sets itself as thread 0 # and the backwards pass thread is set as thread 1 if not is_first_line: output_f.write(",") output_f.write("\n ") json.dump( { "name": "thread_sort_index", "ph": "M", "pid": rank, "tid": event["tid"], "args": { "sort_index": 99999, } }, output_f) rank_to_seen_threads[rank].add(event['tid']) is_first_line = False if event["name"] == "MulBackward0": rank_to_backwards_thread[rank] = event["tid"] if not is_first_line: output_f.write(",") is_first_line = False output_f.write(f"\n ") json.dump(event, output_f) for pid, tid in rank_to_backwards_thread.items(): output_f.write(",\n ") json.dump({ "name": "thread_sort_index", "ph": "M", "pid": pid, "tid": tid, "args": { "sort_index": 1 } }, output_f) output_f.write("\n]\n")
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('input_files', type=str, nargs='+', help='Input files') parser.add_argument("-o", "--output_file", help="Output File", required=True) args = parser.parse_args() output_file = args.output_file input_files = args.input_files merge_traces(output_file, *input_files)