spring-projects / spring-ai

An Application Framework for AI Engineering
https://docs.spring.io/spring-ai/reference/index.html
Apache License 2.0
3.17k stars 789 forks source link

Cannot set ResponseFormat with OpenAI #347

Closed raphaeldelio closed 8 months ago

raphaeldelio commented 8 months ago

Bug description If I try to call the client with a prompt with OpenAiChatOptions it complains that it cannot find the class: java.lang.ClassNotFoundException: org.springframework.ai.chat.prompt.ChatOptions

Environment

import org.jetbrains.kotlin.gradle.tasks.KotlinCompile

plugins {
    id("org.springframework.boot") version "3.2.2"
    id("io.spring.dependency-management") version "1.1.4"
    kotlin("jvm") version "1.9.22"
    kotlin("plugin.spring") version "1.9.22"
}

group = "demo.springai"
version = "0.0.1-SNAPSHOT"
val springAiVersion = "0.8.0-SNAPSHOT"

java {
    sourceCompatibility = JavaVersion.VERSION_21
}

repositories {
    mavenCentral()
    maven("https://repo.spring.io/milestone")
    maven("https://repo.spring.io/snapshot")
}

dependencies {
    implementation("org.springframework.boot:spring-boot-starter-web")
    implementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310")
    implementation("org.springframework.ai:spring-ai-openai-spring-boot-starter:0.8.0-SNAPSHOT")

    implementation("com.fasterxml.jackson.module:jackson-module-kotlin")
    implementation("org.jetbrains.kotlin:kotlin-reflect")
    testImplementation("org.springframework.boot:spring-boot-starter-test")
}

tasks.withType<KotlinCompile> {
    kotlinOptions {
        freeCompilerArgs += "-Xjsr305=strict"
        jvmTarget = "21"
    }
}

tasks.withType<Test> {
    useJUnitPlatform()
}

Steps to reproduce

        val response = chatClient.call(
            Prompt(
                listOf(UserMessage(reviews.toJson()), templateMessage),
                OpenAiChatOptions.builder()
                    .withResponseFormat(ResponseFormat("json"))
                    .build()
            )
        )

Expected behavior Without chat options, it works just fine.

java.lang.ClassNotFoundException: org.springframework.ai.chat.prompt.ChatOptions
    at java.base/jdk.internal.loader.BuiltinClassLoader.loadClass(BuiltinClassLoader.java:641) ~[na:na]
    at java.base/jdk.internal.loader.ClassLoaders$AppClassLoader.loadClass(ClassLoaders.java:188) ~[na:na]
    at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:526) ~[na:na]
    at demo.springai.springaidemo.ReviewAnalyzerService.analyzeReviews(ReviewAnalyzerService.kt:32) ~[main/:na]
    at demo.springai.springaidemo.ReviewAnalyzerController.analyze(ReviewAnalyzerController.kt:24) ~[main/:na]
    at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103) ~[na:na]
    at java.base/java.lang.reflect.Method.invoke(Method.java:580) ~[na:na]
    at kotlin.reflect.jvm.internal.calls.CallerImpl$Method.callMethod(CallerImpl.kt:97) ~[kotlin-reflect-1.9.22.jar:1.9.22-release-704]
    at kotlin.reflect.jvm.internal.calls.CallerImpl$Method$Instance.call(CallerImpl.kt:113) ~[kotlin-reflect-1.9.22.jar:1.9.22-release-704]
    at kotlin.reflect.jvm.internal.KCallableImpl.callDefaultMethod$kotlin_reflection(KCallableImpl.kt:207) ~[kotlin-reflect-1.9.22.jar:1.9.22-release-704]
    at kotlin.reflect.jvm.internal.KCallableImpl.callBy(KCallableImpl.kt:112) ~[kotlin-reflect-1.9.22.jar:1.9.22-release-704]
    at org.springframework.web.method.support.InvocableHandlerMethod$KotlinDelegate.invokeFunction(InvocableHandlerMethod.java:338) ~[spring-web-6.1.3.jar:6.1.3]
    at org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:258) ~[spring-web-6.1.3.jar:6.1.3]
    at org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:189) ~[spring-web-6.1.3.jar:6.1.3]
    at org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:118) ~[spring-webmvc-6.1.3.jar:6.1.3]
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:917) ~[spring-webmvc-6.1.3.jar:6.1.3]
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:829) ~[spring-webmvc-6.1.3.jar:6.1.3]
    at org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87) ~[spring-webmvc-6.1.3.jar:6.1.3]
    at org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:1089) ~[spring-webmvc-6.1.3.jar:6.1.3]
    at org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:979) ~[spring-webmvc-6.1.3.jar:6.1.3]
    at org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:1014) ~[spring-webmvc-6.1.3.jar:6.1.3]
    at org.springframework.web.servlet.FrameworkServlet.doGet(FrameworkServlet.java:903) ~[spring-webmvc-6.1.3.jar:6.1.3]
    at jakarta.servlet.http.HttpServlet.service(HttpServlet.java:564) ~[tomcat-embed-core-10.1.18.jar:6.0]
    at org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:885) ~[spring-webmvc-6.1.3.jar:6.1.3]
    at jakarta.servlet.http.HttpServlet.service(HttpServlet.java:658) ~[tomcat-embed-core-10.1.18.jar:6.0]
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:205) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:51) ~[tomcat-embed-websocket-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:100) ~[spring-web-6.1.3.jar:6.1.3]
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116) ~[spring-web-6.1.3.jar:6.1.3]
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.springframework.web.filter.FormContentFilter.doFilterInternal(FormContentFilter.java:93) ~[spring-web-6.1.3.jar:6.1.3]
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116) ~[spring-web-6.1.3.jar:6.1.3]
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:201) ~[spring-web-6.1.3.jar:6.1.3]
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:116) ~[spring-web-6.1.3.jar:6.1.3]
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:174) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:149) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:167) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:90) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:482) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:115) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:93) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:74) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:340) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:391) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:63) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:896) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1744) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:52) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.tomcat.util.threads.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1191) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.tomcat.util.threads.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:659) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61) ~[tomcat-embed-core-10.1.18.jar:10.1.18]
    at java.base/java.lang.Thread.run(Thread.java:1583) ~[na:na]
