# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Parallelism configs."""
import warnings
from dataclasses import dataclass, field, fields
from typing import Any, Optional
from torch.distributed._tensor.device_mesh import DeviceMesh
[docs]@dataclass
class FSDPConfig:
"""Configuration for Fully Sharded Data Parallelism (FSDP)."""
activation_checkpointing: bool = False
activation_checkpointing_reentrant: bool = True
activation_cpu_offload: bool = False
auto_wrap: bool = True
te_checkpoint_wrapper: bool = False
te_shard_fp8_weight: bool = False
backward_prefetch: str = 'BACKWARD_POST'
backward_prefetch_limit: int = 1
cpu_offload: bool = False
data_parallel_shard_degree: int = -1
data_parallel_replicate_degree: Optional[int] = None
forward_prefetch: bool = False
forward_prefetch_limit: int = 1
ignored_modules: Optional[Any] = None
keep_low_precision_grads: bool = False
limit_all_gathers: bool = True
load_monolith_rank0_only: bool = False
load_planner: Optional[Any] = None
mixed_precision: str = 'DEFAULT'
process_group: Optional[Any] = None
save_planner: Optional[Any] = None
sharded_ckpt_prefix_dir: str = 'ep{epoch}-ba{batch}'
sharding_strategy: str = 'FULL_SHARD'
state_dict_type: str = 'full'
sync_module_states: bool = False
use_orig_params: bool = True
verbose: bool = False
_device_mesh: Optional[DeviceMesh] = field(default=None, init=False, repr=False)
def __init__(self, **kwargs):
if 'device_mesh' in kwargs or '_device_mesh' in kwargs:
raise ValueError(
f'Directly specifying device mesh for FSDP was deprecated in Composer version 0.24.0. ' +
f"Please specify 'data_parallel_shard_degree' and/or 'data_parallel_replicate_degree' instead.",
)
for k, v in kwargs.items():
setattr(self, k, v)
@property
def device_mesh(self) -> Optional[DeviceMesh]:
return self._device_mesh
@device_mesh.setter
def device_mesh(self, value: Optional[DeviceMesh]):
self._device_mesh = value
[docs]@dataclass
class FSDP2Config:
"""Configuration for Fully Sharded Data Parallelism (FSDP2).
Args:
device_mesh (Optional[DeviceMesh]): The DeviceMesh for sharding. If None, a default 1D mesh is created.
For 1D mesh, parameters are fully sharded across the mesh (FSDP).
For 2D mesh, parameters are sharded across the 1st dimension and replicated across the 0th dimension (HSDP).
reshard_after_forward (Union[bool, int]): Controls parameter behavior after forward.
"""
# Settable core FSDP2 attrs
device_mesh: Optional[DeviceMesh] = None
reshard_after_forward: bool | int = True
# TODO: If we have reasonable evidence that activation checkpointing/activation offloading is decoupled from FSDP(2)
# in most of our use cases, we can decouple these two attributes from the FSDP2Config class.
activation_checkpointing: bool = False
activation_cpu_offload: bool = False
verbose: bool = False
[docs] @classmethod
def settable_attrs(cls) -> set[str]:
"""Return a set of all settable attributes of FSDP2Config."""
return {field.name for field in fields(cls)}
[docs] @classmethod
def from_compatible_attrs(cls, attrs: dict[str, Any]) -> 'FSDP2Config':
"""Create an FSDP2Config by filtering FSDP2 compatible attributes from given attrs.
Only attributes that are valid for FSDP2Config will be used, and warnings will be issued
for any attributes that cannot be transferred. Therefore it supports both FSDP1 and FSDP2 attributes, and main
use case is FSDP1 backwards compatibility.
Args:
attrs (dict[str, Any]): Dictionary of FSDP1/2 configuration attributes.
Returns:
FSDP2Config: A new FSDP2Config instance with compatible attributes.
Warnings:
UserWarning: If an attribute in the input dictionary is not a settable attribute
of FSDP2Config and will be ignored.
"""
# Get the settable attributes of FSDP2Config
settable_attrs = cls.settable_attrs()
# Filter the input attributes to only include settable ones
valid_attrs = {}
for key, value in attrs.items():
if key in settable_attrs:
valid_attrs[key] = value
else:
warnings.warn(
f"Attribute '{key}: {value}' is not a settable attribute of FSDP2Config and will be ignored",
UserWarning,
)
# Create and return a new FSDP2Config with the valid attributes
return FSDP2Config(**valid_attrs)
### Temporary read-only properties for FSDP 1 compatibility ###
# to be supported in FSDP2
@property
def auto_wrap(self) -> bool:
return False
@property
def load_monolith_rank0_only(self) -> bool:
return False
@property
def sync_module_states(self) -> bool:
return False
@property
def load_planner(self) -> Optional[Any]:
return None
@property
def save_planner(self) -> Optional[Any]:
return None
@property
def sharded_ckpt_prefix_dir(self) -> str:
return 'ep{epoch}-ba{batch}'
@property
def data_parallel_shard_degree(self) -> int:
return -1
@property
def data_parallel_replicate_degree(self) -> Optional[int]:
return None
# to be deprecated in FSDP2
@property
def state_dict_type(self) -> str:
return 'sharded'
@property
def use_orig_params(self) -> bool:
return True
def __post_init__(self):
warnings.warn('FSDP2 Config/APIs are experimental and subject to heavy changes', UserWarning)
[docs]@dataclass
class TPConfig:
"""Configuration for tensor parallelism (TP)."""
device_mesh: Optional[DeviceMesh] = None
tensor_parallel_degree: int = 1
layer_plan: Any = None
[docs]@dataclass
class ParallelismConfig:
"""Configuration for parallelism."""
fsdp: Optional[FSDPConfig] = None
tp: Optional[TPConfig] = None
fsdp2: Optional[FSDP2Config] = None