Closed lalitpagaria closed 4 years ago
Hey @lalitpagaria - spot on! Thanks a lot for your issue, you're 100% correct here.
I actually noticed that the RagSequence
generate function is a bit more complex so that we cannot do the decomposed (embed, retrieve, generate) example here...
The PR linked to the issue removes the use case from the examples and fixes the one for RagToken...
.
UPDATE: @patrickvonplaten Sorry for my miss-understanding. Yes without calling generate directly fixed this with your PR. Thanks you very much for fix.
@patrickvonplaten Thank you for update. I tried your changes on my code snippets and still got same error. If you see my example I am passing context_input_ids
Environment info
transformers
version: 3.3.1Who can help
@patrickvonplaten @LysandreJik
Information
Model I am using (Bert, XLNet ...): RAG
The problem arises when using:
The tasks I am working on is:
dummy_dataset
To reproduce
Steps to reproduce the behavior:
Code snippets:
Stacktrace:
I suspect
context_input_ids
is not passed toforward
method. And if model is not initialised with retriever thenforward
function complain about missingcontext_input_ids
orretriever
. Referring to following piece of code inRagSequenceForGeneration
class andgenerator
function.Expected behavior
It should work as intended as
RagTokenForGeneration
do.