instructlab / training

InstructLab Training Library
Apache License 2.0
13 stars 35 forks source link

[Epic] Replace DeepSpeed with PyTorch FSDP for Model Training #197

Open ktam3 opened 1 week ago

ktam3 commented 1 week ago

Feature Overview This Feature card is for transitioning our model training infrastructure from DeepSpeed to PyTorch's Fully Sharded Data Parallel (FSDP) to enhance training metrics visibility, broaden accelerator support, and maintain performance parity.

Goals

Requirements

  1. Implement PyTorch FSDP as the primary distributed training framework, replacing DeepSpeed.
  2. Integrate PyTorch FSDP with Weights & Biases for comprehensive training metrics collection and visualization.
  3. Ensure compatibility with a broad range of accelerators (e.g., NVIDIA GPUs, AMD GPUs, TPUs).
  4. Achieve performance parity or improvement compared to DeepSpeed on GPU configurations.
  5. Implement and test CPU offload capabilities.
  6. Update all relevant training scripts and documentation to reflect the transition to PyTorch FSDP.
  7. Ensure security measures are in place for data handling during distributed training.
  8. Maintain or improve the scalability of the training process.
  9. (if applicable) Provide clear documentation on how to use the new PyTorch FSDP setup for different training scenarios.

Completion Checklist:

Questions to Answer

  1. What is the performance impact of PyTorch FSDP on our specific model architectures?
  2. How does the CPU offload capability of PyTorch FSDP compare to DeepSpeed?
  3. Are there any specific optimizations needed for different accelerator types?
  4. What changes are required in our CI/CD pipeline to accommodate this transition?

Out of Scope

Background Our current training infrastructure uses DeepSpeed for distributed training. While effective, transitioning to PyTorch FSDP offers strategic advantages in terms of metrics visibility, accelerator support, and potential performance improvements.

User Considerations

Documentation Considerations

Additional notes wrt to FSDP -

ktam3 commented 1 week ago

Summary from discussion:

@RobotSail @JamesKunstle @Maxusmusti - to follow up and create issues and link to this epic as the work is being done

ktam3 commented 1 week ago

@RobotSail - additional notes wrt to support

RobotSail commented 1 week ago

With regards to FSDP - the main risks that we still need to overcome are going to be:

LoRA Getting LoRA to work properly, since DeepSpeed was very compatible with running PEFT models, FSDP will require more work on our end to get this working.

Checkpointing

In our current implementation, we run DeepSpeed with ZeRO stage-2, which allows us to save a model checkpoint by taking its state on one of the running GPUs because the models are simply replicated across all GPUs. DeepSpeed implements all ZeRO stages, but we are only using stage-2 at the moment.

zero stages listed for reference:

FSDP on the other hand only supports ZeRO stage-3 or no offloading at all. So for this reason, it wouldn't be straightforward to feature-gate DeepSpeed as-is without also providing ZeRO-3 support there as well.

We'll need to make sure that this is all tested against the full matrix of the devices we intend to support as well

RobotSail commented 4 days ago

The following issues now are a part of this epic:

JamesKunstle commented 3 days ago

Adding the following general issue as well:

I'll be working on converting / testing this code on Gaudi 2 cards in the multi-GPU case as well.