huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.36k stars 875 forks source link

Unable to specify HYBRID_SHARD for FSDP which requires process group or device_mesh to be passed #2671

Open npuichigo opened 2 months ago

npuichigo commented 2 months ago

System Info

- `Accelerate` version: 0.29.1
- Platform: Linux-5.19.0-46-generic-x86_64-with-glibc2.35
- `accelerate` bash location: /home/yuchao/miniconda3/envs/TorchTTS/bin/accelerate
- Python version: 3.10.13
- Numpy version: 1.23.5
- PyTorch version (GPU?): 2.2.2+cu118 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 125.48 GB
- GPU type: NVIDIA GeForce RTX 4090
- `Accelerate` default config:
fsdp_plugin:
  sharding_strategy: 4

Information

Tasks

Reproduction

Specify fsdp strategy to ShardingStrategy.HYBRID_SHARD or _HYBRID_SHARD_ZERO2

File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/accelerate/accelerator.py", line 1434, in prepare_model
    return self.prepare_model(obj, device_placement=device_placement)
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/accelerate/accelerator.py", line 1434, in prepare_model
    model = FSDP(model, **kwargs)
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 448, in __init__
    model = FSDP(model, **kwargs)
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 448, in __init__
    _init_process_group_state(
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 118, in _init_process_group_state
    _init_process_group_state(
  File "/home/yuchao/miniconda3/envs/TorchTTS/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 118, in _init_process_group_state
    raise ValueError(
ValueError: ('Manual wrapping with ShardingStrategy.HYBRID_SHARD', 'requires explicit specification of process group or device_mesh.')
    raise ValueError(
ValueError: ('Manual wrapping with ShardingStrategy.HYBRID_SHARD', 'requires explicit specification of process group or device_mesh.')

Expected behavior

Should provide a way to provide the process group or device_mesh as they're parameters for FSDP https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel

muellerzr commented 2 months ago

I think the issue here is you're not setting an auto_wrap_policy, which is required? cc @pacman100 to confirm

npuichigo commented 2 months ago

auto_wrap_policy may be another way. But I find no way to setup process_group and device_mesh which is in the interface of FSDP

torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)
muellerzr commented 2 months ago

True, we need to enable this in the FullyShardedDataParallelPlugin, and make it work nicely when things are set via accelerate launch. (In which, this would not be a param capable of being set there, for obvious reasons)

lidingsnyk commented 1 month ago

Sounds like one implication here is that transformer Trainer's --fsdp hybrid_shard / --fsdp hybrid_shard_zero2 can't possibly run successfully, since process_group is not amoung trainer arg

github-actions[bot] commented 4 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.