deepset-ai / haystack

AI orchestration framework to build customizable, production-ready LLM applications. Connect components (models, vector DBs, file converters) to pipelines or agents that can interact with your data. With advanced retrieval methods, it's best suited for building RAG, question answering, semantic search or conversational agent chatbots.
https://haystack.deepset.ai
Apache License 2.0
17.82k stars 1.92k forks source link

PromptBuilder yields huggingface_hub.errors.ValidationError: Input validation error: `inputs` must have less than 4095 tokens. Given: 4701 #8188

Open fhamborg opened 3 months ago

fhamborg commented 3 months ago

Describe the bug

When using a standard RAG pipeline I get the above error.

Error message

  File "/home/felix/PycharmProjects/anychat/src/anychat/analysis/rag.py", line 124, in query_rag_in_document_store
    result = self.llm_pipeline.run(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/core/pipeline/pipeline.py", line 197, in run
    res = comp.run(**last_inputs[name])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/components/generators/hugging_face_api.py", line 187, in run
    return self._run_non_streaming(prompt, generation_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/components/generators/hugging_face_api.py", line 211, in _run_non_streaming
    tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/huggingface_hub/inference/_client.py", line 2061, in text_generation
    raise_text_generation_error(e)
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/huggingface_hub/inference/_common.py", line 457, in raise_text_generation_error
    raise exception from http_error
huggingface_hub.errors.ValidationError: Input validation error: `inputs` must have less than 4095 tokens. Given: 4701

Expected behavior

My expectation would be that there is a truncation built in that truncates the input so that not too many tokens are passed to the model. Ideally the input should be truncated not at the end of the prompt (in which case the question would be truncated) but at a specific part (e.g., instead of using all tokens of my top_k=10 documents but truncating those).

To Reproduce

query_template = """Beantworte die Frage basierend auf dem nachfolgenden Kontext und Chatverlauf. Antworte so detailliert wie möglich, aber nur mit Informationen aus dem Kontext. Wenn du eine Antwort nicht weißt, sag, dass du sie nicht kennst. Wenn sich eine Frage nicht auf den Kontext bezieht, sag, dass du die Frage nicht beantworten kannst und höre auf. Stelle nie selbst eine Frage.

Kontext:
{% for document in documents %}
    {{ document.content }}
{% endfor %}

Vorheriger Chatverlauf:
{{ history }}

Frage: {{ question }}
Antwort: """

    def _create_generator(self):
        if AnyChatConfig.hf_use_local_generator:
            return HuggingFaceLocalGenerator(
                model=self.model_id,
                task="text2text-generation",
                device=ComponentDevice.from_str("cuda:0"),
                huggingface_pipeline_kwargs={
                    "device_map": "auto",
                    "model_kwargs": {
                        "load_in_4bit": True,
                        "bnb_4bit_use_double_quant": True,
                        "bnb_4bit_quant_type": "nf4",
                        "bnb_4bit_compute_dtype": torch.bfloat16,
                    },
                },
                generation_kwargs={"max_new_tokens": 350},
            )
        else:
            return HuggingFaceAPIGenerator(
                api_type="text_generation_inference",
                api_params={"url": AnyChatConfig.hf_api_generator_url},
            )

    def create_llm_pipeline(self, document_store):
        """
        Creates an LLM pipeline that employs RAG on a document store that must have been set up before.
        :return:
        """
        # create the pipeline with the individual components
        self.llm_pipeline = Pipeline()
        self.llm_pipeline.add_component(
            "embedder",
            SentenceTransformersTextEmbedder(
                model=DocumentManager.embedding_model_id,
                device=ComponentDevice.from_str(
                    AnyChatConfig.hf_device_rag_text_embedder
                ),
            ),
        )
        self.llm_pipeline.add_component(
            "retriever",
            InMemoryEmbeddingRetriever(document_store=document_store, top_k=8),
        )
        self.llm_pipeline.add_component(
            "prompt_builder", PromptBuilder(template=query_template)
        )

        self.llm_pipeline.add_component("llm", self._create_generator())

        # connect the individual nodes to create the final pipeline
        self.llm_pipeline.connect("embedder.embedding", "retriever.query_embedding")
        self.llm_pipeline.connect("retriever", "prompt_builder.documents")
        self.llm_pipeline.connect("prompt_builder", "llm")

    def _get_formatted_history(self):
        history = ""
        for message in self.conversation_history:
            history += f"{message[0]}: {message[1]}\n"
        history = history.strip()

        return history

    def query_rag_in_document_store(self, query):
        """
        Uses the LLM and RAG to provide an answer to the given query based on the documents in the document store.
        :param query:
        :return:
        """
        logger.debug("querying using rag with: {}", query)

        # run the query through the pipeline
        result = self.llm_pipeline.run(
            {
                "embedder": {"text": query},
                "prompt_builder": {
                    "question": query,
                    "history": self._get_formatted_history(),
                },
                "llm": {"generation_kwargs": {"max_new_tokens": 350}},
            }
        )
        response = result["llm"]["replies"][0]
        post_processed_response = self._post_process_response(response, query)
        logger.debug(post_processed_response)

        self.conversation_history.append(("Frage", query))
        self.conversation_history.append(("Antwort", post_processed_response))

        return post_processed_response

FAQ Check

System:

anakin87 commented 3 months ago

Related to #6593

julian-risch commented 3 months ago

Thank you @fhamborg for the suggestion to truncate specific parts of the prompt. We are tracking this with #6593 Regarding the error caused by the input length, you could set a maximum length for truncation as part of the generation_kwargs of the HuggingFaceAPIGenerator. Does that work for you as a workaround? https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation

fhamborg commented 3 months ago

Thanks @julian-risch for the quick reply! As for setting the truncation parameter to some value, I guess while it would help to avoid the error above it would cut of the actual question in such cases where the input is too long (as the question is the last item in my prompt), which would be worse.

Is there a way to retrieve the actual input to the LLM (or rather the text that is converted to that input), i.e., the potentially truncated input? This way I could compare my full prompt and the actual one (after potential truncation) and if it in fact was truncated I could rerun the pipeline but with top_k for the retriever component set one lower, for example. Or would you think it'd be better to just catch the exception above and then rerun with decreased top_k?

EDIT: I just figured that the top_k parameter has to be set during creation of the pipeline, not during running it. So the above idea wouldn't work unfortunately (only if I recreated the pipeline each time the situation above occurs). Do you have any idea of how to both avoid the error above and also cutting of the question, other than setting top_k to a very low value (in which case still at some point, i.e., if the chat history get long, it would come up again)?

chris-brightbeam commented 2 months ago

@julian-risch

Thank you @fhamborg for the suggestion to truncate specific parts of the prompt. We are tracking this with #6593 Regarding the error caused by the input length, you could set a maximum length for truncation as part of the generation_kwargs of the HuggingFaceAPIGenerator. Does that work for you as a workaround? https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation

what's the exact naming of the parameter? the link does not contain a max_length and also not a truncation parameter.

My usecase is slightly different as I'm trying to achieve getting around this bug with HuggingFaceLocalGenerator. Setting max_length (generation_kwargs) here only applies to the output length but won't truncate the input and thus crashes my application.