run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
33.11k stars 4.61k forks source link

[Bug]: async bug (llms/huggingface/TextGenerationInference) and SchemaLLMPathExtractor pydantic validation #14211

Open mphipps2 opened 1 week ago

mphipps2 commented 1 week ago

Bug Description

``I'm using TGI (newest version 0.7.0) to host locally, and I'm trying to run the PropertyGraphIndex routines. SimpleLLMPathExtractor and ImplicitPathExtractor seem to be fine, but I'm having issues with SchemaLLMPathExtractor.

Issue 1) in llms/huggingface/TextGenerationInference.py on line 1044 in the achat_with_tools() function, the response is called like:

response = self.achat(
            messages=messages,
            tools=tool_specs,
            tool_choice=resolve_tool_choice(tool_specs, tool_choice),
            **kwargs,
        )

I believe there should be an "await" in front of that achat call. Without the "await" I end up with empty responses and the error shown below for issue 1. After adding the await I get past this point

Issue 2) after fixing the first issue, I end up with the issue 2 error shown below, which I think is related to the on the fly pydantic schema defined in the SchemaLLMPathExtractor? But maybe it's related to the TGI integration as well

Any idea what's going on here?

Version

0.10.45.post1 and for llama-index-llms-huggingface: 0.2.3

Steps to Reproduce

kg_extractor = SchemaLLMPathExtractor(
            llm=llm,
            extract_prompt=DEFAULT_SCHEMA_PATH_EXTRACT_PROMPT,
            possible_entities=DEFAULT_ENTITIES,
            possible_relations=DEFAULT_RELATIONS,
            kg_validation_schema=DEFAULT_VALIDATION_SCHEMA,
            kg_schema_cls=None,
            num_workers=4,
            max_triplets_per_chunk=5,
            # if false, allows for values outside of the schema, useful for using the schema as a suggestion
            strict=True,
        )

    graph_store = Neo4jPropertyGraphStore(
        username=username,
        password=password,
        url=url,
        database=database
    )

    index = PropertyGraphIndex(
        nodes=nodes,
        embed_model=embedding_model,
        llm=llm,
        vector_store=None,
        kg_extractors=[
            kg_extractor
        ],
        property_graph_store=graph_store,
        embed_kg_nodes=True,
        show_progress=True,
    )

Relevant Logs/Tracbacks

