NVlabs / MambaVision

Official PyTorch Implementation of MambaVision: A Hybrid Mamba-Transformer Vision Backbone
https://arxiv.org/abs/2407.08083
Other
709 stars 40 forks source link

Training loss NaN #38

Open FelixMessi opened 1 week ago

FelixMessi commented 1 week ago

Hi, thanks for the amazing work!

I've encountered an issue with NaN losses when using AMP in MambaVision, particularly when I reduce the training epochs to 30. The problem seems to stem from the selective_scan_fn function. I've tried switching to float32 for training, which resolves the NaN issue, but this approach is more resource-intensive compared to using AMP. Could anyone suggest more flexible solutions?

ahatamiz commented 1 week ago

Hi @FelixMessi

It's hard to exactly pinpoint the issue by only knowing that total number of epochs have been reduced. However, my best bet would be to decrease the learning rate but yet keep AMP.