run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
33.4k stars 4.68k forks source link

[Bug]: TextToCypherRetriever raise error is allowed_output_fields is not specified #14450

Open ootkin opened 1 week ago

ootkin commented 1 week ago

Bug Description

I'm using an existing neo4j graph and trying to run TextToCypherRetriever

When I run a query, the retriever returns a UnboundLocalError

Version

0.10.50

Steps to Reproduce

graph_store = Neo4jPropertyGraphStore(
    username=os.getenv('GRAPH_USERNAME'),
    password=os.getenv('GRAPH_PASSWORD'),
    url=os.getenv('GRAPH_URI'),
)

index = PropertyGraphIndex.from_existing(
    property_graph_store=graph_store,
    embed_kg_nodes=False,
)

text_to_cypher = TextToCypherRetriever(
    index.property_graph_store,
)

retriever = PGRetriever(sub_retrievers=[text_to_cypher])

if __name__ == '__main__':
    nodes = retriever.retrieve("query")
    pprint(nodes)

Relevant Logs/Tracbacks

File "llama_index/core/async_utils.py", line 47, in asyncio_run
    raise type(e)(f"Error running coroutine: {e!s}") from e
UnboundLocalError: Error running coroutine: cannot access local variable 'parsed_cypher_query' where it is not associated with a value
dosubot[bot] commented 1 week ago

To resolve the neo4j.exceptions.CypherSyntaxError and UnboundLocalError when using the TextToCypherRetriever in version 0.10.50, ensure that the Cypher query is correctly constructed and validated. Additionally, handle the UnboundLocalError by properly initializing all variables before use.

Here is the corrected code snippet:

def run_request(self, question, history=None, retry=True):
    """
    Executes a Cypher query based on a given question.

    Args:
        question (str): The question to execute the Cypher query for.
        history (list, optional): A list of previous interactions for context. Defaults to None.
        retry (bool, optional): Whether to retry in case of a syntax error. Defaults to True.

    Returns:
        list/str: The query results or an error message.
    """
    from neo4j.exceptions import CypherSyntaxError

    # Construct Cypher statement
    cypher = self.construct_cypher_query(question, history)
    # Validate Cypher statement
    if self.cypher_query_corrector:
        cypher = self.cypher_query_corrector(cypher)
    print(cypher)
    try:
        return self.query_graph_db(cypher)
    # Self-healing flow
    except CypherSyntaxError as e:
        # If out of retries
        if not retry:
            return "Invalid Cypher syntax"
        # Self-healing Cypher flow by
        # providing specific error to GPT-4
        print("Retrying")
        return self.run_request(
            question,
            [
                ChatMessage(role=MessageRole.ASSISTANT, content=cypher),
                ChatMessage(
                    role=MessageRole.SYSTEM,
                    content=f"This query returns an error: {e!s}\n"
                    "Give me a improved query that works without any explanations or apologies",
                ),
            ],
            retry=False,
        )

