๐จโ๐ฉโ๐งโ๐ฆ Distributed Training#
Composer supports distributed training on multiple devices, whether it be multiple GPUs on a single node or multiple GPUs across multiple nodes.
Data Parallelism#
Composer distributes work across devices via data-parallelism-only. We made this design choice in order to provide the most flexibility to algorithms, which can modify the training loop in complex ways. Data parallelism greatly simplifies model building and memory management. Every GPU is performing the same work, so inspecting the rank zero is sufficient to reason about memory, performance, and other properties.
Within Composer, we have two options for data-parallelism-only execution: Pytorch DDP (default) and Pytorch FSDP. Although Pytorch DDP is the default, Pytorch FSDP increases memory and computational efficiency when configured correctly while producing the same results and is the recommended option.
Usage#
To launch a multi-GPU training job, we provide the composer
launcher:
# run training on 8 GPUs
>>> composer -n 8 my_training_script.py
Under the hood, this script (source code
here)
sets the required torch.distributed
environment variables, launches
the processes, and runs the script on each process.
By default, only the rank zero logs will be sent to the console. To save the logs
from all the ranks, use --stdout
and --stderr
:
>>> composer -n 8 --stdout stdout_{rank}.log --stderr stderr_{rank}.log script.py
The stdout for each rank will then be available at stdout_1.log
, stdout_2.log
, and so forth.
The filename is customizable, see the command help for more details.
Alternatively, the logs can also be captured using our FileLogger
.
Note
The batch_size
passed to your dataloader should be the per-device
minibatch size. We further split this into smaller microbatches with
gradient accumulation.
For additional configurations of our launcher script, run composer --help
.
usage: composer [-h] [--version] [-n NPROC] [--stdout STDOUT]
[--stderr STDERR] [-v] [-m] [-c] [--world_size WORLD_SIZE]
[--base_rank BASE_RANK] [--node_rank NODE_RANK]
[--master_addr MASTER_ADDR] [--master_port MASTER_PORT]
training_script ...
Named Arguments#
- --version
show programโs version number and exit
- -n, --nproc
The number of processes to launch on this node. Overrides env var LOCAL_WORLD_SIZE if specified; otherwise, defaults to max(1, torch.cuda.device_count()).
- --stdout
Format string for a filename to dump the STDOUT from the non-local-rank-zero processes. The local rank zero process will be piped through to STDOUT. The available format variables are: โ{rank}โ, โ{local_rank}โ, โ{world_size}โ, โ{node_rank}โ, and โ{local_world_size}โ. If specified, it is recommended to include โ{rank}โ or โ{local_rank}โ in the filename so each rank will write to its own file. By default, the STDOUT of the non-local-rank-zero processes is discarded; instead, use the FileLogger within Composer. This logger captures and saves the STDOUT of each process.
- --stderr
Format string for a filename to dump the STDERR from the non-local-rank-zero processes. The local rank zero process will be piped through to STDERR. The available format variables are: โ{rank}โ, โ{local_rank}โ, โ{world_size}โ, โ{node_rank}โ, and โ{local_world_size}โ. If specified, it is recommended to include โ{rank}โ or โ{local_rank}โ in the filename so each rank will write to its own file. By default, the STDERR of the non-local-rank-zero processes is discarded; instead, use the FileLogger within Composer. This logger captures and saves the STDERR of each process.
- -v, --verbose
If set, print verbose messages
Default: False
- -m, --module_mode
If set, run the training script as a module instead of as a script. Cannot be used in conjunction with command_mode
Default: False
- -c, --command_mode
If set, run the training script as a command (i.e. without python). Cannot be used in conjunction with module_mode.
Default: False
required arguments#
- training_script
The path to the training script used to initialize a single training process. Should be followed by any command-line arguments the script should be launched with.
- training_script_args
Any arguments for the training script, given in the expected order.
multi-node arguments#
These arguments generally only need to be set when training in a multi-node environment, i.e. when the world_size is bigger than nproc.
- --world_size
The total number of processes to launch across all nodes. Setting this to a value greater than nproc indicates a multi-node environment. Overrides env var WORLD_SIZE. Defaults to nproc.
- --base_rank
The rank of the lowest ranked process to launch on this node. Specifying a base_rank B and an nproc N will spawn processes with global ranks [B, B+1, โฆ B+N-1]. In a multi-node environment, at least one of base_rank and node_rank must be specified. If only one of base_rank and node_rank are provided, it is assumed that all nodes have the same amount of processes, and that the two values are related as node_rank * nproc = base_rank. If this is not the case, both base_rank and node_rank must be provided. Overrides env var BASE_RANK. Defaults to 0 in a single-node environment.
- --node_rank
The rank of this node. See base_rank for information on when this must be provided. Overrides env var NODE_RANK. Defaults to 0 in a single-node environment.
- --master_addr
The FQDN of the node hosting the C10d TCP store. For single-node operation, this can generally be left as 127.0.0.1. Overrides env var MASTER_ADDR. Defaults to 127.0.0.1 in a single-node environment.
- --master_port
The port on the master hosting the C10d TCP store. If you are running multiple trainers on a single node, this generally needs to be unique for each one. Overrides env var MASTER_PORT. Defaults to a random free port in a single-node environment.
Distributed Properties#
Developers may need to access the current rank or world size in a
distributed setting. For example, a callback may only want to log
something for rank zero. Use our composer.utils.dist
module to
retrieve this information. The methods are similiar to
torch.distributed
, but also return defaults in a non-distributed
setting.
from composer.utils import dist
dist.get_world_size() # torch.distributed.get_world_size()
dist.get_local_rank()
dist.get_global_rank() # torch.distributed.get_rank()
For all retrievable properties, see composer.utils.dist
.
Space-time Equivalence#
We consider an equivalency principle between distributed training and gradient accumulation. That is, batches can either be parallelized across space (e.g. devices) or across time (e.g. gradient accumulation). Furthermore, the two dimensions are interchangable โ more devices, less gradient accumulation, and vice versa. Our trainer strives to respect this equivalency and ensure identical behavior regardless of the combinations of space and time parallelization used.
Distributed Sampling#
When providing torch.utils.data.Dataset
which is not torch.utils.data.IterableDataset
with torch.utils.data.DataLoader
to Composer, a torch.utils.data.distributed.DistributedSampler
is necessary to ensure different devices receive different batches. Composer will
raise an error if a DistributedSampler is not provided. composer.utils.dist
provides a helper function to create a DistributedSampler with the correct
parameters in composer.utils.dist.get_sampler()
.
from composer.utils import dist
sampler = dist.get_sampler(dataset, shuffle=True)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
composer.datasets.StreamingDataset
is an IterableDataset so a
DistributedSampler is not supported as IterableDatasets need to handle multi-worker
training internally. See IterableDataset [docs](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset)
for more information
FullyShardedDataParallel (FSDP)#
Composer integrates Pytorchโs FullyShardedDataParallel engine with some syntactic sugar to make it easy to write custom models that work with Composer + FSDP.
At a high level, when you use the Composer Trainer, you must pass it a ComposerModel
like
ComposerGPT
that defines certain functions like forward
, eval_forward
, loss
, etc. that
are called during the training loop.
Inside that ComposerModel
you may have one or many submodules, such as a .model
or
.language_model
or .classifier
that is the actual torch.nn.Module
that you
will be deploying at inference time. In our case, this is the
GPT
module that we build and attach ComposerGPT.model
.
When you provide an parallelism_config={'fsdp': {...}}
dictionary to the Composer Trainer,
then on __init__
, the Trainer will attempt to wrap each of the submodules of your
ComposerModel
with an FSDP auto wrap policy. This wrapping is recursive, so not only is
GPT wrapped, but all submodules of GPT may/may not be wrapped too. See the
FSDP documentation for more details on how auto
wrap policies work.
The full spec and defaults for Composerโs fsdp config is here:
fsdp_config = {
'activation_checkpointing': bool = True | False, # Default: False
'activation_checkpointing_reentrant': bool = True | False, # Default: True
'activation_cpu_offload': bool = True | False, # Default: False
'backward_prefetch': str = 'BACKWARD_PRE' | 'BACKWARD_POST' | 'NONE', # Default: 'BACKWARD_POST'
'cpu_offload': bool = True | False, # Default: False, cpu_offload not supported yet
'data_parallel_shard_degree': int = -1, # Default: -1
'data_parallel_replicate_degree': int = 1, # Default: 1
'forward_prefetch': bool = True | False, # Default: False
'ignored_modules': Optional[Iterable[torch.nn.Module]], # Default: None
'keep_low_precision_grads': bool = True | False, # Default: False
'limit_all_gathers': bool = True | False, # Default: False
'load_monolith_rank0_only': bool = True | False, # Default: False
'load_planner': torch.distributed.checkpoint.planner.LoadPlanner, # Default: None
'mixed_precision': str = 'FULL' | 'DEFAULT' | 'PURE', # Default: 'DEFAULT'
# Note: you can explicitly provide a dictionary too
# 'mixed_precision': dict = {
# 'param_dtype': 'fp32' | 'fp16' | 'bf16',
# 'reduce_dtype': 'fp32' | 'fp16' | 'bf16',
# 'buffer_dtype': 'fp32' | 'fp16' | 'bf16',
# },
'save_planner': torch.distributed.checkpoint.planner.SavePlanner, # Default: None
'sharded_ckpt_prefix_dir': str = 'ep{epoch}-ba{batch}', # Default: 'ep{epoch}-ba{batch}'
'sharding_strategy': str = 'FULL_SHARD' | 'SHARD_GRAD_OP' | 'NO_SHARD', # Default: 'FULL_SHARD'
'state_dict_type': str = 'full' | 'local' | 'sharded', # Default: full
'sync_module_states': bool = True | False, # Default: False
'use_orig_params': bool = True | False, # Default: True
'verbose': bool = True | False, # Default: False
}
All values come with defaults and can be optionally defined in the fsdp_config
. Most
parameters map directly to parameters in the
FSDP documentation.
This config is passed under parallelism_config[โfsdpโ] to the Composer Trainer. Two important
parameters which do not map include data_parallel_shard_degree, which dictates the number of
devices to shard across, and data_parallel_replicate_degree, which dictates the number of
devices to replicate across.
One Composer-specific pattern is that if mixed_precision
is provided as a str
,
then we automatically infer the settings to use from the Trainerโs precision
, which is
already being used for autocast, and we construct an associated MixedPrecision object for FSDP:
# If mixed_precision = 'full'
mixed_precision = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
# If mixed_precision = 'default'; emulates automatic mixed precision training.
mixed_precision = MixedPrecision(
param_dtype=autocast_precision, # Master weights stored in fp32 but are downcast to autocast_precision before the dist all_gather
reduce_dtype=torch.float32, # Gradient dist all_reduce in fp32
buffer_dtype=autocast_precision, # Buffers stored in fp32 but are downcast to autocast_precision before the dist all_gather
)
# If mixed_precision = 'pure'
mixed_precision = MixedPrecision(
param_dtype=autocast_precision, # Master weights stored in fp32 but are downcast to autocast_precision before the dist all_gather
reduce_dtype=autocast_precision, # Gradient dist all_reduce in autocast_precision
buffer_dtype=autocast_precision, # Buffers stored in fp32 but are downcast to autocast_precision before the dist all_gather
)
An example code snippet for using FSDP with composer is provided below:
import torch.nn as nn
from composer import Trainer
class Block(nn.Module):
...
class Model(nn.Module):
def __init__(self, n_layers):
super().__init__()
self.blocks = nn.ModuleList([
Block(...) for _ in range(n_layers)
]),
self.head = nn.Linear(...)
def forward(self, inputs):
...
# FSDP Wrap Function
def fsdp_wrap_fn(self, module):
return isinstance(module, Block)
# Activation Checkpointing Function
def activation_checkpointing_fn(self, module):
return isinstance(module, Block)
class MyComposerModel(ComposerModel):
def __init__(self, n_layers):
super().__init__()
self.model = Model(n_layers)
...
def forward(self, batch):
...
def eval_forward(self, batch, outputs=None):
...
def loss(self, outputs, batch):
...
...
composer_model = MyComposerModel(n_layers=3)
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'cpu_offload': False, # Not supported yet
'mixed_precision': 'DEFAULT',
'backward_prefetch': 'BACKWARD_POST',
'activation_checkpointing': False,
'activation_cpu_offload': False,
'verbose': True
}
trainer = Trainer(
model=composer_model,
parallelism_config={'fsdp': fsdp_config},
...
)
trainer.fit()
Warning
As of now now we donโt support CPU Offloading
for FSDP.
Warning
As of now, default parameters might not provide optimal convergence. Please proceed with caution.
Composerโs FSDP Auto Wrap Policy#
To make auto-wrapping easier on users, Composer uses a custom auto wrap policy that wraps modules according to the following rules:
If any module is attributed with
module._fsdp_wrap = True | False
, that choice will be respected.If the root module (e.g. GPT) defines a function
def fsdp_wrap_fn(module: torch.nn.Module) -> bool
, then that function will be used to evaluate the root moduleโs children.
These rules are meant to make it easy for users to modify existing models for usage with FSDP. You can either add attributes to modules you want to wrap (#1) or define a filter (#2).
In gpt.py, you can see that we used rule #2 to specify that all GPTBlock
modules within GPT
should be wrapped. Alternatively, we could have easily attributed each of the blocks with block._fsdp_wrap = True
and it would have accomplished the same thing. Whatever style you prefer, itโs up to you!
A very similar auto wrap policy is provided for activation checkpointing, with analogous rule #1 that looks for module._activation_checkpointing = True | False
and rule #2 that looks for def activation_checkpointing_fn(module: torch.nn.Module) -> bool
.
Experimental: Composer enables users to specify custom FSDP args for all wrapped modules. This is enabled by returning a dictionary of args instead of returning a bool.
import torch.nn as nn
class Block(nn.Module):
...
class BlockRequiringCustomArgs(nn.Module):
...
class Model(nn.Module):
def __init__(self, n_layers):
super().__init__()
self.blocks = nn.ModuleList([
Block(...) for _ in range(n_layers)
])
self.custom_arg_blocks = nn.ModuleList([
BlockRequiringCustomArgs(...) for _ in range(n_layers)
]),
self.head = nn.Linear(...)
def forward(self, inputs):
...
# FSDP Wrap function
def fsdp_wrap_fn(self, module):
if isinstance(module, Block):
return True
# extends FSDP wrapping to custom args
if isinstance(module, BlockRequiringCustomArgs):
return {
'process_group': 'node',
'mixed_precision': 'FULL',
}
# default to False
return False
# Activation Checkpointing Function
def activation_checkpointing_fn(self, module):
return isinstance(module, Block)
While the user can instantiate and pass in process groups, Composer enables process groups to be specified using the following options:
'self'
: the degenerate case where all process groups only operate within their current rank ('self'
=='set1'
). This is useful when you do not want a layer to be synchonized across accelerators.'node'
: instantiates process groups which opereate within a node ('node'
==f'set{local_world_size}'
). This is useful for Expert Layers in MoE models.'local_rank_across_nodes'
: instantiates process groups with the same local rank across all nodes ('local_rank_across_nodes'
==f'mod{local_world_size}'
). This is useful for Tensor Parallel Layers.'setK'
: (K
is an integer where world_size must be divisible byK
) instantiates process groups which opereate within a set of K GPUs. This is useful for Expert Layers in MoE models.'modK'
: (K
is an integer where world_size must be divisible byK
) instantiates process groups which opereate on every Kth GPUs. This is useful for Tensor Parallel Layers.
Saving and Loading Sharded Checkpoints with FSDP#
To save and load sharded checkpoints with FSDP, you can make use of the field, state_dict_type
in fsdp_config
.
Depending on the value you set for state_dict_type
, you can get different checkpointing behavior:
1. state_dict_type='full'
The default. Saves one big checkpoint file for the whole model.
It does this by gathering the model state to the global rank 0 device, unflattening it, and then saving it out.
If load_monolith_rank0_only=True, then when loading checkpoints the global rank 0 device will load
in the checkpoint file and scatter the model and optimizer state to the other ranks, which will will
dramatically reduce the memory usage on system. Otherwise, all ranks will separately load in the checkpoint file.
2. state_dict_type='sharded'
Each rank saves out an unflattened shard. For loading, each rank loads in the checkpoint file
corresponding to their unflattened shard.
Note: state_dict_type=โshardedโ is the recommended setting for sharded checkpointing in Composer for torch versions 2.0.0 or higher.
See The FSDP docs for more info.
If you use sharded checkpoints (state_dict_type=โshardedโ), your run will save as many files as you have ranks at each checkpointing event (plus one metadata file for torch versions 2.0.0 or higher). This can quicky pollute your save_folder with a lot of files after a couple checkpointing events. To help keep your checkpoint shard files organized, Composer will save each set of shards in itโs own prefix directory, which you can configure by using โsharded_ckpt_prefix_dirโ (default value sharded_ckpt_prefix_dir=โep{epoch}-ba{batch}โ). Checkpoint shards will be saved to {save_folder} / {sharded_ckpt_prefix_dir}
For example, to save sharded checkpoints to disk locally (state_dict_type=โshardedโ) with FSDP on PyTorch version 2.0.0 and higher, you can do:
import torch.nn as nn
from composer import Trainer
class Block(nn.Module):
...
class Model(nn.Module):
def __init__(self, n_layers):
super().__init__()
self.blocks = nn.ModuleList([
Block(...) for _ in range(n_layers)
]),
self.head = nn.Linear(...)
def forward(self, inputs):
...
# FSDP Wrap Function
def fsdp_wrap_fn(self, module):
return isinstance(module, Block)
class MyComposerModel(ComposerModel):
def __init__(self, n_layers):
super().__init__()
self.model = Model(n_layers)
...
def forward(self, batch):
...
def eval_forward(self, batch, outputs=None):
...
def loss(self, outputs, batch):
...
...
composer_model = MyComposerModel(n_layers=3)
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'state_dict_type': 'sharded',
'sharded_ckpt_prefix_dir': 'ba{batch}-shards' # will save each set of shards checkpoint to a unique folder based on batch
}
trainer = Trainer(
model=composer_model,
max_duration='4ba'
parallelism_config={'fsdp': fsdp_config},
save_folder='checkpoints',
save_interval='2ba',
...
)
trainer.fit()
After the second batch, this code will save N+1 checkpoint files to the local directory ./checkpoints/ba2-shards
. For example,
if you trained with 4 ranks, ./checkpoints/ba2-shards
would contain 5 files: a metadata file: .metadata
and 4 checkpoint files for each rank: __0_0.distcp
, __1_0.distcp
, __2_0.distcp
, and __3_0.distcp
.
After the fourth batch, N+1 checkpoint files (.metadata
, __0_0.distcp
, __1_0.distcp
, etc.) will saved to ./checkpoints/ba4-shards
To load these checkpoint files, you would need to do something like this:
from composer import Trainer
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'state_dict_type': 'sharded',
}
trainer = Trainer(
model=composer_model,
max_duration='4ba'
parallelism_config={'fsdp': fsdp_config},
load_path='./checkpoints/ba2-shards' # load_path must be the path to the prefix directory and not to a specific file.
...
)
Four things to note in this load example:
Instead of setting
load_path
to the path to a specific file, we set it to the directory which contains all the checkpoint files.We must set
'state_dict_type': 'sharded'
, like we did during the save.Composer with PyTorch version 2.0.0 and higher does support elastic checkpointing (more ranks than checkpoint files or more files than ranks), so you can resume on a different number of ranks than you saved on.
To do multinode resumption (resuming on more than one node regardless of how many nodes you saved on), you must be using torch 2.0.1 or higher due a bug in torch 2.0.0.
Tensor Parallel (TP)#
Composer integrates Pytorchโs Tensor Parallel API with some syntactic sugar to make it easy to write custom models that work with Composer + TP.
To enable Tensor Parallel, a tensor parallel config must be passed to the Composer Trainer. The full spec and defaults for Composerโs tensor parallelism config is here:
tp_config = {
tensor_parallel_degree: int = 1, # Default: 1
layer_plan: dict = None, # Default: None, maps to torch's `parallelize_plan`
}
All values come with defaults and can be optionally defined in the tp_config
. Most parameters
map directly to parameters in the
Tensor Parallel documentation.
This config is passed under parallelism_config[โtpโ] to the Composer Trainer. Important parameters
which do not directly map include tensor_parallel_degree, which dictates the number of devices to shard across,
and layer_plan, which simply corresponds to torchโs parallelize_plan.
An example code snippet for using TP and FSDP with Composer is provided below:
import torch.nn as nn
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from composer import Trainer
class Block(nn.Module):
...
class Model(nn.Module):
def __init__(self, n_layers):
super().__init__()
self.blocks = nn.ModuleList([
Block(...) for _ in range(n_layers)
]),
self.head = nn.Linear(...)
def forward(self, inputs):
...
# FSDP Wrap Function
def fsdp_wrap_fn(self, module):
return isinstance(module, Block)
# Activation Checkpointing Function
def activation_checkpointing_fn(self, module):
return isinstance(module, Block)
class MyComposerModel(ComposerModel):
def __init__(self, n_layers):
super().__init__()
self.model = Model(n_layers)
...
def forward(self, batch):
...
def eval_forward(self, batch, outputs=None):
...
def loss(self, outputs, batch):
...
...
composer_model = MyComposerModel(n_layers=3)
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'cpu_offload': False, # Not supported yet
'mixed_precision': 'DEFAULT',
'backward_prefetch': 'BACKWARD_POST',
'activation_checkpointing': False,
'activation_cpu_offload': False,
'verbose': True
}
tp_config = {
'tensor_parallel_degree': 2,
layer_plan = {
'model.0.fc1': ColwiseParallel(),
'model.0.fc2': RowwiseParallel(),
}
}
trainer = Trainer(
model=composer_model,
parallelism_config={
'fsdp': fsdp_config,
'tp': tp_config,
},
...
)
trainer.fit()
Note
This is an experimental feature and is subject to change. Many features, such as load_monolith_rank0_only or tensor parallelism without FSDP, are not yet supported.