lipiji / TranSummar

Transformer for abstractive summarization on cnn/daily-mail and gigawords
MIT License
140 stars 20 forks source link

It seems that "coverage" isn't used? #4

Closed tungloong closed 5 years ago

tungloong commented 5 years ago

Code in model.py for this repositories

    def forward(self, x, y_inp, y_tgt, mask_x, mask_y, x_ext, y_ext, max_ext_len):
        hs, src_padding_mask = self.encode(x)
        if self.copy:
            y_pred, _ = self.decode(y_inp, mask_x, mask_y, hs, src_padding_mask, x_ext, max_ext_len)
            cost = self.label_smotthing_loss(y_pred, y_ext, mask_y, self.avg_nll)
        else:
            y_pred, _ = self.decode(y_inp, mask_x, mask_y, hs, src_padding_mask)
            cost = self.nll_loss(y_pred, y_tgt, mask_y, self.avg_nll)

        return y_pred, cost

Code in model.py for neural-summ-cnndm-pytorch

    def forward(self, x, len_x, y, mask_x, mask_y, x_ext, y_ext, max_ext_len):

        hs, dec_init_state = self.encode(x, len_x, mask_x)

        y_emb = self.w_rawdata_emb(y)
        y_shifted = y_emb[:-1, :, :]
        y_shifted = T.cat((Variable(torch.zeros(1, *y_shifted[0].size())).to(self.device), y_shifted), 0)
        h0 = dec_init_state
        if self.cell == "lstm":
            h0 = (dec_init_state, dec_init_state)
        if self.coverage:
            acc_att = Variable(torch.zeros(T.transpose(x, 0, 1).size())).to(self.device) # B * len(x)

        if self.copy and self.coverage:
            hcs, dec_status, atted_context, att_dist, xids, C = self.decoder(y_shifted, hs, h0, mask_x, mask_y, x_ext, acc_att)
        elif self.copy:
            hcs, dec_status, atted_context, att_dist, xids = self.decoder(y_shifted, hs, h0, mask_x, mask_y, xid=x_ext)
        elif self.coverage:
            hcs, dec_status, atted_context, att_dist, C = self.decoder(y_shifted, hs, h0, mask_x, mask_y, init_coverage=acc_att)
        else:
            hcs, dec_status, atted_context = self.decoder(y_shifted, hs, h0, mask_x, mask_y)

        if self.copy:
            y_pred = self.word_prob(dec_status, atted_context, y_shifted, att_dist, xids, max_ext_len)
            cost = self.nll_loss(y_pred, y_ext, mask_y, self.avg_nll)
        else:
            y_pred = self.word_prob(dec_status, atted_context, y_shifted)
            cost = self.nll_loss(y_pred, y, mask_y, self.avg_nll)

        if self.coverage:
            cost_c = T.mean(T.sum(T.min(att_dist, C), 2))
            return y_pred, cost, cost_c
        else:
            return y_pred, cost, None

Is the "coverage" mechanism deprecated? But result show in README.md says that "with copy and coverage"

lipiji commented 5 years ago

@tungloong For transformer, coverage is only used during beam search: https://github.com/lipiji/TranSummar/blob/master/main.py#L221