foundation-model-stack / fms-fsdp

🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
https://pytorch.org/docs/stable/fsdp.html
Apache License 2.0
116 stars 18 forks source link

RuntimeError: CUDA driver error: an illegal memory access was encountered #15

Closed lchu-ibm closed 3 months ago

lchu-ibm commented 4 months ago

@daviswer It seems calling model.reset_parameters() after FSDP call will raise the following error.

Can you take a look?

[rank8]: Traceback (most recent call last):
[rank8]:   File "/lustre/lchu/fms-fsdp/main_training.py", line 168, in <module>
[rank8]:     fire.Fire(main)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
[rank8]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
[rank8]:     component, remaining_args = _CallAndUpdateTrace(
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
[rank8]:     component = fn(*varargs, **kwargs)
[rank8]:   File "/lustre/lchu/fms-fsdp/main_training.py", line 130, in main
[rank8]:     model.reset_parameters()
[rank8]:   File "/lustre/t5/public/foundation-model-stack/fms/models/llama.py", line 237, in reset_parameters
[rank8]:     m.reset_parameters()
[rank8]:   File "/lustre/t5/public/foundation-model-stack/fms/modules/embedding.py", line 95, in reset_parameters
[rank8]:     nn.init.trunc_normal_(getattr(self, layer).weight, mean=0.0, std=0.02)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/torch/nn/init.py", line 205, in trunc_normal_
[rank8]:     return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
[rank8]:   File "/home/lchu/.conda/envs/latest/lib/python3.9/site-packages/torch/nn/init.py", line 47, in _no_grad_trunc_normal_
[rank8]:     tensor.erfinv_()
[rank8]: RuntimeError: CUDA driver error: an illegal memory access was encountered

Moving it before the FSDP won't trigger this error.

daviswer commented 4 months ago

It turns out that fsdp requires a slightly different approach when initializing after sharding. I'll open a PR to fix this on Monday

lchu-ibm commented 3 months ago

@daviswer did we revisit this thread? just making sure if this issue is still there.

daviswer commented 3 months ago

I think #18 changed the calculus for how we were planning to handle this, and after that it never got revisited. Not sure if the issue is still relevant.

lchu-ibm commented 3 months ago

@daviswer yes. that pr helps this issue by moving reset_parameters() before FSDP call. But ultimately we would want the call to be made after the FSDP call, as we can save a full cpu model initialized before FSDP call, which can be 20 mins for large models like 70b.

So this will have to be fixed eventually.

daviswer commented 3 months ago

So the way we'd want to do this is to add the (various) reset_parameters() to this portion of the FSDP call in main_training.py.

        param_init_fn=lambda module: (
            module.to_empty(device=torch.device("cuda"), recurse=False)
            if cfg.low_cpu_fsdp
            else None
        ),

But we need to make sure it keeps playing nicely with the the low_cpu_fsdp flag. Since you know that portion and I know the model init portion, we should probably coordinate @lchu-ibm

lchu-ibm commented 3 months ago

@daviswer technically, since you named it reset_parameters(), we can make param_init_fn=None as under the hood it will call reset_parameters() if no specific param_init_fn is passed. However, I vaguely remembered this wasn't working as expected last time (well, last time was end of last year, so maybe worth revisiting).

Can you prepare a small validation code snippet (to be called after FSDP call) to validate if the model is init as expected?

e.g. it should pass with current code, it should not pass with current code but removing model.reset_parameters(), and it should pass if we do a good param_init_fn.

daviswer commented 3 months ago

ok I opened a branch of fms main: fsdp_init_check, which adds a check_weights() function to Llama. This should error out for any improper init and return silently if successful

lchu-ibm commented 3 months ago

@daviswer great. I will start working on this.

lchu-ibm commented 3 months ago

closing this one in favor of the new issue: https://github.com/foundation-model-stack/fms-fsdp/issues/64