openai / openai-dotnet

The official .NET library for the OpenAI API
https://www.nuget.org/packages/OpenAI
MIT License
693 stars 59 forks source link

Code Generation for Tool Definitions/Dispatching #63

Open daeken opened 1 week ago

daeken commented 1 week ago

Prior to the release of this library, I built my own (much more limited) library to access the OpenAI API, but I have a code generator package that automatically builds the JSON for tools, as well as dispatching code to deserialize incoming parameters and serialize output. Is that something that would be appreciated for this? It'd be relatively easy to port over and I'd be happy to put in a PR, but it would have to be a separate NuGet package I believe, due to the way code generators are referenced.

Below is an example usage of the code generator I have in place. I think this is a higher level of abstraction than would be appropriate for this library, but the core tool code generation could be helpful.

public interface TestMixin {
    [Tool("Ciphers/deciphers a string")]
    async Task<string> Cipher(
        [Desc("string", "The string to encrypt or decrypt")] string value
    ) {
        Console.WriteLine($"Ciphering '{value}'");
        return new(value.ToCharArray().Select(s => (char)(s >= 97 && s <= 122 ? s + 13 > 122 ? s - 13 : s + 13 : s >= 65 && s <= 90 ? s + 13 > 90 ? s - 13 : s + 13 : s)).ToArray());
    }
}

partial class GptTest : GPT, TestMixin {
    //public override string Model => "gpt-4-turbo";
    public override string SystemMessage => "You are part of an automated test suite for GPT integration. Please utilize the tools accessible to you.";

    [Tool("reverse", "Reverses a given string")]
    string Reverse(
        [Desc("The string to reverse")] string value
    ) {
        Console.WriteLine($"Reversing '{value}'");
        return string.Join("", value.Reverse());
    }

    [Tool("register_animals", "Registers a set of animals with the system")]
    bool RegisterAnimals(
        [Desc("The animals to register")] IReadOnlyList<string> animals
    ) {
        Console.WriteLine($"The list of animals: {string.Join(", ", animals)}");
        return true;
    }

    [Tool("register_weather", "Informs the user of average weather in a collection of cities")]
    void ExpectedWeather(
        [Desc("A set of city names and their expected high temperatures")] IReadOnlyDictionary<string, float> temps
    ) {
        Console.WriteLine($"The cities and temps: {string.Join(", ", temps.Select(x => $"{x.Key} -> {x.Value}"))}");
    }

    [Tool("Smile to show your happiness")]
    void Smile() {
        Console.WriteLine("GPT is smiling at you");
    }

    [Tool("Frown to show your displeasure")]
    void Frown() {
        Console.WriteLine("GPT is frowning at you");
    }
}

Generates:

partial class GptTest {
    protected override async Task<JToken> CallTool(string toolName, JToken args) {
        switch(toolName) {
            case "reverse": {
                if(args is not JObject argo) return new JValue("Invalid parameters");
                try {
                    if(argo["value"] is not JValue __0_v || __0_v.Value is not string __0) throw new ArgumentException("value");
                    var ret = ((GptTest) this).Reverse(__0);
                    return new JValue(ret);
                } catch(ArgumentException e) {
                    return new JValue($"Invalid parameter type for '{e.Message}'");
                }
            }
            case "register_animals": {
                if(args is not JObject argo) return new JValue("Invalid parameters");
                try {
                    if(argo["animals"] is not JArray __0_a) throw new ArgumentException("animals");
                    var __0 = __0_a.Select(__1 => {
                        if(__1 is not JValue __2_v || __2_v.Value is not string __2) throw new ArgumentException("animals");
                        return __2;
                    }).ToList();
                    var ret = ((GptTest) this).RegisterAnimals(__0);
                    return new JValue(ret);
                } catch(ArgumentException e) {
                    return new JValue($"Invalid parameter type for '{e.Message}'");
                }
            }
            case "register_weather": {
                if(args is not JObject argo) return new JValue("Invalid parameters");
                try {
                    if(argo["temps"] is not JArray __0_a) throw new ArgumentException("temps");
                    var __0 = __0_a.Select(__1 => {
                        if(__1 is not JObject __1_o || !__1_o.ContainsKey("key") || !__1_o.ContainsKey("value")) throw new ArgumentException("temps");
                        if(__1_o["key"] is not JValue __2_v || __2_v.Value is not string __2) throw new ArgumentException("temps");
                        if(__1_o["value"] is not JValue __3_v) throw new ArgumentException("temps");
                        return (Key: __2, Value: __3_v.ToObject<float>());
                    }).ToDictionary(__1 => __1.Key, __1 => __1.Value);
                    ((GptTest) this).ExpectedWeather(__0);
                    return new JValue((object) null);
                } catch(ArgumentException e) {
                    return new JValue($"Invalid parameter type for '{e.Message}'");
                }
            }
            case "Smile": {
                try {
                    ((GptTest) this).Smile();
                    return new JValue((object) null);
                } catch(ArgumentException e) {
                    return new JValue($"Invalid parameter type for '{e.Message}'");
                }
            }
            case "Frown": {
                try {
                    ((GptTest) this).Frown();
                    return new JValue((object) null);
                } catch(ArgumentException e) {
                    return new JValue($"Invalid parameter type for '{e.Message}'");
                }
            }
            case "Cipher": {
                if(args is not JObject argo) return new JValue("Invalid parameters");
                try {
                    if(argo["string"] is not JValue __0_v || __0_v.Value is not string __0) throw new ArgumentException("string");
                    var ret = await ((TestMixin) this).Cipher(__0);
                    return new JValue(ret);
                } catch(ArgumentException e) {
                    return new JValue($"Invalid parameter type for '{e.Message}'");
                }
            }
            default:
                return new JValue($"INVALID TOOL '{toolName}'");
        }
    }

