64bit / async-openai

Rust library for OpenAI
https://docs.rs/async-openai
MIT License
1.11k stars 165 forks source link

Changing base url - Usage with open source LLM - Invalid status code: 404 Not Found #173

Closed louis030195 closed 8 months ago

louis030195 commented 8 months ago

Hey I'm trying to use async-openai with axum and open source LLMs through perplexity.ai in my test.

Basically my endpoint would route the request to OpenAI API or an OpenAI API like API changing the URL of the API based on the given model like gpt-4 would go to openai and mistralai/whatever would go to MODEL_URL

Getting a 404. Not sure if I'm doing something wrong or this use case is implemented?

These are my env var

MODEL_API_KEY="pplx-..."
MODEL_URL="https://api.perplexity.ai/chat/completions"

My code:

use async_openai::Client;
use async_openai::{config::OpenAIConfig, types::CreateChatCompletionRequest};
use axum::{
    extract::{Extension, Json, Path, State},
    http::StatusCode,
    response::IntoResponse,
    response::Json as JsonResponse,
};

use async_stream::try_stream;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures::Stream;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::error::Error;
use std::io::{stdout, Write};
use tokio::sync::broadcast::Receiver;
use tokio_stream::wrappers::BroadcastStream;
use url::Url;

fn extract_base_url(model_url: &str) -> Result<String, url::ParseError> {
    let url = Url::parse(model_url)?;
    let base_url = url.join("/")?;
    Ok(base_url.as_str().to_string())
}
// copied from https://github.com/tokio-rs/axum/discussions/1670

pub async fn stream_chat_handler(
    Json(request): Json<CreateChatCompletionRequest>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, String)> {
    let model_name = &request.model;
    let model_url = std::env::var("MODEL_URL")
        .unwrap_or_else(|_| String::from("http://localhost:8000/v1/chat/completions"));
    let base_url = extract_base_url(&model_url).unwrap_or_else(|_| model_url);
    let (api_key, base_url) = if model_name.contains("/") {
        // Open Source model
        (std::env::var("MODEL_API_KEY").unwrap_or_default(), base_url)
    } else {
        // OpenAI model
        (
            std::env::var("OPENAI_API_KEY").unwrap_or_default(),
            String::from("https://api.openai.com"),
        )
    };
    let client = Client::with_config(
        OpenAIConfig::new()
            .with_api_key(&api_key)
            .with_api_base(&base_url),
    );

    let mut stream = client
        .chat()
        .create_stream(request)
        .await
        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;

    let sse_stream = try_stream! {
        while let Some(result) = stream.next().await {
            match result {
                Ok(response) => {
                    for chat_choice in response.choices.iter() {
                        if let Some(ref content) = chat_choice.delta.content {
                            yield Event::default().data(content.clone());
                        }
                    }
                }
                Err(err) => {
                    println!("Error: {}", err);
                    tracing::error!("Error: {}", err);
                }
            }
        }
    };

    Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::body::Body;
    use axum::http::{self, Request};
    use axum::response::Response;
    use axum::routing::post;
    use axum::Router;
    use dotenv::dotenv;
    use serde_json::json;
    use std::convert::Infallible;
    use tower::{Service, ServiceExt};
    use tower_http::trace::TraceLayer;

    fn app() -> Router {
        Router::new()
            .route("/chat/completions", post(stream_chat_handler))
            .layer(TraceLayer::new_for_http())
    }

    #[tokio::test]
    async fn test_stream_chat_handler() {
        dotenv().ok();
        let app = app();

        let chat_input = json!({
            "model": "mistralai/mixtral-8x7b-instruct",
            "messages": [
                {
                    "role": "system",
                    "content": "You are a helpful assistant."
                },
                {
                    "role": "user",
                    "content": "Hello!"
                }
            ]
        });

        let request = Request::builder()
            .method(http::Method::POST)
            .uri("/chat/completions")
            .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
            .body(Body::from(json!(chat_input).to_string()))
            .unwrap();

        let response = app.clone().oneshot(request).await.unwrap();

        assert_eq!(
            response.status(),
            StatusCode::OK,
            "response: {:?}",
            hyper::body::to_bytes(response.into_body()).await.unwrap()
        );

        let response = hyper::body::to_bytes(response.into_body()).await.unwrap();
        println!("response: {:?}", response);
    }
}

