Bedrock token count callbacks #20

Closed NAPTlME closed 1 month ago

NAPTlME commented 2 months ago

Updated both the BedrockLLM and ChatBedrock classes to yield token counts and stop reasons upon generation/call. This works for streaming/non-streaming as well as messages vs raw text.

The goal behind this is to take the input/output tokens and stop reasons directly from the Bedrock call and use them in a CallbackHandler on_llm_end.

Example use

from typing import Any
from uuid import UUID

from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationChain
from langchain.memory import ConversationTokenBufferMemory

from langchain_core.runnables import RunnableConfig

from langchain_aws.chat_models import ChatBedrock
from langchain_core.outputs import LLMResult

class BedrockHandler(BaseCallbackHandler):

    def __init__(self, initial_text=""):
        self.text = initial_text
        self.input_token_count = 0
        self.output_token_count = 0
        self.stop_reason = None

    def on_llm_new_token(self, token: str, **kwargs):
        self.text += token
        # do something

    def on_llm_end(
        response: LLMResult,
        run_id: UUID,
        parent_run_id: UUID | None = None,
        **kwargs: Any,
    ) -> Any:
        if response.llm_output is not None:
            self.input_token_count = response.llm_output.get("usage", {}).get("prompt_tokens", None)
            self.output_token_count = response.llm_output.get("usage", {}).get("completion_tokens", None)
            self.stop_reason = response.llm_output.get("stop_reason", None)

llm = ChatBedrock(model_id="anthropic.claude-3-sonnet-20240229-v1:0", streaming=True)

memory = ConversationTokenBufferMemory(llm=llm)

chain = ConversationChain(llm=llm, memory=memory)

callback = BedrockHandler()

input_prompt = "Write an explanation of math in 3 sentences"
stop_sequences = ["\n\nHuman:"]

response = chain.invoke(
    max_tokens_to_sample = 1024,

print(f"Input tokens: {callback.input_token_count}, Output tokens: {callback.output_token_count}, Stop reason: {callback.stop_reason}")
3coins commented 2 months ago

@NAPTlME Thanks for submitting this update. Please fix the lint and test errors from the CI.

NAPTlME commented 2 months ago

@NAPTlME Thanks for submitting this update. Please fix the lint and test errors from the CI.

@3coins Apologies, I'm on Windows, but manually running Ruff and the unit tests (rather than via make). I also don't have the integration tests set up, but have tested with my current project. If anything further fails, I will take a look at my options to run the makefile.

NAPTlME commented 2 months ago

@3coins I went the WSL route to run the makefile. I fixed some references in the makefile that appear to be holdovers from when this was a part of langchain. Currently, all of those checks pass.

The only item I was unable to run was the integration test (I assume due to what I have provisioned using my AWS credentials).


NAPTlME commented 2 months ago

@3coins Hoping to kick off that workflow again. I'm expecting no further issues, but will address them if CI picks anything else up.

NAPTlME commented 2 months ago

@3coins Addressed minor (import and readme) conflicts introduced by recent changes to main.

NAPTlME commented 2 months ago

@3coins Are there any issues with merging this? (Or am I missing any part of your process for contributions?) Thanks.

DanielWhite95 commented 1 month ago

Hello, thanks for the suggestion. This could be useful to have in the package. Only one note: I think it should add the token count on each llm_end without resetting it with the new value. This could be useful if used in cases where the application will do different calls to the model with the same callback

NAPTlME commented 1 month ago

@DanielWhite95 Thanks for taking a look at it. For my particular use-case we are logging each transaction to track costs/chargeback (thus the need to overwrite the token counts). We make calls to our logger from on_llm_end as well to log each call and usage. Also for the case using the input+output tokens to know the input tokens associated with the context for the next query.

That said, if you wanted to increment, this is just in the callback example so you would modify the callback handler to do so

    if response.llm_output is not None:
            self.input_token_count += response.llm_output.get("usage", {}).get("prompt_tokens", 0)
            self.output_token_count+ = response.llm_output.get("usage", {}).get("completion_tokens", 0)
            self.stop_reason = response.llm_output.get("stop_reason", None)
NAPTlME commented 1 month ago

@3coins @efriis Hey. Not trying bug you here, but I would like to know if it is possible to contribute to this project. If so, is there anything further you need from me for this PR? Thanks.

efriis commented 1 month ago

@3coins is your man!

3coins commented 1 month ago

@NAPTlME Thanks for making all the updates, I had been sick out of office this past week so could not get to this earlier. Your code looks good overall, would prefer less nested blocks if possible. Also, there seems to be a lot of changes here for me to process, so give me until end of tomorrow to merge this.

Can you do one last update and convert your sample code into an integration/unit test?

3coins commented 1 month ago

@NAPTlME There are a few integration test errors with this PR. Can you check these and make sure these pass. I have added details here.

NAPTlME commented 1 month ago

@3coins Thanks for the update. Hope you are feeling better.

It looks like merge_dicts doesn't allow for int types. I looked into updating this to merge integers by addition, but I see that would break some tests in which integers are used in a nominal fashion.

To get around this, I am now putting "usage" token counts into lists and summing them when combining.

Feel free to kick off the workflow. I believe everything is good. Let me know if there are any further areas you feel need to be modified.