I've encountered an issue with NaN losses when using AMP in MambaVision, particularly when I reduce the training epochs to 30. The problem seems to stem from the selective_scan_fn function. I've tried switching to float32 for training, which resolves the NaN issue, but this approach is more resource-intensive compared to using AMP. Could anyone suggest more flexible solutions?
It's hard to exactly pinpoint the issue by only knowing that total number of epochs have been reduced. However, my best bet would be to decrease the learning rate but yet keep AMP.
Hi, thanks for the amazing work!
I've encountered an issue with NaN losses when using AMP in MambaVision, particularly when I reduce the training epochs to 30. The problem seems to stem from the
selective_scan_fn
function. I've tried switching to float32 for training, which resolves the NaN issue, but this approach is more resource-intensive compared to using AMP. Could anyone suggest more flexible solutions?