stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
16.86k stars 1.3k forks source link

ReAct calls custom tools twice #804

Closed timothyvinzent closed 5 months ago

timothyvinzent commented 5 months ago

Hi all, I know ReAct is due for some improvement as mentioned in previous issues. #748 #703

However, I've built a ReAct agent for process mining QA with two custom tools at its disposal (text-to-SQL for retrieval and executing python code to enrich/add new columns to the DB). Running on OAI gpt3.5 in chat mode. Everytime it calls any of the tools in Action step, regardless of the iteration it will call the tool twice, which is kind of a problem when it has already executed the python code to add a column to the DB. Is this a bug in ReAct behavior, or have I just implemented it wrong? Since none of the documentation really goes into this use-case I wanted to ask and share.

Code for the ReAct Agent:

class React_logic(dspy.Signature):
    """Answer the question and provide the observation in the answer. If the columns contain the information required to answer the question, use the sql_tool by providing the question. If the columns do not contain the information required to answer the question, use the create_new_column by providing detailed instructions to create a new column."""
    question = dspy.InputField()
    column_description = dspy.InputField()
    answer = dspy.OutputField()

class PM_JOINT_v4(dspy.Module):
    def __init__(self, conn):
        super().__init__()

        self.column_description = """The database contains only one table event_log... """
        Retr = Retrieve(conn = conn, column_description = self.column_description)
        py = create_new_column(column_description = self.column_description)

        self.logic = dspy.ReAct(signature = React_logic, tools = [Retr, py])

    def forward(self, question):
        result = self.logic(question = question, column_description = self.column_description)

        return result

Code for text-to-SQL tool:

  class CodeOutput(BaseModel):
      sql : str

  class SQL_format(dspy.Signature):
      """Based on a question, generates an SQL LITE query to answer it."""
      column_description = dspy.InputField(desc="Information about the database and its tables")
      question = dspy.InputField(desc="Question which this function transforms into an SQL query")
      answer: CodeOutput = dspy.OutputField(desc="SQL LITE query which will be executed to answer the question.")

  class Retrieve:
      name = "Retrieve"
      input_variable = ["Question", "Columns to be used in query"]
      desc = "Takes as input a question and a detailed column description and provides back infromation in a pipe-delimited format."

    def __init__(self, conn, column_description, max_length = 2000):
        self.conn =  conn
        self.max_length = max_length
        self.generated_query = dspy.TypedChainOfThought(SQL_format )
        self.column_description = column_description

    def __call__(self, input_variable, *args, **kwargs):
        """
        Generates an SQL query and returns the results in a pipe-delimited format.
        Limits the output to a specified number of characters.
        """
        var_lst = ast.literal_eval(input_variable)
        print(f"From SQL: input variable received: {var_lst}, type: {type(var_lst)}")
        self.column_description += "\n" + var_lst[1]
        col_desc = self.column_description + "\n" + "Columns to be used in query" + var_lst[1]
        result = self.generated_query(question = var_lst[0], column_description = col_desc)
        query = result.answer.sql
        cur = self.conn.cursor()
        result = ""
        print(f"From SQL: Query that will be executed: {query}")
        try:
            cur.execute(query)
            column_names = [description[0] for description in cur.description]
            header = "|".join(column_names)

            # Initialize result with the header and account for its length.
            result = header
            current_length = len(result)

            for row in cur.fetchall():
                row_data = "|".join([str(cell) for cell in row])
                if current_length + len(row_data) + 1 > self.max_length:
                    break  # Keep within the max_length limit.
                result += "\n" + row_data
                current_length += len(row_data) + 1

        except sqlite3.Error as e:
            result = "Error executing query: " + str(e)
            print(f"From SQL: result: {result}")

        cur.close()

        return result

Code for python code exec:



class PythonCode(BaseModel):
    python : str

class Generate(dspy.Signature):
    """Based on an instruction, generate Python code to execute instruction"""
    column_description = dspy.InputField(desc="Information about the database and its tables")
    instruction = dspy.InputField()
    generated_code: PythonCode = dspy.OutputField(desc="Python code which will be executed to fulfill instruction")

