Open oneilsh opened 11 months ago
Thanks for the kind words! I should note that often times the internals of the LLM providers (in particular OpenAI) are a bit of a mystery, so Kani's token counting is really just a best guess to within a couple of percent.
You have a couple options if you want to track tokens as accurately as possible, which I'll lay out here:
Kani.get_model_completion
- this is the method that the Kani instance uses to go the underlying LLM, and it returns a Completion
, which includes the prompt token len and completion token len as returned by the engine. You could, for example, add tokens_used_prompt
and tokens_used_completion
attributes in a subclass of Kani and increment those after a super call; this has the disadvantage of being post-hoc counting though. I use a similar approach in one of my projects here: https://github.com/zhudotexe/kanpai/blob/cc603705d353e4e9b9aa3cf9fbb12e3a46652c55/kanpai/base_kani.py#L48
sum(self.message_token_len(m) for m in await self.get_prompt()) + self.engine.token_reserve + self.engine.function_token_reserve(list(self.functions.values()))
if you wanted a token estimation before sending it to the LLM. The instance caches the token lengths so this won't result in a major slowdown.api_base
and headers
when constructing an OpenAIEngine. I've also been interested in Cloudflare AI Gateway, though I haven't used it yet. These solutions also require a bit more engineering though, and I believe they're also only post-hoc.I'll have to think a bit more about how to implement an official token counting interface if we decide to - maybe Kani.prompt_len_estimate(msgs: list[ChatMessage]) -> int
to perform the estimation detailed above?
Wonderful, thank you! Post-hoc is fine for my case, I used your first suggestion and it works great. I did need to remember to update the counts manually when calling out to sub-kanis. (Maybe engine-level counting?)
Hello again :) I've started playing with the newer streaming functionality, and it's been nice! However, it looks like get_model_completion
isn't called when streaming, so the override described above isn't processing token counts. Perhaps the new get_model_stream
could be overridden similarly? I suppose I could increment the completion token count as they are yielded, but I'm not sure how to get the prompt token count accurately.
Thanks as always-
Good call out - this is a little less elegant with streaming. The get_model_stream
method is a little lower-level (returning a mixed iterator which is managed by a StreamManager
) with no guarantee that a Completion
will be yielded, or that each yield is exactly one token.
In the current version (1.0.1), your best option is probably to look at await stream.completion()
, which should have the prompt_tokens
and completion_tokens
attributes set, with some caveats:
example:
stream = ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?")
async for token in stream:
print(token, end="")
completion = await stream.completion()
prompt_tokens = completion.prompt_tokens
completion_tokens = completion.completion_tokens
# ...
# msg = await stream.message()
In v1.0.0 I added a private _add_completion_to_history
method that acts like add_to_history
, but is only called on model completions with a Completion
rather than on each message. I'll make this method public since it's called on both stream and non-stream completions and token counting is a good use case for it.
I'll update this thread with new code snippets (probably later today?) once that's done.
As of v1.0.2
, Kani.add_completion_to_history
is called after each completion for both streaming and non-streaming generations. You can implement token counting by overriding it like so:
class TokenCountingKani(Kani):
# ...
async def add_completion_to_history(self, completion):
prompt_tokens = completion.prompt_tokens
completion_tokens = completion.completion_tokens
# ...
return await super().add_completion_to_history(completion)
Note that completion.[prompt|completion]_tokens
might be None for user-implemented engines and llama.cpp streams (https://github.com/abetlen/llama-cpp-python/issues/1498), but it should be present for all kani-bundled OpenAI, Anthropic, and Hugging Face engines.
Wonderful, I will give it a try, thank you!
I see that the API supports
.message_token_len()
for an individual ChatMessage; it would be nice to be able query total token usage over the course of a conversation for cost tracking purposes.I'm not entirely sure the best way to handle it - maybe like a
.next_message_tokens_cost(message: ChatMessage)
that would return the total prompt tokens (system + function defs + chat history) plus the tokens inmessage
that would be incurred? If it could be done over the course of a chat (accumulating after each full round) maybe something like.conversation_history_total_prompt_tokens()
and.conversation_history_total_response_tokens()
so a user could compute a running chat cost?Thanks for considering, and for developing Kani! It really is the 'right' API interface to tool-enabled LLMs in my opinion :)