langchain-ai / langchain-google

MIT License
74 stars 78 forks source link

Forced function calling does not work for Gemini Flash. #330

Open akos-sch opened 4 days ago

akos-sch commented 4 days ago

I wrote a script to exemplify the problem I encountered earlier with Gemini Flash forced function calling:

from langchain_core.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field

from langchain_google_vertexai import ChatVertexAI

llm = ChatVertexAI(
    model="gemini-1.5-flash-preview-0514",
    temperature=0,
    max_retries=2,
)

class SumResult(BaseModel):
    """Result of the sum tool."""
    result: int = Field(description="The sum of the two numbers.")

sum_llm = llm.with_structured_output(SumResult)

system = "You are an assistant helping in adding two numbers. You will be given two numbers and you need to add them together."
human = "Add {num_1} and {num_2}."

prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human),
        ]
    )
chain = prompt | sum_llm
print(chain.invoke({"num_1": 3, "num_2": 5}))

As you can see, this setup generally tests structured output parsing. I get the following error: google.api_core.exceptions.InvalidArgument: 400 Unable to submit request because the forced function calling (mode = ANY) is only supported for Gemini 1.5 Pro models. Learn more: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling

Out of the gemini family on vertex ai, only 1.5 pro supports forced function calling.

This error is valid for lib versions 1.0.5 and 1.0.6, as I validated it works well with 1.0.4.