    protected override JArray ToolObj => SToolObj;
    readonly static JArray SToolObj = new JArray {
        new JObject {
            ["type"] = "function",
            ["function"] = new JObject {
                ["name"] = "reverse",
                ["description"] = "Reverses a given string",
                ["parameters"] = new JObject {
                    ["type"] = "object",
                    ["properties"] = new JObject {
                        ["value"] = new JObject {
                            ["description"] = "The string to reverse",
                            ["type"] = "string",
                        },
                    },
                    ["required"] = new JArray(new JValue("value"))
                }
            }
        },
        new JObject {
            ["type"] = "function",
            ["function"] = new JObject {
                ["name"] = "register_animals",
                ["description"] = "Registers a set of animals with the system",
                ["parameters"] = new JObject {
                    ["type"] = "object",
                    ["properties"] = new JObject {
                        ["animals"] = new JObject {
                            ["description"] = "The animals to register",
                            ["type"] = "array",
                            ["items"] = new JObject {
                                ["type"] = "string",
                            },
                        },
                    },
                    ["required"] = new JArray(new JValue("animals"))
                }
            }
        },
        new JObject {
            ["type"] = "function",
            ["function"] = new JObject {
                ["name"] = "register_weather",
                ["description"] = "Informs the user of average weather in a collection of cities",
                ["parameters"] = new JObject {
                    ["type"] = "object",
                    ["properties"] = new JObject {
                        ["temps"] = new JObject {
                            ["description"] = "A set of city names and their expected high temperatures",
                            ["type"] = "array",
                            ["items"] = new JObject {
                                ["type"] = "object",
                                ["properties"] = new JObject {
                                    ["key"] = new JObject {
                                        ["type"] = "string",
                                    },
                                    ["value"] = new JObject {
                                        ["type"] = "number",
                                    },
                                },
                            },
                        },
                    },
                    ["required"] = new JArray(new JValue("temps"))
                }
            }
        },
        new JObject {
            ["type"] = "function",
            ["function"] = new JObject {
                ["name"] = "Smile",
                ["description"] = "Smile to show your happiness",
            }
        },
        new JObject {
            ["type"] = "function",
            ["function"] = new JObject {
                ["name"] = "Frown",
                ["description"] = "Frown to show your displeasure",
            }
        },
        new JObject {
            ["type"] = "function",
            ["function"] = new JObject {
                ["name"] = "Cipher",
                ["description"] = "Ciphers/deciphers a string",
                ["parameters"] = new JObject {
                    ["type"] = "object",
                    ["properties"] = new JObject {
                        ["string"] = new JObject {
                            ["description"] = "The string to encrypt or decrypt",
                            ["type"] = "string",
                        },
                    },
                    ["required"] = new JArray(new JValue("string"))
                }
            }
        },
    };

    protected override bool HasTools => true;
}
KrzysztofCwalina commented 1 week ago

Yes, we are thinking about adding something that would help with tool definitions and with tools calls. You can see an early prototype at https://github.com/openai/openai-dotnet/pull/42.