Issue 1 error: 

  File "/lib/python3.9/site-packages/llama_index/core/llms/function_calling.py", line 141, in apredict_and_call
    response = await self.achat_with_tools(
  File "/lib/python3.9/site-packages/llama_index/llms/huggingface/base.py", line 1042, in achat_with_tools
    force_single_tool_call(response)
  File "/lib/python3.9/site-packages/llama_index/llms/huggingface/utils.py", line 45, in force_single_tool_call
    tool_calls = response.message.additional_kwargs.get("tool_calls", [])
AttributeError: 'coroutine' object has no attribute 'message'
sys:1: RuntimeWarning: coroutine 'Dispatcher.span.<locals>.async_wrapper' was never awaited

Issue 2 error:

  File "/lib/python3.9/site-packages/text_generation/client.py", line 531, in chat
    return await self._chat_single_response(request)
  File "/lib/python3.9/site-packages/text_generation/client.py", line 544, in _chat_single_response
    raise parse_error(resp.status, payload)
text_generation.errors.GenerationError: Request failed during generation: Server error: '/definitions/Triplet' does not exist within {'$functions': {'KGSchema': {'definitions': {'Entity': {'properties': {'name': {'title': 'Name', 'type': 'string'}, 'type': {'description': "Entity in a knowledge graph. Only extract entities with types that are listed as valid: typing.Literal['PRODUCT', 'MARKET', 'TECHNOLOGY', 'EVENT', 'CONCEPT', 'ORGANIZATION', 'PERSON', 'LOCATION', 'TIME', 'MISCELLANEOUS']", 'enum': ['PRODUCT', 'MARKET', 'TECHNOLOGY', 'EVENT', 'CONCEPT', 'ORGANIZATION', 'PERSON', 'LOCATION', 'TIME', 'MISCELLANEOUS'], 'title': 'Type', 'type': 'string'}}, 'required': ['type', 'name'], 'title': 'Entity', 'type': 'object'}, 'Relation': {'properties': {'type': {'description': "Relation in a knowledge graph. Only extract relations with types that are listed as valid: typing.Literal['USED_BY', 'USED_FOR', 'LOCATED_IN', 'PART_OF', 'WORKED_ON', 'HAS', 'IS_A', 'BORN_IN', 'DIED_IN', 'HAS_ALIAS']", 'enum': ['USED_BY', 'USED_FOR', 'LOCATED_IN', 'PART_OF', 'WORKED_ON', 'HAS', 'IS_A', 'BORN_IN', 'DIED_IN', 'HAS_ALIAS'], 'title': 'Type', 'type': 'string'}}, 'required': ['type'], 'title': 'Relation', 'type': 'object'}, 'Triplet': {'properties': {'object': {'$ref': '#/definitions/Entity'}, 'relation': {'$ref': '#/definitions/Relation'}, 'subject': {'$ref': '#/definitions/Entity'}}, 'required': ['subject', 'relation', 'object'], 'title': 'Triplet', 'type': 'object'}}, 'description': 'Knowledge Graph Schema.', 'properties': {'_name': {'const': 'KGSchema', 'type': 'string'}, 'triplets': {'items': {'$ref': '#/definitions/Triplet'}, 'title': 'Triplets', 'type': 'array'}}, 'required': ['triplets', '_name'], 'type': 'object'}, 'notify_error': {'properties': {'_name': {'const': 'notify_error', 'type': 'string'}, 'error': {'description': 'The error or issue to notify', 'type': 'string'}}, 'required': ['error', '_name'], 'type': 'object'}}, 'properties': {'function': {'anyOf': [{'$ref': '#/$functions/KGSchema'}, {'$ref': '#/$functions/notify_error'}]}}}
sys:1: RuntimeWarning: coroutine 'SchemaLLMPathExtractor._aextract' was never awaited
Extracting paths from text with schema:   0%|
dosubot[bot] commented 1 week ago

To resolve the issues with the SchemaLLMPathExtractor in TGI version 0.7.0, you need to address both the missing "await" in the achat_with_tools() function and the GenerationError related to the on-the-fly pydantic schema.

Issue 1: Missing "await" in achat_with_tools()

You need to ensure that the achat_with_tools() function is properly awaited. Here is the corrected code for the achat_with_tools() function:

async def apredict_and_call(
    self,
    tools: List["BaseTool"],
    user_msg: Optional[Union[str, ChatMessage]] = None,
    chat_history: Optional[List[ChatMessage]] = None,
    verbose: bool = False,
    allow_parallel_tool_calls: bool = False,
    **kwargs: Any,
) -> "AgentChatResponse":
    """Predict and call the tool."""
    from llama_index.core.tools.calling import acall_tool_with_selection
    from llama_index.core.chat_engine.types import AgentChatResponse

    if not self.metadata.is_function_calling_model:
        return await super().apredict_and_call(
            tools,
            user_msg=user_msg,
            chat_history=chat_history,
            verbose=verbose,
            **kwargs,
        )

    response = await self.achat_with_tools(
        tools,
        user_msg,
        chat_history,
        verbose=verbose,
        allow_parallel_tool_calls=allow_parallel_tool_calls,
        **kwargs,
    )
    tool_calls = self.get_tool_calls_from_response(response)
    tool_tasks = [
        acall_tool_with_selection(tool_call, tools, verbose=verbose)
        for tool_call in tool_calls
    ]
    tool_outputs = await asyncio.gather(*tool_tasks)
    if allow_parallel_tool_calls:
        output_text = "\n\n".join(
            [tool_output.content for tool_output in tool_outputs]
        )
        return AgentChatResponse(response=output_text, sources=tool_outputs)
    else:
        if len(tool_outputs) > 1:
            raise ValueError("Invalid")
        return AgentChatResponse(
            response=tool_outputs[0].content, sources=tool_outputs
        )

This code ensures that the achat_with_tools() function is properly awaited, which should resolve the missing "await" issue [1].

Issue 2: GenerationError Related to Pydantic Schema

The GenerationError is likely due to the on-the-fly pydantic schema generation in the SchemaLLMPathExtractor. Ensure that the schema is correctly defined and used. Here is an example of how the SchemaLLMPathExtractor handles schema validation and extraction:

class SchemaLLMPathExtractor(TransformComponent):
    # ... (other parts of the class)

    def _prune_invalid_triplets(self, kg_schema: Any) -> List[Triplet]:
        """Prune invalid triplets."""
        assert isinstance(kg_schema, self.kg_schema_cls)

        valid_triplets = []
        for triplet in kg_schema.triplets:
            subject = triplet.subject.name
            subject_type = triplet.subject.type

            relation = triplet.relation.type

            obj = triplet.object.name
            obj_type = triplet.object.type

            # check relations
            if relation not in self.kg_validation_schema.get(
                subject_type, [relation]
            ) and relation not in self.kg_validation_schema.get(obj_type, [relation]):
                continue

            # remove self-references
            if subject.lower() == obj.lower():
                continue

            subj_node = EntityNode(label=subject_type, name=subject)
            obj_node = EntityNode(label=obj_type, name=obj)
            rel_node = Relation(
                label=relation, source_id=subj_node.id, target_id=obj_node.id
            )
            valid_triplets.append((subj_node, rel_node, obj_node))

        return valid_triplets

    async def _aextract(self, node: BaseNode) -> BaseNode:
        """Extract triplets from a node."""
        assert hasattr(node, "text")

        text = node.get_content(metadata_mode="llm")
        try:
            kg_schema = await self.llm.astructured_predict(
                self.kg_schema_cls,
                self.extract_prompt,
                text=text,
                max_triplets_per_chunk=self.max_triplets_per_chunk,
            )
            triplets = self._prune_invalid_triplets(kg_schema)
        except ValueError:
            triplets = []

        existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
        existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])

        metadata = node.metadata.copy()
        for subj, rel, obj in triplets:
            subj.properties = metadata
            obj.properties = metadata
            rel.properties = metadata

            existing_relations.append(rel)
            existing_nodes.append(subj)
            existing_nodes.append(obj)

        node.metadata[KG_NODES_KEY] = existing_nodes
        node.metadata[KG_RELATIONS_KEY] = existing_relations

        return node

    async def acall(
        self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
    ) -> List[BaseNode]:
        """Extract triplets from nodes async."""
        jobs = []
        for node in nodes:
            jobs.append(self._aextract(node))
        return await run_jobs(
            jobs,
            workers=self.num_workers,
            show_progress=show_progress,
            desc="Extracting paths from text with schema",
        )

