Open timhartill opened 3 years ago
Thoughts @shmsw25 ?
After a bit more digging it appears that the arguments to the forward method in modelling_bart.py in transformers 4.4.2 are rather different to the arguments passed to the forward method in the unifiedqa bart.py. I'm thinking I may need to update bart.py to match the latest modelling_bart.py to make this work. If I manage to do so would you like a copy of the updated version?
Sure, thank you! 🙏
I forked your code and updated bart.py and also run.py. I've run it a few times and it seems to work. Generally I've commented my changes with comments starting with #TJH..
You can access at: https://github.com/timhartill/unifiedqa-tjh
Appreciate it! Will look into your changes.
transformers 4.x brought breaking changes & past_key_values
were changed to past
.
But it shouldn't be an issue if you use the HF's modelclass & not the derived class here.
Example of how generation would like:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
base_model = "facebook/bart-large"
unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint
tokenizer = BartTokenizer.from_pretrained(base_model)
model = BartForConditionalGeneration.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path))
model.eval()
def generate_text(text, model, tokenizer):
inputs = tokenizer([text], max_length=512, truncation=True, return_tensors='pt')
output_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=50, early_stopping=True)
return ' '.join([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in output_ids])
text = "Which is best conductor? \\n (A) iron (B) feather"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))
text = "What is the sum of 3 and 5? \\n (A) 8 (B) 3 (C) 5 (D) 10"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))
text = "What is 42 ? \\n 42 is the answer to life, the universe and everything"
print('\n', text)
print("generated_text:", generate_text(text, model, tokenizer))
Or one could also use HF's pipelines as follows:
# Using Pipeline
from transformers import pipeline
text = "What is 42 ? \\n 42 is the answer to life, the universe and everything"
print('\n', text)
text2text_generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
print(text2text_generator(text))
Hi @timhartill and @tshrjn, it looks like the error is coming from the discrepancy in HF versions. The code is written in an older version of HF; please see README. @tshrjn's solution looks like a good workaround to run inference in a newer version. However if you want to run finetuning, I recommend to follow the version in README, as finetuning using a newer version is not guaranteed to reproduce the result in the paper.
@danyaljj I was thinking keeping the version as it is in the repo is better since HF library will keep being updated and it would not easy to update the code every time with the guarantee of reproducing the numbers in the paper. Or we could update the inference code only and put a note that finetuning is only tested with the version in README. What do you think?
Hi, I'm attempting to run the BART model example as given in the readme:
import torch from transformers import BartTokenizer from bart import MyBart
base_model = "facebook/bart-large" unifiedqa_path = "unifiedQA-uncased/best-model.pt" # path to the downloaded checkpoint
tokenizer = BartTokenizer.from_pretrained(base_model) model = MyBart.from_pretrained(base_model, state_dict=torch.load(unifiedqa_path)) model.eval()
x = model.generate_from_string("Which is best conductor? \n (A) iron (B) feather", tokenizer=tokenizer)
The .from_pretrained line executes fine but the .generate_from_string(..) line errors out with the error:
TypeError: forward() got an unexpected keyword argument 'past_key_values'
I tried using the run_model(..) method from the main git page and it gives exactly the same error.
Any idea what might be causing this and how to fix it?
I am using python 3.85 with transformers 4.4.2 and pytorch 1.7.1