jadore801120 / attention-is-all-you-need-pytorch

A PyTorch implementation of the Transformer model in "Attention is All You Need".
MIT License
8.82k stars 1.98k forks source link

nan loss when training #5

Closed sliedes closed 7 years ago

sliedes commented 7 years ago

Training and validation loss is nan (using commit e21800a6):

$ python3 preprocess.py -train_src data/multi30k/train.en -train_tgt data/multi30k/train.de -valid_src data/multi30k/val.en -valid_tgt data/multi30k/val.de -output data/multi30k/data.pt
$ python3 train.py -data data/multi30k/data.pt -save_model trained -save_model best
[ Epoch 0 ]
  - (Training)   loss:      nan, accuracy: 3.7 %
  - (Validation) loss:      nan, accuracy: 10.0 %
    - [Info] The checkpoint file has been updated.
[ Epoch 1 ]
  - (Training)   loss:      nan, accuracy: 9.09 %
  - (Validation) loss:      nan, accuracy: 9.87 %
[ Epoch 2 ]
  - (Training)   loss:      nan, accuracy: 9.09 %
  - (Validation) loss:      nan, accuracy: 9.83 %
[ Epoch 3 ]
  - (Training)   loss:      nan, accuracy: 9.1 %
  - (Validation) loss:      nan, accuracy: 9.92 %
[ Epoch 4 ]
  - (Training)   loss:      nan, accuracy: 9.09 %
  - (Validation) loss:      nan, accuracy: 9.91 %
jadore801120 commented 7 years ago

Hi @sliedes , Thanks for report. I am also facing the same problem. Once I fix it, I will update the info here. Any further inspection about it will be appreciated!

sliedes commented 7 years ago

While I understand neither the code or the theory fully yet, I think the problem is in ScaledDotProductAttention.forward(). It sets the masked values to -Inf before passing them to nn.SoftMax. I think nn.SoftMax does not deal well with -Inf. Indeed, while I got nan loss most of the time already during the first epoch, when I modify ScaledDotProductAttention.forward() to set the masked values to -100.0 instead of -Inf, I have now trained for five epochs without seeing nans.

jadore801120 commented 7 years ago

Hi @sliedes , Sorry for the late update! You are right. The wrong part is about the softmax function, but it is not because of the -Inf value. I misplace the k/q pair in the attention mask calculation routine. (see 94aae68) Please pull the newest commit to fix this bug. Thanks you!

jadore801120 commented 7 years ago

I think this fix will eliminate the NaN error, and I am sorry for the confusion so far. Let me close this issue now. However, if there emerges other fatal NaN error, feel free to open another issue.

xdwang0726 commented 4 years ago

Hi, it seems that the problem still exist when using the solution above.