getmetal / motorhead

🧠 Motorhead is a memory and information retrieval server for LLMs.
https://getmetal.io
Apache License 2.0
843 stars 79 forks source link

Sweep: Connection to openai via OPENAI_API_BASE doesn't seem to work #108

Closed kevin-littlejohn closed 6 months ago

kevin-littlejohn commented 6 months ago

Details

I believe there's a logic error in models.rs, where it creates a connection - if OPENAI_API_BASE is provided, it does not properly execute the connection to the specified server, instead it still creates a default openai connection. This means you can't use a local or self-hosted model. I think the logic is wrong in that it only reaches that block if the AZURE env variables are set, but I'm not confident in rust to say for sure.

Checklist - [X] Modify `src/models.rs` βœ“ https://github.com/getmetal/motorhead/commit/abf0e546bda733eaa3948c35fd99dd7dae847174 [Edit](https://github.com/getmetal/motorhead/edit/sweep/connection_to_openai_via_openai_api_base/src/models.rs#L28-L76) - [X] Running GitHub Actions for `src/models.rs` βœ“ [Edit](https://github.com/getmetal/motorhead/edit/sweep/connection_to_openai_via_openai_api_base/src/models.rs#L28-L76)
sweep-ai[bot] commented 6 months ago

πŸš€ Here's the PR! #109

See Sweep's progress at the progress dashboard!
⚑ Sweep Basic Tier: I'm using GPT-4. You have 5 GPT-4 tickets left for the month and 3 for the day. (tracking ID: 5e0552f48c)

For more GPT-4 tickets, visit our payment portal. For a one week free trial, try Sweep Pro (unlimited GPT-4 tickets).

[!TIP] I can email you next time I complete a pull request if you set up your email here!


Actions (click)

GitHub Actionsβœ“

Here are the GitHub Actions logs prior to making any changes:

Sandbox logs for 47b6dad
Checking src/models.rs for syntax errors... βœ… src/models.rs has no syntax errors! 1/1 βœ“
Checking src/models.rs for syntax errors...
βœ… src/models.rs has no syntax errors!

Sandbox passed on the latest main, so sandbox checks will be enabled for this issue.


Step 1: πŸ”Ž Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I think are relevant in decreasing order of relevance (click to expand). If some file is missing from here, you can mention the path in the ticket description. https://github.com/getmetal/motorhead/blob/47b6dad303f88f0f196f3b416b6a669fd725e2bd/src/models.rs#L20-L76

Step 2: ⌨️ Coding


