Source code for composer.utils.parallelism

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Parallelism configs."""

from dataclasses import dataclass, field
from typing import Any, Optional

from torch.distributed._tensor.device_mesh import DeviceMesh


[docs]@dataclass class FSDPConfig: """Configuration for Fully Sharded Data Parallelism (FSDP).""" activation_checkpointing: bool = False activation_checkpointing_reentrant: bool = True activation_cpu_offload: bool = False auto_wrap: bool = True te_checkpoint_wrapper: bool = False te_shard_fp8_weight: bool = False backward_prefetch: str = 'BACKWARD_POST' backward_prefetch_limit: int = 1 cpu_offload: bool = False data_parallel_shard_degree: int = -1 data_parallel_replicate_degree: Optional[int] = None forward_prefetch: bool = False forward_prefetch_limit: int = 1 ignored_modules: Optional[Any] = None keep_low_precision_grads: bool = False limit_all_gathers: bool = True load_monolith_rank0_only: bool = False load_planner: Optional[Any] = None mixed_precision: str = 'DEFAULT' process_group: Optional[Any] = None save_planner: Optional[Any] = None sharded_ckpt_prefix_dir: str = 'ep{epoch}-ba{batch}' sharding_strategy: str = 'FULL_SHARD' state_dict_type: str = 'full' sync_module_states: bool = False use_orig_params: bool = True verbose: bool = False _device_mesh: Optional[DeviceMesh] = field(default=None, init=False, repr=False) def __init__(self, **kwargs): if 'device_mesh' in kwargs or '_device_mesh' in kwargs: raise ValueError( f'Directly specifying device mesh for FSDP was deprecated in Composer version 0.24.0. ' + f"Please specify 'data_parallel_shard_degree' and/or 'data_parallel_replicate_degree' instead.", ) for k, v in kwargs.items(): setattr(self, k, v) @property def device_mesh(self) -> Optional[DeviceMesh]: return self._device_mesh @device_mesh.setter def device_mesh(self, value: Optional[DeviceMesh]): self._device_mesh = value
[docs]@dataclass class TPConfig: """Configuration for tensor parallelism (TP).""" device_mesh: Optional[DeviceMesh] = None tensor_parallel_degree: int = 1 layer_plan: Any = None
[docs]@dataclass class ParallelismConfig: """Configuration for parallelism.""" fsdp: Optional[FSDPConfig] = None tp: Optional[TPConfig] = None