Closed zeyuanyin closed 1 month ago
Hey @zeyuanyin; the call to clip_grad_norm_
is actually invoking this function in PyTorch's native FSDP wrapper class.
The issue with this function is that it expects all parts of the model to have the same dtype
; I did something semi-dumb in the code assuming the ViT would also be frozen, casting its weights directly to bfloat16
here which I realize is inflexible with other training schemes.
I'm going to find some time to make this a config flag, but in the meantime, could you try removing the offending casting line and see if things work ok?
Hi @siddk, I modified the VLM model but the program is stuck when running
self.clip_grad_norm()
in base_strategy.pyhttps://github.com/TRI-ML/prismatic-vlms/blob/874c5bbff52b248294a3ab97006491a7faa698e6/prismatic/training/strategies/base_strategy.py#L214
The related code in fsdp.py is
https://github.com/TRI-ML/prismatic-vlms/blob/874c5bbff52b248294a3ab97006491a7faa698e6/prismatic/training/strategies/fsdp.py#L241-L243
But I don't find any![image](https://github.com/TRI-ML/prismatic-vlms/assets/51396847/ad8c3d2e-8ef5-4dfd-b392-adacace61450)
vlm.clip_grad_norm_
implementation.Did
clip_grad_norm
work, or has it been turned off in fsdp.py?