This class ensures that the extracted paths conform to the specified schema and handles the extraction process efficiently using asynchronous methods [2].

By addressing these two issues, you should be able to resolve the problems with the SchemaLLMPathExtractor in TGI version 0.7.0.

To continue talking to Dosu, mention @dosu.

logan-markewich commented 1 week ago

Hmm, the second issue is something related to the format of the schema being sent 🤔 will be tricky to debug without access to TGI myself, but I can try and take a look

mphipps2 commented 1 week ago

HI @logan-markewich

I've made some progress but don't understand how the kg_schema_cls in schema_llm.py is supposed to look. By default I was setting the kg_schema_cls to None in the SchemaLLMPathExtractor:

        kg_extractor = SchemaLLMPathExtractor(
            llm=llm,
            extract_prompt=DEFAULT_SCHEMA_PATH_EXTRACT_PROMPT,
            possible_entities=DEFAULT_ENTITIES,
            possible_relations=DEFAULT_RELATIONS,
            kg_validation_schema=DEFAULT_VALIDATION_SCHEMA,
            kg_schema_cls=None,
            num_workers=4,
            max_triplets_per_chunk=5,
            strict=True,
        )

So this triggers the pydantic model on the fly in schema_llm.py. This then results in this schema being sent to model through TGI's async client (text_generation/client.py):

