spring-projects / spring-ai

An Application Framework for AI Engineering
https://docs.spring.io/spring-ai/reference/index.html
Apache License 2.0
3.28k stars 839 forks source link

Cannot create the response with Vertex Gemini from the second time #1340

Closed coderphonui closed 2 months ago

coderphonui commented 2 months ago

Bug description

I am using Spring AI in my project and try to generate the content several times with the same chatClient. With Vertex Gemini, I got the exception at the second time to execute the call method. Below is the exception I am facing:

java.lang.RuntimeException: Failed to generate content

    at org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.getContentResponse(VertexAiGeminiChatModel.java:532)
    at org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.call(VertexAiGeminiChatModel.java:173)
    at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallPromptResponseSpec.doGetChatResponse(DefaultChatClient.java:959)
    at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallPromptResponseSpec.chatResponse(DefaultChatClient.java:955)
    at com.cdpn.agentcore.it.VertexBugReproduceTest.cannot_produce_content_at_the_second_call(VertexBugReproduceTest.java:40)
    at java.base/java.lang.reflect.Method.invoke(Method.java:568)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
Caused by: com.google.api.gax.rpc.InvalidArgumentException: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Unable to submit request because it must include at least one parts field, which describes the prompt input. Learn more: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini
    at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:92)
    at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:41)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:86)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:66)
    at com.google.api.gax.grpc.GrpcExceptionCallable$ExceptionTransformingFuture.onFailure(GrpcExceptionCallable.java:97)
    at com.google.api.core.ApiFutures$1.onFailure(ApiFutures.java:84)
    at com.google.common.util.concurrent.Futures$CallbackListener.run(Futures.java:1130)
    at com.google.common.util.concurrent.DirectExecutor.execute(DirectExecutor.java:31)
    at com.google.common.util.concurrent.AbstractFuture.executeListener(AbstractFuture.java:1298)
    at com.google.common.util.concurrent.AbstractFuture.complete(AbstractFuture.java:1059)
    at com.google.common.util.concurrent.AbstractFuture.setException(AbstractFuture.java:809)
    at io.grpc.stub.ClientCalls$GrpcFuture.setException(ClientCalls.java:568)
    at io.grpc.stub.ClientCalls$UnaryStreamToFuture.onClose(ClientCalls.java:538)
    at io.grpc.PartialForwardingClientCallListener.onClose(PartialForwardingClientCallListener.java:39)
    at io.grpc.ForwardingClientCallListener.onClose(ForwardingClientCallListener.java:23)
    at io.grpc.ForwardingClientCallListener$SimpleForwardingClientCallListener.onClose(ForwardingClientCallListener.java:40)
    at com.google.api.gax.grpc.ChannelPool$ReleasingClientCall$1.onClose(ChannelPool.java:570)
    at io.grpc.internal.ClientCallImpl.closeObserver(ClientCallImpl.java:574)
    at io.grpc.internal.ClientCallImpl.access$300(ClientCallImpl.java:72)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInternal(ClientCallImpl.java:742)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImpl$1StreamClosed.runInContext(ClientCallImpl.java:723)
    at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
    at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
    at java.base/java.lang.Thread.run(Thread.java:840)
    Suppressed: com.google.api.gax.rpc.AsyncTaskException: Asynchronous task failed
        at com.google.api.gax.rpc.ApiExceptions.callAndTranslateApiException(ApiExceptions.java:57)
        at com.google.api.gax.rpc.UnaryCallable.call(UnaryCallable.java:112)
        at com.google.cloud.vertexai.generativeai.GenerativeModel.generateContent(GenerativeModel.java:400)
        at com.google.cloud.vertexai.generativeai.GenerativeModel.generateContent(GenerativeModel.java:387)
        at org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.getContentResponse(VertexAiGeminiChatModel.java:529)
        at org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel.call(VertexAiGeminiChatModel.java:173)
        at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallPromptResponseSpec.doGetChatResponse(DefaultChatClient.java:959)
        at org.springframework.ai.chat.client.DefaultChatClient$DefaultCallPromptResponseSpec.chatResponse(DefaultChatClient.java:955)
        at com.cdpn.agentcore.it.VertexBugReproduceTest.cannot_produce_content_at_the_second_call(VertexBugReproduceTest.java:40)
        at java.base/java.lang.reflect.Method.invoke(Method.java:568)
        at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
        at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
Caused by: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Unable to submit request because it must include at least one parts field, which describes the prompt input. Learn more: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini
    at io.grpc.Status.asRuntimeException(Status.java:533)
    ... 14 more

Environment I am using the Java 17, Spring AI 1.0.0-SNAPSHOT and spring-ai-vertex-ai-gemini-spring-boot-starter

Steps to reproduce

Code to reproduce the issue:

package com.cdpn.agentcore.it;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

import java.util.ArrayList;
import java.util.List;

@SpringBootTest
public class VertexBugReproduceTest {
    @Autowired
    VertexAiGeminiChatModel vertexAiGeminiChatModel;

    @EnabledIfEnvironmentVariable(named = "GEMINI_PROJECT_ID", matches = ".*")
    @Test
    public void cannot_produce_content_at_the_second_call() {
        ChatClient.Builder chatClientBuilder = ChatClient.builder(vertexAiGeminiChatModel);
        chatClientBuilder.defaultSystem("You are a funny and lovely girlfriend of mine. Your name is Anna. Please talk to me as my darling");
        ChatClient chatClient = chatClientBuilder.build();
        List<Message> messages = new ArrayList<>();
        messages.add(new UserMessage("hello"));
        ChatResponse result = chatClient.prompt(new Prompt(messages)).call().chatResponse();
        String answer = result.getResult().getOutput().getContent();
        System.out.println(answer); // This is good

        messages.add(new AssistantMessage(answer));
        messages.add(new UserMessage("I want to hear a funny story"));
        result = chatClient.prompt(new Prompt(messages)).call().chatResponse(); // but this will fail
        answer = result.getResult().getOutput().getContent();
        System.out.println(answer);
    }
}

Expected behavior

chatClient should be able to run many times as long as the prompt correct.

coderphonui commented 2 months ago

I have figured out the issue. In my current version of Spring AI, the code on the highlighted line of GemeniModel does not add the part for the message.

I see it's fixed on the main branch:

https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java#L434

VertexAiGeminiChatModel-bug

However, seems the snapshot release is still using the old version.