huggingface / transformers

šŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.76k stars 27.18k forks source link

[RAG] RagSequenceForGeneration: Running "retriever separately example" giving error #7829

Closed lalitpagaria closed 4 years ago

lalitpagaria commented 4 years ago

Environment info

Who can help

@patrickvonplaten @LysandreJik

Information

Model I am using (Bert, XLNet ...): RAG

The problem arises when using:

The tasks I am working on is:

To reproduce

Steps to reproduce the behavior:

  1. Execute code snippets provided (Partially modified example script from https://huggingface.co/transformers/master/model_doc/rag.html)

Code snippets:

!pip install git+https://github.com/huggingface/transformers.git
!pip install datasets
!pip install faiss-cpu

from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, RagSequenceForGeneration
import torch
import faiss

tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True)

input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
input_ids = input_dict["input_ids"]

# Caling retriever seperately

question_hidden_states = model.question_encoder(input_ids)[0]
# 2. Retrieve
docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
print(docs_dict)
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
# 3. Forward to generator
outputs = model.generate(input_ids=input_ids, context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)

generated_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(generated_string)

Stacktrace:

AssertionError                            Traceback (most recent call last)
<ipython-input-5-9f622b1f6353> in <module>()
      7 doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
      8 # 3. Forward to generator
----> 9 outputs = model.generate(input_ids=input_ids, context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
     10 generated_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)
     11 print(generated_string)

5 frames
/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     13         def decorate_context(*args, **kwargs):
     14             with self:
---> 15                 return func(*args, **kwargs)
     16         return decorate_context
     17 

/usr/local/lib/python3.6/dist-packages/transformers/modeling_rag.py in generate(self, input_ids, attention_mask, context_input_ids, do_deduplication, num_return_sequences, num_beams, **kwargs)
    902             # then, run model forwards to get nll scores:
    903             new_input_ids = input_ids[index : index + 1].repeat(len(output_sequences), 1)
--> 904             outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
    905             top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]
    906 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/transformers/modeling_rag.py in forward(self, input_ids, attention_mask, encoder_outputs, decoder_input_ids, decoder_attention_mask, past_key_values, context_input_ids, context_attention_mask, doc_scores, use_cache, output_attentions, output_hidden_states, output_retrieved, exclude_bos_score, reduce_loss, labels, **kwargs)
    767             output_attentions=output_attentions,
    768             output_hidden_states=output_hidden_states,
--> 769             output_retrieved=output_retrieved,
    770         )
    771 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/transformers/modeling_rag.py in forward(self, input_ids, attention_mask, encoder_outputs, decoder_input_ids, decoder_attention_mask, past_key_values, doc_scores, context_input_ids, context_attention_mask, use_cache, output_attentions, output_hidden_states, output_retrieved)
    589                 assert (
    590                     context_input_ids is not None
--> 591                 ), "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
    592                 assert (
    593                     context_attention_mask is not None

AssertionError: Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.

I suspect context_input_ids is not passed to forward method. And if model is not initialised with retriever then forward function complain about missing context_input_ids or retriever. Referring to following piece of code in RagSequenceForGeneration class and generator function.

            # then, run model forwards to get nll scores:
            new_input_ids = input_ids[index : index + 1].repeat(len(output_sequences), 1)
            outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
            top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]

Expected behavior

It should work as intended as RagTokenForGeneration do.

patrickvonplaten commented 4 years ago

Hey @lalitpagaria - spot on! Thanks a lot for your issue, you're 100% correct here.

I actually noticed that the RagSequence generate function is a bit more complex so that we cannot do the decomposed (embed, retrieve, generate) example here...

The PR linked to the issue removes the use case from the examples and fixes the one for RagToken....

lalitpagaria commented 4 years ago

UPDATE: @patrickvonplaten Sorry for my miss-understanding. Yes without calling generate directly fixed this with your PR. Thanks you very much for fix.


@patrickvonplaten Thank you for update. I tried your changes on my code snippets and still got same error. If you see my example I am passing context_input_ids