pbloem / former

Simple transformer implementation from scratch in pytorch.
http://peterbloem.nl/blog/transformers
MIT License
1.03k stars 170 forks source link

Issue with masking #10

Closed mc-robinson closed 4 years ago

mc-robinson commented 4 years ago

Thanks for your great tutorial. I think your code on github is probably find, but I believe there is an error in the mask as defined in the post.

I believe the masking line should read

indices = torch.triu_indices(t, t, offset=1)

In the current implementation, dot should be t x t. Furthermore, an offset of 0 will create a row of all -inf, which gives a complete row of NaNs when fed into softmax. See https://github.com/pytorch/pytorch/issues/24816

pbloem commented 4 years ago

You're absolutely right. The dot matrix is t by t, and the offset should be 1 if you want to include the diagonal (as you say, it was correct in the code, but not in the blog post).

Thanks for the pointer, I've fixed it in the post.