microsoft / chat-copilot

MIT License
2.02k stars 686 forks source link

Bug: The CopilotChatWebApi backend should invoke CreateCollectionAsync once at least for each chat session when use PostgresMemoryStore with local postgresql database as vectors MemoryStore #215

Closed AwesomeYuer closed 1 year ago

AwesomeYuer commented 1 year ago

Describe the bug

The WebAPI backend should invoke CreateCollectionAsync once at least for each chat session when use PostgresMemoryStore and local postgresql database as vectors MemoryStore

To Reproduce Steps to reproduce the behavior:

  1. running local postgres server
  2. modify CopilotChatWebApi appsettings.json
    "MemoryStore": {
    "Type": "postgres",
    "Qdrant": {
    "Host": "http://localhost",
    "Port": "6333",
    "VectorSize": 1536
    // "Key":  ""
    },
    "AzureCognitiveSearch": {
    "Endpoint": ""
    // "Key": ""
    },
    "Chroma": {
    "Host": "http://localhost",
    "Port": "8000"
    },
    "Postgres": {
    "VectorSize": 1536
    // "ConnectionString": // dotnet user-secrets set "MemoryStore:Postgres:ConnectionString" "MY_POSTGRES_CONNECTION_STRING"
    }
    },
  3. modify CopilotChatWebApi user secrets.json
    {
    "AIService:Key": "eeXXXXXXXXXXXXXXXXXXXXXXX",
    "MemoryStore:Postgres:ConnectionString": "Host=localhost;Database=postgres;User Id=postgres;Password=password01!"
    }
  4. Debug CopilotChatWebApi in VS2022
  5. Run webapp in VSCode using command
    yarn serve
  6. chat in new session in webapp UI
  7. see same error in webapp frontend and webapi backend console

    
    fail: Microsoft.SemanticKernel.IKernel[0]
      Something went wrong while executing the native function. Function: <GetDelegateInfo>b__0. Error: 42P01: relation "public.d8ff415b-167b-4b1a-8dd8-7d62f97f120d-LongTermMemory" does not exist
    
      POSITION: 115
      Npgsql.PostgresException (0x80004005): 42P01: relation "public.d8ff415b-167b-4b1a-8dd8-7d62f97f120d-LongTermMemory" does not exist
    
      POSITION: 115
         at Npgsql.Internal.NpgsqlConnector.<ReadMessage>g__ReadMessageLong|221_0(NpgsqlConnector connector, Boolean async, DataRowLoadingMode dataRowLoadingMode, Boolean readingNotifications, Boolean isReadingPrependedMessage)
         at Npgsql.NpgsqlDataReader.NextResult(Boolean async, Boolean isConsuming, CancellationToken cancellationToken)
         at Npgsql.NpgsqlDataReader.NextResult(Boolean async, Boolean isConsuming, CancellationToken cancellationToken)
         at Npgsql.NpgsqlCommand.ExecuteReader(CommandBehavior behavior, Boolean async, CancellationToken cancellationToken)
         at Npgsql.NpgsqlCommand.ExecuteReader(CommandBehavior behavior, Boolean async, CancellationToken cancellationToken)
         at Microsoft.SemanticKernel.Connectors.Memory.Postgres.PostgresDbClient.GetNearestMatchesAsync(String tableName, Vector embedding, Int32 limit, Double minRelevanceScore, Boolean withEmbeddings, CancellationToken cancellationToken)+MoveNext() in D:\MyGitHub\semantic-kernel@microshaoft-Debug\dotnet\src\Connectors\Connectors.Memory.Postgres\PostgresDbClient.cs:line 166
         at Microsoft.SemanticKernel.Connectors.Memory.Postgres.PostgresDbClient.GetNearestMatchesAsync(String tableName, Vector embedding, Int32 limit, Double minRelevanceScore, Boolean withEmbeddings, CancellationToken cancellationToken)+MoveNext() in D:\MyGitHub\semantic-kernel@microshaoft-Debug\dotnet\src\Connectors\Connectors.Memory.Postgres\PostgresDbClient.cs:line 173
         at Microsoft.SemanticKernel.Connectors.Memory.Postgres.PostgresDbClient.GetNearestMatchesAsync(String tableName, Vector embedding, Int32 limit, Double minRelevanceScore, Boolean withEmbeddings, CancellationToken cancellationToken)+System.Threading.Tasks.Sources.IValueTaskSource<System.Boolean>.GetResult()
         at Microsoft.SemanticKernel.Connectors.Memory.Postgres.PostgresMemoryStore.GetNearestMatchesAsync(String collectionName, ReadOnlyMemory`1 embedding, Int32 limit, Double minRelevanceScore, Boolean withEmbeddings, CancellationToken cancellationToken)+MoveNext() in D:\MyGitHub\semantic-kernel@microshaoft-Debug\dotnet\src\Connectors\Connectors.Memory.Postgres\PostgresMemoryStore.cs:line 160
         at Microsoft.SemanticKernel.Connectors.Memory.Postgres.PostgresMemoryStore.GetNearestMatchesAsync(String collectionName, ReadOnlyMemory`1 embedding, Int32 limit, Double minRelevanceScore, Boolean withEmbeddings, CancellationToken cancellationToken)+MoveNext() in D:\MyGitHub\semantic-kernel@microshaoft-Debug\dotnet\src\Connectors\Connectors.Memory.Postgres\PostgresMemoryStore.cs:line 160
         at Microsoft.SemanticKernel.Connectors.Memory.Postgres.PostgresMemoryStore.GetNearestMatchesAsync(String collectionName, ReadOnlyMemory`1 embedding, Int32 limit, Double minRelevanceScore, Boolean withEmbeddings, CancellationToken cancellationToken)+System.Threading.Tasks.Sources.IValueTaskSource<System.Boolean>.GetResult()
         at Microsoft.SemanticKernel.Memory.SemanticTextMemory.SearchAsync(String collection, String query, Int32 limit, Double minRelevanceScore, Boolean withEmbeddings, CancellationToken cancellationToken)+MoveNext() in D:\MyGitHub\semantic-kernel@microshaoft-Debug\dotnet\src\SemanticKernel\Memory\SemanticTextMemory.cs:line 114
         at Microsoft.SemanticKernel.Memory.SemanticTextMemory.SearchAsync(String collection, String query, Int32 limit, Double minRelevanceScore, Boolean withEmbeddings, CancellationToken cancellationToken)+MoveNext() in D:\MyGitHub\semantic-kernel@microshaoft-Debug\dotnet\src\SemanticKernel\Memory\SemanticTextMemory.cs:line 114
         at Microsoft.SemanticKernel.Memory.SemanticTextMemory.SearchAsync(String collection, String query, Int32 limit, Double minRelevanceScore, Boolean withEmbeddings, CancellationToken cancellationToken)+System.Threading.Tasks.Sources.IValueTaskSource<System.Boolean>.GetResult()
         at CopilotChat.WebApi.Skills.ChatSkills.SemanticChatMemorySkill.QueryMemoriesAsync(String query, String chatId, Int32 tokenLimit, ISemanticTextMemory textMemory) in D:\MyGitHub\semantic-kernel@microshaoft-Debug\chat-copilot\webapi\Skills\ChatSkills\SemanticChatMemorySkill.cs:line 78
         at CopilotChat.WebApi.Skills.ChatSkills.SemanticChatMemorySkill.QueryMemoriesAsync(String query, String chatId, Int32 tokenLimit, ISemanticTextMemory textMemory) in D:\MyGitHub\semantic-kernel@microshaoft-Debug\chat-copilot\webapi\Skills\ChatSkills\SemanticChatMemorySkill.cs:line 78
         at CopilotChat.WebApi.Skills.ChatSkills.ChatSkill.GetChatResponseAsync(String chatId, String userId, SKContext chatContext, CancellationToken cancellationToken) in D:\MyGitHub\semantic-kernel@microshaoft-Debug\chat-copilot\webapi\Skills\ChatSkills\ChatSkill.cs:line 377
         at CopilotChat.WebApi.Skills.ChatSkills.ChatSkill.ChatAsync(String message, String userId, String userName, String chatId, String messageType, String planJson, String messageId, SKContext context, CancellationToken cancellationToken) in D:\MyGitHub\semantic-kernel@microshaoft-Debug\chat-copilot\webapi\Skills\ChatSkills\ChatSkill.cs:line 299         at Microsoft.SemanticKernel.SkillDefinition.NativeFunction.InvokeAsync(SKContext context, CompleteRequestSettings settings, CancellationToken cancellationToken) in D:\MyGitHub\semantic-kernel@microshaoft-Debug\dotnet\src\SemanticKernel\SkillDefinition\NativeFunction.cs:line 165
        Exception data:
          Severity: ERROR
          SqlState: 42P01
          MessageText: relation "public.d8ff415b-167b-4b1a-8dd8-7d62f97f120d-LongTermMemory" does not exist
          Position: 115
          File: parse_relation.c
          Line: 1371
          Routine: parserOpenTable


**Expected behavior**
A clear and concise description of what you expected to happen.

1. before searching , the memorystore related collection  should be create at first
2. The method of CreateCollectionAsync belongs Other implemention of IMemoryStore (such as QdrantMemoryStore and VolatileMemoryStore)  can be invoked  properly, but PostgresMemoryStore's  CreateCollectionAsync method have never been invoked when use local Postgresql as memorystore. 

**Screenshots**
If applicable, add screenshots to help explain your problem.

1. error when search vector in PostgresMemoryStore, but maybe it's not root cause
![image](https://github.com/microsoft/chat-copilot/assets/1026479/a22bec1d-b58d-42d6-af22-5f32d10be55c)

2. For reference, the below code is in `Connectors.Memory.Postgres` project , it belongs to `Semantic-Kernel`, not belong to `chat-copilot`
I have debugged chat-copilot and semantic-kernel code ,I can't break in the break points in below screenshot:
![image](https://github.com/microsoft/chat-copilot/assets/1026479/d57cb004-e32b-4466-9499-67067b666c80)

**Platform**
 - OS: windows 11
 - IDE: Visual Studio 2022, VS Code
 - Language: C#
 - Source: [e.g. NuGet package version 0.1.0, pip package version 0.1.0, main branch of repository]

**Additional context**
Add any other context about the problem here.
crickman commented 1 year ago

Thank you for reporting this. I appreciate your detailed analsyis and will follow-up.

We're always open to contributions from the community.

crickman commented 1 year ago

FYI - I'm working to land a fix for this with the core SK team that will address all connectors

crickman commented 1 year ago

@AwesomeYuer - We are days away from merging this PR which will remove the issue you've oberved:

https://github.com/microsoft/chat-copilot/pull/152

crickman commented 1 year ago

I believe this addressers the issue: https://github.com/microsoft/chat-copilot/pull/365

AwesomeYuer commented 1 year ago

I found walkaround just temporary solution as below:

Don't use default PostgresMemoryStore , other implement that ensure collection table exist before access it.

Can't inherited from default PostgresMemoryStore, because can't override non-virtual methods!

AwesomePostgresMemoryStore.cs

// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Connectors.Memory.Postgres;
using Npgsql;
using Pgvector;

namespace Microshaoft.SemanticKernel.Connectors.Memory.Postgres;

/// <summary>
/// An implementation of <see cref="IMemoryStore"/> backed by a Postgres database with pgvector extension.
/// </summary>
/// <remarks>
/// The embedded data is saved to the Postgres database specified in the constructor.
/// Similarity search capability is provided through the pgvector extension. Use Postgres's "Table" to implement "Collection".
/// </remarks>
public class AwesomePostgresMemoryStore : IMemoryStore
{
    // add by Awesome Yuer
    private HashSet<string> _collections = new();
    private async Task EnsureCollectionExistsAsync(string collectionName, CancellationToken cancellationToken = default)
    {
        if (!this._collections.Contains(collectionName))
        {
            var r = await this.DoesCollectionExistAsync(collectionName, cancellationToken).ConfigureAwait(false);
            if (!r)
            {
                await this.CreateCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false);
            }
            this._collections.Add(collectionName);
        }
    }

    internal const string DefaultSchema = "public";

    /// <summary>
    /// Initializes a new instance of the <see cref="AwesomePostgresMemoryStore"/> class.
    /// </summary>
    /// <param name="dataSource">Postgres data source.</param>
    /// <param name="vectorSize">Embedding vector size.</param>
    /// <param name="schema">Database schema of collection tables. The default value is "public".</param>
    public AwesomePostgresMemoryStore(NpgsqlDataSource dataSource, int vectorSize, string schema = DefaultSchema)
        : this(new PostgresDbClient(dataSource, schema, vectorSize))
    {
    }

    /// <summary>
    /// Initializes a new instance of the <see cref="AwesomePostgresMemoryStore"/> class.
    /// </summary>
    /// <param name="postgresDbClient">An instance of <see cref="IPostgresDbClient"/>.</param>
    public AwesomePostgresMemoryStore(IPostgresDbClient postgresDbClient)
    {
        this._postgresDbClient = postgresDbClient;
    }

    /// <inheritdoc/>
    public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);

        await this._postgresDbClient.CreateTableAsync(collectionName, cancellationToken).ConfigureAwait(false);
    }

    /// <inheritdoc/>
    public async Task<bool> DoesCollectionExistAsync(string collectionName, CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);

        return await this._postgresDbClient.DoesTableExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
    }

    /// <inheritdoc/>
    public async IAsyncEnumerable<string> GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        await foreach (string collection in this._postgresDbClient.GetTablesAsync(cancellationToken).ConfigureAwait(false))
        {
            yield return collection;
        }
    }

    /// <inheritdoc/>
    public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);
        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        await this._postgresDbClient.DeleteTableAsync(collectionName, cancellationToken).ConfigureAwait(false);
    }

    /// <inheritdoc/>
    public async Task<string> UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);
        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        return await this.InternalUpsertAsync(collectionName, record, cancellationToken).ConfigureAwait(false);
    }

    /// <inheritdoc/>
    public async IAsyncEnumerable<string> UpsertBatchAsync(string collectionName, IEnumerable<MemoryRecord> records,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);
        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        foreach (MemoryRecord record in records)
        {
            yield return await this.InternalUpsertAsync(collectionName, record, cancellationToken).ConfigureAwait(false);
        }
    }

    /// <inheritdoc/>
    public async Task<MemoryRecord?> GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);
        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        PostgresMemoryEntry? entry = await this._postgresDbClient.ReadAsync(collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false);

        if (!entry.HasValue) { return null; }

        return this.GetMemoryRecordFromEntry(entry.Value);
    }

    /// <inheritdoc/>
    public async IAsyncEnumerable<MemoryRecord> GetBatchAsync(string collectionName, IEnumerable<string> keys, bool withEmbeddings = false,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);
        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        await foreach (PostgresMemoryEntry entry in this._postgresDbClient.ReadBatchAsync(collectionName, keys, withEmbeddings, cancellationToken).ConfigureAwait(false))
        {
            yield return this.GetMemoryRecordFromEntry(entry);
        }
    }

    /// <inheritdoc/>
    public async Task RemoveAsync(string collectionName, string key, CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);
        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        await this._postgresDbClient.DeleteAsync(collectionName, key, cancellationToken).ConfigureAwait(false);
    }

    /// <inheritdoc/>
    public async Task RemoveBatchAsync(string collectionName, IEnumerable<string> keys, CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);
        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        await this._postgresDbClient.DeleteBatchAsync(collectionName, keys, cancellationToken).ConfigureAwait(false);
    }

    /// <inheritdoc/>
    public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync(
        string collectionName,
        ReadOnlyMemory<float> embedding,
        int limit,
        double minRelevanceScore = 0,
        bool withEmbeddings = false,
        [EnumeratorCancellation] CancellationToken cancellationToken = default)
    {
        //Verify.NotNullOrWhiteSpace(collectionName);

        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        var r = await this.DoesCollectionExistAsync(collectionName, cancellationToken).ConfigureAwait(false);
        if (!r)
        {
            await this.CreateCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false);
        }

        if (limit <= 0)
        {
            yield break;
        }

        IAsyncEnumerable<(PostgresMemoryEntry, double)> results = this._postgresDbClient.GetNearestMatchesAsync(
            tableName: collectionName,
            embedding: new Vector(GetOrCreateArray(embedding)),
            limit: limit,
            minRelevanceScore: minRelevanceScore,
            withEmbeddings: withEmbeddings,
            cancellationToken: cancellationToken);

        await foreach ((PostgresMemoryEntry entry, double cosineSimilarity) in results.ConfigureAwait(false))
        {
            yield return (this.GetMemoryRecordFromEntry(entry), cosineSimilarity);
        }
    }

    /// <inheritdoc/>
    public async Task<(MemoryRecord, double)?> GetNearestMatchAsync(string collectionName, ReadOnlyMemory<float> embedding, double minRelevanceScore = 0, bool withEmbedding = false,
        CancellationToken cancellationToken = default)
    {
        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        return await this.GetNearestMatchesAsync(
            collectionName: collectionName,
            embedding: embedding,
            limit: 1,
            minRelevanceScore: minRelevanceScore,
            withEmbeddings: withEmbedding,
            cancellationToken: cancellationToken).FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false);
    }

    #region private ================================================================================

    private readonly IPostgresDbClient _postgresDbClient;

    private async Task<string> InternalUpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken)
    {
        await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);

        record.Key = record.Metadata.Id;

        await this._postgresDbClient.UpsertAsync(
            tableName: collectionName,
            key: record.Key,
            metadata: record.GetSerializedMetadata(),
            embedding: new Vector(GetOrCreateArray(record.Embedding)),
            timestamp: record.Timestamp?.UtcDateTime,
            cancellationToken: cancellationToken).ConfigureAwait(false);

        return record.Key;
    }

    private MemoryRecord GetMemoryRecordFromEntry(PostgresMemoryEntry entry)
    {
        return MemoryRecord.FromJsonMetadata(
            json: entry.MetadataString,
            embedding: entry.Embedding?.ToArray() ?? ReadOnlyMemory<float>.Empty,
            key: entry.Key,
            timestamp: entry.Timestamp?.ToLocalTime());
    }

    private static float[] GetOrCreateArray(ReadOnlyMemory<float> memory) =>
        MemoryMarshal.TryGetArray(memory, out ArraySegment<float> array) &&
        array.Count == array.Array!.Length ?
            array.Array :
            memory.ToArray();

    #endregion
}

AsyncEnumerable.cs

// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
#pragma warning disable CA1510 // Use 'ArgumentNullException.ThrowIfNull' (.NET 8)

// Used for compatibility with System.Linq.Async Nuget pkg
namespace System.Linq;

internal static class AsyncEnumerable
{
    public static IAsyncEnumerable<T> Empty<T>() => EmptyAsyncEnumerable<T>.Instance;

    public static IEnumerable<T> ToEnumerable<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
    {
        var enumerator = source.GetAsyncEnumerator(cancellationToken);
        try
        {
            while (enumerator.MoveNextAsync().AsTask().GetAwaiter().GetResult())
            {
                yield return enumerator.Current;
            }
        }
        finally
        {
            enumerator.DisposeAsync().AsTask().GetAwaiter().GetResult();
        }
    }

    public static async IAsyncEnumerable<T> ToAsyncEnumerable<T>(this IEnumerable<T> source)
    {
        foreach (var item in source)
        {
            yield return item;
        }
    }

    public static async ValueTask<T?> FirstOrDefaultAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
    {
        await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
        {
            return item;
        }

        return default;
    }

    public static async ValueTask<T?> LastOrDefaultAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
    {
        var last = default(T)!; // NB: Only matters when hasLast is set to true.
        var hasLast = false;

        await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
        {
            hasLast = true;
            last = item;
        }

        return hasLast ? last! : default;
    }

    public static async ValueTask<List<T>> ToListAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
    {
        var result = new List<T>();

        await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
        {
            result.Add(item);
        }

        return result;
    }

    public static async ValueTask<bool> ContainsAsync<T>(this IAsyncEnumerable<T> source, T value)
    {
        await foreach (var item in source.ConfigureAwait(false))
        {
            if (EqualityComparer<T>.Default.Equals(item, value))
            {
                return true;
            }
        }

        return false;
    }

    public static async ValueTask<int> CountAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
    {
        int count = 0;
        await foreach (var _ in source.WithCancellation(cancellationToken).ConfigureAwait(false))
        {
            checked { count++; }
        }

        return count;
    }

    /// <summary>
    /// Determines whether any element of an async-enumerable sequence satisfies a condition.
    /// </summary>
    /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
    /// <param name="source">An async-enumerable sequence whose elements to apply the predicate to.</param>
    /// <param name="predicate">A function to test each element for a condition.</param>
    /// <param name="cancellationToken">The optional cancellation token to be used for cancelling the sequence at any time.</param>
    /// <returns>An async-enumerable sequence containing a single element determining whether any elements in the source sequence pass the test in the specified predicate.</returns>
    /// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
    /// <remarks>The return type of this operator differs from the corresponding operator on IEnumerable in order to retain asynchronous behavior.</remarks>
    public static ValueTask<bool> AnyAsync<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken = default)
    {
        if (source == null)
        {
            throw new ArgumentNullException(nameof(source));
        }

        if (predicate == null)
        {
            throw new ArgumentNullException(nameof(predicate));
        }

        return Core(source, predicate, cancellationToken);

        static async ValueTask<bool> Core(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
        {
            await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
            {
                if (predicate(item))
                {
                    return true;
                }
            }

            return false;
        }
    }

    private sealed class EmptyAsyncEnumerable<T> : IAsyncEnumerable<T>, IAsyncEnumerator<T>
    {
        public static readonly EmptyAsyncEnumerable<T> Instance = new();
        public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default) => this;
        public ValueTask<bool> MoveNextAsync() => new(false);
        public T Current => default!;
        public ValueTask DisposeAsync() => default;
    }
}

SemanticKernelExtensions.cs

// Copyright (c) Microsoft. All rights reserved.

using System;
using System.IO;
using System.Net.Http;
using System.Threading.Tasks;
using CopilotChat.WebApi.Hubs;
using CopilotChat.WebApi.Options;
using CopilotChat.WebApi.Services;
using CopilotChat.WebApi.Skills.ChatSkills;
using CopilotChat.WebApi.Storage;
using Microshaoft.SemanticKernel.Connectors.Memory.Postgres;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
using Microsoft.SemanticKernel.Connectors.Memory.AzureCognitiveSearch;
using Microsoft.SemanticKernel.Connectors.Memory.Chroma;
using Microsoft.SemanticKernel.Connectors.Memory.Qdrant;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Skills.Core;
using Npgsql;
using Pgvector.Npgsql;
using static CopilotChat.WebApi.Options.MemoryStoreOptions;

namespace CopilotChat.WebApi.Extensions;

/// <summary>
/// Extension methods for registering Semantic Kernel related services.
/// </summary>
internal static class SemanticKernelExtensions
{
    /// <summary>
    /// Delegate to register skills with a Semantic Kernel
    /// </summary>
    public delegate Task RegisterSkillsWithKernel(IServiceProvider sp, IKernel kernel);

    /// <summary>
    /// Add Semantic Kernel services
    /// </summary>
    internal static IServiceCollection AddSemanticKernelServices(this IServiceCollection services)
    {
        // Semantic Kernel
        services.AddScoped<IKernel>(sp =>
        {
            IKernel kernel = Kernel.Builder
                .WithLoggerFactory(sp.GetRequiredService<ILoggerFactory>())
                .WithMemory(sp.GetRequiredService<ISemanticTextMemory>())
                .WithCompletionBackend(sp.GetRequiredService<IOptions<AIServiceOptions>>().Value)
                .WithEmbeddingBackend(sp.GetRequiredService<IOptions<AIServiceOptions>>().Value)
                .Build();

            sp.GetRequiredService<RegisterSkillsWithKernel>()(sp, kernel);
            return kernel;
        });

        // Semantic memory
        services.AddSemanticTextMemory();

        // Azure Content Safety
        services.AddContentSafety();

        // Register skills
        services.AddScoped<RegisterSkillsWithKernel>(sp => RegisterSkillsAsync);

        return services;
    }

    /// <summary>
    /// Add Planner services
    /// </summary>
    public static IServiceCollection AddPlannerServices(this IServiceCollection services)
    {
        IOptions<PlannerOptions>? plannerOptions = services.BuildServiceProvider().GetService<IOptions<PlannerOptions>>();
        services.AddScoped<CopilotChatPlanner>(sp =>
        {
            IKernel plannerKernel = Kernel.Builder
                .WithLoggerFactory(sp.GetRequiredService<ILoggerFactory>())
                .WithMemory(sp.GetRequiredService<ISemanticTextMemory>())
                // TODO: [sk Issue #2046] verify planner has AI service configured
                .WithPlannerBackend(sp.GetRequiredService<IOptions<AIServiceOptions>>().Value)
                .Build();
            return new CopilotChatPlanner(plannerKernel, plannerOptions?.Value, sp.GetRequiredService<ILogger<CopilotChatPlanner>>());
        });

        // Register Planner skills (AI plugins) here.
        // TODO: [sk Issue #2046] Move planner skill registration from ChatController to this location.

        return services;
    }

    /// <summary>
    /// Register the chat skill with the kernel.
    /// </summary>
    public static IKernel RegisterChatSkill(this IKernel kernel, IServiceProvider sp)
    {
        // Chat skill
        kernel.ImportSkill(new ChatSkill(
                kernel: kernel,
                chatMessageRepository: sp.GetRequiredService<ChatMessageRepository>(),
                chatSessionRepository: sp.GetRequiredService<ChatSessionRepository>(),
                messageRelayHubContext: sp.GetRequiredService<IHubContext<MessageRelayHub>>(),
                promptOptions: sp.GetRequiredService<IOptions<PromptsOptions>>(),
                documentImportOptions: sp.GetRequiredService<IOptions<DocumentMemoryOptions>>(),
                contentSafety: sp.GetService<AzureContentSafety>(),
                planner: sp.GetRequiredService<CopilotChatPlanner>(),
                logger: sp.GetRequiredService<ILogger<ChatSkill>>()),
            nameof(ChatSkill));

        return kernel;
    }

    /// <summary>
    /// Propagate exception from within semantic function
    /// </summary>
    public static void ThrowIfFailed(this SKContext context)
    {
        if (context.ErrorOccurred)
        {
            var logger = context.LoggerFactory.CreateLogger(nameof(SKContext));
            logger.LogError(context.LastException, "{0}", context.LastException?.Message);
            throw context.LastException!;
        }
    }

    /// <summary>
    /// Register the skills with the kernel.
    /// </summary>
    private static Task RegisterSkillsAsync(IServiceProvider sp, IKernel kernel)
    {
        // Copilot chat skills
        kernel.RegisterChatSkill(sp);

        // Time skill
        kernel.ImportSkill(new TimeSkill(), nameof(TimeSkill));

        // Semantic skills
        ServiceOptions options = sp.GetRequiredService<IOptions<ServiceOptions>>().Value;
        if (!string.IsNullOrWhiteSpace(options.SemanticSkillsDirectory))
        {
            foreach (string subDir in Directory.GetDirectories(options.SemanticSkillsDirectory))
            {
                try
                {
                    kernel.ImportSemanticSkillFromDirectory(options.SemanticSkillsDirectory, Path.GetFileName(subDir)!);
                }
                catch (SKException ex)
                {
                    var logger = kernel.LoggerFactory.CreateLogger(nameof(Kernel));
                    logger.LogError("Could not load skill from {Directory}: {Message}", subDir, ex.Message);
                }
            }
        }

        return Task.CompletedTask;
    }

    /// <summary>
    /// Add the semantic memory.
    /// </summary>
    private static void AddSemanticTextMemory(this IServiceCollection services)
    {
        MemoryStoreOptions config = services.BuildServiceProvider().GetRequiredService<IOptions<MemoryStoreOptions>>().Value;

        switch (config.Type)
        {
            case MemoryStoreType.Volatile:
                services.AddSingleton<IMemoryStore, VolatileMemoryStore>();
                break;

            case MemoryStoreType.Qdrant:
                if (config.Qdrant == null)
                {
                    throw new InvalidOperationException("MemoryStore type is Qdrant and Qdrant configuration is null.");
                }

                services.AddSingleton<IMemoryStore>(sp =>
                {
                    HttpClient httpClient = new(new HttpClientHandler { CheckCertificateRevocationList = true });
                    if (!string.IsNullOrWhiteSpace(config.Qdrant.Key))
                    {
                        httpClient.DefaultRequestHeaders.Add("api-key", config.Qdrant.Key);
                    }

                    var endPointBuilder = new UriBuilder(config.Qdrant.Host);
                    endPointBuilder.Port = config.Qdrant.Port;

                    return new QdrantMemoryStore(
                        httpClient: httpClient,
                        config.Qdrant.VectorSize,
                        endPointBuilder.ToString(),
                        loggerFactory: sp.GetRequiredService<ILoggerFactory>()
                    );
                });
                break;

            case MemoryStoreType.AzureCognitiveSearch:
                if (config.AzureCognitiveSearch == null)
                {
                    throw new InvalidOperationException("MemoryStore type is AzureCognitiveSearch and AzureCognitiveSearch configuration is null.");
                }

                services.AddSingleton<IMemoryStore>(sp =>
                {
                    return new AzureCognitiveSearchMemoryStore(config.AzureCognitiveSearch.Endpoint, config.AzureCognitiveSearch.Key);
                });
                break;

            case MemoryStoreOptions.MemoryStoreType.Chroma:
                if (config.Chroma == null)
                {
                    throw new InvalidOperationException("MemoryStore type is Chroma and Chroma configuration is null.");
                }

                services.AddSingleton<IMemoryStore>(sp =>
                {
                    HttpClient httpClient = new(new HttpClientHandler { CheckCertificateRevocationList = true });
                    var endPointBuilder = new UriBuilder(config.Chroma.Host);
                    endPointBuilder.Port = config.Chroma.Port;

                    return new ChromaMemoryStore(
                        httpClient: httpClient,
                        endpoint: endPointBuilder.ToString(),
                        loggerFactory: sp.GetRequiredService<ILoggerFactory>()
                    );
                });
                break;

            case MemoryStoreOptions.MemoryStoreType.Postgres:
                if (config.Postgres == null)
                {
                    throw new InvalidOperationException("MemoryStore type is Postgres and Postgres configuration is null.");
                }

                var dataSourceBuilder = new NpgsqlDataSourceBuilder(config.Postgres.ConnectionString);
                dataSourceBuilder.UseVector();

                services.AddSingleton<IMemoryStore>(sp =>
                {
                    return new AwesomePostgresMemoryStore(
                        dataSource: dataSourceBuilder.Build(),
                        vectorSize: config.Postgres.VectorSize
                    );
                });

                break;

            default:
                throw new InvalidOperationException($"Invalid 'MemoryStore' type '{config.Type}'.");
        }

        services.AddScoped<ISemanticTextMemory>(sp => new SemanticTextMemory(
            sp.GetRequiredService<IMemoryStore>(),
            sp.GetRequiredService<IOptions<AIServiceOptions>>().Value
                .ToTextEmbeddingsService(loggerFactory: sp.GetRequiredService<ILoggerFactory>())));
    }

    /// <summary>
    /// Adds Azure Content Safety
    /// </summary>
    internal static void AddContentSafety(this IServiceCollection services)
    {
        IConfiguration configuration = services.BuildServiceProvider().GetRequiredService<IConfiguration>();
        ContentSafetyOptions options = configuration.GetSection(ContentSafetyOptions.PropertyName).Get<ContentSafetyOptions>();

        if (options.Enabled)
        {
            services.AddSingleton<IContentSafetyService, AzureContentSafety>(sp => new AzureContentSafety(new Uri(options.Endpoint), options.Key, options));
        }
    }

    /// <summary>
    /// Add the completion backend to the kernel config
    /// </summary>
    private static KernelBuilder WithCompletionBackend(this KernelBuilder kernelBuilder, AIServiceOptions options)
    {
        return options.Type switch
        {
            AIServiceOptions.AIServiceType.AzureOpenAI
                => kernelBuilder.WithAzureChatCompletionService(options.Models.Completion, options.Endpoint, options.Key),
            AIServiceOptions.AIServiceType.OpenAI
                => kernelBuilder.WithOpenAIChatCompletionService(options.Models.Completion, options.Key),
            _
                => throw new ArgumentException($"Invalid {nameof(options.Type)} value in '{AIServiceOptions.PropertyName}' settings."),
        };
    }

    /// <summary>
    /// Add the embedding backend to the kernel config
    /// </summary>
    private static KernelBuilder WithEmbeddingBackend(this KernelBuilder kernelBuilder, AIServiceOptions options)
    {
        return options.Type switch
        {
            AIServiceOptions.AIServiceType.AzureOpenAI
                => kernelBuilder.WithAzureTextEmbeddingGenerationService(options.Models.Embedding, options.Endpoint, options.Key),
            AIServiceOptions.AIServiceType.OpenAI
                => kernelBuilder.WithOpenAITextEmbeddingGenerationService(options.Models.Embedding, options.Key),
            _
                => throw new ArgumentException($"Invalid {nameof(options.Type)} value in '{AIServiceOptions.PropertyName}' settings."),
        };
    }

    /// <summary>
    /// Add the completion backend to the kernel config for the planner.
    /// </summary>
    private static KernelBuilder WithPlannerBackend(this KernelBuilder kernelBuilder, AIServiceOptions options)
    {
        return options.Type switch
        {
            AIServiceOptions.AIServiceType.AzureOpenAI => kernelBuilder.WithAzureChatCompletionService(options.Models.Planner, options.Endpoint, options.Key),
            AIServiceOptions.AIServiceType.OpenAI => kernelBuilder.WithOpenAIChatCompletionService(options.Models.Planner, options.Key),
            _ => throw new ArgumentException($"Invalid {nameof(options.Type)} value in '{AIServiceOptions.PropertyName}' settings."),
        };
    }

    /// <summary>
    /// Construct IEmbeddingGeneration from <see cref="AIServiceOptions"/>
    /// </summary>
    /// <param name="options">The service configuration</param>
    /// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
    /// <param name="loggerFactory">Custom <see cref="ILoggerFactory"/> for logging.</param>
    private static ITextEmbeddingGeneration ToTextEmbeddingsService(this AIServiceOptions options,
        HttpClient? httpClient = null,
        ILoggerFactory? loggerFactory = null)
    {
        return options.Type switch
        {
            AIServiceOptions.AIServiceType.AzureOpenAI
                => new AzureTextEmbeddingGeneration(options.Models.Embedding, options.Endpoint, options.Key, httpClient: httpClient, loggerFactory: loggerFactory),
            AIServiceOptions.AIServiceType.OpenAI
                => new OpenAITextEmbeddingGeneration(options.Models.Embedding, options.Key, httpClient: httpClient, loggerFactory: loggerFactory),
            _
                => throw new ArgumentException("Invalid AIService value in embeddings backend settings"),
        };
    }
}