betalgo / openai

OpenAI .NET sdk - Azure OpenAI, ChatGPT, Whisper, and DALL-E
https://betalgo.github.io/openai/
MIT License
2.84k stars 513 forks source link

Stream returns no message after tool call. #556

Closed SebastianStehle closed 1 month ago

SebastianStehle commented 1 month ago

Describe the bug I try to migrate from the normal API to the function call metho. But after the tool has been called, the stream does not return a result. I am not sure if this is a bug or just a mistake from my side.

Your code piece

using Microsoft.Extensions.Options;
using OpenAI;
using OpenAI.Managers;
using OpenAI.ObjectModels;
using OpenAI.ObjectModels.RequestModels;
using OpenAI.ObjectModels.SharedModels;
using System;

namespace ConsoleApp5
{
    internal static class Program
    {
        static async Task Main(string[] args)
        {
            var service = new OpenAIService(new OpenAiOptions
            {
                ApiKey = "",
            });

            await TestNormal(service);
            Console.WriteLine();
            await TestStreaming(service);
        }

        async static Task TestNormal(OpenAIService service)
        {
            var internalRequest = new ChatCompletionCreateRequest
            {
                Messages = [
                    ChatMessage.FromUser("What is 10 plus 12?")
                ],
                Model = Models.Gpt_3_5_Turbo,
                N = 1,
                Seed = 42,
                Temperature = 0,
                ToolChoice = ToolChoice.Auto,
                Tools = [
                    ToolDefinition.DefineFunction(
                        new FunctionDefinition
                        {
                            Name = "add",
                            Parameters = new PropertyDefinition
                            {
                                Properties = new Dictionary<string, PropertyDefinition>
                                {
                                    ["lhs"] = PropertyDefinition.DefineNumber(),
                                    ["rhs"] = PropertyDefinition.DefineNumber(),
                                }
                            }
                        }
                    )
                ],
            };

            for (var run = 1; run <= 5; run++)
            {
                var response = await service.ChatCompletion.CreateCompletion(internalRequest);

                if (response.Error != null)
                {
                    throw new InvalidOperationException($"Request failed with internal error: {response.Error.Message}");
                }

                var choice = response.Choices[0].Message;

                if (choice.ToolCalls is not { Count: > 0 })
                {
                    if (!string.IsNullOrWhiteSpace(choice.Content))
                    {
                        Console.WriteLine(choice.Content);
                    }

                    break;
                }
                else if (run == 5)
                {
                    throw new InvalidOperationException($"Exceeded max tool runs.");
                }
                else
                {
                    var toolsResults = await ExecuteToolsAsync(choice);

                    internalRequest.Messages.Add(choice);
                    internalRequest.Messages.AddRange(toolsResults);
                    Console.WriteLine("Tool Called");
                }
            }
        }

        async static Task TestStreaming(OpenAIService service)
        {
            var internalRequest = new ChatCompletionCreateRequest
            {
                Messages = [
                    ChatMessage.FromUser("What is 10 plus 12?")
                ],
                Model = Models.Gpt_3_5_Turbo,
                N = 1,
                Seed = 42,
                Temperature = 0,
                ToolChoice = ToolChoice.Auto,
                Tools = [
                    ToolDefinition.DefineFunction(
                        new FunctionDefinition
                        {
                            Name = "add",
                            Parameters = new PropertyDefinition
                            {
                                Properties = new Dictionary<string, PropertyDefinition>
                                {
                                    ["lhs"] = PropertyDefinition.DefineNumber(),
                                    ["rhs"] = PropertyDefinition.DefineNumber(),
                                }
                            }
                        }
                    )
                ],
            };

            for (var run = 1; run <= 5; run++)
            {
                var stream = service.ChatCompletion.CreateCompletionAsStream(internalRequest);

                var isToolCall = false;
                await foreach (var response in stream)
                {
                    if (response.Error != null)
                    {
                        throw new InvalidOperationException($"Request failed with internal error: {response.Error.Message}");
                    }

                    var choice = response.Choices[0].Message;

                    if (choice.ToolCalls is not { Count: > 0 })
                    {
                        if (!string.IsNullOrWhiteSpace(choice.Content))
                        {
                            Console.WriteLine(choice.Content);
                        }
                    }
                    else if (run == 5)
                    {
                        throw new InvalidOperationException($"Exceeded max tool runs.");
                    }
                    else
                    {
                        // Only continue with the outer loop if we have a tool call.
                        isToolCall = true;

                        var toolsResults = await ExecuteToolsAsync(choice);

                        internalRequest.Messages.Add(choice);
                        internalRequest.Messages.AddRange(toolsResults);
                        Console.WriteLine("Tool Called");
                    }
                }

                if (!isToolCall)
                {
                    break;
                }
            }
        }

        public static void AddRange<T>(this IList<T> list, IEnumerable<T> items)
        {
            foreach (var item in items)
            {
                list.Add(item);
            }
        }

        private static async Task<ChatMessage[]> ExecuteToolsAsync(ChatMessage choice)
        {
            var validCalls = new List<(int Index, string Id, FunctionCall Call)>();

            var i = 0;
            foreach (var call in choice.ToolCalls!)
            {
                var toolName = call.FunctionCall?.Name;

                if (string.IsNullOrWhiteSpace(call.FunctionCall?.Name))
                {
                    throw new InvalidOperationException($"Tool response has no function name.");
                }

                validCalls.Add((i++, call.Id!, call.FunctionCall));
            }

            var results = new ChatMessage[validCalls.Count];

            await Parallel.ForEachAsync(validCalls, default(CancellationToken), async (job, ct) =>
            {
                var lhs = 10;
                var rhs = 12;

                await Task.Delay(1, ct);

                results[job.Index] = ChatMessage.FromTool($"The result {(lhs + rhs) + 42}. Return this value to the user.", job.Id);
            });

            return results;
        }
    }
}

Result The code above implements 2 function calls:

  1. With the normal completion API.
  2. Using the completion API.

I Would expect that both results are the same.

Furthermore I realized that the stream endpoint does not return errors. Also not sure if this is a bug or just OpenAI. But if I pass over a wrong model name, the stream is also just empty. Which makes it super hard to find actual bugs.

Expected behavior Streaming after tool should return result.

Screenshots If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

SebastianStehle commented 1 month ago

The actual root cause is the lack or error handling in the streaming method. I have fixed that in the PR.