allenai / PRIMER

The official code for PRIMERA: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization
Apache License 2.0
150 stars 31 forks source link

Question about inferencing multi-news datasets #26

Open zhaoxh16 opened 1 year ago

zhaoxh16 commented 1 year ago

Hi, thank you for your sharing. I got troubles on inferencing multinews datasets. I followed the code in Evaluation_Example.ipynb with "use_stemmers=True"to test on multinews test set but got ROUGE scores mid rouge1 fmeasure=49.87, mid rouge2 fmeasure=20.61, mid rouge-L fmeasure=25.59, which is lower than your result. Could you please tell me how to solve the problem? Thank you.

Here is my code.

# %%
from transformers import (
    AutoTokenizer,
    LEDForConditionalGeneration,
)
from datasets import load_dataset
import torch

# %%
dataset=load_dataset('multi_news')

# %%
PRIMER_path='allenai/PRIMERA-multinews'
TOKENIZER = AutoTokenizer.from_pretrained(PRIMER_path)
MODEL = LEDForConditionalGeneration.from_pretrained(PRIMER_path)
MODEL.cuda()
PAD_TOKEN_ID = TOKENIZER.pad_token_id
DOCSEP_TOKEN_ID = TOKENIZER.convert_tokens_to_ids("<doc-sep>")

# %%
def process_document(documents):
    input_ids_all=[]
    for data in documents:
        all_docs = data.split("|||||")
        for i, doc in enumerate(all_docs):
            doc = doc.replace("\n", " ")
            doc = " ".join(doc.split())
            all_docs[i] = doc

        #### concat with global attention on doc-sep
        input_ids = []
        for i, doc in enumerate(all_docs):
            input_ids.extend(
                TOKENIZER.encode(
                    doc,
                    truncation=True,
                    max_length=4096 // len(all_docs),
                )[1:-1]
            )
            if i != len(all_docs) - 1:
                input_ids.append(DOCSEP_TOKEN_ID)
        input_ids = (
            [TOKENIZER.bos_token_id]
            + input_ids
            + [TOKENIZER.eos_token_id]
        )
        input_ids_all.append(torch.tensor(input_ids))
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    )
    return input_ids

def batch_process(batch):
    input_ids=process_document(batch['document']).cuda()
    # get the input ids and attention masks together
    global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device).cuda()
    # put global attention on <s> token

    global_attention_mask[:, 0] = 1
    global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1
    generated_ids = MODEL.generate(
        input_ids=input_ids,
        global_attention_mask=global_attention_mask,
        use_cache=True,
        max_length=1024,
        num_beams=5,
    )
    generated_str = TOKENIZER.batch_decode(
            generated_ids.tolist(), skip_special_tokens=True
        )
    result={}
    result['generated_summaries'] = generated_str
    result['gt_summaries']=batch['summary']
    return result

result_all = dataset['test'].map(batch_process, batched=True, batch_size=8)
with open("generated_summaries.txt", 'w') as wf1, open("gt_summaries.txt", 'w') as wf2: 
    for generated_summary in result_all['generated_summaries']:
        wf1.write(generated_summary + '\n')
    for gt_summary in result_all['gt_summaries']:
        wf2.write(gt_summary+ '\n')
from datasets import load_metric

rouge = load_metric("rouge")
with open("generated_summaries.txt") as f:
    generated_summaries = []
    for line in f:
        generated_summaries.append(line.strip())
with open("gt_summaries.txt") as f:
    gt_summaries = []
    for line in f:
        gt_summaries.append(line.strip())
result = rouge.compute(predictions=generated_summaries, references=gt_summaries, rouge_types=["rouge1", "rouge2", "rougeL", "rougeLsum"], use_stemmer=True)
print("ROUGE scores:")
print(result)

And the result is

ROUGE scores: {'rouge1': AggregateScore(low=Score(precision=0.5241600177789812, recall=0.49415454039406814, fmeasure=0.49612791043579474), mid=Score(precision=0.5276629476118126, recall=0.4973478489591434, fmeasure=0.4987075089885308), high=Score(precision=0.5310748692379784, recall=0.5002904986992729, fmeasure=0.5012372252869416)), 'rouge2': AggregateScore(low=Score(precision=0.2153434014630018, recall=0.20143954489550434, fmeasure=0.2030364177857921), mid=Score(precision=0.21863796353503595, recall=0.2046911123876728, fmeasure=0.20606699168573342), high=Score(precision=0.22209823319305308, recall=0.20779096705177613, fmeasure=0.2091556653585579)), 'rougeL': AggregateScore(low=Score(precision=0.2681362240477783, recall=0.25188841094800846, fmeasure=0.25310320754179366), mid=Score(precision=0.2713920792728549, recall=0.25498739732693815, fmeasure=0.2559033080441543), high=Score(precision=0.27440427911451565, recall=0.25785737610997284, fmeasure=0.2585799444804941)), 'rougeLsum': AggregateScore(low=Score(precision=0.2680113119171556, recall=0.2518150046006951, fmeasure=0.25266660590838774), mid=Score(precision=0.2713474291715047, recall=0.25480955392265, fmeasure=0.2558016718700355), high=Score(precision=0.27479809392676074, recall=0.25794410881227786, fmeasure=0.25880025528901673))}

raymondsim commented 1 year ago

Hi, I think it's because the line max_length=4096//len(all_docs) set the max length of output summary based on your batch size.

For example , 4096 // 4 = 1024 4096 // 8 = 512 etc.

Therefore, different batch sizes can cause slightly different output summary (with different lengths) and ROUGE is sensitive to summary length, causing minor differences between your result and reported scores.

Hope this helps.

zhaoxh16 commented 1 year ago

Sorry, I don't quite understand what you mean. I think that the line max_length=4096//len(all_docs) set the max length of input documents based on document number for each piece of data, which is not relevant to batch size and summary length.