Open JavisPeng opened 5 years ago
the v1 function of class CoAttention, v_att is different from the formula in that paper, is there loss a sum function?
def v1(self, avg_features, semantic_features, h_sent) -> object: """ only training :rtype: object """ W_v = self.bn_v(self.W_v(avg_features)) W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1))) alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h)))) v_att = torch.mul(alpha_v, avg_features) # over there........... W_a_h = self.bn_a_h(self.W_a_h(h_sent)) W_a = self.bn_a(self.W_a(semantic_features)) alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a))))) a_att = torch.mul(alpha_a, semantic_features).sum(1) ctx = self.W_fc(torch.cat([v_att, a_att], dim=1)) return ctx, alpha_v, alpha_a
the v1 function of class CoAttention, v_att is different from the formula in that paper, is there loss a sum function?