allenai / unifiedqa

UnifiedQA: Crossing Format Boundaries With a Single QA System
https://arxiv.org/abs/2005.00700
Apache License 2.0
428 stars 43 forks source link

error from generate_from_string method for BART #17

Open timhartill opened 3 years ago

timhartill commented 3 years ago

Hi, I'm attempting to run the BART model example as given in the readme:

import torch from transformers import BartTokenizer from bart import MyBart

base_model = "facebook/bart-large" unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint

tokenizer = BartTokenizer.from_pretrained(base_model) model = MyBart.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path)) model.eval()

x = model.generate_from_string("Which is best conductor? \n (A) iron (B) feather", tokenizer=tokenizer)

The .from_pretrained line executes fine but the .generate_from_string(..) line errors out with the error:

TypeError: forward() got an unexpected keyword argument 'past_key_values'

I tried using the run_model(..) method from the main git page and it gives exactly the same error.

Any idea what might be causing this and how to fix it?

I am using python 3.85 with transformers 4.4.2 and pytorch 1.7.1

danyaljj commented 3 years ago

Thoughts @shmsw25 ?

timhartill commented 3 years ago

After a bit more digging it appears that the arguments to the forward method in modelling_bart.py in transformers 4.4.2 are rather different to the arguments passed to the forward method in the unifiedqa bart.py. I'm thinking I may need to update bart.py to match the latest modelling_bart.py to make this work. If I manage to do so would you like a copy of the updated version?

danyaljj commented 3 years ago

Sure, thank you! 🙏

timhartill commented 3 years ago

I forked your code and updated bart.py and also run.py. I've run it a few times and it seems to work. Generally I've commented my changes with comments starting with #TJH..

You can access at: https://github.com/timhartill/unifiedqa-tjh

danyaljj commented 3 years ago

Appreciate it! Will look into your changes.

tshrjn commented 3 years ago

transformers 4.x brought breaking changes & past_key_values were changed to past.

But it shouldn't be an issue if you use the HF's modelclass & not the derived class here.

Example of how generation would like:


import torch
from transformers import BartTokenizer, BartForConditionalGeneration

base_model = "facebook/bart-large"
unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint

tokenizer = BartTokenizer.from_pretrained(base_model)
model = BartForConditionalGeneration.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path))
model.eval()

def generate_text(text, model, tokenizer):
    inputs = tokenizer([text], max_length=512, truncation=True, return_tensors='pt')

    output_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=50, early_stopping=True)
    return ' '.join([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in output_ids])

text = "Which is best conductor? \\n (A) iron (B) feather"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))

text = "What is the sum of 3 and 5? \\n (A) 8 (B) 3 (C) 5 (D) 10"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))

text = "What is 42 ? \\n 42 is the answer to life, the universe and everything"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))

Or one could also use HF's pipelines as follows:

# Using Pipeline
from transformers import pipeline

text = "What is 42 ? \\n 42 is the answer to life, the universe and everything"
print('\n', text)

text2text_generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
print(text2text_generator(text))
shmsw25 commented 3 years ago

Hi @timhartill and @tshrjn, it looks like the error is coming from the discrepancy in HF versions. The code is written in an older version of HF; please see README. @tshrjn's solution looks like a good workaround to run inference in a newer version. However if you want to run finetuning, I recommend to follow the version in README, as finetuning using a newer version is not guaranteed to reproduce the result in the paper.

@danyaljj I was thinking keeping the version as it is in the repo is better since HF library will keep being updated and it would not easy to update the code every time with the guarantee of reproducing the numbers in the paper. Or we could update the inference code only and put a note that finetuning is only tested with the version in README. What do you think?