Invalid status code: 404 Not Found

Any help appreciated 🙏

louis030195 commented 8 months ago

400 with this:

#[tokio::test]
async fn test_stream() {
    dotenv().ok();

    let messages = match ChatCompletionRequestUserMessageArgs::default()
        .content("Write a marketing blog praising and introducing Rust library async-openai")
        .build()
    {
        Ok(msg) => msg.into(),
        Err(e) => {
            println!("Error: {}", e);
            assert!(false);
            return;
        }
    };
    let client = Client::with_config(
        OpenAIConfig::new()
            .with_api_key(&std::env::var("MODEL_API_KEY").unwrap_or_default())
            .with_api_base("https://api.perplexity.ai"),
    );
    let request = match CreateChatCompletionRequestArgs::default()
        .model("mistralai/mixtral-8x7b-instruct")
        .max_tokens(512u16)
        .messages([messages])
        .build()
    {
        Ok(req) => req,
        Err(e) => {
            println!("Error: {}", e);
            assert!(false);
            return;
        }
    };

    let stream_result = client.chat().create_stream(request).await;
    let mut stream = match stream_result {
        Ok(s) => s,
        Err(e) => {
            println!("Error: {}", e);
            assert!(false);
            return;
        }
    };

    let mut lock = stdout().lock();
    while let Some(result) = stream.next().await {
        match result {
            Ok(response) => {
                response.choices.iter().for_each(|chat_choice| {
                    if let Some(ref content) = chat_choice.delta.content {
                        write!(lock, "{}", content).unwrap();
                    }
                });
            }
            Err(err) => {
                writeln!(lock, "error: {err}").unwrap();
            }
        }
        match stdout().flush() {
            Ok(_) => (),
            Err(e) => {
                println!("Error: {}", e);
                assert!(false);
                return;
            }
        }
    }
}
louis030195 commented 8 months ago

interesting, using mistral api, different results:

#[tokio::test]
async fn test_stream() {
    dotenv().ok();

    let messages = match ChatCompletionRequestUserMessageArgs::default()
        .content("Write a marketing blog praising and introducing Rust library async-openai")
        .build()
    {
        Ok(msg) => msg.into(),
        Err(e) => {
            println!("Error: {}", e);
            assert!(false);
            return;
        }
    };
    let client = Client::with_config(
        OpenAIConfig::new()
            .with_api_key(&std::env::var("MODEL_API_KEY").unwrap_or_default())
            .with_api_base("https://api.mistral.ai/v1"),
    );
    let request = match CreateChatCompletionRequestArgs::default()
        // .model("mistralai/mixtral-8x7b-instruct")
        .model("mistral-tiny")
        .max_tokens(512u16)
        .messages([messages])
        .build()
    {
        Ok(req) => req,
        Err(e) => {
            println!("Error: {}", e);
            assert!(false);
            return;
        }
    };

    let stream_result = client.chat().create_stream(request).await;
    let mut stream = match stream_result {
        Ok(s) => s,
        Err(e) => {
            println!("Error: {}", e);
            assert!(false);
            return;
        }
    };

    let mut lock = stdout().lock();
    while let Some(result) = stream.next().await {
        match result {
            Ok(response) => {
                response.choices.iter().for_each(|chat_choice| {
                    if let Some(ref content) = chat_choice.delta.content {
                        write!(lock, "{}", content).unwrap();
                    }
                });
            }
            Err(err) => {
                println!("Error: {}", err);
                // jsonify error 
                let err = json!({
                    "error": err.to_string()
                });
                println!("error: {}", err);
                writeln!(lock, "error: {err}").unwrap();
            }
        }
        match stdout().flush() {
            Ok(_) => (),
            Err(e) => {
                println!("Error: {}", e);
                assert!(false);
                return;
            }
        }
    }
}

