# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Helper methods for :mod:`torch.distributed`.
To use :mod:`torch.distributed`, launch your training script with the
:ref:`composer launcher for distributed training <distributed-training>`. For example,
the following command launches an eight-process training run.
.. code-block::
composer -n 8 path/to/train.py
The composer launcher will automatically configure the following environment variables, which are
required for distributed training:
* ``RANK``: The global rank of the process, which should be on ``[0; WORLD_SIZE - 1]``.
* ``LOCAL_RANK``: The local rank for the process, which should be on ``[0; LOCAL_WORLD_SIZE - 1]``.
* ``NODE_RANK``: The rank of the node.
* ``WORLD_SIZE``: The total number of processes.
* ``LOCAL_WORLD_SIZE``: The number of processes on the current node.
* ``MASTER_ADDR``: The hostname for the rank-zero process.
* ``MASTER_PORT``: The port for the rank-zero process.
If none of these environment variables are set, this module will safely assume a single-rank configuration, where::
RANK=0
LOCAL_RANK=0
NODE_RANK=0
WORLD_SIZE=1
LOCAL_WORLD_SIZE=1
"""
from __future__ import annotations
import datetime
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union, cast
import torch
import torch.distributed as dist
import torch.utils.data
from composer.utils.device import get_device
if TYPE_CHECKING:
from composer.devices import Device
TObj = TypeVar('TObj')
__all__ = [
'all_gather',
'all_gather_object',
'all_reduce',
'barrier',
'broadcast',
'broadcast_object_list',
'get_global_rank',
'get_local_rank',
'get_local_world_size',
'get_node_rank',
'get_sampler',
'get_world_size',
'initialize_dist',
'is_available',
'is_initialized',
]
log = logging.getLogger(__name__)
class MissingEnvironmentError(Exception):
pass
def _get_distributed_config_var(
env_var: str,
human_name: str,
default: int,
fetch_fn_name: Optional[str] = None,
) -> int:
if not dist.is_available():
return default
if dist.is_initialized() and fetch_fn_name is not None:
dist_value = int(getattr(dist, fetch_fn_name)())
if env_var in os.environ:
env_value = int(os.environ[env_var])
if dist_value != env_value:
raise RuntimeError('Torch distributed has been initialized with a value of '
f'{dist_value} for {human_name}, but environment variable '
f'{env_var} has value {env_value}.')
return dist_value
if env_var in os.environ:
return int(os.environ[env_var])
if dist.is_initialized():
raise MissingEnvironmentError('Torch distributed is initialized but environment variable '
f'{env_var} is not set.')
return default
[docs]def get_world_size() -> int:
"""Returns the world size, which is the number of processes participating in this training run.
Returns:
int: The world size.
"""
return _get_distributed_config_var(env_var='WORLD_SIZE',
human_name='world size',
default=1,
fetch_fn_name='get_world_size')
[docs]def get_global_rank() -> int:
"""Returns the global rank of the current process, which is on ``[0; WORLD_SIZE - 1]``.
Returns:
int: The global rank.
"""
return _get_distributed_config_var(env_var='RANK', human_name='global rank', default=0, fetch_fn_name='get_rank')
[docs]def get_local_world_size() -> int:
"""Returns the local world size, which is the number of processes for the current node.
Returns:
int: The local world size.
"""
return _get_distributed_config_var(env_var='LOCAL_WORLD_SIZE', default=1, human_name='local world size')
[docs]def get_local_rank() -> int:
"""Returns the local rank for the current process, which is on ``[0; LOCAL_WORLD_SIZE - 1]``.
Returns:
int: The local rank.
"""
return _get_distributed_config_var(env_var='LOCAL_RANK', default=0, human_name='local rank')
[docs]def get_node_rank() -> int:
"""Returns the node rank.
For example, if there are 2 nodes, and 2 ranks per node, then global ranks 0-1 will have a
node rank of 0, and global ranks 2-3 will have a node rank of 1.
Returns:
int: The node rank, starting at 0.
"""
return _get_distributed_config_var(env_var='NODE_RANK', default=0, human_name='node rank')
[docs]def barrier() -> None:
"""Synchronizes all processes.
This function blocks until all processes reach this function.
.. seealso:: :func:`torch.distributed.barrier`
"""
if dist.is_available() and dist.is_initialized():
dist.barrier()
return
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def all_reduce(
tensor: torch.Tensor,
reduce_operation: str = 'SUM',
) -> None:
"""Reduce a ``tensor`` by applying the ``reduce_operation``.
All ranks get the same, bitwise-identical result.
.. seealso:: :func:`torch.distributed.all_reduce`
Args:
tensor (torch.Tensor): Input and output of the collective. The function
operates in-place.
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
Args:
tensor (torch.Tensor): Tensor to reduce. The function operates in-place.
reduce_operation (str, optional): The reduction operation (default: ``SUM``).
Valid options are:
* ``SUM``
* ``PRODUCT``
* ``MIN``
* ``MAX``
* ``BAND``
* ``BOR``
* ``BXOR``
Returns:
None: ``tensor`` is modified in-place.
"""
if dist.is_available() and dist.is_initialized():
reduce_op = getattr(dist.ReduceOp, reduce_operation.upper())
dist.all_reduce(tensor, op=reduce_op)
return
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def broadcast(tensor: torch.Tensor, src: int) -> None:
"""Broadcasts the tensor to the whole group.
``tensor`` must have the same number of elements in all processes participating in the collective.
See :func:`torch.distributed.broadcast`.
Args:
tensor (torch.Tensor): Data to be sent if ``src`` is the rank of current process,
and tensor to be used to save received data otherwise.
src (int): Source rank
"""
if dist.is_available() and dist.is_initialized():
dist.broadcast(tensor, src)
return
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def broadcast_object_list(object_list: List[Any], src: int = 0) -> None:
"""Broadcasts picklable objects in ``object_list`` to the whole group.
Similar to :func:`broadcast`, but Python objects can be passed in.
Note that all objects in ``object_list`` must be picklable in order to be broadcasted.
.. seealso:: :func:`torch.distributed.broadcast`.
Args:
object_list (torch.Tensor): List of input objects to broadcast.
Each object must be picklable. Only objects on the ``src`` rank will be broadcast,
but each rank must provide lists of equal sizes.
src (int, optional): Source rank (default: ``0``)
Returns:
None: ``object_list`` will be modified in-place and set to values of ``object_list`` from the ``src`` rank.
"""
if dist.is_available() and dist.is_initialized():
dist.broadcast_object_list(object_list, src)
# torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0
# or will just be None on non-rank-0
return
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def all_gather(tensor: torch.Tensor) -> Sequence[torch.Tensor]:
"""Collects a :class:`~torch.Tensor` from each rank.
.. seealso:: :func:`torch.distributed.all_gather`
Args:
tensor (torch.Tensor): Tensor from each rank to be gathered.
Returns:
Sequence[Tensor]: A sequence of tensors indexed by rank.
"""
if dist.is_available() and dist.is_initialized():
obj_gather_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
dist.all_gather(obj_gather_list, tensor)
return obj_gather_list
world_size = get_world_size()
if world_size == 1:
return [tensor]
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def all_gather_object(obj: TObj) -> List[TObj]:
"""Collect a pickleable object from each rank and return a list of these objects indexed by rank.
.. seealso:: :func:`torch.distributed.all_gather_object`
Args:
obj (TObj): Object to be gathered.
Returns:
List[TObj]: A list of objects indexed by rank.
"""
if dist.is_available() and dist.is_initialized():
obj_gather_list = [None for _ in range(get_world_size())]
dist.all_gather_object(obj_gather_list, obj)
# torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0
# or will just be None on non-rank-0
return cast(List[TObj], obj_gather_list)
world_size = get_world_size()
if world_size == 1:
return [obj]
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def is_available():
"""Returns whether PyTorch was built with distributed support.
.. seealso:: :func:`torch.distributed.is_available`
Returns:
bool: Whether PyTorch distributed support is available.
"""
return dist.is_available()
[docs]def is_initialized():
"""Returns whether PyTorch distributed is initialized.
.. seealso:: :func:`torch.distributed.is_initialized`
Returns:
bool: Whether PyTorch distributed is initialized.
"""
return dist.is_initialized()
[docs]def initialize_dist(device: Union[str, Device], timeout: float = 300.0):
"""Initialize the default PyTorch distributed process group.
This function assumes that the following environment variables are set:
* ``RANK``: The global rank of the process, which should be on ``[0; WORLD_SIZE - 1]``.
* ``LOCAL_RANK``: The local rank for the process, which should be on ``[0; LOCAL_WORLD_SIZE - 1]``.
* ``NODE_RANK``: The rank of the node.
* ``WORLD_SIZE``: The total number of processes.
* ``LOCAL_WORLD_SIZE``: The number of processes on the current node.
* ``MASTER_ADDR``: The hostname for the rank-zero process.
* ``MASTER_PORT``: The port for the rank-zero process.
If none of the environment variables are set, this function will assume a single-rank
configuration and initialize the default process group using a :class:`torch.distributed.HashStore` store.
.. seealso:: :func:`torch.distributed.init_process_group`
Args:
device (str | Device): The device from which the distributed backend is
interpreted. Either a string corresponding to a device (one of ``'cpu'``,
``'gpu'``, ``'mps'``, or ``'tpu'``) or a :class:`.Device`.
timeout (float, optional): The timeout for operations executed against the process
group, expressed in seconds. (default: ``300.0``).
"""
# If device is string, get corresponding composer.devices.Device object
device_obj = get_device(device)
timeout_timedelta = datetime.timedelta(seconds=timeout)
if get_world_size() > 1 and not dist.is_available():
raise RuntimeError('When the world size is > 1, ``torch.distributed`` must be used. However, it is '
'not available in your installation of PyTorch. Please install or build PyTorch '
'with distributed support.')
if dist.is_initialized():
if dist.get_backend() != device_obj.dist_backend.lower():
raise RuntimeError(f'The requested backend ({device_obj.dist_backend}) differs from the backend '
f'of the current process group ({dist.get_backend()}). If you '
'wish to change backends, please restart the python process.')
return
# If any of these variables are set, and they do not match the single rank defaults,
# then do not automatically configure distributed. There are no reasonable defaults to infer
# for the other variables. Instead, let torch.dist error on an incomplete configuration.
# If none of these variables are set, or some are set but they match the single rank defaults,
# then fill the rest in.
dist_env_var_defaults = {
'NODE_RANK': '0',
'WORLD_SIZE': '1',
'LOCAL_WORLD_SIZE': '1',
'RANK': '0',
'LOCAL_RANK': '0',
}
log.debug(
'Initializing torch.dist: global_rank=%d, local_rank=%d, world_size=%d, local_world_size=%d, node_rank=%d',
get_global_rank(),
get_local_rank(),
get_world_size(),
get_local_world_size(),
get_node_rank(),
)
dist_env_vars_match_defaults = all(os.environ.get(k, v) == v for (k, v) in dist_env_var_defaults.items())
if dist_env_vars_match_defaults:
# Fill in the remaining single-rank variables
os.environ.update(dist_env_var_defaults)
dist.init_process_group(device_obj.dist_backend, store=dist.HashStore(), world_size=1, rank=0)
else:
dist.init_process_group(device_obj.dist_backend, timeout=timeout_timedelta)
[docs]def get_sampler(dataset: torch.utils.data.Dataset, *, drop_last: bool = False, shuffle: bool = False):
"""Constructs a :class:`~torch.utils.data.distributed.DistributedSampler` for a dataset.
The :class:`~torch.utils.data.distributed.DistributedSampler` assumes that each rank has a complete copy of the
dataset. It ensures that each rank sees a unique shard for each epoch containing
``len(dataset) / get_world_size()`` samples.
.. note::
If the ``dataset`` is already sharded by rank, use a :class:`~torch.utils.data.SequentialSampler`
or :class:`~torch.utils.data.RandomSampler`.
Args:
dataset (torch.utils.data.Dataset): The dataset.
drop_last (bool): Whether to trop the last batch.
shuffle (bool): Whether to shuffle the dataset.
Returns:
torch.utils.data.distributed.DistributedSampler: The sampler.
"""
return torch.utils.data.DistributedSampler[int](
dataset,
drop_last=drop_last,
shuffle=shuffle,
num_replicas=get_world_size(),
rank=get_global_rank(),
)
@contextmanager
def run_local_rank_zero_first():
"""Context manager to hold all non-zero ranks until rank zero completes.
The below example will let the local rank zero download
the dataset, and hold all non-rank zeros until the
download is complete.
.. code-block: python
with run_local_rank_zero_first():
dataset = CIFAR10(
...,
download=True,
)
This prevents race conditions where multiple
ranks attempt to download the dataset to the
same location.
"""
if dist.is_available() and dist.is_initialized():
# hold non-zero ranks until rank zero done
if get_local_rank() != 0:
dist.barrier()
yield
else:
yield
dist.barrier()
return
world_size = get_world_size()
if world_size == 1:
yield
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')