Source code for composer.distributed.dist_strategy

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Helpers for running distributed data parallel training."""

import logging
import warnings
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, ContextManager, Iterator, Optional, Sequence, Union, cast

import torch
from packaging import version
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl,
    apply_activation_checkpointing,
    checkpoint_wrapper,
    offload_wrapper,
)
from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
from torch.distributed.fsdp._common_utils import clean_tensor_name
from torch.distributed.fsdp.wrap import CustomPolicy
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Metric, MetricCollection

from composer.core import Precision, State
from composer.core.precision import _validate_precision
from composer.devices import Device, DeviceGPU
from composer.distributed.mosaic_parallelism import (
    BACKWARD_PREFETCH_MAP,
    SHARDING_MAP,
    get_cpu_offload,
    get_mixed_precision,
    set_custom_fsdp_module_kwargs,
)
from composer.utils import FSDPConfig, StringEnum, TPConfig, dist, ensure_tuple, get_device

__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module']

log = logging.getLogger(__name__)

process_group_cache = {}


[docs]class DDPSyncStrategy(StringEnum): """How and when gradient synchronization should happen. Attributes: SINGLE_AUTO_SYNC: The default behavior. Gradients are synchronized as they computed, for only the final microbatch of a batch. This is the most efficient strategy, but can lead to errors when ``find_unused_parameters`` is set, since it is possible different microbatches may use different sets of parameters, leading to an incomplete sync. MULTI_AUTO_SYNC: The default behavior when ``find_unused_parameters`` is set. Gradients are synchronized as they are computed for all microbatches. This ensures complete synchronization, but is less efficient than :attr:`SINGLE_AUTO_SYNC`. This efficiency gap is usually small, as long as either DDP syncs are a small portion of the trainer's overall runtime, or the number of microbatches per batch is relatively small. FORCED_SYNC: Gradients are manually synchronized only after all gradients have been computed for the final microbatch of a batch. Like :attr:`MULTI_AUTO_SYNC`, this strategy ensures complete gradient synchronization, but this tends to be slower than :attr:`MULTI_AUTO_SYNC`. This is because ordinarily syncs can happen in parallel with the ``loss.backward()`` computation, meaning syncs can be mostly complete by the time that function finishes. However, in certain circumstances, syncs may take a very long time to complete - if there are also a lot of microbatches per batch, this strategy may be optimal. """ SINGLE_AUTO_SYNC = 'single_auto_sync' MULTI_AUTO_SYNC = 'multi_auto_sync' FORCED_SYNC = 'forced_sync'
[docs]@contextmanager def ddp_sync_context(state: State, is_final_microbatch: bool, sync_strategy: Union[str, DDPSyncStrategy]): """A context manager for handling the :class:`DDPSyncStrategy`. Args: state (State): The state of the :class:`.Trainer`. is_final_microbatch (bool): Whether or not the context is being used during the final microbatch of the gradient accumulation steps. sync_strategy (str | DDPSyncStrategy): The ddp sync strategy to use. If a string is provided, the string must be one of the values in :class:`DDPSyncStrategy`. """ if not isinstance(state.model, DistributedDataParallel): yield return assert state.optimizers is not None, 'optimizers have not been initialized' sync_strategy = DDPSyncStrategy(sync_strategy) no_sync_context = cast(Callable[[], ContextManager], state.model.no_sync) auto_sync_context = nullcontext if sync_strategy == DDPSyncStrategy.SINGLE_AUTO_SYNC: context = auto_sync_context if is_final_microbatch else no_sync_context with context(): yield elif sync_strategy == DDPSyncStrategy.MULTI_AUTO_SYNC: with auto_sync_context(): yield elif sync_strategy == DDPSyncStrategy.FORCED_SYNC: try: with no_sync_context(): yield finally: if is_final_microbatch: for optimizer in state.optimizers: for group in optimizer.param_groups: for p in group['params']: if p.grad is not None: dist.all_reduce(p.grad) p.grad = p.grad / dist.get_world_size() else: raise ValueError('Unknown sync strategy', sync_strategy)
[docs]def prepare_ddp_module(module: torch.nn.Module, find_unused_parameters: bool) -> torch.nn.Module: """Wraps the module in a :class:`torch.nn.parallel.DistributedDataParallel` object if running distributed training. Args: module (torch.nn.Module): The module to wrap. find_unused_parameters (bool): Whether or not to do a pass over the autograd graph to find parameters to not expect gradients for. This is useful if there are some parameters in the model that are not being trained. """ if dist.is_available() and dist.is_initialized(): if any((p.requires_grad for p in module.parameters())): log.debug('Wrapping model with DistributedDataParallel') ddp_model = DistributedDataParallel(module, find_unused_parameters=find_unused_parameters) return ddp_model return module if dist.is_available(): raise RuntimeError('Please call dist.initialize_dist() before calling ddp.prepare_module()') raise RuntimeError( 'When the world size is > 1, ``torch.distributed`` must be used. However, it is ' 'not available in your installation of PyTorch. Please install or build PyTorch ' 'with distributed support.', )
def _recreate_fsdp_param_groups_from_unwrapped_opt_info( fsdp_wrapped_named_params: Iterator[tuple[str, torch.nn.Parameter]], non_wrapped_param_names_to_group_num: dict[str, int], group_num_to_optimizer_info: dict[int, dict[str, Any]], ) -> list[dict[str, Any]]: """Helper function to recreate optimizer groups for FSDP wrapped modules. Optimizer param groups are formatted as: [ {'params': [p1, p2], 'lr' : lr1}, # group 0 {'params': [p3], 'lr' : lr2} # group 1 ] ie. there are multiple parameters per group. Here, we track the group number in order to map multiple parameters to the same group Args: fsdp_wrapped_named_params: output of model.named_parameters() of an FSDP wrapped model non_wrapped_param_names_to_group_num: a dict mapping from the original model param names to the param group number group_num_to_optimizer_info: stores info like lr, eps for each group Returns a list of param groups, referencing the fsdp parameters """ # Initialize an empty list of parameters for each optimizer group for group_num in group_num_to_optimizer_info.keys(): group_num_to_optimizer_info[group_num]['params'] = [] for fsdp_name, param in fsdp_wrapped_named_params: unwrapped_name = clean_tensor_name(fsdp_name) # Since we are iterating over all model.named_parameters() after fsdp wrapping, we need to check # if the parameter was included in the optimizer param_group pre fsdp wrapping, in order to support # passing a subset of model params in the optimizer if unwrapped_name in non_wrapped_param_names_to_group_num: # Need to have a 1:1 mapping between a fsdp param name and the non-wrapped vanilla param name retrieved_group_num = non_wrapped_param_names_to_group_num[unwrapped_name] group_num_to_optimizer_info[retrieved_group_num]['params'].append(param) # return sorted optimizer info groups return [group_num_to_optimizer_info[num] for num in sorted(group_num_to_optimizer_info.keys())]
[docs]def prepare_tp_module( model: torch.nn.Module, optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]], tp_config: TPConfig, ) -> None: """Prepare a module (assumed ComposerModel) for use with tensor parallel.""" optimizers_tuple = ensure_tuple(optimizers) if len(optimizers_tuple) != 1: raise NotImplementedError(f'Only one optimizer is supported; found {len(optimizers_tuple)} optimizers') optim = optimizers_tuple[0] if len(optim.param_groups) > 1: raise RuntimeError('Multiple optimizer groups are not supported with tensor parallelism.',) if len(optim.param_groups[0]['params']) != len(list(model.parameters())): raise ValueError( 'Passing in a subset of model parameters to the optimizer is not supported with tensor parallelism.', ) from torch.distributed.tensor.parallel import parallelize_module device_mesh = tp_config.device_mesh assert device_mesh is not None # For type checking, set in State.__init__ parallelize_module( module=model, device_mesh=device_mesh, parallelize_plan=tp_config.layer_plan, )
[docs]def prepare_fsdp_module( model: torch.nn.Module, optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]], fsdp_config: FSDPConfig, precision: Optional[Union[str, Precision]] = None, device: Optional[Union[str, Device]] = None, auto_microbatching: bool = False, te_rng_seed: int = 1234, ) -> tuple[list, dict]: """Prepare a module (assumed ComposerModel) and optimizer for use with :class:`torch.distributed.fsdp.FullyShardedDataParallel`. Args: model (torch.nn.Module): The model to wrap. optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer], optional): The optimizer for `model`, assumed to have a single param group := model.parameters(). fsdp_config (FSDPConfig): The FSDP config. precision: (Precision): The precision being used by the Trainer, used to fill in defaults for FSDP `mixed_precision` settings. device: The device being used by the Trainer. auto_microbatching (bool, optional): Whether or not auto microbatching is enabled. te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234. """ device = get_device(device) if precision is None: precision = Precision.AMP_FP16 if isinstance(device, DeviceGPU) else Precision.FP32 elif isinstance(precision, str): precision = Precision(precision) _validate_precision(precision, device) # Check sync_module_states is True for mixed initialization or HSDP if fsdp_config.sync_module_states == False: rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0 all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8)) dist.all_reduce(all_ranks_meta, reduce_operation='MIN') any_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8)) dist.all_reduce(any_ranks_meta, reduce_operation='MAX') if all_ranks_meta.item() == 0 and any_ranks_meta.item() == 1: raise ValueError( 'Detected mixed initialization where some ranks have model on cpu or ' 'gpu and some ranks are on meta. Either keep all ranks on the same ' "device or set parallelism_config['fsdp']['sync_module_states'] = True. Otherwise, " 'some weights may be randomly initialized when loading a checkpoint.', ) # Handles of FSDP sync hooks if automicrobatching is on hook_handles = [] # Check if other ranks OOMed after forward/backward pass when using auto microbatching. This # may happen when close to memory limit or with uneven memory usage across ranks. Since we # need to do this before the model weights are gathered for the next FSDP block, we wrap every # FSPD block with a hook that checks if any other rank OOMed. def sync_hook(*args): # Check if any other rank hit an OOM found_cuda_oom_tensor = device.tensor_to_device(torch.tensor([0], dtype=torch.uint8)) dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') found_cuda_oom = found_cuda_oom_tensor.item() # Signal current rank is still in batch all_ranks_finished_tensor = device.tensor_to_device(torch.tensor([0], dtype=torch.uint8)) dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') if found_cuda_oom == 1: raise RuntimeError('CUDA out of memory encountered on a different rank') # Necessary variables for optimizers with multiple param groups in FSDP param_name_to_group_num = None group_num_to_opt_group_info = None single_param_group_opt_info = None if optimizers: optimizers_tuple = ensure_tuple(optimizers) if len(optimizers_tuple) != 1: raise NotImplementedError(f'Only one optimizer is supported; found {len(optimizers_tuple)} optimizers') # clearing optimizer param groups and state # that will be recreated at the end of prepare_fsdp_module optim = optimizers_tuple[0] # Simplest case - single param group & all model params stored in optimizer if len(optim.param_groups) == 1 and len(optim.param_groups[0]['params']) == len(list(model.parameters())): single_param_group_opt_info = {k: v for k, v in optim.param_groups[0].items() if k != 'params'} elif fsdp_config.use_orig_params: # this code block stores information about param groups pre-fsdp wrapping in order to recreate them post-wrapping # to do so, it relies on the ptrs of the model.parameters() in a model and the names of the params # for this to work, use_orig_params=True, as we need the names of the params post-wrapping # TP is not supported, as the underlying parameters in the model differ from the params in the param groups after being dtensorified ptr_to_param_name = {id(p): n for n, p in model.named_parameters()} param_name_to_group_num = {} group_num_to_opt_group_info = {} for group_num in range(len(optim.param_groups)): # Need to in-line to avoid a reference which causes FSDP to allocate extra GPU memory # group = optim.param_groups[group_num] for param_num in range(len(optim.param_groups[group_num]['params'])): param_ptr = id(optim.param_groups[group_num]['params'][param_num]) if param_ptr not in ptr_to_param_name: raise ValueError('The same model must be passed to the optimizer and trainer.') param_name_to_group_num[ptr_to_param_name[param_ptr]] = group_num # this includes optimizer-specific values like lr, eps # this will be used as the kwargs for the optim param groups later optimizer_specific_group_info = { k: v for k, v in optim.param_groups[group_num].items() if k != 'params' } group_num_to_opt_group_info[group_num] = optimizer_specific_group_info else: if len(optim.param_groups) > 1: raise RuntimeError('Multiple optimizer groups with FSDP are not supported with use_orig_params=False.',) if len(optim.param_groups[0]['params']) != len(list(model.parameters())): raise ValueError( 'Passing in a subset of model parameters to the optimizer is not supported with use_orig_params=False.', ) optim.param_groups.clear() optim.state.clear() sharding_map_key = fsdp_config.sharding_strategy.upper() sharding_strategy = SHARDING_MAP[sharding_map_key] kwargs = {} if version.parse( torch.__version__.split('.dev')[0], ) >= version.parse('2.2.0') and fsdp_config.device_mesh is not None: if fsdp_config.process_group is not None: warnings.warn( 'process_group and device_mesh are set for FSDP, so ignoring device_mesh. Please set process_group to None.', ) else: ndim = fsdp_config.device_mesh.ndim if ndim == 1 and sharding_strategy == ShardingStrategy.HYBRID_SHARD: sharding_strategy = ShardingStrategy.FULL_SHARD warnings.warn('HYBRID_SHARD is not supported with 1D device mesh. Using FULL_SHARD instead.') elif ndim == 1 and sharding_strategy == ShardingStrategy._HYBRID_SHARD_ZERO2: sharding_strategy = ShardingStrategy.SHARD_GRAD_OP warnings.warn('_HYBRID_SHARD_ZERO2 is not supported with 1D device mesh. Using SHARD_GRAD_OP instead.') elif ndim == 2 and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP: sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 warnings.warn('SHARD_GRAD_OP is not supported with 2D device mesh. Using _HYBRID_SHARD_ZERO2 instead.') elif ndim == 2 and sharding_strategy == ShardingStrategy.FULL_SHARD: sharding_strategy = ShardingStrategy.HYBRID_SHARD warnings.warn('FULL_SHARD is not supported with 2D device mesh. Using HYBRID_SHARD instead.') kwargs['device_mesh'] = fsdp_config.device_mesh cpu_offload = get_cpu_offload(cpu_offload=fsdp_config.cpu_offload) mixed_precision = fsdp_config.mixed_precision keep_low_precision_grads = fsdp_config.keep_low_precision_grads mixed_precision, _, _, _ = get_mixed_precision( precision, mixed_precision=mixed_precision, keep_low_precision_grads=keep_low_precision_grads, ) process_group = None if fsdp_config.process_group is not None: process_group_dict = {'process_group': fsdp_config.process_group} process_group = set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group'] backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config.backward_prefetch.upper()] activation_checkpointing = fsdp_config.activation_checkpointing activation_cpu_offload = fsdp_config.activation_cpu_offload sync_module_states = fsdp_config.sync_module_states forward_prefetch = fsdp_config.forward_prefetch limit_all_gathers = fsdp_config.limit_all_gathers ignored_modules = fsdp_config.ignored_modules state_dict_type = fsdp_config.state_dict_type activation_checkpointing_reentrant = fsdp_config.activation_checkpointing_reentrant te_checkpoint_wrapper = fsdp_config.te_checkpoint_wrapper if precision == Precision.AMP_FP8 else False te_shard_fp8_weight = fsdp_config.te_shard_fp8_weight if precision == Precision.AMP_FP8 else False sharded_ckpt_prefix_dir = fsdp_config.sharded_ckpt_prefix_dir use_orig_params = fsdp_config.use_orig_params fsdp_obj_named_modules = {} # We choose to not wrap the ComposerModel directly, but instead wrap any submodules like `ComposerModel.model` # This makes it safer to call ComposerModel-specific functions like 'eval_forward' that # may make calls to sharded submodules. If we only wrap the submodules, then any call that ComposerModel makes # to a FSDP-wrapped submodule's `forward()` function will be safe and all-gather the necessary weights before `forward()`. for obj_name, obj in model.named_children(): if not isinstance(obj, (Metric, MetricCollection)): # Skip wrapping submodules which are explicitly marked with no wrap if hasattr(obj, '_fsdp_wrap') and not bool(obj._fsdp_wrap): continue # A dictionary of all tied parameter pointers to (module, attr) tuples tied_pointers = {} # Goes through all modules finding which weights have the same pointers for mod in obj.modules(): for attr_name, attr in mod.named_parameters(recurse=False): ptr = id(attr) mod_attr_list = tied_pointers.get(ptr, []) mod_attr_list.append((mod, attr_name)) tied_pointers[ptr] = mod_attr_list # Dictionary mapping the source module to a list of (target module, source attr, target attr) tuples source_mod_to_mod_attr = {} for mod_attr_list in tied_pointers.values(): # If there is only one module for this pointer, then there is no weight tying if len(mod_attr_list) == 1: continue # Arbitrarily choose the first module as the source module first_mod, first_attr = mod_attr_list[0] source_mod_to_mod_attr[first_mod] = [ (target_mod, first_attr, dest_attr) for target_mod, dest_attr in mod_attr_list[1:] ] # Clean up no longer needed module references for memory safety del tied_pointers def _param_init_fn(module: torch.nn.Module) -> None: # If we do not have any parameters or buffers on meta device managed by this module directly, we do not need to call the parameter init function. # It is assumed that whatever process moved the parameters off of meta device initialized them. # We expect this to occur if we have tied weights, as the second module will already have the weights initialized. is_meta = any(param.is_meta for param in module.parameters(recurse=False) ) or any(buffer.is_meta for buffer in module.buffers(recurse=False)) if not is_meta: return # Move all parameters and buffers to the current device module.to_empty(device=f'cuda:{torch.cuda.current_device()}', recurse=False) # Redo weight tying, which will have been broken by the above line that moves parameters off of meta device if module in source_mod_to_mod_attr: for target_mod, first_attr, dest_attr in source_mod_to_mod_attr[module]: setattr(target_mod, dest_attr, getattr(module, first_attr)) # Run the specified initialization if hasattr(obj, 'param_init_fn') and isinstance(obj.param_init_fn, Callable): obj.param_init_fn(module) elif hasattr(module, 'reset_parameters') and isinstance(module.reset_parameters, Callable): module.reset_parameters() else: raise ValueError( f'Object `{obj_name}` does not have a ``param_init_fn`` or a ``reset_parameters`` function. ' 'This leaves parameters without initialization. Please add a ``param_init_fn`` or ``reset_parameters`` ' f'to module `{obj_name}`.', ) def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]: ret = False if hasattr(module, '_fsdp_wrap'): ret = bool(module._fsdp_wrap) elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable): ret = obj.fsdp_wrap_fn(module) if isinstance(ret, dict): ret = set_custom_fsdp_module_kwargs(ret, process_group_cache) return ret _auto_wrap_policy = CustomPolicy(lambda_fn) fsdp_obj = FullyShardedDataParallel( obj, process_group=process_group, sharding_strategy=sharding_strategy, auto_wrap_policy=_auto_wrap_policy, # type: ignore FSDP type bug cpu_offload=cpu_offload, mixed_precision=mixed_precision, backward_prefetch=backward_prefetch, ignored_modules=ignored_modules, param_init_fn=_param_init_fn, device_id=torch.cuda.current_device(), sync_module_states=sync_module_states, forward_prefetch=forward_prefetch, limit_all_gathers=limit_all_gathers, use_orig_params=use_orig_params, **kwargs, ) if te_shard_fp8_weight: try: from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp except ModuleNotFoundError: raise ModuleNotFoundError('Please install transformer-engine to use prepare_te_modules_for_fsdp') log.info(f'Calling prepare_te_modules_for_fsdp to enable TE weights sharding') prepare_te_modules_for_fsdp(fsdp_obj) # The following sync hooks are added to prevent FSDP deadlocks that are caused when some ranks OOM # and other ranks do not OOM, leading to OOMing ranks calling all_reduce to wait on the non-OOMing # ranks and the non-OOMing ranks calling all_gatherbase to continue with FSDP training: # # forward_pre_hook: before forwards of FSDP modules # full_backward_pre_hook: before backwards of FSDP modules # full_backward_hook: before a prefetched unshard called by FSDP's `post_backward_reshard` if auto_microbatching: for _, module in fsdp_obj.named_modules(): if isinstance(module, FullyShardedDataParallel): hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) else: hook_handles.append(module.register_full_backward_hook(sync_hook)) fsdp_obj_named_modules.update(dict(fsdp_obj.named_modules())) if hasattr(fsdp_obj, '_exec_order_data'): if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'): fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config.forward_prefetch_limit else: warnings.warn( 'FSDP._exec_order_data does not have attribute _forward_prefetch_limit ' 'which is unexpected and will result in `forward_prefetch_limit` from FSDP ' 'config being ignored. Please open an issue to Composer to report this.', ) if hasattr(fsdp_obj._exec_order_data, '_backward_prefetch_limit'): fsdp_obj._exec_order_data._backward_prefetch_limit = fsdp_config.backward_prefetch_limit else: warnings.warn( 'FSDP._exec_order_data does not have attribute _backward_prefetch_limit ' 'which is unexpected and will result in `backward_prefetch_limit` from FSDP ' 'config being ignored. Please open an issue to Composer to report this.', ) else: warnings.warn( 'FSDP does not have attribute _exec_order_data which is unexpected and will ' 'result in `forward_prefetch_limit` and `backward_prefetch_limit` from FSDP ' 'config being ignored. Please open an issue to Composer to report this.', ) # Activation Checkpointing if activation_checkpointing or activation_cpu_offload: # FP8 TE requires using the TE checkpoint function, FSDP activation checkpointing only works with TE non-reentrant checkpointing if te_checkpoint_wrapper: assert not activation_checkpointing_reentrant, 'TE checkpoint only works with non-reentrant checkpointing' if not activation_checkpointing_reentrant: if te_checkpoint_wrapper: try: import transformer_engine.pytorch as te except ModuleNotFoundError: raise ModuleNotFoundError('Please install transformer-engine to use TE checkpoint wrapper',) # RNG state tracker for checkpointing CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker() CUDA_RNG_STATES_TRACKER.add('fsdp-rng', te_rng_seed) def get_cuda_rng_tracker(): return CUDA_RNG_STATES_TRACKER first_wrap_fn = lambda m: checkpoint_wrapper( m, context_fn=te.distributed.get_activation_recompute_contexts, checkpoint_fn=te.distributed.checkpoint, use_reentrant=False, get_rng_state_tracker=get_cuda_rng_tracker, ) else: first_wrap_fn = lambda m: checkpoint_wrapper( m, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) if activation_checkpointing else (lambda module: module) second_wrap_fn = ( lambda module: offload_wrapper( first_wrap_fn(module) if activation_checkpointing else module, # type: ignore reportGeneralTypeIssues ) ) if activation_cpu_offload else first_wrap_fn else: first_wrap_fn = lambda m: checkpoint_wrapper( m, checkpoint_impl=CheckpointImpl.REENTRANT, ) if activation_checkpointing else (lambda module: module) second_wrap_fn = ( lambda module: offload_wrapper( first_wrap_fn(module) if activation_checkpointing else module, # type: ignore reportGeneralTypeIssues ) ) if activation_cpu_offload else first_wrap_fn # Choose which modules to activation checkpoint according to the following priority: # If module has attribute `module._activation_checkpointing = ...`, always respect it # Otherwise checkpoint if root object `obj.activation_checkpointing_fn(module)` is true def _check_fn(module: torch.nn.Module) -> bool: if isinstance(module, FullyShardedDataParallel): return False if hasattr(module, '_activation_checkpointing'): return bool(module._activation_checkpointing) if hasattr( obj, 'activation_checkpointing_fn', ) and isinstance(obj.activation_checkpointing_fn, Callable): return obj.activation_checkpointing_fn(module) return False apply_activation_checkpointing( fsdp_obj, checkpoint_wrapper_fn=second_wrap_fn, # type: ignore check_fn=_check_fn, # type: ignore ) setattr(model, obj_name, fsdp_obj) # Print FSDP wrapped model and FSDP config if `verbose=True` if fsdp_config.verbose: log.info(f'FSDP: Wrapped model: {model}') log.info(f'FSDP: Using sharding_strategy={sharding_strategy}') log.info(f'FSDP: Using cpu_offload={cpu_offload}') log.info(f'FSDP: Using mixed_precision={mixed_precision}') log.info(f'FSDP: Using backward_prefetch={backward_prefetch}') log.info(f'FSDP: Using activation_checkpointing={activation_checkpointing}') log.info(f'FSDP: Using activation_cpu_offload={activation_cpu_offload}') log.info(f'FSDP: Using te_checkpoint_wrapper={te_checkpoint_wrapper}') log.info(f'FSDP: Using te_shard_fp8_weight={te_shard_fp8_weight}') log.info(f'FSDP: Using sync_module_states={sync_module_states}') log.info(f'FSDP: Using forward_prefetch={forward_prefetch}') log.info(f'FSDP: Using limit_all_gathers={limit_all_gathers}') log.info(f'FSDP: Using state_dict_type={state_dict_type}') log.info(f'FSDP: Using sharded_ckpt_prefix_dir={sharded_ckpt_prefix_dir}') # Rebuild optimizer now that parameters are sharded if optimizers: optim = ensure_tuple(optimizers)[0] optim.param_groups.clear() if single_param_group_opt_info is not None: single_param_group_opt_info.update({'params': list(model.parameters())}) optim.add_param_group(single_param_group_opt_info) elif fsdp_config.use_orig_params: assert param_name_to_group_num is not None assert group_num_to_opt_group_info is not None param_groups = _recreate_fsdp_param_groups_from_unwrapped_opt_info( model.named_parameters(), param_name_to_group_num, group_num_to_opt_group_info, ) for param_group in param_groups: optim.add_param_group(param_group) return hook_handles, fsdp_obj_named_modules