Accelerate documentation

Fully Sharded Data Parallel utilities

You are viewing v1.2.1 version. A newer version v1.3.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Fully Sharded Data Parallel utilities

enable_fsdp_ram_efficient_loading

accelerate.utils.enable_fsdp_ram_efficient_loading

< >

( )

Enables RAM efficient loading of Hugging Face models for FSDP in the environment.

disable_fsdp_ram_efficient_loading

accelerate.utils.disable_fsdp_ram_efficient_loading

< >

( )

Disables RAM efficient loading of Hugging Face models for FSDP in the environment.

merge_fsdp_weights

accelerate.utils.merge_fsdp_weights

< >

( checkpoint_dir: str output_path: str safe_serialization: bool = True remove_checkpoint_dir: bool = False )

Parameters

  • checkpoint_dir (str) — The directory containing the FSDP checkpoints (can be either the model or optimizer).
  • output_path (str) — The path to save the merged checkpoint.
  • safe_serialization (bool, optional, defaults to True) — Whether to save the merged weights with safetensors (recommended).
  • remove_checkpoint_dir (bool, optional, defaults to False) — Whether to remove the checkpoint directory after merging.

Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if SHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors if safe_serialization else pytorch_model.bin.

Note: this is a CPU-bound process.

FullyShardedDataParallelPlugin

class accelerate.FullyShardedDataParallelPlugin

< >

( sharding_strategy: typing.Union[str, ForwardRef('torch.distributed.fsdp.ShardingStrategy')] = None backward_prefetch: typing.Union[str, ForwardRef('torch.distributed.fsdp.BackwardPrefetch')] = None mixed_precision_policy: typing.Union[dict, ForwardRef('torch.distributed.fsdp.MixedPrecision'), NoneType] = None auto_wrap_policy: typing.Union[typing.Callable, typing.Literal['transformer_based_wrap', 'size_based_wrap', 'no_wrap'], NoneType] = None cpu_offload: typing.Union[bool, ForwardRef('torch.distributed.fsdp.CPUOffload')] = None ignored_modules: typing.Optional[typing.Iterable[torch.nn.modules.module.Module]] = None state_dict_type: typing.Union[str, ForwardRef('torch.distributed.fsdp.StateDictType')] = None state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedStateDictConfig'), NoneType] = None optim_state_dict_config: typing.Union[ForwardRef('torch.distributed.fsdp.FullOptimStateDictConfig'), ForwardRef('torch.distributed.fsdp.ShardedOptimStateDictConfig'), NoneType] = None limit_all_gathers: bool = True use_orig_params: bool = None param_init_fn: typing.Optional[typing.Callable[[torch.nn.modules.module.Module], NoneType]] = None sync_module_states: bool = None forward_prefetch: bool = None activation_checkpointing: bool = None cpu_ram_efficient_loading: bool = None transformer_cls_names_to_wrap: typing.Optional[typing.List[str]] = None min_num_params: typing.Optional[int] = None )

Parameters

  • sharding_strategy (Union[str, torch.distributed.fsdp.ShardingStrategy], defaults to 'FULL_SHARD') — Sharding strategy to use. Should be either a str or an instance of torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy.
  • backward_prefetch (Union[str, torch.distributed.fsdp.BackwardPrefetch], defaults to 'NO_PREFETCH') — Backward prefetch strategy to use. Should be either a str or an instance of torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch.
  • mixed_precision_policy (Optional[Union[dict, torch.distributed.fsdp.MixedPrecision]], defaults to None) — A config to enable mixed precision training with FullyShardedDataParallel. If passing in a dict, it should have the following keys: param_dtype, reduce_dtype, and buffer_dtype.
  • auto_wrap_policy (Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to NO_WRAP) -- A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of transformer_based_wrap, size_based_wrap, or no_wrap. See torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like.
  • cpu_offload (Union[bool, torch.distributed.fsdp.CPUOffload], defaults to False) — Whether to offload parameters to CPU. Should be either a bool or an instance of torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload.
  • ignored_modules (Optional[Iterable[torch.nn.Module]], defaults to None) — A list of modules to ignore when wrapping with FSDP.
  • state_dict_type (Union[str, torch.distributed.fsdp.StateDictType], defaults to 'FULL_STATE_DICT') — State dict type to use. If a string, it must be one of full_state_dict, local_state_dict, or sharded_state_dict.
  • state_dict_config (Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig], defaults to None) — State dict config to use. Is determined based on the state_dict_type if not passed in.
  • optim_state_dict_config (Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig], defaults to None) — Optim state dict config to use. Is determined based on the state_dict_type if not passed in.
  • limit_all_gathers (bool, defaults to True) — Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries.
  • use_orig_params (bool, defaults to False) — Whether to use the original parameters for the optimizer.
  • param_init_fn (Optional[Callable[[torch.nn.Module], None], defaults to None) — A Callable[torch.nn.Module] -> None that specifies how modules that are currently on the meta device should be initialized onto an actual device. Only applicable when sync_module_states is True. By default is a lambda which calls to_empty on the module.
  • sync_module_states (bool, defaults to False) — Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. Defaults to False unless cpu_ram_efficient_loading is True, then will be forcibly enabled.
  • forward_prefetch (bool, defaults to False) — Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. only use with Static graphs.
  • activation_checkpointing (bool, defaults to False) — A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time for reduced memory usage.
  • cpu_ram_efficient_loading (bool, defaults to None) — If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. Only applicable for Transformers. When using this, sync_module_states needs to be True.
  • transformer_cls_names_to_wrap (Optional[List[str]], defaults to None) — A list of transformer layer class names to wrap. Only applicable when auto_wrap_policy is transformer_based_wrap.
  • min_num_params (Optional[int], defaults to None) — The minimum number of parameters a module must have to be wrapped. Only applicable when auto_wrap_policy is size_based_wrap.

This plugin is used to enable fully sharded data parallelism.

set_auto_wrap_policy

< >

( model )

Given model, creates an auto_wrap_policy baesd on the passed in policy and if we can use the transformer_cls_to_wrap

set_mixed_precision

< >

( mixed_precision buffer_autocast = False override = False )

Sets the mixed precision policy for FSDP

set_state_dict_type

< >

( state_dict_type = None )

Set the state dict config based on the StateDictType.

< > Update on GitHub