tzolov commented 8 months ago

@raphaeldelio , perhaps you have an outdated dependencies?
Could you please try to re-build with -U options?

I've successfully run the OpenAiChatClient2IT test below against the latest 0.8.0 SNAPSHOT and it produces expected JSON output like:

Response content: {
  "planets": [   "Mercury",    "Venus",    "Earth",    "Mars",    "Jupiter",    "Saturn",    "Uranus",    "Neptune"  ]
}

Mind that according to the OpenAI Spec the response_format must be one of text or json_object and your user message MUST mention the JSON world.

Test:

package org.springframework.ai.openai.chat;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(classes = OpenAiChatClient2IT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiChatClient2IT {

    private final Logger logger = LoggerFactory.getLogger(getClass());

    @Autowired
    private OpenAiChatClient openAiChatClient;

    @Test
    void responseFormatTest() throws JsonMappingException, JsonProcessingException {

        // 400 - ResponseError[error=Error[message='json' is not one of ['json_object', 'text'] -
        // 'response_format.type', type=invalid_request_error, param=null, code=null]]

        // 400 - ResponseError[error=Error[message='messages' must contain the word 'json' in some form, to use
        // 'response_format' of type 'json_object'., type=invalid_request_error, param=messages, code=null]]

        Prompt prompt = new Prompt("List 8 planets. Use JSON response", OpenAiChatOptions.builder()
                .withResponseFormat(new ChatCompletionRequest.ResponseFormat("json_object"))
                .build());

        ChatResponse response = this.openAiChatClient.call(prompt);

        assertThat(response).isNotNull();

        String content = response.getResult().getOutput().getContent();

        logger.info("Response content: {}", content);

        assertThat(isValidJson(content)).isTrue();
    }

    private static ObjectMapper MAPPER = new ObjectMapper()
            .enable(DeserializationFeature.FAIL_ON_TRAILING_TOKENS);

    public static boolean isValidJson(String json) {
        try {
            MAPPER.readTree(json);
        }
        catch (JacksonException e) {
            return false;
        }
        return true;
    }

    @SpringBootConfiguration
    static class Config {
        @Bean
        public OpenAiApi chatCompletionApi() {
            return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
        }

        @Bean
        public OpenAiChatClient openAiClient(OpenAiApi openAiApi) {
            return new OpenAiChatClient(openAiApi);
        }
    }
}
markpollack commented 8 months ago

closing as we were not able to reproduce. Please reopen if necessary. Thanks.