foundation-model-stack / fms-fsdp

🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
https://pytorch.org/docs/stable/fsdp.html
Apache License 2.0
114 stars 18 forks source link

fix meta device initialization for very large models #54

Closed mayank31398 closed 3 months ago

mayank31398 commented 3 months ago

I am trying to replicate the current code in my training codebase. It seems to get stuck. Can you see if its the same case for you? I am also opening this PR for a fix that works for me.

lchu-ibm commented 3 months ago

@mayank31398

In short, currently we need a post-init to match a true init which is done with reset_parameters(). The way you proposed here (which is actually what we did before) will require that post-init to be done after the FSDP wrapping call and we haven't got bandwidth addressing that. see https://github.com/foundation-model-stack/fms-fsdp/issues/15

For details, see https://github.com/foundation-model-stack/fms-fsdp/issues/6

lchu-ibm commented 3 months ago

It seems to get stuck.

for large model like 70b, this will take decent amount of time (10-20 mins ish, or even a little more). But I think it should work.

In the future we will revert back to old implementation (the one you proposed here) once the above issue is fixed.

mayank31398 commented 3 months ago

@lchu-ibm hmm weird, I am not seeing https://github.com/foundation-model-stack/fms-fsdp/issues/15 on my end. but I am not using FMS maybe that could be a factor. could be some difference in modeling code maybe? I didn't understand #6 though. I am not familiar with TP implementation of torch. I can try my own TP implementation with FSDP. Feel free to close this issue though :)

lchu-ibm commented 3 months ago

@mayank31398

Yes, it is a little bit complicated. But in short -

What you proposed is definitely correct, and we also used it that way in the past. But for a certain issue we have, we need to make sure at least one full correct copy of the model is presented before the FSDP call, thus we use rank==0 trick. This should save cpu model copy from world_size copies to only 1 copy, but still less efficient than meta-device-all-ranks.

once that issue is fixed, we will revert back to old implementation (the one you have here.)

mayank31398 commented 3 months ago

Makes sense Closing issue

lchu-ibm commented 3 months ago

@mayank31398 this should happen soon: https://github.com/foundation-model-stack/fms-fsdp/issues/64