googleapis / google-cloud-java

Google Cloud Client Library for Java
https://cloud.google.com/java/docs/reference
Apache License 2.0
1.89k stars 1.06k forks source link

[vertexai] Unexpected SafetyRating Format in Vertex AI GenerateContentResponse #10564

Closed RikJux closed 5 months ago

RikJux commented 6 months ago

This issue reports a discrepancy in the format of the SafetyRating protobuf field within the GenerateContentResponse object. When using the Vertex AI library (google-cloud-vertexai), the SafetyRating field lacks additional information like probabilityScore and severityScore present in the actual response by the GenerativeModel

The following sections detail the steps to reproduce the issue, provide environment information, and include code examples from the Vertex AI library, as well as an example of WebClient usage (with json-formatted request/response) for comparison. Additionally, references to relevant API documentation are provided.

Environment details

  1. Using the google-cloud-vertexai imported with maven com.google.cloud libraries-bom 26.34.0 (see below)
  2. OS type and version: Windows 11
  3. Java version: 21

Steps to reproduce

  1. Instantiate a GenerativeModel object
  2. call the generateContent method
  3. to make a comparison, create a similar request to a WebClient working with json data

Code example

Request made using the google-vertexai library


private static String jsonFilePath = ""; // the path to the authentication json

    private static String apiUrl = ""; // see below for further information

@Test
    public void example() throws IOException, Descriptors.DescriptorValidationException {

        OAuth2Credentials credentials;
        credentials = ServiceAccountCredentials
                    .fromStream(new FileInputStream(jsonFilePath))
                    .createScoped(Arrays.asList("https://www.googleapis.com/auth/cloud-platform"));
        credentials.refreshIfExpired();

        try (VertexAI vertexAi = new VertexAI("my-project", "europe-west9", credentials)) {
            GenerationConfig generationConfig =
                    GenerationConfig.newBuilder()
                            .setMaxOutputTokens(2048)
                            .setTemperature(0.9F)
                            .setTopP(1F)
                            .build();
            GenerativeModel model = new GenerativeModel("gemini-1.0-pro", generationConfig, vertexAi);

            GenerateContentResponse response = model.generateContent("say Hi!");
            System.out.println(response);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

Response received

candidates {
  content {
    role: "model"
    parts {
      text: "Hi there! How can I assist you today?"
    }
  }
  finish_reason: STOP
  safety_ratings {
    category: HARM_CATEGORY_HATE_SPEECH
    probability: NEGLIGIBLE
    5: 0x3cef378b
    6: 1
    7: 0x3d1dda8b
  }
  safety_ratings {
    category: HARM_CATEGORY_DANGEROUS_CONTENT
    probability: NEGLIGIBLE
    5: 0x3db8d6b0
    6: 1
    7: 0x3cf5a644
  }
  safety_ratings {
    category: HARM_CATEGORY_HARASSMENT
    probability: NEGLIGIBLE
    5: 0x3d3f0787
    6: 1
    7: 0x3c6205f4
  }
  safety_ratings {
    category: HARM_CATEGORY_SEXUALLY_EXPLICIT
    probability: NEGLIGIBLE
    5: 0x3e4374af
    6: 1
    7: 0x3d401970
  }
}
usage_metadata {
  prompt_token_count: 3
  candidates_token_count: 10
  total_token_count: 13
}

Request made using the springboot WebClient

@Test
    public void example() throws IOException {

        GenerateContentRequest generateContentRequest = getGenerateContentRequest();

        SslProvider sslProvider = SslProvider.defaultClientProvider();
        HttpClient client = HttpClient.create()
                .responseTimeout(Duration.ofSeconds(100))
                .secure(sslProvider);

        WebClient.Builder webClientBuilder = WebClient.builder()
                .baseUrl(apiUrl)
                .clientConnector(new ReactorClientHttpConnector(client));

        GoogleCredentials credentials = ServiceAccountCredentials
                .fromStream(new FileInputStream(jsonFilePath))
                .createScoped(List.of("https://www.googleapis.com/auth/cloud-platform"));
        credentials.refreshIfExpired();
        String auth = credentials.refreshAccessToken().getTokenValue();

        WebClient webClient = webClientBuilder
                .defaultHeader("Authorization", "Bearer" + " " + auth)
                .build();

        WebClient.ResponseSpec responseMono = webClient.post()
                .bodyValue(JsonFormat.printer().print(generateContentRequest))
                .retrieve();

        Mono<JsonNode> monoJson =  responseMono.bodyToMono(JsonNode.class);

        System.out.println(monoJson.block().toPrettyString());

    }

private GenerateContentRequest getGenerateContentRequest() {
        Part textPart = Part.newBuilder()
                .setText("Say hi!")
                .build();

        Content content = Content.newBuilder()
                .setRole("user")
                .addParts(textPart)
                .build();

        GenerationConfig generationConfig = GenerationConfig.newBuilder()
                .setTemperature(0.4f)
                .setTopP(1.0f)
                .setTopK(32)
                .setMaxOutputTokens(2048)
                .build();

        GenerateContentRequest generateContentRequest = GenerateContentRequest.newBuilder()
                .addContents(content)
                .setGenerationConfig(generationConfig)
                .build();

        return generateContentRequest;
    }

Response received

{
  "candidates" : [ {
    "content" : {
      "role" : "model",
      "parts" : [ {
        "text" : "Hi there! How can I assist you today?"
      } ]
    },
    "finishReason" : "STOP",
    "safetyRatings" : [ {
      "category" : "HARM_CATEGORY_HATE_SPEECH",
      "probability" : "NEGLIGIBLE",
      "probabilityScore" : 0.030502057,
      "severity" : "HARM_SEVERITY_NEGLIGIBLE",
      "severityScore" : 0.03782129
    }, {
      "category" : "HARM_CATEGORY_DANGEROUS_CONTENT",
      "probability" : "NEGLIGIBLE",
      "probabilityScore" : 0.09334688,
      "severity" : "HARM_SEVERITY_NEGLIGIBLE",
      "severityScore" : 0.03204008
    }, {
      "category" : "HARM_CATEGORY_HARASSMENT",
      "probability" : "NEGLIGIBLE",
      "probabilityScore" : 0.047869004,
      "severity" : "HARM_SEVERITY_NEGLIGIBLE",
      "severityScore" : 0.015365342
    }, {
      "category" : "HARM_CATEGORY_SEXUALLY_EXPLICIT",
      "probability" : "NEGLIGIBLE",
      "probabilityScore" : 0.20386598,
      "severity" : "HARM_SEVERITY_NEGLIGIBLE",
      "severityScore" : 0.047074173
    } ]
  } ],
  "usageMetadata" : {
    "promptTokenCount" : 3,
    "candidatesTokenCount" : 10,
    "totalTokenCount" : 13
  }
}

External references such as API reference guides

I got the apiUrl from this link: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini

Guide to the maven import: https://cloud.google.com/java/docs/reference/google-cloud-vertexai/latest/overview

meltsufin commented 6 months ago

@ZhenyiQ PTAL

ZhenyiQ commented 6 months ago

Thanks @RikJux for raising the issue! This is likely due to the client lagging the API version. We'll update the client in our next release.

ZhenyiQ commented 5 months ago

The client version was updated.