AIPHES / DiscoScore

DiscoScore: Evaluating Text Generation with BERT and Discourse Coherence
32 stars 6 forks source link

Bug in DiscoScorer initialization #5

Closed UntotaufUrlaub closed 1 year ago

UntotaufUrlaub commented 1 year ago

Hi,

there might be a bug in the initialization code of DiscoScorer in scorer.py! Line 27 (and 33) use self.we which is never set.

class DiscoScorer: 
    def __init__(self, device='cuda:0', model_name='bert-base-uncased', we=None):

        config = BertConfig.from_pretrained(model_name, output_hidden_states=True, output_attentions=True, return_dict=True)
        self.tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False)
        self.model = BertModel.from_pretrained(model_name, config=config)
        self.model.encoder.layer = torch.nn.ModuleList([layer for layer in self.model.encoder.layer[:8]])
        self.model.eval()
        self.model.to(device)  
        if we is not None:
            we = load_embeddings('deps', we)
        self.we = we    # I added this line.

    def LexicalChain(self, sys, ref):
        return discourse.LexicalChain(sys, ref)

    def DS_Focus_NN(self, sys, ref):
        return discourse.DS_Focus(self.model, self.tokenizer, sys, ref, is_semantic_entity=False)

    def DS_Focus_Entity(self, sys, ref):
        return discourse.DS_Focus(self.model, self.tokenizer, sys, ref, is_semantic_entity=True, we=self.we, threshold = 0.8) # uses self.we

    def DS_SENT_NN(self, sys, ref):
        return discourse.DS_Sent(self.model, self.tokenizer, sys, ref, is_lexical_graph=False)

    def DS_SENT_Entity(self, sys, ref):
        return discourse.DS_Sent(self.model, self.tokenizer, sys, ref, is_lexical_graph=True, we=self.we, threshold = 0.5) # uses self.we

       ...

Is this a bug, or have I missed something?

I used the pip setup described in the Readme. My calling code looks essentialy like this:

from disco_score import DiscoScorer

def disco_score(summ, doc, metric_type):
    disco_scorer = DiscoScorer(device='cuda:0', model_name='bert-base-uncased')
    summ = summ.lower()
    doc = doc.lower()
    disco_scorer.DS_SENT_Entity(summ, [doc])

kind regards

P.S. thank you very much for sharing this interesting metric!

andyweizhao commented 1 year ago

Yes, that is a bug. It was accidentally shown while I cleaned up the repository. Thanks for spotting!

I think the calling code is right. For better results, you could replace bert-base-uncased with Conpono or BERT-NLI, as mentioned in the Readme.