spring-projects / spring-ai

An Application Framework for AI Engineering
https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/index.html
Apache License 2.0
2.91k stars 728 forks source link

Vertex Gemini: Unable to submit request because function parameters schema should be of type OBJECT #1373

Open coderphonui opened 3 days ago

coderphonui commented 3 days ago

Bug description I am using the snapshot build from master branch, and I got this error when trying to make a function call using Vertex Gemini.

Full error log:

java.lang.RuntimeException: Failed to generate content

    at org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.getContentResponse(VertexAiGeminiChatModel.java:532)
    at org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.call(VertexAiGeminiChatModel.java:173)
    at org.springframework.ai.chat.client.DefaultChatClient$DefaultChatClientRequestSpec$2.aroundCall(DefaultChatClient.java:722)
    at org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain.lambda$nextAroundCall$1(DefaultAroundAdvisorChain.java:92)
    at io.micrometer.observation.Observation.observe(Observation.java:565)
    at org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain.nextAroundCall(DefaultAroundAdvisorChain.java:92)
    at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.doGetChatResponse(DefaultChatClient.java:372)
    at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.lambda$doGetObservableChatResponse$1(DefaultChatClient.java:342)
    at io.micrometer.observation.Observation.observe(Observation.java:565)
    at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.doGetObservableChatResponse(DefaultChatClient.java:341)
    at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.doGetChatResponse(DefaultChatClient.java:329)
    at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.chatResponse(DefaultChatClient.java:389)
    at com.cdpn.agentcore.it.ReproduceVertexGeminiBugTest.reproduceFunctionCallingBug(ReproduceVertexGeminiBugTest.java:19)
    at java.base/java.lang.reflect.Method.invoke(Method.java:568)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
Caused by: com.google.api.gax.rpc.InvalidArgumentException: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Unable to submit request because function parameters schema should be of type OBJECT. Learn more: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling
    at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:92)
    at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:41)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:86)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:66)
    at com.google.api.gax.grpc.GrpcExceptionCallable$ExceptionTransformingFuture.onFailure(GrpcExceptionCallable.java:97)
    at com.google.api.core.ApiFutures$1.onFailure(ApiFutures.java:84)
    at com.google.common.util.concurrent.Futures$CallbackListener.run(Futures.java:1130)
    at com.google.common.util.concurrent.DirectExecutor.execute(DirectExecutor.java:31)
    at com.google.common.util.concurrent.AbstractFuture.executeListener(AbstractFuture.java:1298)
    at com.google.common.util.concurrent.AbstractFuture.complete(AbstractFuture.java:1059)
    at com.google.common.util.concurrent.AbstractFuture.setException(AbstractFuture.java:809)
    at io.grpc.stub.ClientCalls$GrpcFuture.setException(ClientCalls.java:568)
    at io.grpc.stub.ClientCalls$UnaryStreamToFuture.onClose(ClientCalls.java:538)
    at io.grpc.PartialForwardingClientCallListener.onClose(PartialForwardingClientCallListener.java:39)
    at io.grpc.ForwardingClientCallListener.onClose(ForwardingClientCallListener.java:23)
    at io.grpc.ForwardingClientCallListener$SimpleForwardingClientCallListener.onClose(ForwardingClientCallListener.java:40)
    at com.google.api.gax.grpc.ChannelPool$ReleasingClientCall$1.onClose(ChannelPool.java:570)
    at io.grpc.internal.DelayedClientCall$DelayedListener$3.run(DelayedClientCall.java:489)
    at io.grpc.internal.DelayedClientCall$DelayedListener.delayOrExecute(DelayedClientCall.java:453)
    at io.grpc.internal.DelayedClientCall$DelayedListener.onClose(DelayedClientCall.java:486)
    at io.grpc.internal.ClientCallImpl.closeObserver(ClientCallImpl.java:574)
    at io.grpc.internal.ClientCallImpl.access$300(ClientCallImpl.java:72)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInternal(ClientCallImpl.java:742)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInContext(ClientCallImpl.java:723)
    at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
    at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
    at java.base/java.lang.Thread.run(Thread.java:840)
    Suppressed: com.google.api.gax.rpc.AsyncTaskException: Asynchronous task failed
        at com.google.api.gax.rpc.ApiExceptions.callAndTranslateApiException(ApiExceptions.java:57)
        at com.google.api.gax.rpc.UnaryCallable.call(UnaryCallable.java:112)
        at com.google.cloud.vertexai.generativeai.GenerativeModel.generateContent(GenerativeModel.java:400)
        at com.google.cloud.vertexai.generativeai.GenerativeModel.generateContent(GenerativeModel.java:387)
        at org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.getContentResponse(VertexAiGeminiChatModel.java:529)
        at org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.call(VertexAiGeminiChatModel.java:173)
        at org.springframework.ai.chat.client.DefaultChatClient$DefaultChatClientRequestSpec$2.aroundCall(DefaultChatClient.java:722)
        at org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain.lambda$nextAroundCall$1(DefaultAroundAdvisorChain.java:92)
        at io.micrometer.observation.Observation.observe(Observation.java:565)
        at org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain.nextAroundCall(DefaultAroundAdvisorChain.java:92)
        at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.doGetChatResponse(DefaultChatClient.java:372)
        at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.lambda$doGetObservableChatResponse$1(DefaultChatClient.java:342)
        at io.micrometer.observation.Observation.observe(Observation.java:565)
        at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.doGetObservableChatResponse(DefaultChatClient.java:341)
        at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.doGetChatResponse(DefaultChatClient.java:329)
        at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallResponseSpec.chatResponse(DefaultChatClient.java:389)
        at com.cdpn.agentcore.it.ReproduceVertexGeminiBugTest.reproduceFunctionCallingBug(ReproduceVertexGeminiBugTest.java:19)
        at java.base/java.lang.reflect.Method.invoke(Method.java:568)
        at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
        at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
