TorchProfiler

class composer.callbacks.torch_profiler.TorchProfiler(*, tensorboard_trace_handler_dir: str = 'torch_profiler', tensorboard_use_gzip: bool = False, record_shapes: bool = True, profile_memory: bool = False, with_stack: bool = True, with_flops: bool = True, skip: int = 0, warmup: int = 1, active: int = 5, wait: int = 0, repeat: int = 0)[source]

Bases: composer.core.callback.Callback

Profile the execution using torch.profiler.profile.

Profiling results are stored in TensorBoard format in the :param tensorboard_trace_handler_dir: folder.

To view profiling results, run:

tensorboard --logdir tensorboard_trace_handler_dir

Also see https://pytorch.org/docs/stable/profiler.html.

Parameters
  • tensorboard_trace_handler_dir (str) – Directory to store trace results. Relative to the run_directory. Defaults to torch_profiler in the run directory.

  • tensorboard_use_gzip (bool, optional) – Whether to use gzip for the trace. Defaults to False.

  • record_shapes (bool, optional) – Whether to record tensor shapes. Defaults to True.

  • profile_memory (bool, optional) – Whether to profile memory. Defaults to False.

  • with_stack (bool, optional) – Whether to record stack info. Defaults to True.

  • with_flops (bool, optional) – Whether to estimate flops for operators. Defaults to True.

  • skip (int, optional) – Number of batches to skip at epoch start. Defaults to 0.

  • warmup (int, optional) – Number of warmup batches in a cycle. Defaults to 1.

  • active (int, optional) – Number of batches to profile in a cycle. Defaults to 5.

  • wait (int, optional) – Number of batches to skip at the end of each cycle. Defaults to 0.

batch_end(state: State, logger: Logger) None[source]

Called on the BATCH_END event.

Parameters
  • state (State) – The global state.

  • logger (Logger) – The logger.

batch_start(state: State, logger: Logger) None[source]

Called on the BATCH_START event.

Parameters
  • state (State) – The global state.

  • logger (Logger) – The logger.

epoch_start(state: State, logger: Logger) None[source]

Called on the EPOCH_START event.

Parameters
  • state (State) – The global state.

  • logger (Logger) – The logger.

load_state_dict(state: composer.core.types.StateDict) None[source]

Restores the state of the object.

Parameters

state (StateDict) – The state of the object, as previously returned by state_dict()

state_dict() composer.core.types.StateDict[source]

Returns a dictionary representing the internal state.

The returned dictionary must be pickale-able via torch.save().

Returns

StateDict – The state of the object

training_start(state: State, logger: Logger) None[source]

Called on the Event.TRAINING_START event.

Parameters
  • state (State) – The global state.

  • logger (Logger) – The logger.