from transformers import RagTokenizer, RagRetriever, RagModel
import torch
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
retriever = RagRetriever.from_pretrained("facebook/rag-token-base", index_name="exact", use_dummy_dataset=True)
# initialize with RagRetriever to do everything in one forward call
model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)
input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
input_ids = input_dict["input_ids"]
outputs = model(input_ids=input_ids)
Environment info
transformers
version: 4.3.2Who can help
Models:
Information
Model I am using (Bert, XLNet ...): RAG
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
Run any of the scripts in the examples on https://huggingface.co/transformers/model_doc/rag.html#overview , ex.
Get an error on https://github.com/huggingface/transformers/blob/master/src/transformers/models/rag/tokenization_rag.py#L77 about how
super()
does not haveprepare_seq2seq_batch()
Expected behavior
RAG works properly.
Note that if I copy/paste the code in the file prior to https://github.com/huggingface/transformers/pull/9524 , it works fine. CC: @sgugger of that change.