Closed mayank31398 closed 3 months ago
@mayank31398
In short, currently we need a post-init to match a true init which is done with reset_parameters()
. The way you proposed here (which is actually what we did before) will require that post-init to be done after the FSDP wrapping call and we haven't got bandwidth addressing that. see https://github.com/foundation-model-stack/fms-fsdp/issues/15
For details, see https://github.com/foundation-model-stack/fms-fsdp/issues/6
It seems to get stuck.
for large model like 70b, this will take decent amount of time (10-20 mins ish, or even a little more). But I think it should work.
In the future we will revert back to old implementation (the one you proposed here) once the above issue is fixed.
@lchu-ibm hmm weird, I am not seeing https://github.com/foundation-model-stack/fms-fsdp/issues/15 on my end. but I am not using FMS maybe that could be a factor. could be some difference in modeling code maybe? I didn't understand #6 though. I am not familiar with TP implementation of torch. I can try my own TP implementation with FSDP. Feel free to close this issue though :)
@mayank31398
Yes, it is a little bit complicated. But in short -
What you proposed is definitely correct, and we also used it that way in the past. But for a certain issue we have, we need to make sure at least one full correct copy of the model is presented before the FSDP call, thus we use rank==0
trick. This should save cpu model copy from world_size copies to only 1 copy, but still less efficient than meta-device-all-ranks.
once that issue is fixed, we will revert back to old implementation (the one you have here.)
Makes sense Closing issue
@mayank31398 this should happen soon: https://github.com/foundation-model-stack/fms-fsdp/issues/64
I am trying to replicate the current code in my training codebase. It seems to get stuck. Can you see if its the same case for you? I am also opening this PR for a fix that works for me.