running 1 test Error: failed to deserialize api response: missing field created at line 1 column 154 error: {"error":"failed to deserialize api response: missing field created at line 1 column 154"} error: {"error":"failed to deserialize api response: missing field created at line 1 column 154"} Title: Unleashing Creativity and Productivity: An Introduction to async-openai, the Game-Changing Rust Library for Interacting with OpenAI

As technology continues to evolve at an unprecedented pace, developers are constantly on the lookout for tools that can help them build applications faster, more efficiently, and with greater creativity. One such tool that has been generating buzz in the Rust community is async-openai, an innovative library for interacting with OpenAI's powerful language models. In this blog post, we'll explore why async-openai is a must-have addition to any Rust developer's toolkit and how it can help you build AI-powered applications with ease.

First, let's talk about OpenAI. OpenAI is a leading artificial intelligence research laboratory, and its language models, such as the popular ChatGPT, have captured the imagination of people around the world with their ability to generate human-like text based on given prompts. Interacting with OpenAI's models, however, can be a complex and time-consuming process, especially for those who want to build applications that can leverage these models in real-time. This is where async-openai comes in.

async-openai is an asynchronous Rust library for interacting with OpenAI's models. Asynchronous programming is a programming paradigm that allows multiple tasks to be executed concurrently, which is essential for building high-performance applications. By using async-openai, Rust developers can easily integrate OpenAI's models into their applications, enabling real-time, responsive interactions with the models.

One of the key benefits of using async-openai is its simplicity. The library provides a straightforward and intuitive API for sending requests to OpenAI's models

jacksongoode commented 6 months ago

I think this isn't out of scope anymore since #191 depends on it.

oleander commented 4 months ago

This seems to work based on @louis030195's code using llama3:latest:

use std::io::{stdout, Write};

use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
use async_openai::config::OpenAIConfig;
use llm_chain::output::StreamExt;
use async_openai::Client;
use serde_json::json;

#[tokio::main]
async fn main() {
  dotenv::dotenv().ok();

  let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set");
  let api_host = "http://localhost:11434/v1";
  let api_model = "llama3:latest";
  let max_tokens = 512u16;

  let messages = match ChatCompletionRequestUserMessageArgs::default()
    .content("Write a marketing blog praising and introducing Rust library async-openai")
    .build()
  {
    Ok(msg) => msg.into(),
    Err(e) => {
      println!("Error: {}", e);
      assert!(false);
      return;
    }
  };
  let client = Client::with_config(
    OpenAIConfig::new()
      .with_api_key(&api_key)
      .with_api_base(api_host)
  );
  let request = match CreateChatCompletionRequestArgs::default()
    .model(api_model)
    .max_tokens(max_tokens)
    .messages([messages])
    .build()
  {
    Ok(req) => req,
    Err(e) => {
      println!("Error: {}", e);
      assert!(false);
      return;
    }
  };

  let stream_result = client.chat().create_stream(request).await;
  let mut stream = match stream_result {
    Ok(s) => s,
    Err(e) => {
      println!("Error: {}", e);
      assert!(false);
      return;
    }
  };

  let mut lock = stdout().lock();

  while let Some(result) = stream.next().await {
    match result {
      Ok(response) => {
        response.choices.iter().for_each(|chat_choice| {
          if let Some(ref content) = chat_choice.delta.content {
            write!(lock, "{}", content).unwrap();
          }
        });
      }
      Err(err) => {
        println!("Error: {}", err);
        // jsonify error
        let err = json!({
            "error": err.to_string()
        });
        println!("error: {}", err);
        writeln!(lock, "error: {err}").unwrap();
      }
    }
    match stdout().flush() {
      Ok(_) => (),
      Err(e) => {
        println!("Error: {}", e);
        assert!(false);
        return;
      }
    }
  }
}
jondot commented 3 months ago

for anyone bumping into this, you only need to drop the last / in the URL.