run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
35.41k stars 4.99k forks source link

[Bug]: token_counting.get_tokens_from_response dies trying to convert response.raw (None) to dict #15787

Open Shura1oplot opened 1 week ago

Shura1oplot commented 1 week ago

Bug Description

TokenCountingHandler dies trying to calculate token count (get_tokens_from_response) for the response produced by MockLLM.

MockLLM.complete produces CompletionResponse with only text parameter.

    @llm_completion_callback()
    def complete(
        self, prompt: str, formatted: bool = False, **kwargs: Any
    ) -> CompletionResponse:
        response_text = (
            self._generate_text(self.max_tokens) if self.max_tokens else prompt
        )

        return CompletionResponse(
            text=response_text,  # <------------------
        )

get_tokens_from_response tries to convert CompletionResponse.raw which is None to dict and dies:

def get_tokens_from_response(
    response: Union["CompletionResponse", "ChatResponse"]
) -> Tuple[int, int]:
    """Get the token counts from a raw response."""

    raw_response = response.raw
    if not isinstance(raw_response, dict):
        raw_response = dict(raw_response)  # <------------------

Possible fix: Replace raw_response = dict(raw_response) with

try:
    raw_response = dict(raw_response)
except (TypeError, ValueError):
    raw_response = {}

Version

0.11.3

Steps to Reproduce

llm = MockLLM(max_tokens=512) embed_model = MockEmbedding(embed_dim=3072)

query_engine = index.as_query_engine( llm=llm, embed_model=embed_model, text_qa_template=text_qa_template, # ChatPromptTemplate refine_template=refine_template) # ChatPromptTemplate

token_counter = TokenCountingHandler( tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode) query_engine.callback_manager.add_handler(token_counter)

query_engine.query(prompt)

print(token_counter.completion_llm_token_count)

Relevant Logs/Tracbacks

Traceback (most recent call last):
  File "...\Lib\site-packages\gradio\queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\gradio\route_utils.py", line 321, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\gradio\blocks.py", line 1935, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\gradio\blocks.py", line 1520, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\anyio\to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\anyio\_backends\_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
           ^^^^^^^^^^^^
  File "...\Lib\site-packages\anyio\_backends\_asyncio.py", line 859, in run
    result = context.run(func, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\gradio\utils.py", line 826, in wrapper
    response = f(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^
  File "C:\Users\s0meu\Dropbox\Wisdom\gradio_query.py", line 332, in index_query
    query_engine.query(prompt)
  File "...\Lib\site-packages\llama_index\core\instrumentation\dispatcher.py", line 261, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\base\base_query_engine.py", line 52, in query
    query_result = self._query(str_or_query_bundle)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\instrumentation\dispatcher.py", line 261, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\query_engine\retriever_query_engine.py", line 176, in _query
    response = self._response_synthesizer.synthesize(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\instrumentation\dispatcher.py", line 261, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\response_synthesizers\base.py", line 241, in synthesize
    response_str = self.get_response(
                   ^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\instrumentation\dispatcher.py", line 261, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\response_synthesizers\compact_and_refine.py", line 43, in get_response
    return super().get_response(
           ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\instrumentation\dispatcher.py", line 261, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\response_synthesizers\refine.py", line 172, in get_response
    response = self._give_response_single(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\response_synthesizers\refine.py", line 227, in _give_response_single
    program(
  File "...\Lib\site-packages\llama_index\core\instrumentation\dispatcher.py", line 261, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\response_synthesizers\refine.py", line 79, in __call__
    answer = self._llm.predict(
             ^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\instrumentation\dispatcher.py", line 261, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\llms\llm.py", line 579, in predict
    response = self.complete(formatted_prompt, formatted=True)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\instrumentation\dispatcher.py", line 261, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\llms\callbacks.py", line 431, in wrapped_llm_predict
    f_return_val = f(_self, *args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\s0meu\Dropbox\Wisdom\gradio_query.py", line 255, in complete
    completion_response = super().complete(*args, **kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\instrumentation\dispatcher.py", line 261, in wrapper
    result = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\llms\callbacks.py", line 485, in wrapped_llm_predict
    callback_manager.on_event_end(
  File "...\Lib\site-packages\llama_index\core\callbacks\base.py", line 136, in on_event_end
    handler.on_event_end(event_type, payload, event_id=event_id, **kwargs)
  File "...\Lib\site-packages\llama_index\core\callbacks\token_counting.py", line 206, in on_event_end
    get_llm_token_counts(
  File "...\Lib\site-packages\llama_index\core\callbacks\token_counting.py", line 87, in get_llm_token_counts
    prompt_tokens, completion_tokens = get_tokens_from_response(completion)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\Lib\site-packages\llama_index\core\callbacks\token_counting.py", line 46, in get_tokens_from_response
    raw_response = dict(raw_response)
                   ^^^^^^^^^^^^^^^^^^
TypeError: 'NoneType' object is not iterable
dosubot[bot] commented 1 week ago

Hello, @Shura1oplot! I'm here to help you with any bugs, questions, or contributions you have.

To fix the TypeError: 'NoneType' object is not iterable in token_counting.get_tokens_from_response when CompletionResponse.raw is None, you need to add a check to handle the case where CompletionResponse.raw is None. Here is an example of how you can modify the code to include this check:

def get_tokens_from_response(response: Union["CompletionResponse", "ChatResponse"]) -> Tuple[int, int]:
    """Get the token counts from a raw response."""
    raw_response = response.raw
    if raw_response is None:
        return 0, 0  # Handle the None case appropriately

    if not isinstance(raw_response, dict):
        raw_response = dict(raw_response)

    usage = raw_response.get("usage", {})
    if usage is None:
        usage = response.additional_kwargs

    if not usage:
        return 0, 0

    if not isinstance(usage, dict):
        usage = usage.model_dump()

    possible_input_keys = ("prompt_tokens", "input_tokens")
    possible_output_keys = ("completion_tokens", "output_tokens")

    prompt_tokens = 0
    for input_key in possible_input_keys:
        if input_key in usage:
            prompt_tokens = usage[input_key]
            break

    completion_tokens = 0
    for output_key in possible_output_keys:
        if output_key in usage:
            completion_tokens = usage[output_key]
            break

    return prompt_tokens, completion_tokens

This modification ensures that if response.raw is None, the function will return (0, 0) to handle the None case appropriately, thus avoiding the TypeError [1].

To continue talking to Dosu, mention @dosu.

logan-markewich commented 1 week ago

@Shura1oplot this is actually already fixed, but just need to make a release https://github.com/run-llama/llama_index/blob/1d4d2c653da27b7ac7d2cf53701a73ea7e5d51bb/llama-index-core/llama_index/core/callbacks/token_counting.py#L45

Recently added mypy running on core again, which caught a lot of small bugs