TheoKanning / openai-java

OpenAI Api Client in Java
MIT License
4.68k stars 1.16k forks source link

funcation call + stream进行调用时返回的ChatFunctionCall对应中arguments丢失 #505

Open I-am-DJ opened 1 month ago

I-am-DJ commented 1 month ago

这个是我的示例代码:

     public static void main(String[] args) throws UnknownHostException, InterruptedException {
        try {
            ObjectMapper mapper = defaultObjectMapper();
            OkHttpClient client = defaultClient("", Duration.of(10000L,ChronoUnit.SECONDS))
                    .newBuilder()
                    .build();
            Retrofit retrofit = defaultRetrofit(client, mapper);
            Class<Retrofit> clazz = Retrofit.class;
            Field baseUrl = clazz.getDeclaredField("baseUrl");
            baseUrl.setAccessible(true);
            baseUrl.set(retrofit, HttpUrl.get(BASE_URL));
            OpenAiApi api = retrofit.create(OpenAiApi.class);
            OpenAiService service = new OpenAiService(api);
            List<ChatMessage> messages = Lists.newArrayList();
            messages.add(new ChatMessage("system", "Please use the functions provided below to determine what function needs to be called for the user's problem. " +
                    "If the necessary parameters are missing when calling the function, please return to the user in this format and prompt the user to pass the necessary parameters:\n" +
                    "We also need the following information to complete your request: Required Parameter 1, Required Parameter 2\n" +
                    "Make sure your prompts are accurate, polite, and the directly relevant information is obvious and understandable to users"));
            Scanner scanner = new Scanner(System.in);
            //"Tell me the weather"
            messages.add(new ChatMessage("user", scanner.nextLine()));
            while (true) {
                ChatFunctionDynamic chatFunctionDynamic = getChatFunctionDynamic();
                ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
                        .builder()
                        .model("qwen15-110b.credit-llm")
                        .messages(messages)
                        .n(1)
                        .maxTokens(256)
                        .functions(Lists.newArrayList(chatFunctionDynamic))
                        .functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of("auto"))
                        .build();

                Flowable<ChatCompletionChunk> flowable = service.streamChatCompletion(chatCompletionRequest);
                AtomicBoolean isFirst = new AtomicBoolean(true);
                ChatMessage responseMessage = service.mapStreamToAccumulator(flowable).doOnNext(accumulator -> {
                            if (accumulator.isFunctionCall()) {
                                ChatFunctionCall functionCall = accumulator.getAccumulatedChatFunctionCall();
                                if (isFirst.getAndSet(false)) {
                                    System.out.println("Executing function " + functionCall.getName() + "...");
                                }
                            } else {
                                if (isFirst.getAndSet(false)) {

                                    System.out.print("Response: ");
                                }
                                if (accumulator.getMessageChunk().getContent() != null) {
                                    System.out.print(accumulator.getMessageChunk().getContent());
                                }
                            }
                        })
                        .doOnComplete(System.out::println)
                        .lastElement()
                        .blockingGet()
                        .getAccumulatedMessage();
                messages.add(responseMessage);
                ChatFunctionCall functionCall = responseMessage.getFunctionCall();
                if (functionCall != null) {
                    if (functionCall.getName().equals("get_weather")) {
                        String location = functionCall.getArguments().get("location").asText();
                        String unit = functionCall.getArguments().get("unit").asText();
                        WeatherResponse weather = getWeather(location, unit);
                        ChatMessage weatherMessage = new ChatMessage(ChatMessageRole.FUNCTION.value(), JSON.toJSONString(weather), "get_weather");
                        messages.add(weatherMessage);
                        continue;
                    }
                }
                System.out.print("Next Query: ");

                String nextLine = scanner.nextLine();
                if (nextLine.equalsIgnoreCase("exit")) {
                    System.exit(0);
                }

                messages.add(new ChatMessage(ChatMessageRole.USER.value(), nextLine));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }

    }

    private static WeatherResponse getWeather(String location, String unit) {
        return new WeatherResponse(location, WeatherUnit.valueOf(unit), new Random().nextInt(40), "sunny");
    }

    public static ChatFunctionDynamic getChatFunctionDynamic() {
        return ChatFunctionDynamic.builder()
                .name("get_weather")
                .description("Get the current weather of a location")
                .addProperty(ChatFunctionProperty.builder()
                        .name("location")
                        .type("string")
                        .description("City and state, for example: León, Guanajuato")
                        .build())
                .addProperty(ChatFunctionProperty.builder()
                        .name("unit")
                        .type("string")
                        .description("The temperature unit, can be 'CELSIUS' or 'FAHRENHEIT'")
                        .enumValues(new HashSet<>(Arrays.asList("CELSIUS", "FAHRENHEIT")))
                        .required(true)
                        .build())
                .build();
    }

