Open Yingyue-L opened 7 months ago
Interesting - have you modified train_mamba.py
in some way?
I have only modified the model path
to the local_dir
, everything else remained unchanged.
I encountered the same problem.
Okay so I think I've now finally gotten down to the issue, I had to switch to float16 to bfloat16 wouldn't work on my 2070 super, that's what was causing my logits to become NaN, switching to float32 made the issue go away and learning rate no longer suddenly drop to 0.
Thanks for the great work! I tried to finetune the Mamba model using four 3090 GPUs following your code. But there's a problem: the loss drops to zero after just two steps. Can you help me figure out what's going wrong and fix it?