spring-projects / spring-ai

An Application Framework for AI Engineering
https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/index.html
Apache License 2.0
2.89k stars 721 forks source link

[OpenAiChatModel] Inability to Pass `tool_calls` or `tool_call_id` and Discrepancy in Message Roles #877

Closed RikJux closed 4 days ago

RikJux commented 3 months ago

Description

There are two main issues with the current implementation of the createRequest method (see additional context below) in the OpenAiChatModel class:

  1. Inability to Pass tool_calls or tool_call_id: The ChatCompletionMessage constructor called in the createRequest method does not allow passing tool_calls or tool_call_id. These are always initialized as null.

    public ChatCompletionMessage(Object content, Role role) {
       this(content, role, null, null, null);
    }
  2. Discrepancy in Message Roles: There is a discrepancy between the MessageType enum in AbstractMessage and the Role enum in OpenAI messages.

    MessageType in AbstractMessage:

    public enum MessageType {
       USER("user"),
       ASSISTANT("assistant"),
       SYSTEM("system"),
       FUNCTION("function");
    }

    Role in OpenAI messages:

    public enum Role {
       @JsonProperty("system")
       SYSTEM,
       @JsonProperty("user")
       USER,
       @JsonProperty("assistant")
       ASSISTANT,
       @JsonProperty("tool")
       TOOL
    }

    This results in in an java.lang.IllegalArgumentException

Steps to Reproduce

Passing a FunctionMessage in a prompt results in an error:

@Test
public void testPromptWithTool() throws JsonProcessingException {
    Prompt prompt = new Prompt(new FunctionMessage(""));
    openAiChatModel.call(prompt);
}

Error:

java.lang.IllegalArgumentException: No enum constant org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role.FUNCTION

    at java.base/java.lang.Enum.valueOf(Enum.java:293)
    at org.springframework.ai.openai.api.OpenAiApi$ChatCompletionMessage$Role.valueOf(OpenAiApi.java:488)
    at org.springframework.ai.openai.OpenAiChatModel.lambda$createRequest$9(OpenAiChatModel.java:267)
    at java.base/java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:197)
    at java.base/java.util.Collections$2.tryAdvance(Collections.java:5073)
    at java.base/java.util.Collections$2.forEachRemaining(Collections.java:5081)
    at java.base/java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:509)
    at java.base/java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:499)
    at java.base/java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:575)
    at java.base/java.util.stream.AbstractPipeline.evaluateToArrayNode(AbstractPipeline.java:260)
    at java.base/java.util.stream.ReferencePipeline.toArray(ReferencePipeline.java:616)
    at java.base/java.util.stream.ReferencePipeline.toArray(ReferencePipeline.java:622)
    at java.base/java.util.stream.ReferencePipeline.toList(ReferencePipeline.java:627)
    at org.springframework.ai.openai.OpenAiChatModel.createRequest(OpenAiChatModel.java:268)
    at org.springframework.ai.openai.OpenAiChatModel.call(OpenAiChatModel.java:140)
    at it.ai.foundation.ServiceChatModelTest.testPromptWithTool(ServiceChatModelTest.java:143)
    at java.base/java.lang.reflect.Method.invoke(Method.java:580)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)

Expected Behavior

The current implementation of the createRequest method does not cover all possible prompts that a user could define. The expected behavior should be as follows:

  1. Support for Tool Calls:

    • The createRequest method should allow passing tool_calls and tool_call_id to the ChatCompletionMessage constructor. This will enable the proper handling of tool calls within the chat prompt.
  2. Consistent Message Roles:

    • The MessageType enum in AbstractMessage should be aligned with the Role enum in OpenAI messages. Alternatively, a mapping should be provided to ensure that all message types are correctly translated and no errors occur due to missing enum constants.

Additional Context

Here is the current implementation of the createRequest method:

ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
    Set<String> functionsForThisRequest = new HashSet<>();
    List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m -> {
        // Add text content.
        List<MediaContent> contents = new ArrayList<>(List.of(new MediaContent(m.getContent())));
        if (!CollectionUtils.isEmpty(m.getMedia())) {
            // Add media content.
            contents.addAll(m.getMedia()
                .stream()
                .map(media -> new MediaContent(
                    new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))))
                .toList());
        }
        return new ChatCompletionMessage(contents, ChatCompletionMessage.Role.valueOf(m.getMessageType().name()));
    }).toList();
    // rest of the code...
}
ThomasVitale commented 1 month ago

Thanks for reporting this issue. The function calling APIs have been recently refactored and improved. The changes should have handled both points you reported.

  1. The ChatCompletionMessage in the OpenAiApi class accepts both toolCallId and toolCalls. See: https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java#L519
  2. The MessageType enum uses tool as the role name instead of function. See: https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/MessageType.java#L51

Can you confirm this issue is fixed?

csterwa commented 1 week ago

@RikJux please let us know if this is resolved. Will close in 7 days if no response. Thank you.

csterwa commented 4 days ago

Closing for now.