langchain-ai / langchain

🦜🔗 Build context-aware reasoning applications
https://python.langchain.com
MIT License
91.93k stars 14.62k forks source link

SQL Agent extracts the table name with \n linebreaker and next line word 'Observation' #23585

Open kbatsuren opened 2 months ago

kbatsuren commented 2 months ago

Checked other resources

Example Code

from langchain_community.agent_toolkits import create_sql_agent from langchain.agents.agent_toolkits import SQLDatabaseToolkit

llm = HuggingFaceEndpoint( endpoint_url="endpoint_url", max_new_tokens=512, top_k=10, top_p=0.95, typical_p=0.95, temperature=0.01, repetition_penalty=1.03, )

db = SQLDatabase.from_uri("sqlite:///Chinook.db?isolation_level=IMMEDIATE")

toolkit = SQLDatabaseToolkit(db=db,llm=llm)

agent_executor = create_asql_agent( llm=llm, toolkit=toolkit, verbose=True, agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION ) agent_executor.invoke( "How many genres are there?" )

Error Message and Stack Trace (if applicable)

Entering new SQL Agent Executor chain... I need to know the table that contains the genres. Action: sql_db_list_tables Action Input: ObservationAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track Now I know the table that contains the genres is Genre. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made a mistake, I should remove the Observation part. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets and the single quotes. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets and the single quotes and the \n. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets and the single quotes and the \n and the space. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets and the single quotes and the \n and the space and the error message. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets and the single quotes and the \n and the space and the error message and the table_names. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets and the single quotes and the \n and the space and the error message and the table_names and the Observation. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets and the single quotes and the \n and the space and the error message and the table_names and the Observation and the Error. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets and the single quotes and the \n and the space and the error message and the table_names and the Observation and the Error and the colon. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database I made another mistake, I should remove the newline character and the Observation part and the curly brackets and the single quotes and the \n and the space and the error message and the table_names and the Observation and the Error and the colon and the table_names. Action: sql_db_schema Action Input: Genre ObservationError: table_names {'Genre\nObservation'} not found in database

Finished chain. {'input': 'How many genres are there?', 'output': 'Agent stopped due to iteration limit or time limit.'}

Description

SQL Agent extracts the table name with \n line breaker and next line word 'Observation' as can be seen as 'Genre\nObservation'

System Info

System Information

OS: Linux OS Version: #1 SMP PREEMPT_DYNAMIC Fri May 24 14:06:39 UTC 2024 Python Version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]

Package Information

langchain_core: 0.2.10 langchain: 0.2.6 langchain_community: 0.2.6 langsmith: 0.1.82 langchain_experimental: 0.0.62 langchain_huggingface: 0.0.3 langchain_mistralai: 0.1.8 langchain_openai: 0.1.10 langchain_text_splitters: 0.2.2

Packages not installed (Not Necessarily a Problem)

The following packages were not found:

langgraph langserve

kbatsuren commented 2 months ago

A similar issue was asked four months ago in this discussion https://github.com/langchain-ai/langchain/discussions/17945 , but it got no answer.

keenborder786 commented 2 months ago

okay, one issue was that when using ZERO_SHOT_REACT agent type the correct prompts were not being used which has been fixed in above PR. However, I noticed that despite this (but that PR is still needed since correct prompts should be the one relevant to SQL), the agent was attaching \nObservation in Action Input. After much debugging and playing with different models, the only solution I was able to come up was to replace \nObservation in _run method by creating custom tools:


from __future__ import annotations

from typing import (
    Any,
    Dict,
    List,
    Optional,
    Sequence,
    Type,
    Union,
)

from langchain.agents import create_react_agent
from langchain.agents.agent import AgentExecutor, RunnableAgent
from langchain.agents.agent_types import AgentType
from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    BaseCallbackManager,
    CallbackManagerForToolRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.tools import BaseTool
from langchain_community.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
)
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy.engine import Result

class BaseSQLDatabaseTool(BaseModel):
    """Base tool for interacting with a SQL database."""

    db: SQLDatabase = Field(exclude=True)

    class Config(BaseTool.Config):
        pass

