EdGENetworks / attention-networks-for-classification

Hierarchical Attention Networks for Document Classification in PyTorch
606 stars 133 forks source link

Dimensionalities of word minibatch and Embedding layer don't match #13

Open gabrer opened 6 years ago

gabrer commented 6 years ago

I wonder whether there is an error due to what Pytorch is expecting as input to the nn.Embedding module.

In the function train_data(), it's written:

 for i in xrange(max_sents):
        _s, state_word, _ = word_attn_model(mini_batch[i,:,:].transpose(0,1), state_word)

In this way, after the .transpose(0,1), the resulting mini_batch matrix has size (max_tokens, batch_size).

However, the first function to be called in the forward() is the self.lookup(embed), which is expecting a (batch_size, list_of_indeces).

Currently, the lookup function is (wrongly!?) extracting first all the word embeddings for the beggining words of each sentence in the minibatch. Then, all the word embeddings for the second words and so on. To be fixed, it just needs to be without the .transpose(0,1).

If this is correct, it requires to fix up all the following code.

Sandeep42 commented 6 years ago

I am not quite sure if there is an issue, it's been a long time. I believe I'm also doing a transpose while creating the mini-batch itself, so we have two transposes which sort of cancel each other out. I don't remember why I chose to do this way.

def pad_batch(mini_batch):
    mini_batch_size = len(mini_batch)
    max_sent_len = int(np.mean([len(x) for x in mini_batch]))
    max_token_len = int(np.mean([len(val) for sublist in mini_batch for val in sublist]))
    main_matrix = np.zeros((mini_batch_size, max_sent_len, max_token_len), dtype= np.int)
    for i in xrange(main_matrix.shape[0]):
        for j in xrange(main_matrix.shape[1]):
            for k in xrange(main_matrix.shape[2]):
                try:
                    main_matrix[i,j,k] = mini_batch[i][j][k]
                except IndexError:
                    pass
    return Variable(torch.from_numpy(main_matrix).transpose(0,1))