pandeykartikey / Hierarchical-Attention-Network

Implementation of Hierarchical Attention Networks in PyTorch
https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf
130 stars 27 forks source link

THIS IMPLEMENTATION's GOT at least 2 apparent MISTAKEs! #5

Open SimZhou opened 4 years ago

SimZhou commented 4 years ago

So far as I've read until, the implementation of attention on both word and sentence level are WRONG:

## The word RNN model for generating a sentence vector
class WordRNN(nn.Module):
    def __init__(self, vocab_size,embedsize, batch_size, hid_size):
        super(WordRNN, self).__init__()
        self.batch_size = batch_size
        self.embedsize = embedsize
        self.hid_size = hid_size
        ## Word Encoder
        self.embed = nn.Embedding(vocab_size, embedsize)
        self.wordRNN = nn.GRU(embedsize, hid_size, bidirectional=True)
        ## Word Attention
        self.wordattn = nn.Linear(2*hid_size, 2*hid_size)
        self.attn_combine = nn.Linear(2*hid_size, 2*hid_size,bias=False)
    def forward(self,inp, hid_state):
        emb_out  = self.embed(inp)

        out_state, hid_state = self.wordRNN(emb_out, hid_state)

        word_annotation = self.wordattn(out_state)
        attn = F.softmax(self.attn_combine(word_annotation),dim=1)

        sent = attention_mul(out_state,attn)
        return sent, hid_state

at Line 4 from the bottom: attn = F.softmax(self.attn_combine(word_annotation),dim=1).

As the nature of pytorch, if you don't use batch_first=True for GRU, the output dimention of out_state should be: (n_steps, batch_size, out_dims)

As the paper states, the softmax function should be applied on different time steps (for which the sum of all timesteps of softmax(value) should add up to 1), wheras THE IMPLEMENTATION of F.softmax MADE THE SOFTMAX ON DIFFERENT BATCHES (dim=1), which is incorrect!!! (should be changed to dim=0)

So does the sentence level attention.

Maybe this could be a reason for the non-convergent fluctuating test accuracy. I am reading through the code and trying to make a corrected version for this implementation, will get back later.

SimZhou commented 4 years ago

Sry I give up correcting the codes, they are a bit redundant...

sumba101 commented 2 years ago

Sry I give up correcting the codes, they are a bit redundant...

do you have a correction to the code that you can provide?