Open mphipps2 opened 1 week ago
To resolve the issues with the SchemaLLMPathExtractor
in TGI version 0.7.0, you need to address both the missing "await" in the achat_with_tools()
function and the GenerationError
related to the on-the-fly pydantic schema.
achat_with_tools()
You need to ensure that the achat_with_tools()
function is properly awaited. Here is the corrected code for the achat_with_tools()
function:
async def apredict_and_call(
self,
tools: List["BaseTool"],
user_msg: Optional[Union[str, ChatMessage]] = None,
chat_history: Optional[List[ChatMessage]] = None,
verbose: bool = False,
allow_parallel_tool_calls: bool = False,
**kwargs: Any,
) -> "AgentChatResponse":
"""Predict and call the tool."""
from llama_index.core.tools.calling import acall_tool_with_selection
from llama_index.core.chat_engine.types import AgentChatResponse
if not self.metadata.is_function_calling_model:
return await super().apredict_and_call(
tools,
user_msg=user_msg,
chat_history=chat_history,
verbose=verbose,
**kwargs,
)
response = await self.achat_with_tools(
tools,
user_msg,
chat_history,
verbose=verbose,
allow_parallel_tool_calls=allow_parallel_tool_calls,
**kwargs,
)
tool_calls = self.get_tool_calls_from_response(response)
tool_tasks = [
acall_tool_with_selection(tool_call, tools, verbose=verbose)
for tool_call in tool_calls
]
tool_outputs = await asyncio.gather(*tool_tasks)
if allow_parallel_tool_calls:
output_text = "\n\n".join(
[tool_output.content for tool_output in tool_outputs]
)
return AgentChatResponse(response=output_text, sources=tool_outputs)
else:
if len(tool_outputs) > 1:
raise ValueError("Invalid")
return AgentChatResponse(
response=tool_outputs[0].content, sources=tool_outputs
)
This code ensures that the achat_with_tools()
function is properly awaited, which should resolve the missing "await" issue [1].
GenerationError
Related to Pydantic SchemaThe GenerationError
is likely due to the on-the-fly pydantic schema generation in the SchemaLLMPathExtractor
. Ensure that the schema is correctly defined and used. Here is an example of how the SchemaLLMPathExtractor
handles schema validation and extraction:
class SchemaLLMPathExtractor(TransformComponent):
# ... (other parts of the class)
def _prune_invalid_triplets(self, kg_schema: Any) -> List[Triplet]:
"""Prune invalid triplets."""
assert isinstance(kg_schema, self.kg_schema_cls)
valid_triplets = []
for triplet in kg_schema.triplets:
subject = triplet.subject.name
subject_type = triplet.subject.type
relation = triplet.relation.type
obj = triplet.object.name
obj_type = triplet.object.type
# check relations
if relation not in self.kg_validation_schema.get(
subject_type, [relation]
) and relation not in self.kg_validation_schema.get(obj_type, [relation]):
continue
# remove self-references
if subject.lower() == obj.lower():
continue
subj_node = EntityNode(label=subject_type, name=subject)
obj_node = EntityNode(label=obj_type, name=obj)
rel_node = Relation(
label=relation, source_id=subj_node.id, target_id=obj_node.id
)
valid_triplets.append((subj_node, rel_node, obj_node))
return valid_triplets
async def _aextract(self, node: BaseNode) -> BaseNode:
"""Extract triplets from a node."""
assert hasattr(node, "text")
text = node.get_content(metadata_mode="llm")
try:
kg_schema = await self.llm.astructured_predict(
self.kg_schema_cls,
self.extract_prompt,
text=text,
max_triplets_per_chunk=self.max_triplets_per_chunk,
)
triplets = self._prune_invalid_triplets(kg_schema)
except ValueError:
triplets = []
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
metadata = node.metadata.copy()
for subj, rel, obj in triplets:
subj.properties = metadata
obj.properties = metadata
rel.properties = metadata
existing_relations.append(rel)
existing_nodes.append(subj)
existing_nodes.append(obj)
node.metadata[KG_NODES_KEY] = existing_nodes
node.metadata[KG_RELATIONS_KEY] = existing_relations
return node
async def acall(
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
"""Extract triplets from nodes async."""
jobs = []
for node in nodes:
jobs.append(self._aextract(node))
return await run_jobs(
jobs,
workers=self.num_workers,
show_progress=show_progress,
desc="Extracting paths from text with schema",
)
This class ensures that the extracted paths conform to the specified schema and handles the extraction process efficiently using asynchronous methods [2].
By addressing these two issues, you should be able to resolve the problems with the SchemaLLMPathExtractor
in TGI version 0.7.0.
To continue talking to Dosu, mention @dosu.
Hmm, the second issue is something related to the format of the schema being sent 🤔 will be tricky to debug without access to TGI myself, but I can try and take a look
HI @logan-markewich
I've made some progress but don't understand how the kg_schema_cls in schema_llm.py is supposed to look. By default I was setting the kg_schema_cls to None in the SchemaLLMPathExtractor:
kg_extractor = SchemaLLMPathExtractor(
llm=llm,
extract_prompt=DEFAULT_SCHEMA_PATH_EXTRACT_PROMPT,
possible_entities=DEFAULT_ENTITIES,
possible_relations=DEFAULT_RELATIONS,
kg_validation_schema=DEFAULT_VALIDATION_SCHEMA,
kg_schema_cls=None,
num_workers=4,
max_triplets_per_chunk=5,
strict=True,
)
So this triggers the pydantic model on the fly in schema_llm.py. This then results in this schema being sent to model through TGI's async client (text_generation/client.py):
{
"$functions": {
"KGSchema": {
"definitions": {
"Entity": {
"properties": {
"name": {
"title": "Name",
"type": "string"
},
"type": {
"description": "Entity in a knowledge graph. Only extract entities with types that are listed as valid: typing.Literal['PRODUCT', 'MARKET', 'TECHNOLOGY', 'EVENT', 'CONCEPT', 'ORGANIZATION', 'PERSON', 'LOCATION', 'TIME', 'MISCELLANEOUS']",
"enum": [
"PRODUCT",
"MARKET",
"TECHNOLOGY",
"EVENT",
"CONCEPT",
"ORGANIZATION",
"PERSON",
"LOCATION",
"TIME",
"MISCELLANEOUS"
],
"title": "Type",
"type": "string"
}
},
"required": ["type", "name"],
"title": "Entity",
"type": "object"
},
"Relation": {
"properties": {
"type": {
"description": "Relation in a knowledge graph. Only extract relations with types that are listed as valid: typing.Literal['USED_BY', 'USED_FOR', 'LOCATED_IN', 'PART_OF', 'WORKED_ON', 'HAS', 'IS_A', 'BORN_IN', 'DIED_IN', 'HAS_ALIAS']",
"enum": [
"USED_BY",
"USED_FOR",
"LOCATED_IN",
"PART_OF",
"WORKED_ON",
"HAS",
"IS_A",
"BORN_IN",
"DIED_IN",
"HAS_ALIAS"
],
"title": "Type",
"type": "string"
}
},
"required": ["type"],
"title": "Relation",
"type": "object"
},
"Triplet": {
"properties": {
"object": {
"$ref": "#/definitions/Entity"
},
"relation": {
"$ref": "#/definitions/Relation"
},
"subject": {
"$ref": "#/definitions/Entity"
}
},
"required": ["subject", "relation", "object"],
"title": "Triplet",
"type": "object"
}
},
"description": "Knowledge Graph Schema.",
"properties": {
"_name": {
"const": "KGSchema",
"type": "string"
},
"triplets": {
"items": {
"$ref": "#/definitions/Triplet"
},
"title": "Triplets",
"type": "array"
}
},
"required": ["triplets", "_name"],
"type": "object"
},
"notify_error": {
"properties": {
"_name": {
"const": "notify_error",
"type": "string"
},
"error": {
"description": "The error or issue to notify",
"type": "string"
}
},
"required": ["error", "_name"],
"type": "object"
}
},
"properties": {
"function": {
"anyOf": [
{
"$ref": "#/$functions/KGSchema"
},
{
"$ref": "#/$functions/notify_error"
}
]
}
}
}
But then that results in the Issue 2 from the original post:
src.text_generation.errors.GenerationError: Request failed during generation: Server error: '/definitions/Triplet' does not exist within {'$functions': {'KGSchema': {'definitions': {'Entity': {'properties': {'name': {'title': ...
As you can see above, '/definitions/Triplet' is there but the schema being sent in the TGI client has the additional /$functions/KGSchema hierarchy above it. And if I edit the original kg_schema_cls to "#/$functions/KGSchema/definitions/Triplet" I get past this issue and get model response back, however at that point I have an issue with the parameters not being returned (likely related to my updated schema not being formatted correctly:
KeyError: "'parameters' key not found in tool_call['function']:
Any help you can give in terms of how the kg_schema_cls should look? Is there an issue with the "Build a pydantic model on the fly" section of schema_llm.py not initializing the $refs with the "#/$functions/KGSchema/" prefix or is this a downstream unpacking issue in the TGI integration?
@logan-markewich an update: the SchemaLLMPathExtractor works like it's supposed to with AzureOpenAI, so the schema bug is related to the tool call initialization in the Huggingface/TGI integration
@logan-markewich I'm able to get SchemaLLMPathExtractor working now with TGI function calling but it took some hacking to the TGI schema:
1) the first bug from the original post: achat_with_tools() from huggingface/text_generation_inference.py needs awaited:
response = await self.achat(
messages=messages,
tools=tool_specs,
tool_choice=resolve_tool_choice(tool_specs, tool_choice),
**kwargs,
)
2) The second bug from the original post was the #ref paths needing updated to absolute paths for some reason. So instead of "$ref": "#/definitions/Triplet", I use "$ref": "#/$functions/KGSchema/definitions/Triplet". Don't fully understand why this is necessary though since the OpenAI schema uses "$ref": "#/definitions/Triplet". Anyway this update is in core/tools/types.py
@dataclass
class ToolMetadata:
description: str
name: Optional[str] = None
fn_schema: Optional[Type[BaseModel]] = DefaultToolFnSchema
return_direct: bool = False
def get_parameters_dict(self) -> dict:
if self.fn_schema is None:
print('fn_schema none')
parameters = {
"type": "object",
"title": "KGSchema",
"description": "Knowledge Graph Schema.",
"properties": {
"input": {"title": "input query string", "type": "string"},
},
"required": ["input"],
}
else:
parameters = self.fn_schema.schema()
parameters = {
k: v
for k, v in parameters.items()
if k in ["type", "properties", "required", "definitions"]
}
# Update $ref paths in properties
if "properties" in parameters:
for prop_key, prop_value in parameters["properties"].items():
if isinstance(prop_value, dict) and "$ref" in prop_value:
prop_value["$ref"] = prop_value["$ref"].replace(
"#/definitions", "#/$functions/KGSchema/definitions"
)
elif isinstance(prop_value, dict) and "items" in prop_value and "$ref" in prop_value["items"]:
prop_value["items"]["$ref"] = prop_value["items"]["$ref"].replace(
"#/definitions", "#/$functions/KGSchema/definitions"
)
# Update $ref paths in definitions
if "definitions" in parameters:
for definition_key, definition_value in parameters["definitions"].items():
if "properties" in definition_value:
for prop_key, prop_value in definition_value["properties"].items():
if "$ref" in prop_value:
prop_value["$ref"] = prop_value["$ref"].replace(
"#/definitions", "#/$functions/KGSchema/definitions"
)
# Add title and description
parameters["title"] = "KGSchema"
parameters["description"] = "Knowledge Graph Schema."
return parameters
3) After making these updates, the response generation works but we get one last issue in unpacking the response in huggingface/text_generation_inference.py. The problem here is that we expect tool_call["parameters"] to exist. However for some reason it comes out as tool_call["arguments"]. So I update the function below to check for both:
def get_tool_calls_from_response(
self,
response: "AgentChatResponse",
error_on_no_tool_call: bool = True,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
print('response: ' , response)
print('tool_calls: ' , tool_calls)
print('message: ' , response.message)
if len(tool_calls) < 1:
if error_on_no_tool_call:
raise ValueError(
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
)
else:
return []
tool_selections = []
for tool_call in tool_calls:
# TODO Add typecheck with ToolCall from TGI once the client is updated
if tool_call and (tc_type := tool_call["type"]) != "function":
raise ValueError(
f"Invalid tool type: got {tc_type}, expect 'function'."
)
# Ensure the 'parameters' key exists
function_details = tool_call.get("function", {})
argument_dict = function_details.get("parameters", function_details.get("arguments"))
if argument_dict is None:
raise KeyError(f"'parameters' or 'arguments' key not found in tool_call['function']: {tool_call}")
tool_selections.append(
ToolSelection(
tool_id=tool_call["id"],
tool_name=tool_call["function"][
"name"
], # NOTE for now the tool_name is hardcoded 'tools' in TGI
tool_kwargs=argument_dict,
)
)
return tool_selections
Issue 1 looks like an obvious bug. Issues 2 and 3 are hacks to make this work but there's some underlying schema issue
Bug Description
``I'm using TGI (newest version 0.7.0) to host locally, and I'm trying to run the PropertyGraphIndex routines. SimpleLLMPathExtractor and ImplicitPathExtractor seem to be fine, but I'm having issues with SchemaLLMPathExtractor.
Issue 1) in llms/huggingface/TextGenerationInference.py on line 1044 in the achat_with_tools() function, the response is called like:
I believe there should be an "await" in front of that achat call. Without the "await" I end up with empty responses and the error shown below for issue 1. After adding the await I get past this point
Issue 2) after fixing the first issue, I end up with the issue 2 error shown below, which I think is related to the on the fly pydantic schema defined in the SchemaLLMPathExtractor? But maybe it's related to the TGI integration as well
Any idea what's going on here?
Version
0.10.45.post1 and for llama-index-llms-huggingface: 0.2.3
Steps to Reproduce
Relevant Logs/Tracbacks