foundation-model-stack / fms-acceleration

🚀 Collection of libraries used with fms-hf-tuning to accelerate fine-tuning and training of large models.
Apache License 2.0
6 stars 12 forks source link

ScatterMoE Gradient Norm Needs to be Properly Computed When Used With FSDP #109

Open fabianlim opened 2 weeks ago

fabianlim commented 2 weeks ago

When ScatterMoE is used together with FSDP, HF Accelerate will call FSDP's clip_grad_norm, see here, which does not know how to properly compute the grad norms for ScatterMoE shards.

We need to be able to hook in some logic to handle the grad norm properly.