AuCson / PyTorch-Batch-Attention-Seq2seq

PyTorch implementation of batched bi-RNN encoder and attention-decoder.
278 stars 46 forks source link

Would it be faster(or correct) if use expand() instead of repeat() in Attention model ? #6

Closed KinglittleQ closed 6 years ago

KinglittleQ commented 6 years ago
H = hidden.repeat(max_len,1,1).transpose(0,1)

change to

H = hidden.expand(max_len * hidden.size(0), -1, -1).transpose(0,1)

Dose it work ?

KinglittleQ commented 6 years ago
class AttentionRNN(nn.Module):
    '''
    input:
        inputs: [N, T_y, E//2]
        memory: [N, T_x, E]

    output:
        attn_weights: [N, T_y, T_x]
        outputs: [N, T_y, E]
        hidden: [1, N, E]

    T_x --- encoder len
    T_y --- decoder len
    N --- batch_size
    E --- hidden_size (embedding size)
    '''
    def __init__(self):
        super().__init__()
        self.gru = nn.GRU(input_size=hp.E // 2, hidden_size=hp.E, batch_first=True, bidirectional=False)
        self.W = nn.Linear(in_features=hp.E, out_features=hp.E, bias=False)
        self.U = nn.Linear(in_features=hp.E, out_features=hp.E, bias=False)
        self.v = nn.Linear(in_features=hp.E, out_features=1, bias=False)

    def forward(self, inputs, memory, prev_hidden=None):
        T_x = memory.size(1)
        T_y = inputs.size(1)

        outputs, hidden = self.gru(inputs, prev_hidden)  # outputs: [N, T_y, E]  hidden: [1, N, E]
        w = self.W(outputs).unsqueeze(2).expand(-1, -1, T_x, -1)  # [N, T_y, T_x, E]
        u = self.U(memory).unsqueeze(1).expand(-1, T_y, -1, -1)  # [N, T_y, T_x, E]
        attn_weights = self.v(F.tanh(w + u)).squeeze(3)  # [N, T_y, T_x]
        attn_weights = F.softmax(attn_weights, 2)

        return attn_weights, outputs, hidden

I write an implementation of Attention, but it seems to have some problem. Is there any problem ?