{
    "$functions": {
        "KGSchema": {
            "definitions": {
                "Entity": {
                    "properties": {
                        "name": {
                            "title": "Name",
                            "type": "string"
                        },
                        "type": {
                            "description": "Entity in a knowledge graph. Only extract entities with types that are listed as valid: typing.Literal['PRODUCT', 'MARKET', 'TECHNOLOGY', 'EVENT', 'CONCEPT', 'ORGANIZATION', 'PERSON', 'LOCATION', 'TIME', 'MISCELLANEOUS']",
                            "enum": [
                                "PRODUCT",
                                "MARKET",
                                "TECHNOLOGY",
                                "EVENT",
                                "CONCEPT",
                                "ORGANIZATION",
                                "PERSON",
                                "LOCATION",
                                "TIME",
                                "MISCELLANEOUS"
                            ],
                            "title": "Type",
                            "type": "string"
                        }
                    },
                    "required": ["type", "name"],
                    "title": "Entity",
                    "type": "object"
                },
                "Relation": {
                    "properties": {
                        "type": {
                            "description": "Relation in a knowledge graph. Only extract relations with types that are listed as valid: typing.Literal['USED_BY', 'USED_FOR', 'LOCATED_IN', 'PART_OF', 'WORKED_ON', 'HAS', 'IS_A', 'BORN_IN', 'DIED_IN', 'HAS_ALIAS']",
                            "enum": [
                                "USED_BY",
                                "USED_FOR",
                                "LOCATED_IN",
                                "PART_OF",
                                "WORKED_ON",
                                "HAS",
                                "IS_A",
                                "BORN_IN",
                                "DIED_IN",
                                "HAS_ALIAS"
                            ],
                            "title": "Type",
                            "type": "string"
                        }
                    },
                    "required": ["type"],
                    "title": "Relation",
                    "type": "object"
                },
                "Triplet": {
                    "properties": {
                        "object": {
                            "$ref": "#/definitions/Entity"
                        },
                        "relation": {
                            "$ref": "#/definitions/Relation"
                        },
                        "subject": {
                            "$ref": "#/definitions/Entity"
                        }
                    },
                    "required": ["subject", "relation", "object"],
                    "title": "Triplet",
                    "type": "object"
                }
            },
            "description": "Knowledge Graph Schema.",
            "properties": {
                "_name": {
                    "const": "KGSchema",
                    "type": "string"
                },
                "triplets": {
                    "items": {
                        "$ref": "#/definitions/Triplet"
                    },
                    "title": "Triplets",
                    "type": "array"
                }
            },
            "required": ["triplets", "_name"],
            "type": "object"
        },
        "notify_error": {
            "properties": {
                "_name": {
                    "const": "notify_error",
                    "type": "string"
                },
                "error": {
                    "description": "The error or issue to notify",
                    "type": "string"
                }
            },
            "required": ["error", "_name"],
            "type": "object"
        }
    },
    "properties": {
        "function": {
            "anyOf": [
                {
                    "$ref": "#/$functions/KGSchema"
                },
                {
                    "$ref": "#/$functions/notify_error"
                }
            ]
        }
    }
}

But then that results in the Issue 2 from the original post:

