allenai / PRIMER

The official code for PRIMERA: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization
Apache License 2.0
153 stars 32 forks source link

Evaluation example for multi_x_science_sum dataset #31

Open shubhamagarwal92 opened 1 year ago

shubhamagarwal92 commented 1 year ago

HI!

I am trying to replicate results similar to https://github.com/allenai/PRIMER/blob/main/Evaluation_Example.ipynb for multi_x_science_sum dataset.

I downloaded the PRIMERA model defined in https://github.com/allenai/PRIMER/tree/main#usage-of-primera. However, I am getting the value error as:

ValueError: The state dictionary of the model you are trying to load is corrupted. Are you sure it was properly saved?

Code I am using:

import os
os.environ["HF_HOME"] = "/mnt/home/cached/"
os.environ["TORCH_HOME"] = "/mnt/home/cached/"

from transformers import (
    AutoTokenizer
)
from transformers import LEDForConditionalGeneration
import torch
# from longformer import LongformerEncoderDecoderForConditionalGeneration
# from longformer import LongformerEncoderDecoderConfig
from datasets import load_dataset, load_metric

PRIMER_path='./PRIMERA_model/PRIMER'
TOKENIZER = AutoTokenizer.from_pretrained(PRIMER_path)
MODEL = LEDForConditionalGeneration.from_pretrained(PRIMER_path)
MODEL.gradient_checkpointing_enable()

# TOKENIZER = AutoTokenizer.from_pretrained(PRIMER_path)
# config = LongformerEncoderDecoderConfig.from_pretrained(PRIMER_path)
# MODEL = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(
#             PRIMER_path, config=config)

PAD_TOKEN_ID = TOKENIZER.pad_token_id
DOCSEP_TOKEN_ID = TOKENIZER.convert_tokens_to_ids("<doc-sep>")

def process_document(documents):
    input_ids_all=[]
    for data in documents:
        # TODO: SA
        all_docs = [data["abstract"]]
        for d in data["ref_abstract"]["abstract"]:
            if len(d) > 0:
                all_docs.append(d)
        # https://github.com/allenai/PRIMER/blob/main/script/dataloader.py#L66C1-L70C1
        # all_docs = data.split("|||||")[:-1]
        # for i, doc in enumerate(all_docs):
        #     doc = doc.replace("\n", " ")
        #     doc = " ".join(doc.split())
        #     all_docs[i] = doc

        #### concat with global attention on doc-sep
        input_ids = []
        for doc in all_docs:
            input_ids.extend(
                TOKENIZER.encode(
                    doc,
                    truncation=True,
                    max_length=4096 // len(all_docs),
                )[1:-1]
            )
            input_ids.append(DOCSEP_TOKEN_ID)
        input_ids = (
            [TOKENIZER.bos_token_id]
            + input_ids
            + [TOKENIZER.eos_token_id]
        )
        input_ids_all.append(torch.tensor(input_ids))
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    )
    return input_ids

def batch_process(batch):
    input_ids=process_document(batch)
    # get the input ids and attention masks together
    global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device)
    # put global attention on <s> token

    global_attention_mask[:, 0] = 1
    global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1
    generated_ids = MODEL.generate(
        input_ids=input_ids,
        global_attention_mask=global_attention_mask,
        use_cache=True,
        max_length=1024,
        num_beams=5,
    )
    generated_str = TOKENIZER.batch_decode(
            generated_ids.tolist(), skip_special_tokens=True
        )
    result={}
    result['generated_summaries'] = generated_str
    result['gt_summaries']=batch['related_work']

    return result

dataset = load_dataset('multi_x_science_sum')
dataset_small = dataset.select(list(range(100)))

result_small = dataset_small.map(batch_process, batched=True, batch_size=2)

rouge = load_metric("rouge")
result_small['generated_summaries']

score=rouge.compute(predictions=result_small["generated_summaries"], references=result_small["gt_summaries"])
print(score['rouge1'].mid)
print(score['rouge2'].mid)
print(score['rougeL'].mid)

Do we need to use a different PRIMERA model? Am I doing anything wrong to replicate results?