Source code for composer.utils.parallelism

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Parallelism configs."""

import warnings
from dataclasses import dataclass
from typing import Any, Optional

from torch.distributed._tensor.device_mesh import DeviceMesh

from composer.utils.warnings import VersionedDeprecationWarning


[docs]@dataclass class FSDPConfig: """Configuration for Fully Sharded Data Parallelism (FSDP).""" activation_checkpointing: bool = False activation_checkpointing_reentrant: bool = True activation_cpu_offload: bool = False auto_wrap: bool = True te_checkpoint_wrapper: bool = False te_shard_fp8_weight: bool = False backward_prefetch: str = 'BACKWARD_POST' backward_prefetch_limit: int = 1 cpu_offload: bool = False data_parallel_shard_degree: int = -1 data_parallel_replicate_degree: Optional[int] = None device_mesh: Optional[DeviceMesh] = None forward_prefetch: bool = False forward_prefetch_limit: int = 1 ignored_modules: Optional[Any] = None keep_low_precision_grads: bool = False limit_all_gathers: bool = True load_monolith_rank0_only: bool = False load_planner: Optional[Any] = None mixed_precision: str = 'DEFAULT' process_group: Optional[Any] = None save_planner: Optional[Any] = None sharded_ckpt_prefix_dir: str = 'ep{epoch}-ba{batch}' sharding_strategy: str = 'FULL_SHARD' state_dict_type: str = 'full' sync_module_states: bool = False use_orig_params: bool = True verbose: bool = False
[docs]def create_fsdp_config(fsdp_config: dict[str, Any]): """Modify fsdp_config to set default values for missing keys.""" fsdp_config = {**fsdp_config} # Shallow copy to avoid modifying input if 'process_group' in fsdp_config: warnings.warn( VersionedDeprecationWarning( 'process_group is deprecated. Please specify `data_parallel_shard_degree` and `data_parallel_replicate_degree` instead.', remove_version='0.24', ), ) if 'device_mesh' in fsdp_config: warnings.warn( VersionedDeprecationWarning( 'device_mesh is deprecated. Please specify `data_parallel_shard_degree` and `data_parallel_replicate_degree` instead.', remove_version='0.24', ), ) if 'data_parallel_shard_degree' in fsdp_config or 'data_parallel_replicate_degree' in fsdp_config: raise ValueError( 'Cannot specify both `device_mesh` and `data_parallel_shard_degree` or `data_parallel_replicate_degree`. Please remove `device_mesh`.', ) device_mesh = fsdp_config.pop('device_mesh') if len(device_mesh) == 1: fsdp_config['data_parallel_shard_degree'] = device_mesh[0] elif len(device_mesh) == 2: fsdp_config['data_parallel_replicate_degree'] = device_mesh[0] fsdp_config['data_parallel_shard_degree'] = device_mesh[1] else: raise ValueError( f'device_mesh must be of length 1 or 2 but received length {len(device_mesh)} with device mesh {device_mesh}.', ) return FSDPConfig(**fsdp_config)
[docs]@dataclass class TPConfig: """Configuration for tensor parallelism (TP).""" device_mesh: Optional[DeviceMesh] = None tensor_parallel_degree: int = 1 layer_plan: Any = None
[docs]@dataclass class ParallelismConfig: """Configuration for parallelism.""" fsdp: Optional[FSDPConfig] = None tp: Optional[TPConfig] = None