Open shouldsee opened 1 year ago
Thanks for sharing! Just found out Attention.get_att_weight is calculating attention in a for-loop? this looks rather slow isn't it?
Attention.get_att_weight
4-2.Seq2Seq(Attention)/Seq2Seq(Attention).ipynb
def get_att_weight(self, dec_output, enc_outputs): # get attention weight one 'dec_output' with 'enc_outputs' n_step = len(enc_outputs) attn_scores = torch.zeros(n_step) # attn_scores : [n_step] for i in range(n_step): attn_scores[i] = self.get_att_score(dec_output, enc_outputs[i]) # Normalize scores to weights in range 0 to 1 return F.softmax(attn_scores).view(1, 1, -1) def get_att_score(self, dec_output, enc_output): # enc_outputs [batch_size, num_directions(=1) * n_hidden] score = self.attn(enc_output) # score : [batch_size, n_hidden] return torch.dot(dec_output.view(-1), score.view(-1)) # inner product make scalar value
Suggested parallel version
def get_att_weight(self, dec_output, enc_outputs): # get attention weight one 'dec_output' with 'enc_outputs' n_step = len(enc_outputs) attn_scores = torch.zeros(n_step,device=self.device) # attn_scores : [n_step] enc_t = self.attn(enc_outputs) score = dec_output.transpose(1,0).bmm(enc_t.transpose(1,0).transpose(2,1)) out1 = score.softmax(-1) return out1
You can create a pull request to update the code
Thanks for sharing! Just found out
Attention.get_att_weight
is calculating attention in a for-loop? this looks rather slow isn't it?4-2.Seq2Seq(Attention)/Seq2Seq(Attention).ipynb
Suggested parallel version