Closed lchu-ibm closed 3 months ago
It turns out that fsdp requires a slightly different approach when initializing after sharding. I'll open a PR to fix this on Monday
@daviswer did we revisit this thread? just making sure if this issue is still there.
I think #18 changed the calculus for how we were planning to handle this, and after that it never got revisited. Not sure if the issue is still relevant.
@daviswer yes. that pr helps this issue by moving reset_parameters()
before FSDP call. But ultimately we would want the call to be made after the FSDP call, as we can save a full cpu model initialized before FSDP call, which can be 20 mins for large models like 70b.
So this will have to be fixed eventually.
So the way we'd want to do this is to add the (various) reset_parameters()
to this portion of the FSDP call in main_training.py
.
param_init_fn=lambda module: (
module.to_empty(device=torch.device("cuda"), recurse=False)
if cfg.low_cpu_fsdp
else None
),
But we need to make sure it keeps playing nicely with the the low_cpu_fsdp
flag. Since you know that portion and I know the model init portion, we should probably coordinate @lchu-ibm
@daviswer technically, since you named it reset_parameters()
, we can make param_init_fn=None
as under the hood it will call reset_parameters()
if no specific param_init_fn
is passed. However, I vaguely remembered this wasn't working as expected last time (well, last time was end of last year, so maybe worth revisiting).
Can you prepare a small validation code snippet (to be called after FSDP call) to validate if the model is init as expected?
e.g. it should pass with current code, it should not pass with current code but removing model.reset_parameters()
, and it should pass if we do a good param_init_fn
.
ok I opened a branch of fms main: fsdp_init_check
, which adds a check_weights()
function to Llama. This should error out for any improper init and return silently if successful
@daviswer great. I will start working on this.
closing this one in favor of the new issue: https://github.com/foundation-model-stack/fms-fsdp/issues/64
@daviswer It seems calling model.reset_parameters() after FSDP call will raise the following error.
Can you take a look?
Moving it before the FSDP won't trigger this error.