TorchProfiler
- class composer.callbacks.torch_profiler.TorchProfiler(*, tensorboard_trace_handler_dir: str = 'torch_profiler', tensorboard_use_gzip: bool = False, record_shapes: bool = True, profile_memory: bool = False, with_stack: bool = True, with_flops: bool = True, skip: int = 0, warmup: int = 1, active: int = 5, wait: int = 0, repeat: int = 0)[source]
Bases:
composer.core.callback.Callback
Profile the execution using
torch.profiler.profile
.Profiling results are stored in TensorBoard format in the :param tensorboard_trace_handler_dir: folder.
To view profiling results, run:
tensorboard --logdir tensorboard_trace_handler_dir
Also see https://pytorch.org/docs/stable/profiler.html.
- Parameters
tensorboard_trace_handler_dir (str) – Directory to store trace results. Relative to the run_directory. Defaults to torch_profiler in the run directory.
tensorboard_use_gzip (bool, optional) – Whether to use gzip for the trace. Defaults to False.
record_shapes (bool, optional) – Whether to record tensor shapes. Defaults to True.
profile_memory (bool, optional) – Whether to profile memory. Defaults to False.
with_stack (bool, optional) – Whether to record stack info. Defaults to True.
with_flops (bool, optional) – Whether to estimate flops for operators. Defaults to True.
skip (int, optional) – Number of batches to skip at epoch start. Defaults to 0.
warmup (int, optional) – Number of warmup batches in a cycle. Defaults to 1.
active (int, optional) – Number of batches to profile in a cycle. Defaults to 5.
wait (int, optional) – Number of batches to skip at the end of each cycle. Defaults to 0.
- load_state_dict(state: composer.core.types.StateDict) None [source]
Restores the state of the object.
- Parameters
state (StateDict) – The state of the object, as previously returned by
state_dict()