TRI-ML / prismatic-vlms

A flexible and efficient codebase for training visually-conditioned language models (VLMs)
MIT License
327 stars 93 forks source link

`clip_grad_norm_` in fsdp #35

Closed zeyuanyin closed 1 month ago

zeyuanyin commented 1 month ago

Hi @siddk, I modified the VLM model but the program is stuck when running self.clip_grad_norm() in base_strategy.py

https://github.com/TRI-ML/prismatic-vlms/blob/874c5bbff52b248294a3ab97006491a7faa698e6/prismatic/training/strategies/base_strategy.py#L214

The related code in fsdp.py is

https://github.com/TRI-ML/prismatic-vlms/blob/874c5bbff52b248294a3ab97006491a7faa698e6/prismatic/training/strategies/fsdp.py#L241-L243

But I don't find any vlm.clip_grad_norm_ implementation. image

Did clip_grad_norm work, or has it been turned off in fsdp.py?

siddk commented 1 month ago

Hey @zeyuanyin; the call to clip_grad_norm_ is actually invoking this function in PyTorch's native FSDP wrapper class.

The issue with this function is that it expects all parts of the model to have the same dtype; I did something semi-dumb in the code assuming the ViT would also be frozen, casting its weights directly to bfloat16 here which I realize is inflexible with other training schemes.

I'm going to find some time to make this a config flag, but in the meantime, could you try removing the offending casting line and see if things work ok?