langchain-ai / langchain

🦜🔗 Build context-aware reasoning applications
https://python.langchain.com
MIT License
91.99k stars 14.63k forks source link

Missing key error - Using PromptTemplate and GraphCypherQAChain. #24260

Open pierreoberholzer opened 1 month ago

pierreoberholzer commented 1 month ago

Checked other resources

Example Code

from langchain.prompts.prompt import PromptTemplate
from langchain.chains import GraphCypherQAChain

CYPHER_QA_TEMPLATE = """

You're an AI cook formulating Cypher statements to navigate through a recipe database.

Schema: {schema}

Examples: {examples}

Question: {question}

"""

CYPHER_GENERATION_PROMPT = PromptTemplate(
    input_variables=["schema","examples","question"],
    template = CYPHER_QA_TEMPLATE)

model = ChatOpenAI(temperature=0, model_name = "gpt-4-0125-preview")
chain = GraphCypherQAChain.from_llm(graph=graph, llm=model, verbose=True, validate_cypher = True, cypher_prompt = CYPHER_GENERATION_PROMPT)
res = chain.invoke({"schema": graph.schema,"examples" : examples,"question":question})

Error Message and Stack Trace (if applicable)

> Entering new GraphCypherQAChain chain...
Traceback (most recent call last):
  File "/Users/<path_to_my_project>/src/text2cypher_langchain.py", line 129, in <module>
    res = chain.invoke({"schema": graph.schema,"examples" : examples,"question":question})
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/<path_to_my_project>/venv/lib/python3.11/site-packages/langchain/chains/base.py", line 166, in invoke
    raise e
  File "/Users/<path_to_my_project>/venv/lib/python3.11/site-packages/langchain/chains/base.py", line 154, in invoke
    self._validate_inputs(inputs)
  File "/Users/<path_to_my_project>/venv/lib/python3.11/site-packages/langchain/chains/base.py", line 284, in _validate_inputs
    raise ValueError(f"Missing some input keys: {missing_keys}")
ValueError: Missing some input keys: {'query'}

Description

I'm getting a missing key error when passing custom arguments in PromptTemplate and GraphCypherQAChain. This seems similar to #19560 now closed.

System Info

RafaelXokito commented 1 month ago

The problem is that the GraphCypherQAChain class has a field for the input_key, which is "query" by default.

input_key: str = "query"

This field is then used to retrieve the question. See the code in the _call method from the GraphCypherQAChain class:

        question = inputs[self.input_key]

        intermediate_steps: List = []

        generated_cypher = self.cypher_generation_chain.run(
            {"question": question, "schema": self.graph_schema}, callbacks=callbacks
        )

Your code works if you change it to:

res = chain.invoke({"schema": graph.schema, "examples": examples, "query": question})
pierreoberholzer commented 1 month ago

Thanks @RafaelXokito. Problem persists.

Code

res = chain.invoke({"schema": graph.schema, "examples": examples, "query": question})

Error (similar as above)

...
  File "/Users/pierreoberholzer/code/knowledge_graphs/venv/lib/python3.11/site-packages/langchain/chains/base.py", line 284, in _validate_inputs
    raise ValueError(f"Missing some input keys: {missing_keys}")
ValueError: Missing some input keys: {'examples'}
RafaelXokito commented 1 month ago

At the moment the _call method don't handle any "examples" field/input.

        generated_cypher = self.cypher_generation_chain.run(
            {"question": question, "schema": self.graph_schema}, callbacks=callbacks
        )

My suggestion is to override _call method and change the code for the following. Tell me if you want help doing it.

        generated_cypher = self.cypher_generation_chain.run(
            {"question": question, "schema": self.graph_schema, "examples": inputs["examples"]}, callbacks=callbacks
        )
pierreoberholzer commented 1 month ago

Thanks @RafaelXokito. Yes, it would be highly appreciated if you could provide a patch.

RafaelXokito commented 1 month ago

Hi @pierreoberholzer,

Here's a simple example:

