PetrochukM / PyTorch-NLP

Basic Utilities for PyTorch Natural Language Processing (NLP)
https://pytorchnlp.readthedocs.io
BSD 3-Clause "New" or "Revised" License
2.21k stars 258 forks source link

Attention's softmax dimension is wrong? #2

Closed joelxiangnanchen closed 6 years ago

joelxiangnanchen commented 6 years ago

Hi, I found that the weights computed by nn.Attention is always 1 in example: atten = nn.Attention(3, attention_type='dot') a = Variable(torch.randn(1, 1, 3)) b = Variable(torch.randn(1, 2, 3)) output, weights = atten(a, b) and the output is: Variable containing: (0 ,.,.) = 1 1 [torch.FloatTensor of size 1x1x2] It is always 1 whatever I feed into attention layer with[1, n, m] that is batch size is always 1. Then I check the code found that the nn.Softmax(dims=0) In my opinion this operator does softmax along the first dimension(if a tensor's size is 3 like [z, y, x], it done along the z dimension). But the first dimension is batch_size * output_len, so it does wrong direction softmax cuz it should do it along last dims query_len which contains real scores for every context in a inst.

E.g.

given a query vector [1, 1, 3] and the context [1, 2, 3] that is context is a 2 rows, 3 column matrix . Then, attention context will be transposed to [1, 3, 2], scoring and get scores [1, 1, 2]. As your operation, scores are reshaped to [1, 2] and do softmax along the first dims. Finally, the weights on it are always 1.

I changed it to nn.Softmax(dims=-1), which operates along the last dims(query_len), it worked: Variable containing: (0 ,.,.) = 0.1426 0.8574 [torch.FloatTensor of size 1x1x2]

Thx for your great work -Chen

PetrochukM commented 6 years ago

Fixed this! Thank you for the very detailed explanation.

Updated the tests as well.