zhudotexe / kani

kani (カニ) is a highly hackable microframework for chat-based language models with tool use/function calling. (NLP-OSS @ EMNLP 2023)
https://kani.readthedocs.io
MIT License
556 stars 30 forks source link

Tracking token usage? #29

Open oneilsh opened 10 months ago

oneilsh commented 10 months ago

I see that the API supports .message_token_len() for an individual ChatMessage; it would be nice to be able query total token usage over the course of a conversation for cost tracking purposes.

I'm not entirely sure the best way to handle it - maybe like a .next_message_tokens_cost(message: ChatMessage) that would return the total prompt tokens (system + function defs + chat history) plus the tokens in message that would be incurred? If it could be done over the course of a chat (accumulating after each full round) maybe something like .conversation_history_total_prompt_tokens() and .conversation_history_total_response_tokens() so a user could compute a running chat cost?

Thanks for considering, and for developing Kani! It really is the 'right' API interface to tool-enabled LLMs in my opinion :)

zhudotexe commented 10 months ago

Thanks for the kind words! I should note that often times the internals of the LLM providers (in particular OpenAI) are a bit of a mystery, so Kani's token counting is really just a best guess to within a couple of percent.

You have a couple options if you want to track tokens as accurately as possible, which I'll lay out here:

  1. Overriding Kani.get_model_completion - this is the method that the Kani instance uses to go the underlying LLM, and it returns a Completion, which includes the prompt token len and completion token len as returned by the engine. You could, for example, add tokens_used_prompt and tokens_used_completion attributes in a subclass of Kani and increment those after a super call; this has the disadvantage of being post-hoc counting though. I use a similar approach in one of my projects here: https://github.com/zhudotexe/kanpai/blob/cc603705d353e4e9b9aa3cf9fbb12e3a46652c55/kanpai/base_kani.py#L48
    1. You could also use an estimation like sum(self.message_token_len(m) for m in await self.get_prompt()) + self.engine.token_reserve + self.engine.function_token_reserve(list(self.functions.values())) if you wanted a token estimation before sending it to the LLM. The instance caches the token lengths so this won't result in a major slowdown.
  2. Use an external gateway - in our lab we've been trying out Helicone for token counting. If you're using OpenAI, you can integrate it with Kani pretty easily by specifying the api_base and headers when constructing an OpenAIEngine. I've also been interested in Cloudflare AI Gateway, though I haven't used it yet. These solutions also require a bit more engineering though, and I believe they're also only post-hoc.

I'll have to think a bit more about how to implement an official token counting interface if we decide to - maybe Kani.prompt_len_estimate(msgs: list[ChatMessage]) -> int to perform the estimation detailed above?

oneilsh commented 10 months ago

Wonderful, thank you! Post-hoc is fine for my case, I used your first suggestion and it works great. I did need to remember to update the counts manually when calling out to sub-kanis. (Maybe engine-level counting?)

oneilsh commented 4 months ago

Hello again :) I've started playing with the newer streaming functionality, and it's been nice! However, it looks like get_model_completion isn't called when streaming, so the override described above isn't processing token counts. Perhaps the new get_model_stream could be overridden similarly? I suppose I could increment the completion token count as they are yielded, but I'm not sure how to get the prompt token count accurately.

Thanks as always-

zhudotexe commented 4 months ago

Good call out - this is a little less elegant with streaming. The get_model_stream method is a little lower-level (returning a mixed iterator which is managed by a StreamManager) with no guarantee that a Completion will be yielded, or that each yield is exactly one token.

In the current version (1.0.1), your best option is probably to look at await stream.completion(), which should have the prompt_tokens and completion_tokens attributes set, with some caveats:

example:

stream = ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?")
async for token in stream:
    print(token, end="")
completion = await stream.completion()
prompt_tokens = completion.prompt_tokens
completion_tokens = completion.completion_tokens
# ...
# msg = await stream.message()

In v1.0.0 I added a private _add_completion_to_history method that acts like add_to_history, but is only called on model completions with a Completion rather than on each message. I'll make this method public since it's called on both stream and non-stream completions and token counting is a good use case for it.

I'll update this thread with new code snippets (probably later today?) once that's done.

zhudotexe commented 4 months ago

As of v1.0.2, Kani.add_completion_to_history is called after each completion for both streaming and non-streaming generations. You can implement token counting by overriding it like so:

class TokenCountingKani(Kani):
    # ...

    async def add_completion_to_history(self, completion):
        prompt_tokens = completion.prompt_tokens
        completion_tokens = completion.completion_tokens
        # ...
        return await super().add_completion_to_history(completion)

Note that completion.[prompt|completion]_tokens might be None for user-implemented engines and llama.cpp streams (https://github.com/abetlen/llama-cpp-python/issues/1498), but it should be present for all kani-bundled OpenAI, Anthropic, and Hugging Face engines.

oneilsh commented 4 months ago

Wonderful, I will give it a try, thank you!