64bit / async-openai

Rust library for OpenAI
https://docs.rs/async-openai
MIT License
1.09k stars 161 forks source link

The type of messages in deserialized CreateChatCompletionRequest are all SystemMessage #216

Closed sontallive closed 2 months ago

sontallive commented 4 months ago

I want to deserialize request json to CreateChatCompletionRequest but i found the messages are all System.

code

use async_openai::types::{
    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
    CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let request: CreateChatCompletionRequest = CreateChatCompletionRequestArgs::default()
        .messages([
            ChatCompletionRequestSystemMessageArgs::default()
                .content("your are a calculator")
                .build()?
                .into(),
            ChatCompletionRequestUserMessageArgs::default()
                .content("what is the result of 1+1")
                .build()?
                .into(),
        ])
        .build()?;
    // serialize the request
    let serialized = serde_json::to_string(&request)?;
    println!("{}", serialized);
    // deserialize the request
    let deserialized: CreateChatCompletionRequest = serde_json::from_str(&serialized)?;
    println!("{:?}", deserialized);
    Ok(())
}

result

{"messages":[{"content":"your are a calculator","role":"system"},{"content":"what is the result of 1+1","role":"user"}],"model":""}

CreateChatCompletionRequest { messages: [System(ChatCompletionRequestSystemMessage { content: "your are a calculator", role: System, name: None }), System(ChatCompletionRequestSystemMessage { content: "what is the result of 1+1", role: User, name: None })], model: "", frequency_penalty: None, logit_bias: None, logprobs: None, top_logprobs: None, max_tokens: None, n: None, presence_penalty: None, response_format: None, seed: None, stop: None, stream: None, temperature: None, top_p: None, tools: None, tool_choice: None, user: None, function_call: None, functions: None }
djmango commented 3 months ago

I also have this issue. Using actix_web

djmango commented 3 months ago

Was banging my head on this for a bit, but just pushed a fix on my branch.

thanks coco.codes from the NAMTAO discord!

to solve the parent issue, of them always being System, we implement the macro #[serde(tag = "role", rename_all = "lowercase")] in ChatCompletionRequestMessage

This maps the role key to the appropriate enum under ChatCompletionRequestMessage. however what tripped me up was that in doing so, the role key is consumed, meaning that since the child ChatCompletionRequestUserMessage spits out an error during deserialization because it no longer can see the role key.

I solved this by deleting the role in the child and implementing it in the parent as a method that runs a match on the type of enum (not even really needed, turns out the role is not actually used anywhere in the lib nor my codebase)

I've verified this works in prod across a bunch of different model providers, im happy with this solution, though i dont know if it will be merged. you're free to merge from my fork if you like

sontallive commented 3 months ago

Was banging my head on this for a bit, but just pushed a fix on my branch.

thanks coco.codes from the NAMTAO discord!

to solve the parent issue, of them always being System, we implement the macro #[serde(tag = "role", rename_all = "lowercase")] in ChatCompletionRequestMessage

This maps the role key to the appropriate enum under ChatCompletionRequestMessage. however what tripped me up was that in doing so, the role key is consumed, meaning that since the child ChatCompletionRequestUserMessage spits out an error during deserialization because it no longer can see the role key.

I solved this by deleting the role in the child and implementing it in the parent as a method that runs a match on the type of enum (not even really needed, turns out the role is not actually used anywhere in the lib nor my codebase)

I've verified this works in prod across a bunch of different model providers, im happy with this solution, though i dont know if it will be merged. you're free to merge from my fork if you like

Thank you, I will have a try.

digitalscyther commented 3 months ago

i wrote custom wrapper for ser and deser

use async_openai::types::{ChatCompletionRequestMessage};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::ser::SerializeStruct;
use serde_json::Value;

#[derive(Debug)]
pub struct Message(ChatCompletionRequestMessage);

impl Message {
    pub fn from_original(enum_val: ChatCompletionRequestMessage) -> Self {
        Message(enum_val)
    }

    pub fn into_original(self) -> ChatCompletionRequestMessage {
        self.0
    }
}

impl Serialize for Message {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let mut state = serializer.serialize_struct("Message", 2)?;
        match &self.0 {
            ChatCompletionRequestMessage::System(msg) => {
                state.serialize_field("type", "system")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
            ChatCompletionRequestMessage::User(msg) => {
                state.serialize_field("type", "user")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
            ChatCompletionRequestMessage::Assistant(msg) => {
                state.serialize_field("type", "assistant")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
            ChatCompletionRequestMessage::Tool(msg) => {
                state.serialize_field("type", "tool")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
            ChatCompletionRequestMessage::Function(msg) => {
                state.serialize_field("type", "function")?;
                state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
            }
        }

        state.end()
    }
}

impl<'de> Deserialize<'de> for Message {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        let value: Value = Deserialize::deserialize(deserializer)?;

        let msg_type = value.get("type").and_then(Value::as_str).ok_or_else(|| {
            serde::de::Error::custom("Missing or invalid `type` field")
        })?;

        match msg_type {
            "system" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestSystemMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::System(msg)))
            }
            "user" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestUserMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::User(msg)))
            }
            "assistant" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestAssistantMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::Assistant(msg)))
            }
            "tool" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestToolMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::Tool(msg)))
            }
            "function" => {
                let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestFunctionMessage").unwrap();
                Ok(Message(ChatCompletionRequestMessage::Function(msg)))
            }
            _ => Err(serde::de::Error::unknown_variant(msg_type, &["system", "user", "assistant", "tool", "function"])),
        }
    }
}
64bit commented 2 months ago

Instead of complex ser-de implementations, types have be udpated for proper serialization and deserialization in v0.23.0

Thank you @sontallive for contributing the test too - its included as part of tests in https://github.com/64bit/async-openai/blob/main/async-openai/tests/ser_de.rs