huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.98k stars 974 forks source link

[RFC] Support FSDP2 #3231

Open kmehant opened 2 weeks ago

kmehant commented 2 weeks ago

What does this PR do?

Prototype implementation for porting from FSDP V1 to FSDP V2. There are couple of open questions in this PR that would need comments and discussion.

  1. Do we want to maintain FSDP V1 as is and add a experimental parallel to FSDP V2?
  2. When we want to maintain 2 versions, should we maintain separate FSDP plugins and distributed types for each versions?
  3. For HF/transformers users, using fsdp_config, how we want to allow them to choose between these versions?
  4. How we want prepare 2D mesh for HSDP, should that be an input from user?

Preliminary run of this PR and results

The current version of the PR has been tested for basic functionality (full shard) and compared with previous FSDP V1 implementation.

Key Value
Model Maykeye/TinyLLama-v0
Mesh size 2 GPUs
sharding full shard

Memory

Screenshot 2024-11-09 at 12 50 10 AM

Loss Parity

Screenshot 2024-11-09 at 12 59 56 AM

Throughput

TODO

Fixes #2873

Before submitting

Who can review?

@muellerzr

raghukiran1224 commented 1 week ago

@ByronHsu FYI - thoughts?