langchain-ai / langchain-google

MIT License
117 stars 149 forks source link

Usage Metadata returns wrong input_tokens for streamed output #576

Closed boriswang01 closed 3 weeks ago

boriswang01 commented 3 weeks ago
import asyncio
import os

from langchain.schema import HumanMessage
from langchain_google_genai import (
    ChatGoogleGenerativeAI,
    HarmBlockThreshold,
    HarmCategory,
)

# Set your Google API key
os.environ["GOOGLE_API_KEY"] = "key"

# Initialize the ChatGoogleGenerativeAI model
model = ChatGoogleGenerativeAI(
    model="gemini-1.5-pro-002",
    temperature=0.2,
    max_output_tokens=2000,
    safety_settings={
        HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
    }
)

message_content = """
What are the three primary colors?

Additionally, can you explain:
1. How these colors are used in art?
2. What secondary colors can be created from them?
3. Are there any differences between primary colors in light vs. pigment?

Please provide a detailed explanation for each point.
"""

async def process_message():
    message = [HumanMessage(content=message_content)]
    input_tokens = model.get_num_tokens(message_content)
    output_tokens = 0

    async for stream_chunk in model.astream_events(
            message,
            version="v2",
    ):
        kind = stream_chunk["event"]

        if kind == "on_chat_model_end":
            output = stream_chunk['data']['output']
            print(f"usage_metadata: {output.usage_metadata}")

            # Count output tokens
            output_tokens = model.get_num_tokens(output.content)

    # Print usage data
    print("\nUsage Data:")
    print(f"Input Tokens: {input_tokens}")
    print(f"Output Tokens: {output_tokens}")
    print(f"Total Tokens: {input_tokens + output_tokens}")

async def main():
    await process_message()

if __name__ == "__main__":
    asyncio.run(main())

When ran it returns the following:

usage_metadata: {'input_tokens': 990, 'output_tokens': 662, 'total_tokens': 1652}

Usage Data:
Input Tokens: 65
Output Tokens: 662
Total Tokens: 727

The input_tokens returned by usage_metadata is overly inflated versus the correct Input Tokens returned by "Usage Data"

Current using langchain-google-genai version 2.0.1