joe-siyuan-qiao / WeightStandardization

Standardizing weights to accelerate micro-batch training
545 stars 43 forks source link

the loss is nan #8

Open Maycbj opened 5 years ago

Maycbj commented 5 years ago

It is a very nice work. But there are some problem in my experiments. Training is easy to gradient explosion, the loss is nan, even if my learning rate is set 0. Could you give me some advice.

joe-siyuan-qiao commented 5 years ago

Thanks for your interest. It's hard to diagnose the problem based on the provided information. It seems that the problem may be due to numerical issues. Please make sure all outputs are properly normalized. You can also provide more information so we can offer better help.

Maycbj commented 5 years ago

oh, I have found the reason. The loss is nan, owing to the initialization.

So I use your pre-trained model. But in your MaskRCNN-benchmark, the pre-trained model for Faster-RCNN and Mask-RCNN are different. Why the pre-trained models trained on ImageNet are different?

My guess is that the params in the Resnet are trained on the Imagenet. Then you transfer the random initialization params(FPN and RCNNHead) to the weight standard type param?


image

joe-siyuan-qiao commented 5 years ago

Good to know that you found the reason.

I'm not sure I understand your question. The models pre-trained on ImageNet for Faster-RCNN and Mask-RCNN are the same -- they all point to "catalog://WeightStandardization/R-50-GN-WS", for example. The pre-trained models only contain the parameters of the backbones. Other parts such as heads are not included in the pre-trained models.

zql-seu commented 5 years ago

Hello! I also encountered the problem of loss explosion. I used WS + GN to pre-train on Imagenet without any problems. But when I used the pre-trained model as backbone for semantic segmentation, when the loss dropped to a certain extent, loss become nan. I tried to freeze the backbone and add WS + GN to the decoding path. When the loss decreases, I also encounter nan. I feel that there are problems in some places. epoch0-8 epoch8 epoch9-10

chenxi116 commented 5 years ago

From my experience, nan is caused by either too large learning rate or inappropriate batch norm layer statistics. Based on your screenshot, it's unlikely the first, as the loss is actually decreasing.

I recommend writing some if np.isnan with pdb.set_trace() to diagnose the cause. For example, you can check the logits as input to the loss function.

MarcoForte commented 5 years ago

My reply in this issue might help, #1