src.text_generation.errors.GenerationError: Request failed during generation: Server error: '/definitions/Triplet' does not exist within {'$functions': {'KGSchema': {'definitions': {'Entity': {'properties': {'name': {'title': ...

As you can see above, '/definitions/Triplet' is there but the schema being sent in the TGI client has the additional /$functions/KGSchema hierarchy above it. And if I edit the original kg_schema_cls to "#/$functions/KGSchema/definitions/Triplet" I get past this issue and get model response back, however at that point I have an issue with the parameters not being returned (likely related to my updated schema not being formatted correctly:

KeyError: "'parameters' key not found in tool_call['function']:  

Any help you can give in terms of how the kg_schema_cls should look? Is there an issue with the "Build a pydantic model on the fly" section of schema_llm.py not initializing the $refs with the "#/$functions/KGSchema/" prefix or is this a downstream unpacking issue in the TGI integration?

mphipps2 commented 1 week ago

@logan-markewich an update: the SchemaLLMPathExtractor works like it's supposed to with AzureOpenAI, so the schema bug is related to the tool call initialization in the Huggingface/TGI integration

mphipps2 commented 1 week ago

@logan-markewich I'm able to get SchemaLLMPathExtractor working now with TGI function calling but it took some hacking to the TGI schema:

1) the first bug from the original post: achat_with_tools() from huggingface/text_generation_inference.py needs awaited:

        response = await self.achat(
            messages=messages,
            tools=tool_specs,
            tool_choice=resolve_tool_choice(tool_specs, tool_choice),
            **kwargs,
        )

2) The second bug from the original post was the #ref paths needing updated to absolute paths for some reason. So instead of "$ref": "#/definitions/Triplet", I use "$ref": "#/$functions/KGSchema/definitions/Triplet". Don't fully understand why this is necessary though since the OpenAI schema uses "$ref": "#/definitions/Triplet". Anyway this update is in core/tools/types.py

@dataclass
class ToolMetadata:
    description: str
    name: Optional[str] = None
    fn_schema: Optional[Type[BaseModel]] = DefaultToolFnSchema
    return_direct: bool = False

    def get_parameters_dict(self) -> dict:

        if self.fn_schema is None:
            print('fn_schema none')
            parameters = {
                "type": "object",
                "title": "KGSchema",
                "description": "Knowledge Graph Schema.",
                "properties": {
                    "input": {"title": "input query string", "type": "string"},
                },
                "required": ["input"],
            }
        else:
            parameters = self.fn_schema.schema()
            parameters = {
                k: v
                for k, v in parameters.items()
                if k in ["type", "properties", "required", "definitions"]
            }
            # Update $ref paths in properties
            if "properties" in parameters:
                for prop_key, prop_value in parameters["properties"].items():
                    if isinstance(prop_value, dict) and "$ref" in prop_value:
                        prop_value["$ref"] = prop_value["$ref"].replace(
                            "#/definitions", "#/$functions/KGSchema/definitions"
                        )
                    elif isinstance(prop_value, dict) and "items" in prop_value and "$ref" in prop_value["items"]:
                        prop_value["items"]["$ref"] = prop_value["items"]["$ref"].replace(
                            "#/definitions", "#/$functions/KGSchema/definitions"
                        )

            # Update $ref paths in definitions
            if "definitions" in parameters:
                for definition_key, definition_value in parameters["definitions"].items():
                    if "properties" in definition_value:
                        for prop_key, prop_value in definition_value["properties"].items():
                            if "$ref" in prop_value:
                                prop_value["$ref"] = prop_value["$ref"].replace(
                                    "#/definitions", "#/$functions/KGSchema/definitions"
                                )
            # Add title and description
            parameters["title"] = "KGSchema"
            parameters["description"] = "Knowledge Graph Schema."
        return parameters

3) After making these updates, the response generation works but we get one last issue in unpacking the response in huggingface/text_generation_inference.py. The problem here is that we expect tool_call["parameters"] to exist. However for some reason it comes out as tool_call["arguments"]. So I update the function below to check for both:

    def get_tool_calls_from_response(
        self,
        response: "AgentChatResponse",
        error_on_no_tool_call: bool = True,
    ) -> List[ToolSelection]:
        """Predict and call the tool."""
        tool_calls = response.message.additional_kwargs.get("tool_calls", [])
        print('response: ' , response)
        print('tool_calls: ' , tool_calls)
        print('message: ' , response.message)
        if len(tool_calls) < 1:
            if error_on_no_tool_call:
                raise ValueError(
                    f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
                )
            else:
                return []

        tool_selections = []
        for tool_call in tool_calls:

            # TODO Add typecheck with ToolCall from TGI once the client is updated
            if tool_call and (tc_type := tool_call["type"]) != "function":
                raise ValueError(
                    f"Invalid tool type: got {tc_type}, expect 'function'."
                )
            # Ensure the 'parameters' key exists
            function_details = tool_call.get("function", {})
            argument_dict = function_details.get("parameters", function_details.get("arguments"))

            if argument_dict is None:
                raise KeyError(f"'parameters' or 'arguments' key not found in tool_call['function']: {tool_call}")

            tool_selections.append(
                ToolSelection(
                    tool_id=tool_call["id"],
                    tool_name=tool_call["function"][
                        "name"
                    ],  # NOTE for now the tool_name is hardcoded 'tools' in TGI
                    tool_kwargs=argument_dict,
                )
            )

        return tool_selections

Issue 1 looks like an obvious bug. Issues 2 and 3 are hacks to make this work but there's some underlying schema issue