nlpyang / BertSum

Code for paper Fine-tune BERT for Extractive Summarization
Apache License 2.0
1.46k stars 422 forks source link

Help needed #99

Closed Pradhy729 closed 4 years ago

Pradhy729 commented 4 years ago

Hi can someone help me with this? I would like to understand what this forward function in the model_builder.py is doing. Specifically what are the expected shapes for top_vec, sents_vec, clss and mask_cls

def forward(self, x, segs, clss, mask, mask_cls, sentence_range=None): top_vec = self.bert(x, segs, mask) sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss] sents_vec = sents_vec * mask_cls[:, :, None].float() sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1) return sent_scores, mask_cls

I have output of BERT (top_vec) as a (512,768) tensor. If that is correct, what should sents_vec be?

Pradhy729 commented 4 years ago

Nevermind - figured it out.