class Regenerate(dspy.Signature):
    """You will be given previous_generated_code and error_message due to an error in previous code. Your task is to correct the error and provide the entire new 'generated_code'. Not just the corrected snippet"""
    previous_generated_code = dspy.InputField()
    error_message = dspy.InputField()
    generated_code: PythonCode = dspy.OutputField(desc="Python code which will be executed to fulfill instruction")

class Answer(dspy.Signature):
    """Given the final generated_code, the instruction and the final code_ouput, provide a description of the new column added to the dataframe"""
    generated_code = dspy.InputField()
    instruction = dspy.InputField()
    #code_output = dspy.InputField()
    description = dspy.OutputField(desc="Description of the new column added to the dataframe")

class create_new_column:
    name = "create_new_column"
    input_variable = ["Instruction", "column_description"]
    desc = """Creates a new column in the dataframe based on detailed instructions provided
      in variable question, provide lots of details. Updates the database with the new column 
      and provides back a description of the new column."""

    def __init__(self, column_description):
        self.db_path = "my_database.db"
        self.GENERATE = dspy.TypedChainOfThought(Generate)
        self.REGENERATE = dspy.TypedChainOfThought(Regenerate)
        self.ANSWER = dspy.ChainOfThought(Answer)
        self.read_write = """ The following lines of code are required to read from the database and write to the database:
        import sqlite3
        import pandas as pd
        # import additional libraries if needed

        # connect to the SQLite database
        conn = sqlite3.connect("my_database.db")
        query = "SELECT * FROM event_log"
        df = pd.read_sql_query(query, conn)
        print(df.head())

        # Perform operation on the DataFrame such as creating a new column based on insturctions

        # operation is performed we update the database with the entire DataFrame you worked on
        df.to_sql("event_log", conn, if_exists="replace", index=False)
        conn.close()"""
        self.column_description = column_description
        self.counter = 0
        self.num_errors = 0
        self.num_gen = 0
        self.num_reg = 0
        self.num_ans = 0

    def __call__(self, input_variable, *args, **kwargs):
        """Generates Python code to generate a new column in the database and returns a description of the column"""

        var_lst = ast.literal_eval(input_variable)
        print(f"FROM CNC: input variable received: {var_lst}, type: {type(var_lst)}")

        instruct = self.read_write + "\n" + var_lst[0]

        col_description = self.column_description + "\n" + "column to be created:" + var_lst[1]

        result = self.GENERATE(instruction = instruct, column_description = col_description)
        print(f"From CNC: result has been generated")
        self.num_gen += 1
        python_code = result.generated_code.python
        self.num_errors = 0

        try:
            code_output = exec(python_code)

        except Exception as e:
            self.num_errors += 1
            result = self.REGENERATE(previous_generated_code = python_code, error_message = str(e))
            self.num_reg += 1
            python_code = result.generated_code.python
            try:
                code_output = exec(python_code)
            except Exception as e:
                self.num_errors += 1
                result = self.REGENERATE(previous_generated_code = python_code, error_message = str(e))
                self.num_reg += 1
                python_code = result.generated_code.python
                try:
                    code_output = exec(python_code)
                except Exception as e:
                    self.num_errors += 1
                    result = self.REGENERATE(previous_generated_code = python_code, error_message = str(e))
                    self.num_reg += 1
                    python_code = result.generated_code.python
                    try:
                        code_output = exec(python_code)
                    except Exception as e:
                        return "Error: " + str(e) + "function failed to execute"

        result = self.ANSWER(generated_code = python_code, instruction = var_lst[0])
        self.num_ans += 1
        print(f"FROM CNC: Code Output: {code_output}, Code: {python_code}")
        self.counter += 1
        print(f"FROM CNC: RUN COMPLETED: {self.counter} with {self.num_errors} errors")
        print(f"FROM CNC: Number of Generate: {self.num_gen}, Number of Regenerate: {self.num_reg}, Number of Answer: {self.num_ans}")
        return result.description```

Any help, feedback would greatly be appreciated. Otherwise I will create my own implementation of React. 
okhat commented 5 months ago

Hmm it calls each tool twice. I think I can guess the cause. Great catch!

https://github.com/stanfordnlp/dspy/blob/1a56e69465ef14ab9b7554184e99712766dd77dd/dspy/predict/react.py#L94

See this code block. It has try/except, but the tool is called in try and in except. We should move the call outside the try/except block.