# Copyright 2021 MosaicML. All Rights Reserved.
"""Helpers for running distributed data parallel training."""
from contextlib import contextmanager, nullcontext
from typing import Callable, ContextManager, Union, cast
import torch.nn
from torch.nn.parallel import DistributedDataParallel
from composer.core.state import State
from composer.utils import dist
from composer.utils.string_enum import StringEnum
__all__ = ["DDPSyncStrategy"]
[docs]class DDPSyncStrategy(StringEnum):
"""How and when DDP gradient synchronization should happen.
Attributes:
SINGLE_AUTO_SYNC: The default behavior for DDP. Gradients are synchronized as they
computed, for only the final microbatch of a batch. This is the most efficient
strategy, but can lead to errors when ``find_unused_parameters`` is set, since
it is possible different microbatches may use different sets of parameters,
leading to an incomplete sync.
MULTI_AUTO_SYNC: The default behavior for DDP when ``find_unused_parameters`` is set.
Gradients are synchronized as they are computed for all microbatches. This ensures
complete synchronization, but is less efficient than :attr:`SINGLE_AUTO_SYNC`. This
efficiency gap is usually small, as long as either DDP syncs are a small portion
of the trainer's overall runtime, or the number of microbatches per batch is
relatively small.
FORCED_SYNC: Gradients are manually synchronized only after all gradients have been
computed for the final microbatch of a batch. Like :attr:`MULTI_AUTO_SYNC`, this
strategy ensures complete gradient synchronization, but this tends to be slower than
:attr:`MULTI_AUTO_SYNC`. This is because ordinarily syncs can happen in parallel
with the ``loss.backward()`` computation, meaning syncs can be mostly complete by
the time that function finishes. However, in certain circumstances, syncs may take
a very long time to complete - if there are also a lot of microbatches per batch,
this strategy may be optimal.
"""
SINGLE_AUTO_SYNC = "single_auto_sync"
MULTI_AUTO_SYNC = "multi_auto_sync"
FORCED_SYNC = "forced_sync"
@contextmanager
def _ddp_sync_context(state: State, is_final_microbatch: bool, sync_strategy: Union[str, DDPSyncStrategy]):
"""A context manager for handling the :class:`DDPSyncStrategy`.
Args:
state (State): The state of the :class:`~composer.trainer.trainer.Trainer`.
is_final_microbatch (bool): Whether or not the context is being used during the final
microbatch of the gradient accumulation steps.
sync_strategy (str or DDPSyncStrategy): The ddp sync strategy to use. If a string
is provided, the string must be one of the values in :class:`DDPSyncStrategy`.
"""
if not isinstance(state.model, DistributedDataParallel):
yield
return
assert state.optimizers is not None, "optimizers have not been initialized"
sync_strategy = DDPSyncStrategy(sync_strategy)
no_sync_context = cast(Callable[[], ContextManager], state.model.no_sync)
auto_sync_context = nullcontext
if sync_strategy == DDPSyncStrategy.SINGLE_AUTO_SYNC:
context = auto_sync_context if is_final_microbatch else no_sync_context
with context():
yield
elif sync_strategy == DDPSyncStrategy.MULTI_AUTO_SYNC:
with auto_sync_context():
yield
elif sync_strategy == DDPSyncStrategy.FORCED_SYNC:
try:
with no_sync_context():
yield
finally:
if is_final_microbatch:
for optimizer in state.optimizers:
for group in optimizer.param_groups:
for p in group["params"]:
if p.grad is not None:
dist.all_reduce(p.grad)
p.grad = p.grad / dist.get_world_size()
else:
raise ValueError("Unknown sync strategy", sync_strategy)
def _prepare_ddp_module(module: torch.nn.Module, find_unused_parameters: bool) -> torch.nn.Module:
"""Wraps the module in a :class:`torch.nn.parallel.DistributedDataParallel` object if running distributed training.
Args:
module (torch.nn.Module): The module to wrap.
find_unused_parameters (bool): Whether or not to do a pass over the autograd graph
to find parameters to not expect gradients for. This is useful if there are some
parameters in the model that are not being trained.
"""
if dist.is_available() and dist.is_initialized():
if any((p.requires_grad for p in module.parameters())):
ddp_model = DistributedDataParallel(module, find_unused_parameters=find_unused_parameters)
return ddp_model
return module
if dist.is_available():
raise RuntimeError("Please call dist.initialize_dist() before calling ddp.prepare_module()")
raise RuntimeError("When the world size is > 1, ``torch.distributed`` must be used. However, it is "
"not available in your installation of PyTorch. Please install or build PyTorch "
"with distributed support.")