Closed KinglittleQ closed 6 years ago
class AttentionRNN(nn.Module):
'''
input:
inputs: [N, T_y, E//2]
memory: [N, T_x, E]
output:
attn_weights: [N, T_y, T_x]
outputs: [N, T_y, E]
hidden: [1, N, E]
T_x --- encoder len
T_y --- decoder len
N --- batch_size
E --- hidden_size (embedding size)
'''
def __init__(self):
super().__init__()
self.gru = nn.GRU(input_size=hp.E // 2, hidden_size=hp.E, batch_first=True, bidirectional=False)
self.W = nn.Linear(in_features=hp.E, out_features=hp.E, bias=False)
self.U = nn.Linear(in_features=hp.E, out_features=hp.E, bias=False)
self.v = nn.Linear(in_features=hp.E, out_features=1, bias=False)
def forward(self, inputs, memory, prev_hidden=None):
T_x = memory.size(1)
T_y = inputs.size(1)
outputs, hidden = self.gru(inputs, prev_hidden) # outputs: [N, T_y, E] hidden: [1, N, E]
w = self.W(outputs).unsqueeze(2).expand(-1, -1, T_x, -1) # [N, T_y, T_x, E]
u = self.U(memory).unsqueeze(1).expand(-1, T_y, -1, -1) # [N, T_y, T_x, E]
attn_weights = self.v(F.tanh(w + u)).squeeze(3) # [N, T_y, T_x]
attn_weights = F.softmax(attn_weights, 2)
return attn_weights, outputs, hidden
I write an implementation of Attention, but it seems to have some problem. Is there any problem ?
change to
Dose it work ?