ZexinYan / Medical-Report-Generation

A pytorch implementation of On the Automatic Generation of Medical Imaging Reports.
199 stars 65 forks source link

Is there loss a sum function? #8

Open JavisPeng opened 5 years ago

JavisPeng commented 5 years ago

the v1 function of class CoAttention, v_att is different from the formula in that paper, is there loss a sum function?

    def v1(self, avg_features, semantic_features, h_sent) -> object:
        """
        only training
        :rtype: object
        """
        W_v = self.bn_v(self.W_v(avg_features))
        W_v_h = self.bn_v_h(self.W_v_h(h_sent.squeeze(1)))

        alpha_v = self.softmax(self.bn_v_att(self.W_v_att(self.tanh(W_v + W_v_h))))
        v_att = torch.mul(alpha_v, avg_features) # over there...........

        W_a_h = self.bn_a_h(self.W_a_h(h_sent))
        W_a = self.bn_a(self.W_a(semantic_features))
        alpha_a = self.softmax(self.bn_a_att(self.W_a_att(self.tanh(torch.add(W_a_h, W_a)))))
        a_att = torch.mul(alpha_a, semantic_features).sum(1) 

        ctx = self.W_fc(torch.cat([v_att, a_att], dim=1))

        return ctx, alpha_v, alpha_a