MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
2.07k stars 124 forks source link

About training on ImageNet #165

Closed LMMMEng closed 5 months ago

LMMMEng commented 5 months ago

Hi,

I noticed in your training log (imagenet) that the training loss is normal, but the loss in the test phase is nan in many epochs, I also encountered this problem, could you please tell me how to solve it?

MzeroMiko commented 5 months ago

The reason why the test loss is nan is that when inferencing with torch.no_grad(), we set all the data in float16 in that period of training. We fixed that later.

LMMMEng commented 5 months ago

Does that mean this is caused by using AMP when inferencing?

MzeroMiko commented 5 months ago

In our model, even if you are using amp, we still forces some part of data to be in float32

LMMMEng commented 5 months ago

I see. I still don't understand why training loss is normal, and test loss becomes nan, according to your code, it seems they all use amp.

Journey7331 commented 5 months ago

Hi, @MzeroMiko @LMMMEng I notice this grad_norm 1.5292 (inf) in the log vssm_tiny_0230.txt that MzeroMiko provided, and actually i also encounter this inf and loss=nan in my training, so is there any relation between inf and nan?

image
LMMMEng commented 5 months ago

Hi, @Journey7331. It is usually normal for grad to occasionally appear inf, because amp automatically detects this kind of inf and adjusts the loss scale. However, if the gradient continues to appear inf, even if the loss scale is automatically adjusted to a very small value, the gradient may not return to normal, then the loss will be nan. Please feel free to correct me if I'm wrong.

Journey7331 commented 5 months ago

I see. By the way, during my initial attempt to train the model, I got inf, NAN and the train collapsed at random epoch. Then I start another train with force_fp32=True of this line using AMP, I got some inf, but no more collapse, and get same results when inferencing with force_fp32=True/False. https://github.com/MzeroMiko/VMamba/blob/dd8fc09433cb7c1e7972fee02f165b73b63701e8/classification/models/vmamba.py#L451 Thanks for you guys conversation above. :)