async fn create(&self) -> Result<AnyOpenAIClient, MotorheadError> {
    let openai_api_base = env::var("OPENAI_API_BASE").ok();
    let azure_api_key = env::var("AZURE_API_KEY").ok();
    let azure_deployment_id = env::var("AZURE_DEPLOYMENT_ID").ok();
    let azure_deployment_id_ada = env::var("AZURE_DEPLOYMENT_ID_ADA").ok();
    let azure_api_base = env::var("AZURE_API_BASE").ok();

    let openai_client = if let Some(api_base) = openai_api_base {
        let embedding_config = OpenAIConfig::default().with_api_base(&api_base);
        let completion_config = OpenAIConfig::default().with_api_base(&api_base);

        AnyOpenAIClient::OpenAI {
            embedding_client: Client::with_config(embedding_config),
            completion_client: Client::with_config(completion_config),
        }
    } else if azure_api_key.is_some() && azure_deployment_id.is_some() && azure_deployment_id_ada.is_some() && azure_api_base.is_some() {
        let config = AzureConfig::new()
            .with_api_base(azure_api_base.as_ref().unwrap())
            .with_api_key(azure_api_key.as_ref().unwrap())
            .with_deployment_id(azure_deployment_id.unwrap())
            .with_api_version("2023-05-15");

        let config_ada = AzureConfig::new()
            .with_api_base(azure_api_base.as_ref().unwrap())
            .with_api_key(azure_api_key.as_ref().unwrap())
            .with_deployment_id(azure_deployment_id_ada.unwrap())
            .with_api_version("2023-05-15");

        AnyOpenAIClient::Azure {
            embedding_client: Client::with_config(config_ada),
            completion_client: Client::with_config(config),
        }
    } else {
        AnyOpenAIClient::OpenAI {
            embedding_client: Client::new(),
            completion_client: Client::new(),
        }
    };

    Ok(openai_client)
}
```<br/>β€’ This change ensures that the `OPENAI_API_BASE` is prioritized for configuration, allowing users to connect to a custom OpenAI API base if provided.

<pre>--- 
+++ 
@@ -26,32 +26,46 @@
     type Error = MotorheadError;

     async fn create(&self) -> Result<AnyOpenAIClient, MotorheadError> {
-        let openai_client = match (
-            env::var("AZURE_API_KEY"),
-            env::var("AZURE_DEPLOYMENT_ID"),
-            env::var("AZURE_DEPLOYMENT_ID_ADA"),
-            env::var("AZURE_API_BASE"),
-        ) {
-            (
-                Ok(azure_api_key),
-                Ok(azure_deployment_id),
-                Ok(azure_deployment_id_ada),
-                Ok(azure_api_base),
-            ) => {
-                let config = AzureConfig::new()
-                    .with_api_base(&azure_api_base)
-                    .with_api_key(&azure_api_key)
-                    .with_deployment_id(azure_deployment_id)
-                    .with_api_version("2023-05-15");
-
-                let config_ada = AzureConfig::new()
-                    .with_api_base(&azure_api_base)
-                    .with_api_key(&azure_api_key)
-                    .with_deployment_id(azure_deployment_id_ada)
-                    .with_api_version("2023-05-15");
-
-                AnyOpenAIClient::Azure {
-                    embedding_client: Client::with_config(config_ada),
+        let openai_api_base = env::var("OPENAI_API_BASE").ok();
+        let azure_api_key = env::var("AZURE_API_KEY").ok();
+        let azure_deployment_id = env::var("AZURE_DEPLOYMENT_ID").ok();
+        let azure_deployment_id_ada = env::var("AZURE_DEPLOYMENT_ID_ADA").ok();
+        let azure_api_base = env::var("AZURE_API_BASE").ok();
+
+        let openai_client = if let Some(api_base) = openai_api_base {
+            let embedding_config = OpenAIConfig::default().with_api_base(&api_base);
+            let completion_config = OpenAIConfig::default().with_api_base(&api_base);
+
+            AnyOpenAIClient::OpenAI {
+                embedding_client: Client::with_config(embedding_config),
+                completion_client: Client::with_config(completion_config),
+            }
+        } else if azure_api_key.is_some() && azure_deployment_id.is_some() && azure_deployment_id_ada.is_some() && azure_api_base.is_some() {
+            let config = AzureConfig::new()
+                .with_api_base(azure_api_base.as_ref().unwrap())
+                .with_api_key(azure_api_key.as_ref().unwrap())
+                .with_deployment_id(azure_deployment_id.unwrap())
+                .with_api_version("2023-05-15");
+
+            let config_ada = AzureConfig::new()
+                .with_api_base(azure_api_base.as_ref().unwrap())
+                .with_api_key(azure_api_key.as_ref().unwrap())
+                .with_deployment_id(azure_deployment_id_ada.unwrap())
+                .with_api_version("2023-05-15");
+
+            AnyOpenAIClient::Azure {
+                embedding_client: Client::with_config(config_ada),
+                completion_client: Client::with_config(config),
+            }
+        } else {
+            AnyOpenAIClient::OpenAI {
+                embedding_client: Client::new(),
+                completion_client: Client::new(),
+            }
+        };
+
+        Ok(openai_client)
+    }
                     completion_client: Client::with_config(config),
                 }
             }
</pre>
</blockquote>

- [X] Running GitHub Actions for `src/models.rs` βœ“  [Edit](https://github.com/getmetal/motorhead/edit/sweep/connection_to_openai_via_openai_api_base/src/models.rs#L28-L76)
<blockquote>Check src/models.rs with contents:

Ran GitHub Actions for <a href="https://github.com/getmetal/motorhead/commit/abf0e546bda733eaa3948c35fd99dd7dae847174">abf0e546bda733eaa3948c35fd99dd7dae847174</a>:

</blockquote>

---
## Step 3: πŸ” Code Review
I have finished reviewing the code for completeness. I did not find errors for  [`sweep/connection_to_openai_via_openai_api_base`](https://github.com/getmetal/motorhead/commits/sweep/connection_to_openai_via_openai_api_base).

---

<details>
<summary><b>πŸŽ‰ Latest improvements to Sweep:</b></summary>
<ul>
<li>New <a href="https://progress.sweep.dev">dashboard</a> launched for real-time tracking of Sweep issues, covering all stages from search to coding.</li>
<li>Integration of OpenAI's latest Assistant API for more efficient and reliable code planning and editing, improving speed by 3x.</li>
<li>Use the <a href="https://marketplace.visualstudio.com/items?itemName=GitHub.vscode-pull-request-github">GitHub issues extension</a> for creating Sweep issues directly from your editor.</li>
</ul>
</details>

πŸ’‘ To recreate the pull request edit the issue title or description. To tweak the pull request, leave a comment on the pull request.<sup>Something wrong? [Let us know](https://discord.gg/sweep).</sup>

*This is an automated message generated by [Sweep AI](https://sweep.dev).*
kevin-littlejohn commented 6 months ago

No, I think I got the issue wrong - the logic here is correct, but I'm still seeing motorhead try and connect to the main openai servers instead of to my base url. This PR is not correct, but I'm still debugging why it doesn't honour OPENAI_API_BASE properly...