class GraphCypherQAChainAux(GraphCypherQAChain):
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Generate Cypher statement, use it to look up in the database and answer the question."""
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        question = inputs[self.input_key]

        intermediate_steps: List = []

        generated_cypher = self.cypher_generation_chain.run(
            {"question": question, "schema": self.graph_schema, "examples": inputs["examples"]}, callbacks=callbacks
        )

        # Extract Cypher code if it is wrapped in backticks
        generated_cypher = extract_cypher(generated_cypher)

        # Correct Cypher query if enabled
        if self.cypher_query_corrector:
            generated_cypher = self.cypher_query_corrector(generated_cypher)

        _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
        _run_manager.on_text(
            generated_cypher, color="green", end="\n", verbose=self.verbose
        )

        intermediate_steps.append({"query": generated_cypher})

        # Retrieve and limit the number of results
        # Generated Cypher can be null if query corrector identifies an invalid schema
        if generated_cypher:
            context = self.graph.query(generated_cypher)[: self.top_k]
        else:
            context = []

        if self.return_direct:
            final_result = context
        else:
            _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
            _run_manager.on_text(
                str(context), color="green", end="\n", verbose=self.verbose
            )

            intermediate_steps.append({"context": context})

            result = self.qa_chain(
                {"question": question, "context": context},
                callbacks=callbacks,
            )
            final_result = result[self.qa_chain.output_key]

        chain_result: Dict[str, Any] = {self.output_key: final_result}
        if self.return_intermediate_steps:
            chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps

        return chain_result

And here is how you can use it:

CYPHER_QA_TEMPLATE = """

        You're an AI cook formulating Cypher statements to navigate through a recipe database.

        Schema: {schema}

        Examples: {examples}

        Question: {question}

        """

CYPHER_GENERATION_PROMPT = PromptTemplate(
    input_variables=["schema", "examples", "question"],
    template=CYPHER_QA_TEMPLATE)

chain = GraphCypherQAChainAux.from_llm(
    graph=graph, llm=model, verbose=True, validate_cypher=True,
    cypher_prompt=CYPHER_GENERATION_PROMPT, input_key="question")

res = chain.invoke({"examples": examples, "question": question})

Take into account that I didn't include the schema because the original _call method already gets it.

Please give me your feedback on this.

pierreoberholzer commented 1 month ago

Many thanks @RafaelXokito ! Your patched version of the class is working. But quite a workaround..

Any change in sight in the package ?

RafaelXokito commented 1 month ago

Thank you, @pierreoberholzer, for your feedback!

I am considering making changes to how inputs are used in the cypher_generation_chain. Specifically, I aim to concatenate the inputs dynamically to streamline the process:

question = inputs[self.input_key]
args = {
    "question": question,
    "schema": self.graph_schema,
}
args.update(inputs)

intermediate_steps: List = []

generated_cypher = self.cypher_generation_chain.run(
    args, callbacks=callbacks
)

However, I am concerned about the potential impact this change might have on existing users of this method. @ccurme, could you please provide your insights on this proposed modification?

Thank you!

supreme-core commented 1 month ago

Hi @pierreoberholzer,

Here's a simple example:

class GraphCypherQAChainAux(GraphCypherQAChain):
    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        """Generate Cypher statement, use it to look up in the database and answer the question."""
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        callbacks = _run_manager.get_child()
        question = inputs[self.input_key]

        intermediate_steps: List = []

        generated_cypher = self.cypher_generation_chain.run(
            {"question": question, "schema": self.graph_schema, "examples": inputs["examples"]}, callbacks=callbacks
        )

        # Extract Cypher code if it is wrapped in backticks
        generated_cypher = extract_cypher(generated_cypher)

        # Correct Cypher query if enabled
        if self.cypher_query_corrector:
            generated_cypher = self.cypher_query_corrector(generated_cypher)

        _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
        _run_manager.on_text(
            generated_cypher, color="green", end="\n", verbose=self.verbose
        )

        intermediate_steps.append({"query": generated_cypher})

        # Retrieve and limit the number of results
        # Generated Cypher can be null if query corrector identifies an invalid schema
        if generated_cypher:
            context = self.graph.query(generated_cypher)[: self.top_k]
        else:
            context = []

        if self.return_direct:
            final_result = context
        else:
            _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
            _run_manager.on_text(
                str(context), color="green", end="\n", verbose=self.verbose
            )

            intermediate_steps.append({"context": context})

            result = self.qa_chain(
                {"question": question, "context": context},
                callbacks=callbacks,
            )
            final_result = result[self.qa_chain.output_key]

        chain_result: Dict[str, Any] = {self.output_key: final_result}
        if self.return_intermediate_steps:
            chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps

        return chain_result

And here is how you can use it:

CYPHER_QA_TEMPLATE = """

        You're an AI cook formulating Cypher statements to navigate through a recipe database.

        Schema: {schema}

        Examples: {examples}

        Question: {question}

        """

CYPHER_GENERATION_PROMPT = PromptTemplate(
    input_variables=["schema", "examples", "question"],
    template=CYPHER_QA_TEMPLATE)

chain = GraphCypherQAChainAux.from_llm(
    graph=graph, llm=model, verbose=True, validate_cypher=True,
    cypher_prompt=CYPHER_GENERATION_PROMPT, input_key="question")

res = chain.invoke({"examples": examples, "question": question})

Take into account that I didn't include the schema because the original _call method already gets it.

Please give me your feedback on this.

I ended up doing something similar to get it working to be able to generate sparql. I am also working with TTL file so you would also need to override RdfGraph class.