spring-projects / spring-ai

An Application Framework for AI Engineering
https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/index.html
Apache License 2.0
2.62k stars 634 forks source link

Gemini Function call with multiple functions doesn't work #619

Closed Grogdunn closed 2 months ago

Grogdunn commented 3 months ago

Bug description If you define multiple function in a "call" with Gemini, lead to INVALID_ARGUMENT server error.

Environment tested on spring-ai 0.8.1 bug spotted on code of 1.0.0-SNAPSHOT

Steps to reproduce Pass multiple function to a "VertexAICliente.call"

Expected behavior No error on server, hopefully all function called.

From a trial Vertex seems support one tool with multiple function, not multiple tools with a single funcion each. So this function "getFunctionTools" must produce a single tool with all function inside https://github.com/spring-projects/spring-ai/blob/563f9b48af60a176b32ee4fa03877123f076b9a7/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java#L350

tzolov commented 3 months ago

@Grogdunn, thank you for the feedback. Can you please write a reproducible example?

Grogdunn commented 2 months ago

Sure, when I've spere time I'll try to fix it, but meanwhile put that test in VertexAiGeminiChatClientFunctionCallingIT

 @Test
    public void apiFail() {

        UserMessage userMessage = new UserMessage(
                "What's the weather like in San Francisco, in Paris and in Tokyo, Japan? Use Multi-turn function calling. Provide answer for all requested locations.");

        List<Message> messages = new ArrayList<>(List.of(userMessage));

        final var weatherFunction = FunctionCallbackWrapper.builder(new MockWeatherService())
                .withSchemaType(SchemaType.OPEN_API_SCHEMA)
                .withName("getCurrentWeather")
                .withDescription("Get the current weather in a given location")
                .build();
        final var theAnswer = FunctionCallbackWrapper.<String, Integer>builder(o -> 42)
                .withSchemaType(SchemaType.OPEN_API_SCHEMA)
                .withName("theAnswerToTheUniverse")
                .withDescription("the answer to the ultimate question of life, the universe, and everything")
                .build();
        var promptOptions = VertexAiGeminiChatOptions.builder()
                .withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
                .withFunctionCallbacks(List.of(weatherFunction, theAnswer))
                .build();

        Flux<ChatResponse> response = vertexGeminiClient.stream(new Prompt(messages, promptOptions));

        String responseString = response.collectList()
                .block()
                .stream()
                .map(ChatResponse::getResults)
                .flatMap(List::stream)
                .map(Generation::getOutput)
                .map(AssistantMessage::getContent)
                .collect(Collectors.joining());

        logger.info("Response: {}", responseString);

        assertThat(responseString).containsAnyOf("15.0", "15");
        assertThat(responseString).containsAnyOf("30.0", "30");
        assertThat(responseString).containsAnyOf("10.0", "10");

    }

the response is:

com.google.api.gax.rpc.InvalidArgumentException: Bad Request

    at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:52)
    at com.google.api.gax.httpjson.HttpJsonApiExceptionFactory.createApiException(HttpJsonApiExceptionFactory.java:76)
    at com.google.api.gax.httpjson.HttpJsonApiExceptionFactory.create(HttpJsonApiExceptionFactory.java:54)
    at com.google.api.gax.httpjson.HttpJsonExceptionResponseObserver.onErrorImpl(HttpJsonExceptionResponseObserver.java:82)
    at com.google.api.gax.rpc.StateCheckingResponseObserver.onError(StateCheckingResponseObserver.java:84)
    at com.google.api.gax.httpjson.HttpJsonDirectStreamController$ResponseObserverAdapter.onClose(HttpJsonDirectStreamController.java:125)
    at com.google.api.gax.httpjson.HttpJsonClientCallImpl$OnCloseNotificationTask.call(HttpJsonClientCallImpl.java:552)
    at com.google.api.gax.httpjson.HttpJsonClientCallImpl.notifyListeners(HttpJsonClientCallImpl.java:391)
    at com.google.api.gax.httpjson.HttpJsonClientCallImpl.deliver(HttpJsonClientCallImpl.java:318)
    at com.google.api.gax.httpjson.HttpJsonClientCallImpl.setResult(HttpJsonClientCallImpl.java:164)
    at com.google.api.gax.httpjson.HttpRequestRunnable.run(HttpRequestRunnable.java:149)
    at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:539)
    at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
    at java.base/java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.run(ScheduledThreadPoolExecutor.java:304)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
    at java.base/java.lang.Thread.run(Thread.java:840)
    Suppressed: java.lang.RuntimeException: Asynchronous task failed
        at com.google.api.gax.rpc.ServerStreamIterator.hasNext(ServerStreamIterator.java:105)
        at com.google.cloud.vertexai.generativeai.ResponseStreamIteratorWithHistory.hasNext(ResponseStreamIteratorWithHistory.java:37)
        at java.base/java.util.Spliterators$IteratorSpliterator.tryAdvance(Spliterators.java:1855)
        at reactor.core.publisher.FluxIterable$IterableSubscription.hasNext(FluxIterable.java:271)
        at reactor.core.publisher.FluxIterable.subscribe(FluxIterable.java:187)
        at reactor.core.publisher.FluxStream.subscribe(FluxStream.java:69)
        at reactor.core.publisher.Mono.subscribe(Mono.java:4568)
        at reactor.core.publisher.Mono.block(Mono.java:1778)
        at org.springframework.ai.vertexai.gemini.function.VertexAiGeminiChatClientFunctionCallingIT.apiFail(VertexAiGeminiChatClientFunctionCallingIT.java:205)
        at java.base/java.lang.reflect.Method.invoke(Method.java:568)
        at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
        at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
    Suppressed: java.lang.Exception: #block terminated with an error
        at reactor.core.publisher.BlockingSingleSubscriber.blockingGet(BlockingSingleSubscriber.java:104)
        at reactor.core.publisher.Mono.block(Mono.java:1779)
        at org.springframework.ai.vertexai.gemini.function.VertexAiGeminiChatClientFunctionCallingIT.apiFail(VertexAiGeminiChatClientFunctionCallingIT.java:205)
        at java.base/java.lang.reflect.Method.invoke(Method.java:568)
        at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
        at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
Caused by: com.google.api.client.http.HttpResponseException: 400 Bad Request
POST https://europe-west3-aiplatform.googleapis.com:443/v1/projects/[OMISSIS]/locations/europe-west3/publishers/google/models/gemini-pro:streamGenerateContent
[{
  "error": {
    "code": 400,
    "message": "Request contains an invalid argument.",
    "errors": [
      {
        "message": "Request contains an invalid argument.",
        "domain": "global",
        "reason": "badRequest"
      }
    ],
    "status": "INVALID_ARGUMENT"
  }
}
]
    at com.google.api.client.http.HttpResponseException$Builder.build(HttpResponseException.java:293)
    at com.google.api.client.http.HttpRequest.execute(HttpRequest.java:1118)
    at com.google.api.gax.httpjson.HttpRequestRunnable.run(HttpRequestRunnable.java:115)
    ... 6 more

And a lot of function calling IT test doesn't work for Google Gemini limitation (I think)

tzolov commented 2 months ago

resolved by 9f32a3ca9cd119116b112e050b942e5e94855914