class _QuerySQLDataBaseToolInput(BaseModel):
    query: str = Field(..., description="A detailed and correct SQL query.")

class CustomQuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
    """Tool for querying a SQL database."""

    name: str = "sql_db_query"
    description: str = """
    Execute a SQL query against the database and get back the result..
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput

    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> Union[str, Sequence[Dict[str, Any]], Result]:
        """Execute the query, return the results or an error message."""
        query = query.replace('\nObservation','') #THIS changed
        return self.db.run_no_throw(query)

class _InfoSQLDatabaseToolInput(BaseModel):
    table_names: str = Field(
        ...,
        description=(
            "A comma-separated list of the table names for which to return the schema. "
            "Example input: 'table1, table2, table3'"
        ),
    )

class CustomInfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
    """Tool for getting metadata about a SQL database."""

    name: str = "sql_db_schema"
    description: str = "Get the schema and sample rows for the specified SQL tables."
    args_schema: Type[BaseModel] = _InfoSQLDatabaseToolInput

    def _run(
        self,
        table_names: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Get the schema for tables in a comma-separated list."""
        table_names = table_names.replace("\nObservation", "") # this changed
        return self.db.get_table_info_no_throw(
            [t.strip() for t in table_names.split(",")]
        )

class _ListSQLDataBaseToolInput(BaseModel):
    tool_input: str = Field("", description="An empty string")

class CustomListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
    """Tool for getting tables names."""

    name: str = "sql_db_list_tables"
    description: str = "Input is an empty string, output is a comma-separated list of tables in the database."
    args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput

    def _run(
        self,
        tool_input: str = "",
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Get a comma-separated list of table names."""
        return ", ".join(self.db.get_usable_table_names())

class _QuerySQLCheckerToolInput(BaseModel):
    query: str = Field(..., description="A detailed and SQL query to be checked.")

class CustomQuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
    """Use an LLM to check if a query is correct.
    Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""

    template: str = QUERY_CHECKER
    llm: BaseLanguageModel
    llm_chain: Any = Field(init=False)
    name: str = "sql_db_query_checker"
    description: str = """
    Use this tool to double check if your query is correct before executing it.
    Always use this tool before executing a query with sql_db_query!
    """
    args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput

    @root_validator(pre=True)
    def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if "llm_chain" not in values:
            from langchain.chains.llm import LLMChain

            values["llm_chain"] = LLMChain(
                llm=values.get("llm"),  # type: ignore[arg-type]
                prompt=PromptTemplate(
                    template=QUERY_CHECKER, input_variables=["dialect", "query"]
                ),
            )

        if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
            raise ValueError(
                "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
            )

        return values

    def _run(
        self,
        query: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Use the LLM to check the query."""
        return self.llm_chain.predict(
            query=query,
            dialect=self.db.dialect,
            callbacks=run_manager.get_child() if run_manager else None,
        )

    async def _arun(
        self,
        query: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        return await self.llm_chain.apredict(
            query=query,
            dialect=self.db.dialect,
            callbacks=run_manager.get_child() if run_manager else None,
        )
class CustomSQLDatabaseToolkit(SQLDatabaseToolkit):

    def get_tools(self) -> List[BaseTool]:
        """Get the tools in the toolkit."""
        list_sql_database_tool = CustomListSQLDatabaseTool(db=self.db)
        info_sql_database_tool_description = (
            "Input to this tool is a comma-separated list of tables, output is the "
            "schema and sample rows for those tables. "
            "Be sure that the tables actually exist by calling "
            f"{list_sql_database_tool.name} first! "
            "Example Input: table1, table2, table3"
        )
        info_sql_database_tool = CustomInfoSQLDatabaseTool(
            db=self.db, description=info_sql_database_tool_description
        )
        query_sql_database_tool_description = (
            "Input to this tool is a detailed and correct SQL query, output is a "
            "result from the database. If the query is not correct, an error message "
            "will be returned. If an error is returned, rewrite the query, check the "
            "query, and try again. If you encounter an issue with Unknown column "
            f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
            "to query the correct table fields."
        )
        query_sql_database_tool = CustomQuerySQLDataBaseTool(
            db=self.db, description=query_sql_database_tool_description
        )
        query_sql_checker_tool_description = (
            "Use this tool to double check if your query is correct before executing "
            "it. Always use this tool before executing a query with "
            f"{query_sql_database_tool.name}!"
        )
        query_sql_checker_tool = CustomQuerySQLCheckerTool(
            db=self.db, llm=self.llm, description=query_sql_checker_tool_description
        )
        return [
            query_sql_database_tool,
            info_sql_database_tool,
            list_sql_database_tool,
            query_sql_checker_tool,
        ]

def custom_create_sql_agent(
    llm: BaseLanguageModel,
    toolkit: Optional[SQLDatabaseToolkit] = None,
    callback_manager: Optional[BaseCallbackManager] = None,
    prefix: Optional[str] = None,
    suffix: Optional[str] = None,
    format_instructions: Optional[str] = None,
    top_k: int = 10,
    max_iterations: Optional[int] = 15,
    max_execution_time: Optional[float] = None,
    early_stopping_method: str = "force",
    verbose: bool = False,
    agent_executor_kwargs: Optional[Dict[str, Any]] = None,
    *,
    db: Optional[SQLDatabase] = None,
    prompt: Optional[BasePromptTemplate] = None,
    **kwargs: Any,
) -> AgentExecutor:
    """
    """  # noqa: E501

    tools = toolkit.get_tools()
    if prompt is None:
        prefix = prefix or SQL_PREFIX
        prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
    else:
        if "top_k" in prompt.input_variables:
            prompt = prompt.partial(top_k=str(top_k))
        if "dialect" in prompt.input_variables:
            prompt = prompt.partial(dialect=toolkit.dialect)
        if any(key in prompt.input_variables for key in ["table_info", "table_names"]):
            db_context = toolkit.get_context()
            if "table_info" in prompt.input_variables:
                prompt = prompt.partial(table_info=db_context["table_info"])
                tools = [
                    tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
                ]
            if "table_names" in prompt.input_variables:
                prompt = prompt.partial(table_names=db_context["table_names"])
                tools = [
                    tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
                ]

    if prompt is None:
        from langchain.agents.mrkl import prompt as react_prompt

        format_instructions = (
            format_instructions or react_prompt.FORMAT_INSTRUCTIONS
        )
        template = "\n\n".join(
            [
                prefix,
                "{tools}",
                format_instructions,
                suffix or SQL_SUFFIX,
            ]
        )
        prompt = PromptTemplate.from_template(template)
    agent = RunnableAgent(
        runnable=create_react_agent(llm, tools, prompt),
        input_keys_arg=["input"],
        return_keys_arg=["output"],
        **kwargs,
    )

    return AgentExecutor(
        name="SQL Agent Executor",
        agent=agent,
        tools=tools,
        callback_manager=callback_manager,
        verbose=verbose,
        max_iterations=max_iterations,
        max_execution_time=max_execution_time,
        early_stopping_method=early_stopping_method,
        handle_parsing_errors=True,
        **(agent_executor_kwargs or {}),
    )

llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-v0.1",
huggingfacehub_api_token='',
max_new_tokens=512,
top_k=10,
top_p=0.95,
typical_p=0.95,
temperature=0.01,
repetition_penalty=1.03,
)

db = SQLDatabase.from_uri("sqlite:///Chinook.db?isolation_level=IMMEDIATE")

agent_executor = custom_create_sql_agent(
llm=llm,
verbose=True,
toolkit=CustomSQLDatabaseToolkit(llm=llm, db=db),
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION
)
agent_executor.invoke(
"How many genres are there?"
)

Most of the code is taken from langchain itself but creating a custom tool which replaces Observation seems to work.

PS: I tried LLAMA3 and it seems to work fine in that case. The problem of corrupted \Observation in Tool Input is with weaker LLM models.