Closed ktam3 closed 1 month ago
Summary from discussion:
@RobotSail @JamesKunstle @Maxusmusti - to follow up and create issues and link to this epic as the work is being done
@RobotSail - additional notes wrt to support
With regards to FSDP - the main risks that we still need to overcome are going to be:
LoRA Getting LoRA to work properly, since DeepSpeed was very compatible with running PEFT models, FSDP will require more work on our end to get this working.
Checkpointing
In our current implementation, we run DeepSpeed with ZeRO stage-2, which allows us to save a model checkpoint by taking its state on one of the running GPUs because the models are simply replicated across all GPUs. DeepSpeed implements all ZeRO stages, but we are only using stage-2 at the moment.
zero stages listed for reference:
FSDP on the other hand only supports ZeRO stage-3 or no offloading at all. So for this reason, it wouldn't be straightforward to feature-gate DeepSpeed as-is without also providing ZeRO-3 support there as well.
We'll need to make sure that this is all tested against the full matrix of the devices we intend to support as well
The following issues now are a part of this epic:
Adding the following general issue as well:
I'll be working on converting / testing this code on Gaudi 2 cards in the multi-GPU case as well.
@Maxusmusti @RobotSail It sounds like this is being solved by @aldopareja's PR that uses Accelerate, and Mustafa's work that enables LoRA checkpointing. What do we need to do to finish this / get it tested?
@JamesKunstle we should sync on this tmrw either before or after meetings today to make sure we have everything. Checkpoint resuming and a lot of fsdp testing will def be needed this week, and we still need to bring back padding free via hf transformers Granite model class.
@JamesKunstle - can we close this epic if it's done?
Closing this as discussed in chat. Feel free to reopen if it's incorrect
Feature Overview This Feature card is for transitioning our model training infrastructure from DeepSpeed to PyTorch's Fully Sharded Data Parallel (FSDP) to enhance training metrics visibility, broaden accelerator support, and maintain performance parity.
Goals
Requirements
Completion Checklist:
Questions to Answer
Out of Scope
Background Our current training infrastructure uses DeepSpeed for distributed training. While effective, transitioning to PyTorch FSDP offers strategic advantages in terms of metrics visibility, accelerator support, and potential performance improvements.
User Considerations
Documentation Considerations
Additional notes wrt to FSDP -