run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
34.67k stars 4.9k forks source link

[Bug]: ReActAgentWorker._arun_step_stream drops tokens that arrive before is_done=True #14307

Open Tshimanga opened 2 months ago

Tshimanga commented 2 months ago

Bug Description

In ReActAgentWorker._arun_step_stream, any tokens generated before the is_done condition triggers are dropped.

Version

0.10.38 thru 0.10.46

Steps to Reproduce

from llama_index.core.agent import ReActAgent
from llama_index.llms.openai import OpenAI

llm = OpenAI(
    temperature=0,
    api_key="...",
    model="gpt-4-turbo",
    streaming=True,
)
agent = ReActAgent.from_tools(llm=llm, tools=[])

chat_stream = await agent.astream_chat("Write a couple sentences about the beach.")
async for token in chat_stream.async_response_gen():
    print(token)

Relevant Logs/Tracbacks

No response

dosubot[bot] commented 2 months ago

The ReActAgentWorker._arun_step_stream method is designed to drop tokens that arrive before the is_done=True condition triggers to ensure that only the final response is processed and returned. This behavior is intentional and is implemented to break out of the token stream loop as soon as the final answer is detected:

async for latest_chunk in chat_stream:
    full_response = latest_chunk
    is_done = self._infer_stream_chunk_is_final(latest_chunk)
    if is_done:
        break

In this loop, the method iterates over the chat_stream and updates full_response with the latest chunk. It then checks if the chunk is the final one by calling _infer_stream_chunk_is_final(latest_chunk). If is_done is set to True, the loop breaks, effectively stopping the processing of any further tokens that arrive after this point [1].

If you want to include intermediate tokens in the output, you can set the verbose parameter to True when initializing ReActAgentWorker. This will enable the inclusion of intermediate tokens in the output [1][2].

To continue talking to Dosu, mention @dosu.

Tshimanga commented 2 months ago

to ensure that only the final response is processed and returned

@dosu I don't think the code is functioning as intended since the tokens that are getting dropped are part of the actual final response. For instance, "The beach is a nice place to tan. Just make sure to put sunscreen" gets truncated to " beach is a nice place to tan. Just make sure to put sunscreen"

dosubot[bot] commented 2 months ago

To ensure that the entire final response is captured without truncation when using the ReActAgent with streaming in versions 0.10.38 through 0.10.46, you should use the is_streaming parameter in the parse method of the ReActOutputParser class. This parameter is designed to handle streaming outputs correctly.

Here is the relevant part of the code that shows how the is_streaming parameter is used:

class ReActOutputParser(BaseOutputParser):
    """ReAct Output parser."""

    def parse(self, output: str, is_streaming: bool = False) -> BaseReasoningStep:
        """Parse output from ReAct agent.

        We expect the output to be in one of the following formats:
        1. If the agent need to use a tool to answer the question:
        Thought: <thought>
        Action: <action>
        Action Input: <action_input>
        ```
    2. If the agent can answer the question without any tools:
        ```
        Thought: <thought>
        Answer: <answer>
        ```
    """
    if "Thought:" not in output:
        # NOTE: handle the case where the agent directly outputs the answer
        # instead of following the thought-answer format
        return ResponseReasoningStep(
            thought="(Implicit) I can answer without any more tools!",
            response=output,
            is_streaming=is_streaming,
        )

    if "Answer:" in output:
        thought, answer = extract_final_response(output)
        return ResponseReasoningStep(
            thought=thought, response=answer, is_streaming=is_streaming
        )

    if "Action:" in output:
        return parse_action_reasoning_step(output)

    raise ValueError(f"Could not parse output: {output}")