class TextToCypherRetriever(BasePGRetriever):
    """A Text-to-Cypher retriever that uses a language model to generate Cypher queries.

    NOTE: Executing arbitrary cypher has its risks. Ensure you take the needed measures
    (read-only roles, sandboxed env, etc.) to ensure safe usage in a production environment.

    Args:
        graph_store (PropertyGraphStore):
            The graph store to retrieve data from.
        llm (Optional[LLM], optional):
            The language model to use. Defaults to Settings.llm.
        text_to_cypher_template (Optional[Union[PromptTemplate, str]], optional):
            The template to use for the text-to-cypher query. Defaults to None.
        response_template (Optional[str], optional):
            The template to use for the response. Defaults to None.
        cypher_validator (Optional[callable], optional):
            A callable function to validate the generated Cypher query. Defaults to None.
        allowed_query_fields (Optional[List[str]], optional):
            The fields to allow in the query output. Defaults to ["text", "label", "type"].
    """

    def __init__(
        self,
        graph_store: PropertyGraphStore,
        llm: Optional[LLM] = None,
        text_to_cypher_template: Optional[Union[PromptTemplate, str]] = None,
        response_template: Optional[str] = None,
        cypher_validator: Optional[callable] = None,
        allowed_output_fields: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> None:
        if not graph_store.supports_structured_queries:
            raise ValueError(
                "The provided graph store does not support cypher queries."
            )

        self.llm = llm or Settings.llm

        if isinstance(text_to_cypher_template, str):
            text_to_cypher_template = PromptTemplate(text_to_cypher_template)

        self.response_template = response_template or DEFAULT_RESPONSE_TEMPLATE
        self.text_to_cypher_template = (
            text_to_cypher_template or graph_store.text_to_cypher_template
        )
        self.cypher_validator = cypher_validator
        self.allowed_output_fields = allowed_output_fields
        super().__init__(graph_store=graph_store, include_text=False)

    def _parse_generated_cyher(self, cypher_query: str) -> str:
        if self.cypher_validator is not None:
            return self.cypher_validator(cypher_query)
        return cypher_query

    def _clean_query_output(self, query_output: Any) -> Any:
        """Iterate the cypher response, looking for the allowed fields."""
        if isinstance(query_output, dict):
            filtered = {}
            for key, value in query_output.items():
                if (
                    self.allowed_output_fields is None
                    or key in self.allowed_output_fields
                ):
                    filtered[key] = value
                elif isinstance(value, (dict, list)):
                    filtered_value = self._clean_query_output(value)
                    if filtered_value:
                        filtered[key] = filtered_value
            return filtered
        elif isinstance(query_output, list):
            filtered = []
            for item in query_output:
                filtered_item = self._clean_query_output(item)
                if filtered_item:
                    filtered.append(filtered_item)
            return filtered

        return None

    def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        schema = self._graph_store.get_schema_str()
        question = query_bundle.query_str

        response = self.llm.predict(
            self.text_to_cypher_template,
            schema=schema,
            question=question,
        )

        parsed_cypher_query = response
        if self.allowed_output_fields is not None:
            parsed_cypher_query = self._parse_generated_cyher(response)

        query_output = self._graph_store.structured_query(parsed_cypher_query)

        cleaned_query_output = self._clean_query_output(query_output)

        node_text = self.response_template.format(
            query=parsed_cypher_query,
            response=str(cleaned_query_output),
        )

        return [
            NodeWithScore(
                node=TextNode(
                    text=node_text,
                ),
                score=1.0,
            )
        ]

    async def aretrieve_from_graph(
        self, query_bundle: QueryBundle
    ) -> List[NodeWithScore]:
        schema = await self._graph_store.aget_schema_str()
        question = query_bundle.query_str

        response = await self.llm.apredict(
            self.text_to_cypher_template,
            schema=schema,
            question=question,
        )

        parsed_cypher_query = response
        if self.allowed_output_fields is not None:
            parsed_cypher_query = self._parse_generated_cyher(response)

        query_output = await self._graph_store.astructured_query(parsed_cypher_query)

        cleaned_query_output = self._clean_query_output(query_output)

        node_text = self.response_template.format(
            query=parsed_cypher_query,
            response=str(cleaned_query_output),
        )

        return [
            NodeWithScore(
                node=TextNode(
                    text=node_text,
                ),
                score=1.0,
            )
        ]

Ensure that the construct_cypher_query method correctly constructs the Cypher query and that the cypher_query_corrector is properly initialized if validate_cypher is set to True. This should help in resolving the CypherSyntaxError. The UnboundLocalError can be avoided by ensuring that all variables are properly initialized and used within their scope [1][2].

To continue talking to Dosu, mention @dosu.

ootkin commented 1 week ago

@dosu i just updated my comment

dosubot[bot] commented 1 week ago

@logan-markewich, this one's a bit out of my wheelhouse, could you step in and provide your guidance?

To continue talking to Dosu, mention @dosu.