Closed AwesomeYuer closed 1 year ago
Thank you for reporting this. I appreciate your detailed analsyis and will follow-up.
We're always open to contributions from the community.
FYI - I'm working to land a fix for this with the core SK team that will address all connectors
@AwesomeYuer - We are days away from merging this PR which will remove the issue you've oberved:
I believe this addressers the issue: https://github.com/microsoft/chat-copilot/pull/365
I found walkaround just temporary solution as below:
Don't use default PostgresMemoryStore
, other implement that ensure collection table exist before access it.
Can't inherited from default PostgresMemoryStore
, because can't override non-virtual methods!
AwesomePostgresMemoryStore.cs
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Connectors.Memory.Postgres;
using Npgsql;
using Pgvector;
namespace Microshaoft.SemanticKernel.Connectors.Memory.Postgres;
/// <summary>
/// An implementation of <see cref="IMemoryStore"/> backed by a Postgres database with pgvector extension.
/// </summary>
/// <remarks>
/// The embedded data is saved to the Postgres database specified in the constructor.
/// Similarity search capability is provided through the pgvector extension. Use Postgres's "Table" to implement "Collection".
/// </remarks>
public class AwesomePostgresMemoryStore : IMemoryStore
{
// add by Awesome Yuer
private HashSet<string> _collections = new();
private async Task EnsureCollectionExistsAsync(string collectionName, CancellationToken cancellationToken = default)
{
if (!this._collections.Contains(collectionName))
{
var r = await this.DoesCollectionExistAsync(collectionName, cancellationToken).ConfigureAwait(false);
if (!r)
{
await this.CreateCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false);
}
this._collections.Add(collectionName);
}
}
internal const string DefaultSchema = "public";
/// <summary>
/// Initializes a new instance of the <see cref="AwesomePostgresMemoryStore"/> class.
/// </summary>
/// <param name="dataSource">Postgres data source.</param>
/// <param name="vectorSize">Embedding vector size.</param>
/// <param name="schema">Database schema of collection tables. The default value is "public".</param>
public AwesomePostgresMemoryStore(NpgsqlDataSource dataSource, int vectorSize, string schema = DefaultSchema)
: this(new PostgresDbClient(dataSource, schema, vectorSize))
{
}
/// <summary>
/// Initializes a new instance of the <see cref="AwesomePostgresMemoryStore"/> class.
/// </summary>
/// <param name="postgresDbClient">An instance of <see cref="IPostgresDbClient"/>.</param>
public AwesomePostgresMemoryStore(IPostgresDbClient postgresDbClient)
{
this._postgresDbClient = postgresDbClient;
}
/// <inheritdoc/>
public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
await this._postgresDbClient.CreateTableAsync(collectionName, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public async Task<bool> DoesCollectionExistAsync(string collectionName, CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
return await this._postgresDbClient.DoesTableExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public async IAsyncEnumerable<string> GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (string collection in this._postgresDbClient.GetTablesAsync(cancellationToken).ConfigureAwait(false))
{
yield return collection;
}
}
/// <inheritdoc/>
public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
await this._postgresDbClient.DeleteTableAsync(collectionName, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public async Task<string> UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
return await this.InternalUpsertAsync(collectionName, record, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public async IAsyncEnumerable<string> UpsertBatchAsync(string collectionName, IEnumerable<MemoryRecord> records,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
foreach (MemoryRecord record in records)
{
yield return await this.InternalUpsertAsync(collectionName, record, cancellationToken).ConfigureAwait(false);
}
}
/// <inheritdoc/>
public async Task<MemoryRecord?> GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
PostgresMemoryEntry? entry = await this._postgresDbClient.ReadAsync(collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false);
if (!entry.HasValue) { return null; }
return this.GetMemoryRecordFromEntry(entry.Value);
}
/// <inheritdoc/>
public async IAsyncEnumerable<MemoryRecord> GetBatchAsync(string collectionName, IEnumerable<string> keys, bool withEmbeddings = false,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
await foreach (PostgresMemoryEntry entry in this._postgresDbClient.ReadBatchAsync(collectionName, keys, withEmbeddings, cancellationToken).ConfigureAwait(false))
{
yield return this.GetMemoryRecordFromEntry(entry);
}
}
/// <inheritdoc/>
public async Task RemoveAsync(string collectionName, string key, CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
await this._postgresDbClient.DeleteAsync(collectionName, key, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public async Task RemoveBatchAsync(string collectionName, IEnumerable<string> keys, CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
await this._postgresDbClient.DeleteBatchAsync(collectionName, keys, cancellationToken).ConfigureAwait(false);
}
/// <inheritdoc/>
public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync(
string collectionName,
ReadOnlyMemory<float> embedding,
int limit,
double minRelevanceScore = 0,
bool withEmbeddings = false,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
//Verify.NotNullOrWhiteSpace(collectionName);
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
var r = await this.DoesCollectionExistAsync(collectionName, cancellationToken).ConfigureAwait(false);
if (!r)
{
await this.CreateCollectionAsync(collectionName, cancellationToken).ConfigureAwait(false);
}
if (limit <= 0)
{
yield break;
}
IAsyncEnumerable<(PostgresMemoryEntry, double)> results = this._postgresDbClient.GetNearestMatchesAsync(
tableName: collectionName,
embedding: new Vector(GetOrCreateArray(embedding)),
limit: limit,
minRelevanceScore: minRelevanceScore,
withEmbeddings: withEmbeddings,
cancellationToken: cancellationToken);
await foreach ((PostgresMemoryEntry entry, double cosineSimilarity) in results.ConfigureAwait(false))
{
yield return (this.GetMemoryRecordFromEntry(entry), cosineSimilarity);
}
}
/// <inheritdoc/>
public async Task<(MemoryRecord, double)?> GetNearestMatchAsync(string collectionName, ReadOnlyMemory<float> embedding, double minRelevanceScore = 0, bool withEmbedding = false,
CancellationToken cancellationToken = default)
{
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
return await this.GetNearestMatchesAsync(
collectionName: collectionName,
embedding: embedding,
limit: 1,
minRelevanceScore: minRelevanceScore,
withEmbeddings: withEmbedding,
cancellationToken: cancellationToken).FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false);
}
#region private ================================================================================
private readonly IPostgresDbClient _postgresDbClient;
private async Task<string> InternalUpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken)
{
await this.EnsureCollectionExistsAsync(collectionName, cancellationToken).ConfigureAwait(false);
record.Key = record.Metadata.Id;
await this._postgresDbClient.UpsertAsync(
tableName: collectionName,
key: record.Key,
metadata: record.GetSerializedMetadata(),
embedding: new Vector(GetOrCreateArray(record.Embedding)),
timestamp: record.Timestamp?.UtcDateTime,
cancellationToken: cancellationToken).ConfigureAwait(false);
return record.Key;
}
private MemoryRecord GetMemoryRecordFromEntry(PostgresMemoryEntry entry)
{
return MemoryRecord.FromJsonMetadata(
json: entry.MetadataString,
embedding: entry.Embedding?.ToArray() ?? ReadOnlyMemory<float>.Empty,
key: entry.Key,
timestamp: entry.Timestamp?.ToLocalTime());
}
private static float[] GetOrCreateArray(ReadOnlyMemory<float> memory) =>
MemoryMarshal.TryGetArray(memory, out ArraySegment<float> array) &&
array.Count == array.Array!.Length ?
array.Array :
memory.ToArray();
#endregion
}
AsyncEnumerable.cs
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
#pragma warning disable CA1510 // Use 'ArgumentNullException.ThrowIfNull' (.NET 8)
// Used for compatibility with System.Linq.Async Nuget pkg
namespace System.Linq;
internal static class AsyncEnumerable
{
public static IAsyncEnumerable<T> Empty<T>() => EmptyAsyncEnumerable<T>.Instance;
public static IEnumerable<T> ToEnumerable<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
{
var enumerator = source.GetAsyncEnumerator(cancellationToken);
try
{
while (enumerator.MoveNextAsync().AsTask().GetAwaiter().GetResult())
{
yield return enumerator.Current;
}
}
finally
{
enumerator.DisposeAsync().AsTask().GetAwaiter().GetResult();
}
}
public static async IAsyncEnumerable<T> ToAsyncEnumerable<T>(this IEnumerable<T> source)
{
foreach (var item in source)
{
yield return item;
}
}
public static async ValueTask<T?> FirstOrDefaultAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
{
await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
return item;
}
return default;
}
public static async ValueTask<T?> LastOrDefaultAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
{
var last = default(T)!; // NB: Only matters when hasLast is set to true.
var hasLast = false;
await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
hasLast = true;
last = item;
}
return hasLast ? last! : default;
}
public static async ValueTask<List<T>> ToListAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
{
var result = new List<T>();
await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
result.Add(item);
}
return result;
}
public static async ValueTask<bool> ContainsAsync<T>(this IAsyncEnumerable<T> source, T value)
{
await foreach (var item in source.ConfigureAwait(false))
{
if (EqualityComparer<T>.Default.Equals(item, value))
{
return true;
}
}
return false;
}
public static async ValueTask<int> CountAsync<T>(this IAsyncEnumerable<T> source, CancellationToken cancellationToken = default)
{
int count = 0;
await foreach (var _ in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
checked { count++; }
}
return count;
}
/// <summary>
/// Determines whether any element of an async-enumerable sequence satisfies a condition.
/// </summary>
/// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
/// <param name="source">An async-enumerable sequence whose elements to apply the predicate to.</param>
/// <param name="predicate">A function to test each element for a condition.</param>
/// <param name="cancellationToken">The optional cancellation token to be used for cancelling the sequence at any time.</param>
/// <returns>An async-enumerable sequence containing a single element determining whether any elements in the source sequence pass the test in the specified predicate.</returns>
/// <exception cref="ArgumentNullException"><paramref name="source"/> or <paramref name="predicate"/> is null.</exception>
/// <remarks>The return type of this operator differs from the corresponding operator on IEnumerable in order to retain asynchronous behavior.</remarks>
public static ValueTask<bool> AnyAsync<TSource>(this IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken = default)
{
if (source == null)
{
throw new ArgumentNullException(nameof(source));
}
if (predicate == null)
{
throw new ArgumentNullException(nameof(predicate));
}
return Core(source, predicate, cancellationToken);
static async ValueTask<bool> Core(IAsyncEnumerable<TSource> source, Func<TSource, bool> predicate, CancellationToken cancellationToken)
{
await foreach (var item in source.WithCancellation(cancellationToken).ConfigureAwait(false))
{
if (predicate(item))
{
return true;
}
}
return false;
}
}
private sealed class EmptyAsyncEnumerable<T> : IAsyncEnumerable<T>, IAsyncEnumerator<T>
{
public static readonly EmptyAsyncEnumerable<T> Instance = new();
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default) => this;
public ValueTask<bool> MoveNextAsync() => new(false);
public T Current => default!;
public ValueTask DisposeAsync() => default;
}
}
SemanticKernelExtensions.cs
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.IO;
using System.Net.Http;
using System.Threading.Tasks;
using CopilotChat.WebApi.Hubs;
using CopilotChat.WebApi.Options;
using CopilotChat.WebApi.Services;
using CopilotChat.WebApi.Skills.ChatSkills;
using CopilotChat.WebApi.Storage;
using Microshaoft.SemanticKernel.Connectors.Memory.Postgres;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
using Microsoft.SemanticKernel.Connectors.Memory.AzureCognitiveSearch;
using Microsoft.SemanticKernel.Connectors.Memory.Chroma;
using Microsoft.SemanticKernel.Connectors.Memory.Qdrant;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.Skills.Core;
using Npgsql;
using Pgvector.Npgsql;
using static CopilotChat.WebApi.Options.MemoryStoreOptions;
namespace CopilotChat.WebApi.Extensions;
/// <summary>
/// Extension methods for registering Semantic Kernel related services.
/// </summary>
internal static class SemanticKernelExtensions
{
/// <summary>
/// Delegate to register skills with a Semantic Kernel
/// </summary>
public delegate Task RegisterSkillsWithKernel(IServiceProvider sp, IKernel kernel);
/// <summary>
/// Add Semantic Kernel services
/// </summary>
internal static IServiceCollection AddSemanticKernelServices(this IServiceCollection services)
{
// Semantic Kernel
services.AddScoped<IKernel>(sp =>
{
IKernel kernel = Kernel.Builder
.WithLoggerFactory(sp.GetRequiredService<ILoggerFactory>())
.WithMemory(sp.GetRequiredService<ISemanticTextMemory>())
.WithCompletionBackend(sp.GetRequiredService<IOptions<AIServiceOptions>>().Value)
.WithEmbeddingBackend(sp.GetRequiredService<IOptions<AIServiceOptions>>().Value)
.Build();
sp.GetRequiredService<RegisterSkillsWithKernel>()(sp, kernel);
return kernel;
});
// Semantic memory
services.AddSemanticTextMemory();
// Azure Content Safety
services.AddContentSafety();
// Register skills
services.AddScoped<RegisterSkillsWithKernel>(sp => RegisterSkillsAsync);
return services;
}
/// <summary>
/// Add Planner services
/// </summary>
public static IServiceCollection AddPlannerServices(this IServiceCollection services)
{
IOptions<PlannerOptions>? plannerOptions = services.BuildServiceProvider().GetService<IOptions<PlannerOptions>>();
services.AddScoped<CopilotChatPlanner>(sp =>
{
IKernel plannerKernel = Kernel.Builder
.WithLoggerFactory(sp.GetRequiredService<ILoggerFactory>())
.WithMemory(sp.GetRequiredService<ISemanticTextMemory>())
// TODO: [sk Issue #2046] verify planner has AI service configured
.WithPlannerBackend(sp.GetRequiredService<IOptions<AIServiceOptions>>().Value)
.Build();
return new CopilotChatPlanner(plannerKernel, plannerOptions?.Value, sp.GetRequiredService<ILogger<CopilotChatPlanner>>());
});
// Register Planner skills (AI plugins) here.
// TODO: [sk Issue #2046] Move planner skill registration from ChatController to this location.
return services;
}
/// <summary>
/// Register the chat skill with the kernel.
/// </summary>
public static IKernel RegisterChatSkill(this IKernel kernel, IServiceProvider sp)
{
// Chat skill
kernel.ImportSkill(new ChatSkill(
kernel: kernel,
chatMessageRepository: sp.GetRequiredService<ChatMessageRepository>(),
chatSessionRepository: sp.GetRequiredService<ChatSessionRepository>(),
messageRelayHubContext: sp.GetRequiredService<IHubContext<MessageRelayHub>>(),
promptOptions: sp.GetRequiredService<IOptions<PromptsOptions>>(),
documentImportOptions: sp.GetRequiredService<IOptions<DocumentMemoryOptions>>(),
contentSafety: sp.GetService<AzureContentSafety>(),
planner: sp.GetRequiredService<CopilotChatPlanner>(),
logger: sp.GetRequiredService<ILogger<ChatSkill>>()),
nameof(ChatSkill));
return kernel;
}
/// <summary>
/// Propagate exception from within semantic function
/// </summary>
public static void ThrowIfFailed(this SKContext context)
{
if (context.ErrorOccurred)
{
var logger = context.LoggerFactory.CreateLogger(nameof(SKContext));
logger.LogError(context.LastException, "{0}", context.LastException?.Message);
throw context.LastException!;
}
}
/// <summary>
/// Register the skills with the kernel.
/// </summary>
private static Task RegisterSkillsAsync(IServiceProvider sp, IKernel kernel)
{
// Copilot chat skills
kernel.RegisterChatSkill(sp);
// Time skill
kernel.ImportSkill(new TimeSkill(), nameof(TimeSkill));
// Semantic skills
ServiceOptions options = sp.GetRequiredService<IOptions<ServiceOptions>>().Value;
if (!string.IsNullOrWhiteSpace(options.SemanticSkillsDirectory))
{
foreach (string subDir in Directory.GetDirectories(options.SemanticSkillsDirectory))
{
try
{
kernel.ImportSemanticSkillFromDirectory(options.SemanticSkillsDirectory, Path.GetFileName(subDir)!);
}
catch (SKException ex)
{
var logger = kernel.LoggerFactory.CreateLogger(nameof(Kernel));
logger.LogError("Could not load skill from {Directory}: {Message}", subDir, ex.Message);
}
}
}
return Task.CompletedTask;
}
/// <summary>
/// Add the semantic memory.
/// </summary>
private static void AddSemanticTextMemory(this IServiceCollection services)
{
MemoryStoreOptions config = services.BuildServiceProvider().GetRequiredService<IOptions<MemoryStoreOptions>>().Value;
switch (config.Type)
{
case MemoryStoreType.Volatile:
services.AddSingleton<IMemoryStore, VolatileMemoryStore>();
break;
case MemoryStoreType.Qdrant:
if (config.Qdrant == null)
{
throw new InvalidOperationException("MemoryStore type is Qdrant and Qdrant configuration is null.");
}
services.AddSingleton<IMemoryStore>(sp =>
{
HttpClient httpClient = new(new HttpClientHandler { CheckCertificateRevocationList = true });
if (!string.IsNullOrWhiteSpace(config.Qdrant.Key))
{
httpClient.DefaultRequestHeaders.Add("api-key", config.Qdrant.Key);
}
var endPointBuilder = new UriBuilder(config.Qdrant.Host);
endPointBuilder.Port = config.Qdrant.Port;
return new QdrantMemoryStore(
httpClient: httpClient,
config.Qdrant.VectorSize,
endPointBuilder.ToString(),
loggerFactory: sp.GetRequiredService<ILoggerFactory>()
);
});
break;
case MemoryStoreType.AzureCognitiveSearch:
if (config.AzureCognitiveSearch == null)
{
throw new InvalidOperationException("MemoryStore type is AzureCognitiveSearch and AzureCognitiveSearch configuration is null.");
}
services.AddSingleton<IMemoryStore>(sp =>
{
return new AzureCognitiveSearchMemoryStore(config.AzureCognitiveSearch.Endpoint, config.AzureCognitiveSearch.Key);
});
break;
case MemoryStoreOptions.MemoryStoreType.Chroma:
if (config.Chroma == null)
{
throw new InvalidOperationException("MemoryStore type is Chroma and Chroma configuration is null.");
}
services.AddSingleton<IMemoryStore>(sp =>
{
HttpClient httpClient = new(new HttpClientHandler { CheckCertificateRevocationList = true });
var endPointBuilder = new UriBuilder(config.Chroma.Host);
endPointBuilder.Port = config.Chroma.Port;
return new ChromaMemoryStore(
httpClient: httpClient,
endpoint: endPointBuilder.ToString(),
loggerFactory: sp.GetRequiredService<ILoggerFactory>()
);
});
break;
case MemoryStoreOptions.MemoryStoreType.Postgres:
if (config.Postgres == null)
{
throw new InvalidOperationException("MemoryStore type is Postgres and Postgres configuration is null.");
}
var dataSourceBuilder = new NpgsqlDataSourceBuilder(config.Postgres.ConnectionString);
dataSourceBuilder.UseVector();
services.AddSingleton<IMemoryStore>(sp =>
{
return new AwesomePostgresMemoryStore(
dataSource: dataSourceBuilder.Build(),
vectorSize: config.Postgres.VectorSize
);
});
break;
default:
throw new InvalidOperationException($"Invalid 'MemoryStore' type '{config.Type}'.");
}
services.AddScoped<ISemanticTextMemory>(sp => new SemanticTextMemory(
sp.GetRequiredService<IMemoryStore>(),
sp.GetRequiredService<IOptions<AIServiceOptions>>().Value
.ToTextEmbeddingsService(loggerFactory: sp.GetRequiredService<ILoggerFactory>())));
}
/// <summary>
/// Adds Azure Content Safety
/// </summary>
internal static void AddContentSafety(this IServiceCollection services)
{
IConfiguration configuration = services.BuildServiceProvider().GetRequiredService<IConfiguration>();
ContentSafetyOptions options = configuration.GetSection(ContentSafetyOptions.PropertyName).Get<ContentSafetyOptions>();
if (options.Enabled)
{
services.AddSingleton<IContentSafetyService, AzureContentSafety>(sp => new AzureContentSafety(new Uri(options.Endpoint), options.Key, options));
}
}
/// <summary>
/// Add the completion backend to the kernel config
/// </summary>
private static KernelBuilder WithCompletionBackend(this KernelBuilder kernelBuilder, AIServiceOptions options)
{
return options.Type switch
{
AIServiceOptions.AIServiceType.AzureOpenAI
=> kernelBuilder.WithAzureChatCompletionService(options.Models.Completion, options.Endpoint, options.Key),
AIServiceOptions.AIServiceType.OpenAI
=> kernelBuilder.WithOpenAIChatCompletionService(options.Models.Completion, options.Key),
_
=> throw new ArgumentException($"Invalid {nameof(options.Type)} value in '{AIServiceOptions.PropertyName}' settings."),
};
}
/// <summary>
/// Add the embedding backend to the kernel config
/// </summary>
private static KernelBuilder WithEmbeddingBackend(this KernelBuilder kernelBuilder, AIServiceOptions options)
{
return options.Type switch
{
AIServiceOptions.AIServiceType.AzureOpenAI
=> kernelBuilder.WithAzureTextEmbeddingGenerationService(options.Models.Embedding, options.Endpoint, options.Key),
AIServiceOptions.AIServiceType.OpenAI
=> kernelBuilder.WithOpenAITextEmbeddingGenerationService(options.Models.Embedding, options.Key),
_
=> throw new ArgumentException($"Invalid {nameof(options.Type)} value in '{AIServiceOptions.PropertyName}' settings."),
};
}
/// <summary>
/// Add the completion backend to the kernel config for the planner.
/// </summary>
private static KernelBuilder WithPlannerBackend(this KernelBuilder kernelBuilder, AIServiceOptions options)
{
return options.Type switch
{
AIServiceOptions.AIServiceType.AzureOpenAI => kernelBuilder.WithAzureChatCompletionService(options.Models.Planner, options.Endpoint, options.Key),
AIServiceOptions.AIServiceType.OpenAI => kernelBuilder.WithOpenAIChatCompletionService(options.Models.Planner, options.Key),
_ => throw new ArgumentException($"Invalid {nameof(options.Type)} value in '{AIServiceOptions.PropertyName}' settings."),
};
}
/// <summary>
/// Construct IEmbeddingGeneration from <see cref="AIServiceOptions"/>
/// </summary>
/// <param name="options">The service configuration</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="loggerFactory">Custom <see cref="ILoggerFactory"/> for logging.</param>
private static ITextEmbeddingGeneration ToTextEmbeddingsService(this AIServiceOptions options,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
return options.Type switch
{
AIServiceOptions.AIServiceType.AzureOpenAI
=> new AzureTextEmbeddingGeneration(options.Models.Embedding, options.Endpoint, options.Key, httpClient: httpClient, loggerFactory: loggerFactory),
AIServiceOptions.AIServiceType.OpenAI
=> new OpenAITextEmbeddingGeneration(options.Models.Embedding, options.Key, httpClient: httpClient, loggerFactory: loggerFactory),
_
=> throw new ArgumentException("Invalid AIService value in embeddings backend settings"),
};
}
}
Describe the bug
The WebAPI backend should invoke CreateCollectionAsync once at least for each chat session when use PostgresMemoryStore and local postgresql database as vectors MemoryStore
To Reproduce Steps to reproduce the behavior:
CopilotChatWebApi
appsettings.jsonsee same error in webapp frontend and webapi backend console