โฑ๏ธ Performance Profiling#
Introduction#
The Composer Profiler enables practitioners to collect, analyze, and visualize performance metrics during training which can be used to diagnose bottlenecks and facilitate model development.
The profiler enables users to capture the following metrics:
Duration of each Event, Callback, and Algorithm during training
Time taken by the data loader to return a batch
Host metrics such as CPU, system memory, disk, and network utilization over time
Execution order, latency, and attributes of PyTorch operators and GPU kernels (see
torch.profiler
)
This tutorial will demonstrate how to to setup and configure profiling, as well as capture and visualize performance traces.
Getting Started#
In this tutorial, we will build a simple training application called profiler_demo.py
using the MNIST dataset and
Classifier model with the Composer Trainer.
Setup#
Install Composer, if it is not yet already installed.
pip install mosaicml
Steps#
Import required modules
Instantiate the dataset and model
Instantiate the
Trainer
and configure the ProfilerRun training with profiling
View and analyze traces
Import required modules#
In this example we will use torch.utils.data.DataLoader
with the MNIST
dataset
from torchvision
. From composer
, we will import the Profiler
, the mnist_model()
model and the Trainer
object.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from composer import Trainer
from composer.models.tasks import ComposerClassifier
from composer.profiler import JSONTraceHandler, cyclic_schedule
from composer.profiler.profiler import Profiler
Instantiate the dataset and model#
Next we instantiate the dataset, dataloader, and model.
# Specify Dataset and Instantiate DataLoader
batch_size = 2048
data_directory = '~/datasets'
mnist_transforms = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(data_directory, train=True, download=True, transform=mnist_transforms)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=True,
pin_memory=True,
persistent_workers=True,
num_workers=8,
)
Instantiate the Trainer and configure profiling#
To enable profiling, construct a Profiler
, and pass that to the Trainer.
The trace_handlers
and schedule
are the only required arguments; all others are optional.
Here, we configure following profiling options:
Set the
trace_handlers
to store Composer Profiler traces in the'composer_profiler'
folderSet the profiling window via
schedule
Set the
torch_prof_folder
to store Torch Profiler traces in the'torch_profiler'
folderLimit the duration of the training run to keep the size of the
profiler_trace_file
manageable
# Instantiate the trainer
composer_trace_dir = 'composer_profiler'
torch_trace_dir = 'torch_profiler'
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=train_dataloader,
max_duration=2,
device='gpu' if torch.cuda.is_available() else 'cpu',
eval_interval=0,
precision='amp' if torch.cuda.is_available() else 'fp32',
train_subset_num_batches=16,
profiler=Profiler(
trace_handlers=[JSONTraceHandler(folder=composer_trace_dir, overwrite=True)],
schedule=cyclic_schedule(
wait=0,
warmup=1,
active=4,
repeat=1,
),
torch_prof_folder=torch_trace_dir,
torch_prof_overwrite=True,
torch_prof_memory_filename=None,
),
)
Note, we support both local and object store paths for the composer profiler, e.g:
:language: python
profiler = Profiler(
trace_handlers=[JSONTraceHandler(remote_file_name='oci://your-bucket/composer_profiler/')],
torch_remote_filename='s3://your-bucket/torch_profiler/',
torch_prof_memory_filename=None,
...
)
Specifying the Profile Schedule#
When setting up profiling, it is important to specify the profiling schedule via the schedule
argument.
This schedule determines the profilerโs recording behavior. The schedule is a function that takes the training
State
and returns a ProfilerAction
.
For convenience, the Composer Profiler includes a cyclic_schedule()
which configures a cyclic profiling window
that repeats each epoch. It takes the following arguments:
skip_first
: Number of steps to offset the window relative to the start of the epoch.wait
: Start of the window, number of steps to skip recording relative to the stat of the profiling window.warmup
: Number of steps to start tracing but discard the results (PyTorch profiler only).active
: Number of steps the profiler is active and recording data. The end of the last step demarcates the end of the window.repeat
: Number of consecutive times the profiling window is repeated per epoch.
The profiling window for an epoch is defined as: wait
+ warmup
+ active
, while skip_first
and repeat
control
profiler behavior preceding and after the window, respectively.
Warning
Profiling incurs additional overhead that can impact the performance of the workload. This overhead is fairly
minimal for the various profilers with the exception of the PyTorch profiler. However, the relative duration of
recorded events will remain consistent in all states except warmup
, which incurs a transient profiler initialization
penalty. Thus, trace data is discarded for these steps.
For example, letโs assume the profiling options are set as follows:
skip_first=1, wait=1, warmup=1, active=2, repeat=1
Given the configuration above, profiling will be performed as follows:
Epoch |
Batch |
Profiler State |
Profiler Action |
---|---|---|---|
0 |
0 |
skip_first |
Do not record |
1 |
wait |
Do not record |
|
2 |
warmup |
Record, Torch Profiler does not record |
|
3 |
active |
Record |
|
4 |
active |
Record |
|
5 |
wait |
Do not record |
|
6 |
warmup |
Record, Torch Profiler does not record |
|
7 |
active |
Record |
|
8 |
active |
Record |
|
9 |
disabled |
Do not record |
|
โฆ |
|||
1 |
0 |
skip_first |
Do not record |
1 |
wait |
Do not record |
|
2 |
warmup |
Record, Torch Profiler does not record |
|
3 |
active |
Record |
|
4 |
active |
Record |
|
5 |
wait |
Do not record |
|
6 |
warmup |
Record, Torch Profiler does not record |
|
7 |
active |
Record |
|
8 |
active |
Record |
|
9 |
disabled |
Do not record |
|
โฆ |
As we can see above, the profiler skips the first batch of each epoch and is in the wait state during the following batch, after which the profiler performs warms up in the next batch and actively records trace data for the following two batches. The window is repeated once more in the epoch, and the pattern continues for the duration of the training run.
Limiting the scope of the training run#
Due to the additional overhead incurred by profiling, it is not usually practical to enable profiling for a full
training run. In this example, we limit the duration of the profiling run by specifying max_duration=2
epochs
and limit the number of batches in each epoch by setting train_subset_num_batches=16
to capture performance data
within a reasonable amount of time and limit the size of the trace file.
Since prof_warmup=1
, prof_active=4
, prof_repeat=1
, and prof_repeat=1
, we will record profiling data for 10
batches each epoch, starting with batch 0 (no offset since prof_skip_first=0
and prof_wait=0
).
Additionally, since we are only concerned with profiling during training, we disable validation by setting
eval_interval="0ep"
.
Run training with profiling#
Lastly, we run the training loop by invoking Trainer.fit()
.
# Run training
trainer.fit()
Finally, we can run the application as follows on a single GPU:
python examples/profiler_demo.py
Or, we can profile on multiple GPUs:
composer -n N_GPUS examples/profiler_demo.py # set N_GPUS to the number of GPUs
Viewing traces#
Once the training loop is complete, you should see the following traces
> ls composer_profiler/
... ep0-ba5-rank0.json ep1-ba21-rank0.json merged_trace.json
> ls torch_profiler/
... rank0.21.pt.trace.json rank0.5.pt.trace.json
The trace files within the composer_profiler
folder contain all timing information and metrics collected during the
profiling run. One file is generated per profiling cycle. The file named composer_profiler/node0.json
contains all
trace files merged together. Each file contains all profiler metrics, including:
The durations of Algorithms/Callbacks/Events
The latency of data loader
System host metrics
Torch Profiler events, such as kernel execution times
The trace files within the torch_profiler
folder contain the raw trace files as generated by the PyTorch profiler.
They do not include the Composer Profiler metrics, such as event duration, dataloader latency, or system host metrics.
Viewing traces in Chrome Trace Viewer#
All traces can be viewed using the Chrome Trace Viewer. To launch, open a Chrome browser session and
navigate to chrome://tracing
in the address bar.
In the following example, we load the composer_profiler/node0.json
file which contains the unified trace data.
Open the trace by clicking the โLoadโ button and selecting the composer_profiler/node0.json
file. Depending on the
size of the trace, it could take a moment to load. After the trace has been loaded, you will see a complete trace
capture as follows:
The Trace Viewer provides users the ability to navigate the trace and interact with individual events and analyze key attributes if the information has been recorded. For more details on using and interacting with the Trace Viewer, please see the Chromium How-To.
Viewing standalone Torch Profiler traces#
The Torch Profiler traces found in the torch_profiler
area can also be viewed using
Tensorboard or using the
VSCode Tensorboard extension.
To view the Torch Profiler traces in TensorBoard, run:
pip install tensorboard torch_tb_profiler
tensorboard --logdir torch_profiler
Viewing composer_profiler
traces in TensorBoard is not currently supported.