对应的报错信息,在String location = functionCall.getArguments().get("location").asText();该行报错

java.lang.NullPointerException
    at com.mybank.bkinfocenter.common.recognition.web.Test.main(Test.java:96)

debug代码查看 com.theokanning.openai.service.OpenAiService#mapStreamToAccumulator方法中messageChunk中的arguments类型为objectNode,从而导致asText()方法返回的结果为"" image

请问我可以用什么简单的方法在不修改源代码的情况下来解决这个问题,非常感谢!

I-am-DJ commented 1 month ago

类似的问题:https://github.com/TheoKanning/openai-java/issues/463

I-am-DJ commented 1 month ago

我尝试修改了下com.theokanning.openai.service.OpenAiService#mapStreamToAccumulator和com.theokanning.openai.service.ChatFunctionCallArgumentsSerializerAndDeserializer.Deserializer#deserialize代码:

public static class Deserializer extends JsonDeserializer<JsonNode> {

        private Deserializer() {
        }

        @Override
        public JsonNode deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
            String json = p.getValueAsString();

            if (json == null || p.currentToken() == JsonToken.VALUE_NULL) {
                return null;
            }
            // ADDED
            json = MAPPER.writeValueAsString(json);
           // END ADDED
            try {
                JsonNode node = null;
                try {
                    node = MAPPER.readTree(json);
                } catch (JsonParseException ignored) {
                }
                if (node == null || node.getNodeType() == JsonNodeType.MISSING) {
                    node = MAPPER.readTree(p);
                }
                return node;
            } catch (Exception ex) {
                ex.printStackTrace();
                return null;
            }
        }
    }
     public Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ChatCompletionChunk> flowable) {
        ChatFunctionCall functionCall = new ChatFunctionCall(null, null);
        ChatMessage accumulatedMessage = new ChatMessage(ChatMessageRole.ASSISTANT.value(), null);

        return flowable.map(chunk -> {
            ChatMessage messageChunk = chunk.getChoices().get(0).getMessage();
            ChatFunctionCall chunkFunctionCall = new ChatFunctionCall(null, null);
            if (messageChunk.getFunctionCall() != null) {
                if (messageChunk.getFunctionCall().getName() != null) {
                    String namePart = messageChunk.getFunctionCall().getName();
                    chunkFunctionCall.setName((functionCall.getName() == null ? "" : functionCall.getName()) + namePart);
                }
                if (messageChunk.getFunctionCall().getArguments() != null) {
                    String argumentsPart = messageChunk.getFunctionCall().getArguments().asText();
                    chunkFunctionCall.setArguments(new TextNode((functionCall.getArguments() == null ? "" : functionCall.getArguments().asText()) + argumentsPart));
                }
                accumulatedMessage.setFunctionCall(functionCall);
            } else {
                accumulatedMessage.setContent((accumulatedMessage.getContent() == null ? "" : accumulatedMessage.getContent()) + (messageChunk.getContent() == null ? "" : messageChunk.getContent()));
            }

            if (chunk.getChoices().get(0).getFinishReason() != null) { // last
                if (chunkFunctionCall.getArguments() != null) {
                    functionCall.setName(chunkFunctionCall.getName());
                    functionCall.setArguments(mapper.readTree(chunkFunctionCall.getArguments().asText()));
                    accumulatedMessage.setFunctionCall(functionCall);
                }
            }

            return new ChatMessageAccumulator(messageChunk, accumulatedMessage);
        });

修改mapStreamToAccumulator的主要原因是flow会返回两次带有finishReason的情况,第一次存在functioncall,第二次没有,导致原有代码会进入两次last注释下的代码,原有的readTree函数返回的Node类型为ObjectNode,再次进入后调用asText()函数会导致数据为空字符串,所以采用chunkFunctionCall局部变量的方式fix该问题

Lambdua commented 1 month ago

这个其实根本原因是序列化的问题. 这个库在序列化时对于<",">这个字段序列化有问题,所有会有各种textNode和ObjectNode转换问题. 我fork后的库修复了这个问题. 欢迎使用 openai4j.