tencent-ailab / OASum

Other
13 stars 0 forks source link

Pretrained Checkpoint Usage Guide #2

Open nishan-chatterjee opened 5 months ago

nishan-chatterjee commented 5 months ago

Hi, can you provide any code for using your pretrained checkpoint for zero-shot summarization? And can the aspects be provided as slots or are they automatically extracted? Thanks :D

KaiQiangSong commented 4 months ago

Hi, can you provide any code for using your pretrained checkpoint for zero-shot summarization? And can the aspects be provided as slots or are they automatically extracted? Thanks :D

The aspects are prepended to the inputs as demonstrated in the paper.

nishan-chatterjee commented 2 months ago

Hey, thanks for the feedback. Do you know if this is the correct way to use the model? This appends the target aspect to each line of the document and concatenates the summary generated by the model. Or should I append the aspect to the whole document?

import torch
from transformers import LEDTokenizer, LEDForConditionalGeneration
import warnings
warnings.filterwarnings("ignore")

# Load the tokenizer and pre-trained model
tokenizer = LEDTokenizer.from_pretrained('allenai/led-base-16384')
model = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384')

# Load the checkpoint
checkpoint = torch.load("../models/epoch=19-step=34300.ckpt.weights", map_location=torch.device('cuda'))

# If the checkpoint is a state dictionary directly, load it
if 'state_dict' in checkpoint:
    model.load_state_dict(checkpoint['state_dict'])
else:
    model.load_state_dict(checkpoint)

# Function to summarize text based on an aspect
def aspect_based_summarization(aspect, document):
    # Concatenate aspect and document
    input_text = f"{aspect}: {document}"

    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors='pt', truncation=True, padding='max_length', max_length=4096)

    # Generate summary
    summary_ids = model.generate(inputs['input_ids'], max_length=150, num_beams=5, early_stopping=True)

    # Decode the summary
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    return summary

# Example usage
aspect = "Opening"
document = dataset[0]['document']

summary = ""
for line in document:
    summary += aspect_based_summarization(aspect, line) + " "

print(summary)

Or should I do it like this, where I append the aspect to the whole document?

# Example usage

aspect = "Opening"
# combine all the lines in the document
document = " ".join(first_n_instances[0]['document'])
summary = aspect_based_summarization(aspect, document)
print(summary)