graykode / nlp-tutorial

Natural Language Processing Tutorial for Deep Learning Researchers
https://www.reddit.com/r/MachineLearning/comments/amfinl/project_nlptutoral_repository_who_is_studying/
MIT License
14.03k stars 3.9k forks source link

Question about tensor.view operation in Bi-LSTM(Attention) #38

Open iamxpy opened 4 years ago

iamxpy commented 4 years ago

https://github.com/graykode/nlp-tutorial/blob/cb4881ebf6683dc6970c53a2cf50d5fd01edf118/4-3.Bi-LSTM(Attention)/Bi-LSTM(Attention)-Torch.py#L50

Hi, this repo is awesome, but there might be something wrong in the code above. According to the comment above, this snippet intends to change a tensor from shape [num_layers(=1) * num_directions(=2), batch_size, n_hidden] to shape [batch_size, n_hidden * num_directions(=2), 1(=n_layer)], i.e. to concatenate the 2 hidden vector from different direction for every data example in a batch(By saying "data example", I mean a batch has batch_size examples). But I think the code above will mess up the data examples in a batch and lead to unexpected result.

For example, we can use IPython to check the effect of the snippet above.

# create a tensor with shape [num_layers(=1) * num_directions(=2), batch_size, n_hidden]                                                                                           
In [10]: a=torch.arange(2*3*5).reshape(2,3,5) 

In [11]: a                                                             
Out[11]:                                                               
tensor([[[ 0,  1,  2,  3,  4],                                         
         [ 5,  6,  7,  8,  9],                                         
         [10, 11, 12, 13, 14]],                                        

        [[15, 16, 17, 18, 19],                                         
         [20, 21, 22, 23, 24],                                         
         [25, 26, 27, 28, 29]]])                                       

In [12]: a.view(-1,10,1)                                               
Out[12]:                                                               
tensor([[[ 0],                                                         
         [ 1],                                                         
         [ 2],                                                         
         [ 3],                                                         
         [ 4],                                                         
         [ 5],                                                         
         [ 6],                                                         
         [ 7],                                                         
         [ 8],                                                         
         [ 9]],                                                        

        [[10],                                                         
         [11],                                                         
         [12],                                                         
         [13],                                                         
         [14],                                                         
         [15],                                                         
         [16],                                                         
         [17],                                                         
         [18],                                                         
         [19]],                                                        

        [[20],                                                         
         [21],                                                         
         [22],                                                         
         [23],                                                         
         [24],                                                         
         [25],                                                         
         [26],                                                         
         [27],                                                         
         [28],                                                         
         [29]]])                                                       

As you can see, we create a tensor with batch_size=3 and n_hidden=5, e.g [ 0, 1, 2, 3, 4] and [15, 16, 17, 18, 19] belong to the same data example in the batch, but they are from different directions, so what we want is to concatenate them in the resulting tensor. But what the code really does is to concatenate [ 0, 1, 2, 3, 4] and [ 5, 6, 7, 8, 9], which are from different data examples in a batch.

I think it can be fixed by changing the line of code to hidden=torch.cat(final_state[0],final_state[1]],1).view(-1,10,1)

The effect of the new code can be shown as follows:

In [13]: torch.cat([a[0],a[1]],1).view(-1,10,1)
Out[13]:
tensor([[[ 0],
         [ 1],
         [ 2],
         [ 3],
         [ 4],
         [15],
         [16],
         [17],
         [18],
         [19]],

        [[ 5],
         [ 6],
         [ 7],
         [ 8],
         [ 9],
         [20],
         [21],
         [22],
         [23],
         [24]],

        [[10],
         [11],
         [12],
         [13],
         [14],
         [25],
         [26],
         [27],
         [28],
         [29]]])
liuxiaoqun commented 3 years ago

I think it need to change hidden = final_state.view(batch_size, -1, 1) to hidden = final_state.transpose(0,1).reshape(batch_size,-1,1)