Caused by: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Unable to submit request because function parameters schema should be of type OBJECT. Learn more: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling
    at io.grpc.Status.asRuntimeException(Status.java:533)
    ... 17 more

Environment

Spring AI snapshot version built from master branch.

Steps to reproduce

Code to reproduce:

Test case

import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

@SpringBootTest
public class ReproduceVertexGeminiBugTest {
    @Autowired
    private VertexAiGeminiChatModel chatModel;

    @Test
    public void reproduceFunctionCallingBug() {
        ChatClient.Builder clientBuilder = ChatClient.builder(chatModel);
        clientBuilder.defaultFunctions("weatherFunction");
        ChatClient chatClient = clientBuilder.build();
        chatClient.prompt().user("What is the weather in New York in Celsius?").call().chatResponse();
    }
}

Configuration

import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class FunctionCallingTestConfig {
    @Bean
    public FunctionCallback weatherFunctionInfo() {
        return FunctionCallbackWrapper.builder(new MockWeatherService())
                .withName("weatherFunction")
                .withDescription("Get the weather in location")
                .build();
    }

//  Note that if I use this way of configuration, it will work normally. However, using FunctionCallback will throw the exception.
//    @Bean
//    @Description("Get the weather in location")
//    public Function<MockWeatherService.Request, MockWeatherService.Response> weatherFunction() {
//        return new MockWeatherService();
//    }

}

MockWeatherService

import com.fasterxml.jackson.annotation.JsonClassDescription;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;

import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;

@JsonClassDescription("Get the weather in location")
public class MockWeatherService implements Function<MockWeatherService.Request, MockWeatherService.Response> {

    /**
     * Weather Function request.
     */
    @JsonInclude(JsonInclude.Include.NON_NULL)
    @JsonClassDescription("Weather API request")
    public record Request(@JsonProperty(required = true,
            value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location,
                          @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) {
    }

    /**
     * Temperature units.
     */
    public enum Unit {

        /**
         * Celsius.
         */
        C("metric"),
        /**
         * Fahrenheit.
         */
        F("imperial");

        /**
         * Human readable unit name.
         */
        public final String unitName;

        Unit(String text) {
            this.unitName = text;
        }

    }

    /**
     * Weather Function response.
     */
    public record Response(double temp, double feels_like, double temp_min, double temp_max, int pressure, int humidity,
                           Unit unit) {
    }

    @Override
    public Response apply(Request request) {
        double temperature = ThreadLocalRandom.current().nextDouble(15, 31);
        Unit unit = request.unit();
        int humidity = ThreadLocalRandom.current().nextInt(50, 101);
        return new Response(temperature, 15, 20, 2, 53, humidity, unit);
    }

}

Expected behavior The function call should work normally with the approach of using FunctionCallbackWrapper

coderphonui commented 1 day ago

Adding

.withSchemaType(FunctionCallbackWrapper.Builder.SchemaType.OPEN_API_SCHEMA)

will fix the issue.

The final code should be like this:

