langgenius / dify

Dify is an open-source LLM app development platform. Dify's intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production.
https://dify.ai
Other
44.44k stars 6.22k forks source link

The total_price in the multi-iteration agent run's usage is incorrect. #6687

Closed PedroGomes02 closed 1 month ago

PedroGomes02 commented 1 month ago

Self Checks

Dify version

0.6.15

Cloud or Self Hosted

Self Hosted (Docker)

Steps to reproduce

In a multi-iteration agent run, when I check the agent run's final usage, the total_price only reflects the first iteration made.

✔️ Expected Behavior

I expect that the total_price reflects all the iterations.

❌ Actual Behavior

The total_price only reflects the first iteration made.

dosubot[bot] commented 1 month ago

@PedroGomes02 Hello there! I'm here to help you with any bugs, questions, or contributions you might have for the repository. Let me know how I can assist you!

The issue with the total_price only reflecting the cost of the first iteration in a multi-iteration agent run is due to the llm_usage dictionary being re-initialized in each iteration. This causes the usage from previous iterations to be lost.

To fix this, you should move the initialization of llm_usage outside of the loop so that it retains the accumulated usage across all iterations. Here is the corrected code snippet:

llm_usage = {
    'usage': None
}

while function_call_state and iteration_step <= max_iteration_steps:
    # continue to run until there is not any tool call
    function_call_state = False

    if iteration_step == max_iteration_steps:
        # the last iteration, remove all tools
        self._prompt_messages_tools = []

    message_file_ids = []

    agent_thought = self.create_agent_thought(
        message_id=message.id,
        message='',
        tool_name='',
        tool_input='',
        messages_ids=message_file_ids
    )

    if iteration_step > 1:
        self.queue_manager.publish(QueueAgentThoughtEvent(
            agent_thought_id=agent_thought.id
        ), PublishFrom.APPLICATION_MANAGER)

    # recalc llm max tokens
    prompt_messages = self._organize_prompt_messages()
    self.recalc_llm_max_tokens(self.model_config, prompt_messages)
    # invoke model
    chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
        prompt_messages=prompt_messages,
        model_parameters=app_generate_entity.model_conf.parameters,
        tools=[],
        stop=app_generate_entity.model_conf.stop,
        stream=True,
        user=self.user_id,
        callbacks=[],
    )

    # check llm result
    if not chunks:
        raise ValueError("failed to invoke llm")

    usage_dict = {}
    react_chunks = CotAgentOutputParser.handle_react_stream_output(
        chunks, usage_dict)
    scratchpad = AgentScratchpadUnit(
        agent_response='',
        thought='',
        action_str='',
        observation='',
        action=None,
    )

    # publish agent thought if it's first iteration
    if iteration_step == 1:
        self.queue_manager.publish(QueueAgentThoughtEvent(
            agent_thought_id=agent_thought.id
        ), PublishFrom.APPLICATION_MANAGER)

    for chunk in react_chunks:
        if isinstance(chunk, AgentScratchpadUnit.Action):
            action = chunk
            # detect action
            scratchpad.agent_response += json.dumps(chunk.model_dump())
            scratchpad.action_str = json.dumps(chunk.model_dump())
            scratchpad.action = action
        else:
            scratchpad.agent_response += chunk
            scratchpad.thought += chunk
            yield LLMResultChunk(
                model=self.model_config.model,
                prompt_messages=prompt_messages,
                system_fingerprint='',
                delta=LLMResultChunkDelta(
                    index=0,
                    message=AssistantPromptMessage(
                        content=chunk
                    ),
                    usage=None
                )
            )

    scratchpad.thought = scratchpad.thought.strip(
    ) or 'I am thinking about how to help you'
    self._agent_scratchpad.append(scratchpad)
    # get llm usage
    if 'usage' in usage_dict:
        increase_usage(llm_usage, usage_dict['usage'])
    else:
        usage_dict['usage'] = LLMUsage.empty_usage()

    self.save_agent_thought(
        agent_thought=agent_thought,
        tool_name=scratchpad.action.action_name if scratchpad.action else '',
        tool_input={
            scratchpad.action.action_name: scratchpad.action.action_input
        } if scratchpad.action else {},
        tool_invoke_meta={},
        thought=scratchpad.thought,
        observation='',
        answer=scratchpad.agent_response,
        messages_ids=[],
        llm_usage=usage_dict['usage']
    )

    if not scratchpad.is_final():
        self.queue_manager.publish(QueueAgentThoughtEvent(
            agent_thought_id=agent_thought.id
        ), PublishFrom.APPLICATION_MANAGER)

    if not scratchpad.action:
        # failed to extract action, return final answer directly
        final_answer = ''
    else:
        if scratchpad.action.action_name.lower() == "final answer":
            # action is final answer, return final answer directly
            try:
                if isinstance(scratchpad.action.action_input, dict):
                    final_answer = json.dumps(
                        scratchpad.action.action_input)
                elif isinstance(scratchpad.action.action_input, str):
                    final_answer = scratchpad.action.action_input
                else:
                    final_answer = f'{scratchpad.action.action_input}'
            except json.JSONDecodeError:
                final_answer = f'{scratchpad.action.action_input}'
        else:
            function_call_state = True
            # action is tool call, invoke tool
            tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
                action=scratchpad.action,
                tool_instances=tool_instances,
                message_file_ids=message_file_ids,
                trace_manager=trace_manager,
            )
            scratchpad.observation = tool_invoke_response
            scratchpad.agent_response = tool_invoke_response
            self.save_agent_thought(
                agent_thought=agent_thought,
                tool_name=scratchpad.action.action_name,
                tool_input={
                    scratchpad.action.action_name: scratchpad.action.action_input},
                thought=scratchpad.thought,
                observation={
                    scratchpad.action.action_name: tool_invoke_response},
                tool_invoke_meta={
                    scratchpad.action.action_name: tool_invoke_meta.to_dict()},
                answer=scratchpad.agent_response,
                messages_ids=message_file_ids,
                llm_usage=usage_dict['usage']
            )

            self.queue_manager.publish(QueueAgentThoughtEvent(
                agent_thought_id=agent_thought.id
            ), PublishFrom.APPLICATION_MANAGER)

        # update prompt tool message
        for prompt_tool in self._prompt_messages_tools:
            self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)

    iteration_step += 1

yield LLMResultChunk(
    model=model_instance.model,
    prompt_messages=prompt_messages,
    delta=LLMResultChunkDelta(
        index=0,
        message=AssistantPromptMessage(
            content=final_answer
        ),
        usage=llm_usage['usage']
    ),
    system_fingerprint=''
)

# save agent thought
self.save_agent_thought(
    agent_thought=agent_thought,
    tool_name='',
    tool_input={},
    tool_invoke_meta={},
    thought=final_answer,
    observation={},
    answer=final_answer,
    messages_ids=[]
)

This change ensures that the total_price reflects the accumulated usage across all iterations [1].

To continue talking to Dosu, mention @dosu.

PedroGomes02 commented 1 month ago

I have open a PR for this issue. https://github.com/langgenius/dify/pull/6688