Open jason571 opened 2 months ago
Feels like an issue with @tool
decorator being applied to a method potentially. If so it's a langchain-core issue.
But in the code snippet that you shared, why would this work:
result = manager.run_query(query)
I don't see run_query
defined anywhere?
class SQLToolManager: def init(self): self.interface = Interface() self.db_handler = SQLiteHandler() self.llm = self.interface.get_current_model() self.toolkit = SQLDatabaseToolkit(db=self.db_handler.get_sql_database(), llm=self.llm) self.tools = self.toolkit.get_tools() self.list_tables_tool = next(tool for tool in self.tools if tool.name == "sql_db_list_tables") self.get_schema_tool = next(tool for tool in self.tools if tool.name == "sql_db_schema")
def create_tool_node_with_fallback(self, tools: list) -> RunnableWithFallbacks[Any, dict]:
"""
Create a ToolNode with a fallback to handle errors and surface them to the agent.
"""
return ToolNode(tools).with_fallbacks(
[RunnableLambda(self.handle_tool_error)], exception_key="error"
)
def handle_tool_error(self, state) -> dict:
error = state.get("error")
tool_calls = state["messages"][-1].tool_calls
return {
"messages": [
ToolMessage(
content=f"Error: {repr(error)}\n please fix your mistakes.",
tool_call_id=tc["id"],
)
for tc in tool_calls
]
}
@tool
def db_query_tool(self, query: str) -> str:
"""
Execute a SQL query against the database and get back the result.
If the query is not correct, an error message will be returned.
"""
mylogging.info(f"Executing query: {query}")
result = self.db_handler.db.run_no_throw(query)
if not result:
return "Error: Query failed. Please rewrite your query and try again."
return result
def create_query_check(self):
query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
You will call the appropriate tool to execute the query after running this check."""
query_check_prompt = ChatPromptTemplate.from_messages(
[("system", query_check_system), ("placeholder", "{messages}")]
)
return query_check_prompt | self.llm.bind_tools([self.db_query_tool], tool_choice="required")
def list_tables(self):
return self.list_tables_tool.invoke("")
def get_schema(self, table_name: str):
return self.get_schema_tool.invoke(table_name)
def run_query(self, query: str):
return self.db_query_tool.invoke({"query": query})
def check_and_run_query(self, query: str):
return self.query_check.invoke({"messages": [("user", query)]})
if name == "main":
manager = SQLToolManager()
print(manager.list_tables())
print(manager.get_schema_tool.invoke("Artist"))
# db_query_tool
input_query = "SELECT * FROM Artist LIMIT 10;"
result = manager.db_query_tool.invoke({"query": input_query})
print(result)
result = manager.run_query(input_query)
print(result)
# 检查并运行查询
result = manager.check_and_run_query(input_query)
print(result)
File "/mnt/c/workspace/pr_train/LLMs/src/sqlAgent/sqlTools.py", line 145, in <module>
result = manager.db_query_tool.invoke({"query": input_query})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 489, in invoke return self.run(tool_input, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 692, in run raise error_to_raise File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 655, in run tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 578, in _to_args_and_kwargs tool_input = self._parse_input(tool_input) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 520, in _parse_input result = input_args.model_validate(tool_input) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/pydantic/main.py", line 595, in model_validate return cls.__pydantic_validator__.validate_python( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pydantic_core._pydantic_core.ValidationError: 1 validation error for db_query_tool self Field required [type=missing, input_value={'query': 'SELECT FROM Artist LIMIT 10;'}, input_type=dict] For further information visit https://errors.pydantic.dev/2.9/v/missing
OS: Linux OS Version: #3672-Microsoft Fri Jan 01 08:00:00 PST 2016 Python Version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]
langchain_core: 0.3.0 langchain: 0.3.0 langchain_community: 0.3.0 langsmith: 0.1.128 langchain_cli: 0.0.31 langchain_cohere: 0.3.0 langchain_experimental: 0.3.0 langchain_google_community: 2.0.0 langchain_huggingface: 0.1.0 langchain_milvus: 0.1.5 langchain_openai: 0.2.0 langchain_text_splitters: 0.3.0 langgraph: 0.2.22 langserve: 0.2.2
input_query = "SELECT FROM Artist LIMIT 10;" result = manager.db_query_tool({"query": input_query}) File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/pydantic/main.py", line 595, in model_validate return cls.__pydantic_validator__.validate_python( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pydantic_core._pydantic_core.ValidationError: 1 validation error for db_query_tool self Field required [type=missing, input_value={'query': 'SELECT FROM Artist LIMIT 10;'}, input_type=dict]
input_query = "SELECT * FROM Artist LIMIT 10;"
result = manager.db_query_tool(input_query)
#result = manager.db_query_tool.invoke({"query": input_query})
print(result)
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/langchain_core/tools/base.py", line 513, in _parse_input
input_args.model_validate({key_: tool_input})
File "/home/flyang/anaconda3/envs/LLMs/lib/python3.11/site-packages/pydantic/main.py", line 595, in model_validate return cls.__pydantic_validator__.validate_python( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pydantic_core._pydantic_core.ValidationError: 1 validation error for db_query_tool query Field required [type=missing, input_value={'self': 'SELECT * FROM Artist LIMIT 10;'}, input_type=dict] For further information visit https://errors.pydantic.dev/2.9/v/missing
@tool
decorator not supported directly on methods right now (parent not bound at time of decoration)
I believe something like this works:
@property
def db_query_tool(self):
@tool
def query_db(query: str):
"""
Execute a SQL query against the database and get back the result.
If the query is not correct, an error message will be returned.
"""
mylogging.info(f"Executing query: {query}")
result = self.db_handler.db.run_no_throw(query)
if not result:
return "Error: Query failed. Please rewrite your query and try again."
return result
return query_db
@jason571 did the suggestion above resolve your issue?
Checked other resources
Example Code
Error Message and Stack Trace (if applicable)
Description
Example Code https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/sql-agent.ipynb
System Info
System Information
Package Information
Other Dependencies