    @Bean
    public FunctionCallback weatherFunctionInfo() {
        return FunctionCallbackWrapper.builder(new MockWeatherService())
                .withName("weatherFunction")
                .withDescription("Get the weather in location")
                .withSchemaType(FunctionCallbackWrapper.Builder.SchemaType.OPEN_API_SCHEMA)
                .build();
    }

I see this is mentioned in the documentation: https://docs.spring.io/spring-ai/reference/api/chat/functions/vertexai-gemini-chat-functions.html

However, I am not satisfied with the solution. In my opinion, the FunctionCallbackWrapper should be abstract enough to cover the schema compatibility on each model.

Let's put an example mapping to my scenario:

I am building an agent system that allows user to configure agents with prompt and the LLM. I also built some built-in functions - this is where I need to load them to Spring AI using FunctionCallBackWrapper. When implementing the built-in functions, I don't know which LLM that the user will configure to use that function for the agent they create. Fortunately, OPEN_API_SCHEMA can help to cover most of the cases, but what happens if another LLM (in the future) will require only JSON_SCHEMA?

So, I would suggest to remove the withSchemaType in the builder and the core logic of Spring AI should cover the compatibility in this case.

coderphonui commented 16 hours ago

Finally, I come with a solution to wrap up the Function definition like below:

@Data
@Builder
@AllArgsConstructor
public class FunctionDefinition<I, O> {
    private String name;
    private String description;
    private Function<I, O> function;
    public FunctionCallbackWrapper<I, O> toFunctionCallbackWrapper(FunctionCallbackWrapper.Builder.SchemaType schemaType) {
        return FunctionCallbackWrapper.builder(function)
                .withName(name)
                .withDescription(description)
                .withSchemaType(schemaType)
                .build();
    }
}

I make a hook to build up the ChatModel for each model I need to support and rebuild the CallbackFunction list before passing them to the model constructor. Below is an example code:

@Slf4j
@Component
public class VertexGeminiChatClientBuilder implements ChatClientBuilder {
    private final VertexAI vertexAi;
    private final List<FunctionCallback> toolFunctionCallbacks;
    private final ApplicationContext context;
    private final List<FunctionDefinition> functionDefinitions;

    public VertexGeminiChatClientBuilder(VertexAI vertexAi, List<FunctionCallback> toolFunctionCallbacks, ApplicationContext context, List<FunctionDefinition> functionDefinitions) {
        this.vertexAi = vertexAi;
        this.toolFunctionCallbacks = toolFunctionCallbacks;
        this.context = context;
        this.functionDefinitions = functionDefinitions;
    }

    @Override
    public String getProviderName() {
        return "vertex-ai";
    }

    @Override
    public ChatModel buildChatModel(AgentConfig agentConfig) {
        if(agentConfig == null || agentConfig.getLlmConfig() == null) {
            return null;
        }
        FunctionCallbackContext functionCallbackContext = this.springAiFunctionManager(context);
        List<FunctionCallback> clonedToolFunctionCallbacks = new ArrayList<>(List.copyOf(toolFunctionCallbacks));
        if(functionDefinitions != null) {
            functionDefinitions.forEach(definition -> {
                clonedToolFunctionCallbacks.add(definition.toFunctionCallbackWrapper(FunctionCallbackWrapper.Builder.SchemaType.OPEN_API_SCHEMA));
            });
        }
        return new VertexAiGeminiChatModel(vertexAi, buildVertexGeminiChatOptions(agentConfig),
                functionCallbackContext, clonedToolFunctionCallbacks);
    }

    @Override
    public ChatOptions buildChatOptions(AgentConfig agentConfig) {
        return buildVertexGeminiChatOptions(agentConfig);
    }

    private VertexAiGeminiChatOptions buildVertexGeminiChatOptions(AgentConfig agentConfig) {
        if(agentConfig == null) {
            return null;
        }
        LLMConfig modelConfig = agentConfig.getLlmConfig();
        if(modelConfig == null) {
            return null;
        }
        if(modelConfig.getModelName() == null) {
            modelConfig.setModelName(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH.getValue());
        }
        VertexAiGeminiChatOptions.Builder builder =  VertexAiGeminiChatOptions.builder()
                .withModel(modelConfig.getModelName())
                .withMaxOutputTokens(modelConfig.getMaxTokens())
                .withTemperature(modelConfig.getTemperature());
        if(!agentConfig.getTools().isEmpty()) {
            agentConfig.getTools().forEach(builder::withFunction);
        }
        return builder.build();
    }

    private FunctionCallbackContext springAiFunctionManager(ApplicationContext context) {
        FunctionCallbackContext manager = new FunctionCallbackContext();
        manager.setSchemaType(FunctionCallbackWrapper.Builder.SchemaType.OPEN_API_SCHEMA);
        manager.setApplicationContext(context);
        return manager;
    }

}

By this approach, I bind the SchemaType exactly to where it needs (when the ChatModel is initialized)