By setting `is_streaming=True` when calling the `parse` method, you ensure that the parser handles the streaming output correctly and captures the entire final response without truncation <sup>[[1]](https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/agent/react/output_parser.py)</sup><sup>[[2]](https://github.com/run-llama/llama_index/blob/main/llama-index-legacy/llama_index/legacy/agent/react/output_parser.py)</sup>.

Additionally, you can refer to the test cases provided in the `llama-index-legacy/tests/agent/react/test_react_agent.py` file to see how streaming is handled in practice. For example, the `test_stream_chat_basic` function demonstrates how to use the `ReActAgent` with a mock streaming LLM to ensure the entire response is captured:

```python
def test_stream_chat_basic(
    add_tool: FunctionTool,
) -> None:
    mock_llm = MockStreamChatLLM(
        responses=[
            ChatMessage(
                content=MOCK_ACTION_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
            ChatMessage(
                content=MOCK_STREAM_FINAL_RESPONSE,
                role=MessageRole.ASSISTANT,
            ),
        ]
    )

    agent = ReActAgent.from_tools(
        tools=[add_tool],
        llm=mock_llm,
    )
    response = agent.stream_chat("What is 1 + 1?")
    assert isinstance(response, StreamingAgentChatResponse)

    # exhaust stream
    for delta in response.response_gen:
        continue
    expected_answer = MOCK_STREAM_FINAL_RESPONSE.split("Answer: ")[-1].strip()
    assert response.response == expected_answer

    assert agent.chat_history == [
        ChatMessage(
            content="What is 1 + 1?",
            role=MessageRole.USER,
        ),
        ChatMessage(
            content="2 is the final answer.",
            role=MessageRole.ASSISTANT,
        ),
    ]

This test ensures that the entire final response is captured correctly when using streaming [3][4].

To continue talking to Dosu, mention @dosu.

Tshimanga commented 2 months ago

@dosu that parameter is already being set True when you call ReActAgent.astream_chat

dosubot[bot] commented 2 months ago

To ensure that the entire final response is captured without truncation when using the ReActAgentWorker._arun_step_stream method in versions 0.10.38 through 0.10.46 of the llama_index library, you need to iterate over the stream and check for the final answer using the _infer_stream_chunk_is_final method. Here is the relevant code snippet:

async def _arun_step_stream(
        self,
        step: TaskStep,
        task: Task,
    ) -> TaskStepOutput:
    """Run step."""
    if step.input is not None:
        add_user_step_to_reasoning(
            step,
            task.extra_state["new_memory"],
            task.extra_state["current_reasoning"],
            verbose=self._verbose,
        )
    tools = self.get_tools(task.input)

    input_chat = self._react_chat_formatter.format(
        tools,
        chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(),
        current_reasoning=task.extra_state["current_reasoning"],
    )

    chat_stream = await self._llm.astream_chat(input_chat)

    # iterate over stream, break out if is final answer after the "Answer: "
    full_response = ChatResponse(
        message=ChatMessage(content=None, role="assistant")
    )
    is_done = False
    async for latest_chunk in chat_stream:
        full_response = latest_chunk
        is_done = self._infer_stream_chunk_is_final(latest_chunk)
        if is_done:
            break

    if not is_done:
        reasoning_steps, is_done = await self._aprocess_actions(
            task, tools=tools, output=full_response, is_streaming=True
        )
        task.extra_state["current_reasoning"].extend(reasoning_steps)
        agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response(
            task.extra_state["current_reasoning"], task.extra_state["sources"]
        )
    else:
        response_stream = self._async_add_back_chunk_to_stream(
            chunk=latest_chunk, chat_stream=chat_stream
        )

In this code, the async for latest_chunk in chat_stream loop ensures that the entire response is captured by iterating over the stream until the final chunk is identified by the _infer_stream_chunk_is_final method. If the final chunk is not reached, the method continues processing actions and appending reasoning steps to ensure the response is complete. If the final chunk is reached, the response stream is updated to include the final chunk using the _async_add_back_chunk_to_stream method [1][2].

To continue talking to Dosu, mention @dosu.

logan-markewich commented 2 months ago

@Tshimanga don't we want to drop those tokens? The idea is the stream shouldn't include all the agents thoughts, just the final response

Tshimanga commented 2 months ago

@logan-markewich What you say makes sense, but it looks like I came to premature conclusions when debugging.

The issue I've been encountering is that too much of the response gets dropped. I'm ending up with sentences like (as a response to the reproducer above):

" beach is a magical place where golden sands stretch underfoot and the ocean whispers in a constant, soothing murmur. It's an ideal spot for families and friends to gather, share moments, and indulge in the simple joys of nature, from watching the sunset to collecting seashells along the shore."

Where the "The" got dropped.

When I was initially stepping through the debugger and testing fixes I didn't see any preceding text for the thought process in latest_token.message.content. That is, latest_token.message.content was "The beach" and not "Thought: ... Answer: The beach" in my local testing, but I've since seen the second case as well.

Tshimanga commented 2 months ago

@logan-markewich ah ok, I think I found the actual issue. In ReActAgentWorker._infer_stream_chunk_is_final there is the following condition:

.
.
.
            if len(latest_content) > len("Thought") and not latest_content.startswith(
                "Thought"
            ):
                return True
.
.
.

If the the LLM doesn't follow thought-action format, but the beginning of the response is a short token like "The" then ReActAgentWorker._infer_stream_chunk_is_final fails to recognize this first token and ends up skipping it.

Does this seems like more so the actual issue here?

logan-markewich commented 1 month ago

@Tshimanga I think so, and I think https://github.com/run-llama/llama_index/pull/14587 is an attempt at solving this