# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Helper methods for :mod:`torch.distributed`.
To use :mod:`torch.distributed`, launch your training script with the
:ref:`composer launcher for distributed training <distributed-training>`. For example,
the following command launches an eight-process training run.
.. code-block::
composer -n 8 path/to/train.py
The composer launcher will automatically configure the following environment variables, which are
required for distributed training:
* ``RANK``: The global rank of the process, which should be on ``[0; WORLD_SIZE - 1]``.
* ``LOCAL_RANK``: The local rank for the process, which should be on ``[0; LOCAL_WORLD_SIZE - 1]``.
* ``NODE_RANK``: The rank of the node.
* ``WORLD_SIZE``: The total number of processes.
* ``LOCAL_WORLD_SIZE``: The number of processes on the current node.
* ``MASTER_ADDR``: The hostname for the rank-zero process.
* ``MASTER_PORT``: The port for the rank-zero process.
If none of these environment variables are set, this module will safely assume a single-rank configuration, where::
RANK=0
LOCAL_RANK=0
NODE_RANK=0
WORLD_SIZE=1
LOCAL_WORLD_SIZE=1
"""
from __future__ import annotations
import datetime
import logging
import os
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union, cast
import torch
import torch.distributed as dist
import torch.utils.data
from composer.utils.device import get_device
if TYPE_CHECKING:
from composer.devices import Device
TObj = TypeVar('TObj')
__all__ = [
'all_gather',
'all_gather_object',
'all_reduce',
'barrier',
'broadcast',
'broadcast_object_list',
'get_global_rank',
'get_local_rank',
'get_local_world_size',
'get_node_rank',
'get_sampler',
'get_world_size',
'initialize_dist',
'is_available',
'is_initialized',
]
log = logging.getLogger(__name__)
class MissingEnvironmentError(Exception):
pass
def _get_distributed_config_var(
env_var: str,
human_name: str,
default: int,
fetch_fn_name: Optional[str] = None,
) -> int:
if not dist.is_available():
return default
if dist.is_initialized() and fetch_fn_name is not None:
dist_value = int(getattr(dist, fetch_fn_name)())
if env_var in os.environ:
env_value = int(os.environ[env_var])
if dist_value != env_value:
raise RuntimeError('Torch distributed has been initialized with a value of '
f'{dist_value} for {human_name}, but environment variable '
f'{env_var} has value {env_value}.')
return dist_value
if env_var in os.environ:
return int(os.environ[env_var])
if dist.is_initialized():
raise MissingEnvironmentError('Torch distributed is initialized but environment variable '
f'{env_var} is not set.')
return default
[docs]def get_world_size() -> int:
"""Returns the world size, which is the number of processes participating in this training run.
Returns:
int: The world size.
"""
return _get_distributed_config_var(env_var='WORLD_SIZE',
human_name='world size',
default=1,
fetch_fn_name='get_world_size')
[docs]def get_global_rank() -> int:
"""Returns the global rank of the current process, which is on ``[0; WORLD_SIZE - 1]``.
Returns:
int: The global rank.
"""
return _get_distributed_config_var(env_var='RANK', human_name='global rank', default=0, fetch_fn_name='get_rank')
[docs]def get_local_world_size() -> int:
"""Returns the local world size, which is the number of processes for the current node.
Returns:
int: The local world size.
"""
return _get_distributed_config_var(env_var='LOCAL_WORLD_SIZE', default=1, human_name='local world size')
[docs]def get_local_rank() -> int:
"""Returns the local rank for the current process, which is on ``[0; LOCAL_WORLD_SIZE - 1]``.
Returns:
int: The local rank.
"""
return _get_distributed_config_var(env_var='LOCAL_RANK', default=0, human_name='local rank')
[docs]def get_node_rank() -> int:
"""Returns the node rank.
For example, if there are 2 nodes, and 2 ranks per node, then global ranks 0-1 will have a
node rank of 0, and global ranks 2-3 will have a node rank of 1.
Returns:
int: The node rank, starting at 0.
"""
return _get_distributed_config_var(env_var='NODE_RANK', default=0, human_name='node rank')
[docs]def barrier() -> None:
"""Synchronizes all processes.
This function blocks until all processes reach this function.
.. seealso:: :func:`torch.distributed.barrier`
"""
if dist.is_available() and dist.is_initialized():
dist.barrier()
return
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def all_reduce(
tensor: torch.Tensor,
reduce_operation: str = 'SUM',
) -> None:
"""Reduce a ``tensor`` by applying the ``reduce_operation``.
All ranks get the same, bitwise-identical result.
.. seealso:: :func:`torch.distributed.all_reduce`
Args:
tensor (torch.Tensor): Input and output of the collective. The function
operates in-place.
op (optional): One of the values from
``torch.distributed.ReduceOp``
enum. Specifies an operation used for element-wise reductions.
Args:
tensor (torch.Tensor): Tensor to reduce. The function operates in-place.
reduce_operation (str, optional): The reduction operation (default: ``SUM``).
Valid options are:
* ``SUM``
* ``PRODUCT``
* ``MIN``
* ``MAX``
* ``BAND``
* ``BOR``
* ``BXOR``
Returns:
None: ``tensor`` is modified in-place.
"""
if dist.is_available() and dist.is_initialized():
reduce_op = getattr(dist.ReduceOp, reduce_operation.upper())
dist.all_reduce(tensor, op=reduce_op)
return
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def broadcast(tensor: torch.Tensor, src: int) -> None:
"""Broadcasts the tensor to the whole group.
``tensor`` must have the same number of elements in all processes participating in the collective.
See :func:`torch.distributed.broadcast`.
Args:
tensor (torch.Tensor): Data to be sent if ``src`` is the rank of current process,
and tensor to be used to save received data otherwise.
src (int): Source rank
"""
if dist.is_available() and dist.is_initialized():
dist.broadcast(tensor, src)
return
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def broadcast_object_list(object_list: List[Any], src: int = 0) -> None:
"""Broadcasts picklable objects in ``object_list`` to the whole group.
Similar to :func:`broadcast`, but Python objects can be passed in.
Note that all objects in ``object_list`` must be picklable in order to be broadcasted.
.. seealso:: :func:`torch.distributed.broadcast`.
Args:
object_list (torch.Tensor): List of input objects to broadcast.
Each object must be picklable. Only objects on the ``src`` rank will be broadcast,
but each rank must provide lists of equal sizes.
src (int, optional): Source rank (default: ``0``)
Returns:
None: ``object_list`` will be modified in-place and set to values of ``object_list`` from the ``src`` rank.
"""
if dist.is_available() and dist.is_initialized():
dist.broadcast_object_list(object_list, src)
# torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0
# or will just be None on non-rank-0
return
world_size = get_world_size()
if world_size == 1:
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def all_gather(tensor: torch.Tensor) -> Sequence[torch.Tensor]:
"""Collects a :class:`~torch.Tensor` from each rank.
.. seealso:: :func:`torch.distributed.all_gather`
Args:
tensor (torch.Tensor): Tensor from each rank to be gathered.
Returns:
Sequence[Tensor]: A sequence of tensors indexed by rank.
"""
if dist.is_available() and dist.is_initialized():
obj_gather_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
dist.all_gather(obj_gather_list, tensor)
return obj_gather_list
world_size = get_world_size()
if world_size == 1:
return [tensor]
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def all_gather_object(obj: TObj) -> List[TObj]:
"""Collect a pickleable object from each rank and return a list of these objects indexed by rank.
.. seealso:: :func:`torch.distributed.all_gather_object`
Args:
obj (TObj): Object to be gathered.
Returns:
List[TObj]: A list of objects indexed by rank.
"""
if dist.is_available() and dist.is_initialized():
obj_gather_list = [None for _ in range(get_world_size())]
dist.all_gather_object(obj_gather_list, obj)
# torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0
# or will just be None on non-rank-0
return cast(List[TObj], obj_gather_list)
world_size = get_world_size()
if world_size == 1:
return [obj]
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')
[docs]def is_available():
"""Returns whether PyTorch was built with distributed support.
.. seealso:: :func:`torch.distributed.is_available`
Returns:
bool: Whether PyTorch distributed support is available.
"""
return dist.is_available()
[docs]def is_initialized():
"""Returns whether PyTorch distributed is initialized.
.. seealso:: :func:`torch.distributed.is_initialized`
Returns:
bool: Whether PyTorch distributed is initialized.
"""
return dist.is_initialized()
[docs]def initialize_dist(device: Union[str, Device], timeout: float = 300.0):
"""Initialize the default PyTorch distributed process group.
This function assumes that the following environment variables are set:
* ``RANK``: The global rank of the process, which should be on ``[0; WORLD_SIZE - 1]``.
* ``LOCAL_RANK``: The local rank for the process, which should be on ``[0; LOCAL_WORLD_SIZE - 1]``.
* ``NODE_RANK``: The rank of the node.
* ``WORLD_SIZE``: The total number of processes.
* ``LOCAL_WORLD_SIZE``: The number of processes on the current node.
* ``MASTER_ADDR``: The hostname for the rank-zero process.
* ``MASTER_PORT``: The port for the rank-zero process.
If none of the environment variables are set, this function will assume a single-rank
configuration and initialize the default process group using a :class:`torch.distributed.HashStore` store.
.. seealso:: :func:`torch.distributed.init_process_group`
Args:
device (str | Device): The device from which the distributed backend is
interpreted. Either a string corresponding to a device (one of ``'cpu'``,
``'gpu'``, ``'mps'``, or ``'tpu'``) or a :class:`.Device`.
timeout (float, optional): The timeout for operations executed against the process
group, expressed in seconds. (default: ``300.0``).
"""
# If device is string, get corresponding composer.devices.Device object
device_obj = get_device(device)
timeout_timedelta = datetime.timedelta(seconds=timeout)
if get_world_size() > 1 and not dist.is_available():
raise RuntimeError('When the world size is > 1, ``torch.distributed`` must be used. However, it is '
'not available in your installation of PyTorch. Please install or build PyTorch '
'with distributed support.')
if dist.is_initialized():
if dist.get_backend() != device_obj.dist_backend.lower():
raise RuntimeError(f'The requested backend ({device_obj.dist_backend}) differs from the backend '
f'of the current process group ({dist.get_backend()}). If you '
'wish to change backends, please restart the python process.')
return
# If any of these variables are set, and they do not match the single rank defaults,
# then do not automatically configure distributed. There are no reasonable defaults to infer
# for the other variables. Instead, let torch.dist error on an incomplete configuration.
# If none of these variables are set, or some are set but they match the single rank defaults,
# then fill the rest in.
dist_env_var_defaults = {
'NODE_RANK': '0',
'WORLD_SIZE': '1',
'LOCAL_WORLD_SIZE': '1',
'RANK': '0',
'LOCAL_RANK': '0',
}
log.debug(
'Initializing torch.dist: global_rank=%d, local_rank=%d, world_size=%d, local_world_size=%d, node_rank=%d',
get_global_rank(),
get_local_rank(),
get_world_size(),
get_local_world_size(),
get_node_rank(),
)
dist_env_vars_match_defaults = all(os.environ.get(k, v) == v for (k, v) in dist_env_var_defaults.items())
if dist_env_vars_match_defaults:
# Fill in the remaining single-rank variables
os.environ.update(dist_env_var_defaults)
dist.init_process_group(device_obj.dist_backend, store=dist.HashStore(), world_size=1, rank=0)
else:
dist.init_process_group(device_obj.dist_backend, timeout=timeout_timedelta)
[docs]def get_sampler(dataset: torch.utils.data.Dataset, *, drop_last: bool = False, shuffle: bool = False):
"""Constructs a :class:`~torch.utils.data.distributed.DistributedSampler` for a dataset.
The :class:`~torch.utils.data.distributed.DistributedSampler` assumes that each rank has a complete copy of the
dataset. It ensures that each rank sees a unique shard for each epoch containing
``len(dataset) / get_world_size()`` samples.
.. note::
If the ``dataset`` is already sharded by rank, use a :class:`~torch.utils.data.SequentialSampler`
or :class:`~torch.utils.data.RandomSampler`.
Args:
dataset (torch.utils.data.Dataset): The dataset.
drop_last (bool): Whether to trop the last batch.
shuffle (bool): Whether to shuffle the dataset.
Returns:
torch.utils.data.distributed.DistributedSampler: The sampler.
"""
return torch.utils.data.DistributedSampler[int](
dataset,
drop_last=drop_last,
shuffle=shuffle,
num_replicas=get_world_size(),
rank=get_global_rank(),
)
@contextmanager
def local_rank_zero_download_and_wait(expected_file_path: str):
"""Context manager to wait for a file to exist on all ranks except local rank zero.
It is expected that the file will be created by local rank zero. This function is useful
as an alternative to ``run_local_rank_zero_first`` when downloading a file, because it does
not require dist to be initialized. It only requires that the ``LOCAL_RANK`` environment variable
is set. If dist is initialized, you should use ``run_local_rank_zero_first`` instead to avoid busy waiting.
Args:
expected_file_path (str): The file to wait for existence of
"""
local_rank = get_local_rank()
if local_rank != 0:
while not os.path.exists(expected_file_path):
time.sleep(0.1)
yield
@contextmanager
def run_local_rank_zero_first():
"""Context manager to hold all non-zero ranks until rank zero completes.
The below example will let the local rank zero download
the dataset, and hold all non-rank zeros until the
download is complete.
.. code-block: python
with run_local_rank_zero_first():
dataset = CIFAR10(
...,
download=True,
)
This prevents race conditions where multiple
ranks attempt to download the dataset to the
same location.
"""
if dist.is_available() and dist.is_initialized():
# hold non-zero ranks until rank zero done
if get_local_rank() != 0:
dist.barrier()
yield
else:
yield
dist.barrier()
return
world_size = get_world_size()
if world_size == 1:
yield
return
raise RuntimeError(f'The world_size({world_size}) > 1, but the distributed package is not '
'available or has not been initialized. Please check you have initialized '
'the distributed runtime and that PyTorch has been built with distributed '
'support. If calling this function outside Trainer, please ensure that '
'`composer.utils.dist.initialize_dist` has been called first.')