pytorch / tutorials

PyTorch tutorials.
https://pytorch.org/tutorials/
BSD 3-Clause "New" or "Revised" License
8.1k stars 4.03k forks source link

Language Translation with nn.Transformer and torchtext[BUG] - mask with -inf leads to nans. #2988

Open danielegana opened 1 month ago

danielegana commented 1 month ago

Add Link

https://pytorch.org/tutorials/beginner/translation_transformer.html

Describe the bug

Running the tutorial on language translation with transformers leads to nans when training on the first batch iteration on the first epoch, and even when evaluating an untrained model for some input sequences.

I find this issue simply by copy-pasting the tutorial to my local computer and starting the training process. The issue seems to stem from the target mask. Replacing the line

mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

by

mask = mask.float().masked_fill(mask == 0, float('-1e9')).masked_fill(mask == 1, float(0.0))

allows to train the model for a few batches on the first epoch without nan outputs. A problem that is possibly related has been pointed out in https://github.com/pytorch/pytorch/issues/41508#issuecomment-1723119580

However, even with this "fix", the losses of the model increase with training, and eventually they become nan too.

Describe your environment

Running on MacOS. I am using pytorch 2.2.2 and python 3.9.7.

cc @pytorch/team-text-core @Nayef211

lmntrx-sys commented 1 month ago

Suggested Steps Review Mask Implementation: Investigate the mask implementation in the tutorial and ensure it's correctly applied during training and evaluation.

Gradient Clipping: Implement gradient clipping to prevent exploding gradients, which could lead to NaNs.

Learning Rate: Experiment with different learning rates to see if a lower learning rate stabilizes the training process.

Check Data Preprocessing: Ensure that the input data is correctly preprocessed and normalized

Monitor NaN Values: Add checks to monitor NaN values during training and identify when and where they first appear