From 21371fb58cb74c401c8ad879248d30c75039603e Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 4 Apr 2025 17:39:01 -0400 Subject: [PATCH] Remove use of ConfigureAwait from Microsoft.Extensions.AI.dll for AIFunction invocations We try to use ConfigureAwait(false) throughout our libraries. However, we exempt ourselves from that in cases where user code is expected to be called back from within the async code, and there's a reasonable presumption that such code might care about the synchronization context. AIFunction fits that bill. And FunctionInvokingChatClient needs to invoke such functions, which means that we need to be able to successfully flow the context from where user code calls Get{Streaming}ResponseAsync through into wherever a FunctionInvokingChatClient is in the middleware pipeline. We could try to selectively avoid ConfigureAwait(false) on the path through middleware that could result in calls to FICC.Get{Streaming}ResponseAsync, but that's fairly brittle and hard to maintain. Instead, this PR just removes ConfigureAwait use from the M.E.AI library. It also fixes a few places where tasks were explicitly being created and queued to the thread pool. --- .../AnonymousDelegatingChatClient.cs | 23 ++++--- .../ChatCompletion/CachingChatClient.cs | 18 ++--- .../ChatClientBuilderChatClientExtensions.cs | 1 - .../ChatClientStructuredOutputExtensions.cs | 2 +- .../ConfigureOptionsChatClient.cs | 4 +- .../DistributedCachingChatClient.cs | 8 +-- .../FunctionInvokingChatClient.cs | 27 ++++---- .../ChatCompletion/LoggingChatClient.cs | 6 +- .../ChatCompletion/OpenTelemetryChatClient.cs | 4 +- .../AnonymousDelegatingEmbeddingGenerator.cs | 2 +- .../Embeddings/CachingEmbeddingGenerator.cs | 12 ++-- .../ConfigureOptionsEmbeddingGenerator.cs | 2 +- .../DistributedCachingEmbeddingGenerator.cs | 4 +- ...atorBuilderEmbeddingGeneratorExtensions.cs | 1 - .../Embeddings/LoggingEmbeddingGenerator.cs | 3 +- .../OpenTelemetryEmbeddingGenerator.cs | 2 +- .../Functions/AIFunctionFactory.cs | 34 ++++----- .../Microsoft.Extensions.AI.csproj | 10 +++ .../ConfigureOptionsSpeechToTextClient.cs | 4 +- .../SpeechToText/LoggingSpeechToTextClient.cs | 6 +- ...ientBuilderSpeechToTextClientExtensions.cs | 1 - .../FunctionInvokingChatClientTests.cs | 69 ++++++++++++++++++- 22 files changed, 157 insertions(+), 86 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs index a906d57c870..db256e94916 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs @@ -4,7 +4,9 @@ using System; using System.Collections.Generic; using System.Diagnostics; +#if !NET9_0_OR_GREATER using System.Runtime.CompilerServices; +#endif using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -100,8 +102,8 @@ async Task GetResponseViaSharedAsync( ChatResponse? response = null; await _sharedFunc(messages, options, async (messages, options, cancellationToken) => { - response = await InnerClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - }, cancellationToken).ConfigureAwait(false); + response = await InnerClient.GetResponseAsync(messages, options, cancellationToken); + }, cancellationToken); if (response is null) { @@ -133,20 +135,19 @@ public override IAsyncEnumerable GetStreamingResponseAsync( { var updates = Channel.CreateBounded(1); -#pragma warning disable CA2016 // explicitly not forwarding the cancellation token, as we need to ensure the channel is always completed - _ = Task.Run(async () => -#pragma warning restore CA2016 + _ = ProcessAsync(); + async Task ProcessAsync() { Exception? error = null; try { await _sharedFunc(messages, options, async (messages, options, cancellationToken) => { - await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken)) { - await updates.Writer.WriteAsync(update, cancellationToken).ConfigureAwait(false); + await updates.Writer.WriteAsync(update, cancellationToken); } - }, cancellationToken).ConfigureAwait(false); + }, cancellationToken); } catch (Exception ex) { @@ -157,7 +158,7 @@ await _sharedFunc(messages, options, async (messages, options, cancellationToken { _ = updates.Writer.TryComplete(error); } - }); + } #if NET9_0_OR_GREATER return updates.Reader.ReadAllAsync(cancellationToken); @@ -166,7 +167,7 @@ await _sharedFunc(messages, options, async (messages, options, cancellationToken static async IAsyncEnumerable ReadAllAsync( ChannelReader channel, [EnumeratorCancellation] CancellationToken cancellationToken) { - while (await channel.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + while (await channel.WaitToReadAsync(cancellationToken)) { while (channel.TryRead(out var update)) { @@ -187,7 +188,7 @@ static async IAsyncEnumerable ReadAllAsync( static async IAsyncEnumerable GetStreamingResponseAsyncViaGetResponseAsync(Task task) { - ChatResponse response = await task.ConfigureAwait(false); + ChatResponse response = await task; foreach (var update in response.ToChatResponseUpdates()) { yield return update; diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index 61421b005e7..6fed2157b0b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -55,10 +55,10 @@ public override async Task GetResponseAsync( // concurrent callers might trigger duplicate requests, but that's acceptable. var cacheKey = GetCacheKey(messages, options, _boxedFalse); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result) + if (await ReadCacheAsync(cacheKey, cancellationToken) is not { } result) { - result = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false); + result = await base.GetResponseAsync(messages, options, cancellationToken); + await WriteCacheAsync(cacheKey, result, cancellationToken); } return result; @@ -77,7 +77,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one. var cacheKey = GetCacheKey(messages, options, _boxedTrue); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatResponse) + if (await ReadCacheAsync(cacheKey, cancellationToken) is { } chatResponse) { // Yield all of the cached items. foreach (var chunk in chatResponse.ToChatResponseUpdates()) @@ -89,20 +89,20 @@ public override async IAsyncEnumerable GetStreamingResponseA { // Yield and store all of the items. List capturedItems = []; - await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken)) { capturedItems.Add(chunk); yield return chunk; } // Write the captured items to the cache as a non-streaming result. - await WriteCacheAsync(cacheKey, capturedItems.ToChatResponse(), cancellationToken).ConfigureAwait(false); + await WriteCacheAsync(cacheKey, capturedItems.ToChatResponse(), cancellationToken); } } else { var cacheKey = GetCacheKey(messages, options, _boxedTrue); - if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) + if (await ReadCacheStreamingAsync(cacheKey, cancellationToken) is { } existingChunks) { // Yield all of the cached items. string? chatThreadId = null; @@ -116,14 +116,14 @@ public override async IAsyncEnumerable GetStreamingResponseA { // Yield and store all of the items. List capturedItems = []; - await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken)) { capturedItems.Add(chunk); yield return chunk; } // Write the captured items to the cache. - await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false); + await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken); } } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs index b4e1e7f280f..a43bf5fac75 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index 7ad8ea1d279..915b86b4ee3 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -221,7 +221,7 @@ public static async Task> GetResponseAsync( messages = [.. messages, promptAugmentation]; } - var result = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + var result = await chatClient.GetResponseAsync(messages, options, cancellationToken); return new ChatResponse(result, serializerOptions) { IsWrappedInObject = isWrappedInObject }; } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs index 5a5dfea06c3..50da3928157 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs @@ -36,13 +36,13 @@ public ConfigureOptionsChatClient(IChatClient innerClient, Action c /// public override async Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => - await base.GetResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false); + await base.GetResponseAsync(messages, Configure(options), cancellationToken); /// public override async IAsyncEnumerable GetStreamingResponseAsync( IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken)) { yield return update; } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs index 2312eadcb0d..158c560de14 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs @@ -52,7 +52,7 @@ public JsonSerializerOptions JsonSerializerOptions _ = Throw.IfNull(key); _jsonSerializerOptions.MakeReadOnly(); - if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson) { return (ChatResponse?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(ChatResponse))); } @@ -66,7 +66,7 @@ public JsonSerializerOptions JsonSerializerOptions _ = Throw.IfNull(key); _jsonSerializerOptions.MakeReadOnly(); - if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson) { return (IReadOnlyList?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList))); } @@ -82,7 +82,7 @@ protected override async Task WriteCacheAsync(string key, ChatResponse value, Ca _jsonSerializerOptions.MakeReadOnly(); var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(ChatResponse))); - await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + await _storage.SetAsync(key, newJson, cancellationToken); } /// @@ -93,7 +93,7 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList _jsonSerializerOptions.MakeReadOnly(); var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList))); - await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + await _storage.SetAsync(key, newJson, cancellationToken); } /// Computes a cache key for the specified values. diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index ad88ba90265..6978a01dd44 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -13,7 +13,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Shared.Diagnostics; -using static Microsoft.Extensions.AI.OpenTelemetryConsts.GenAI; #pragma warning disable CA2213 // Disposable fields should be disposed #pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test @@ -233,7 +232,7 @@ public override async Task GetResponseAsync( functionCallContents?.Clear(); // Make the call to the inner client. - response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + response = await base.GetResponseAsync(messages, options, cancellationToken); if (response is null) { Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); @@ -279,7 +278,7 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken); responseMessages.AddRange(modeAndMessages.MessagesAdded); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; @@ -325,7 +324,7 @@ public override async IAsyncEnumerable GetStreamingResponseA updates.Clear(); functionCallContents?.Clear(); - await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken)) { if (update is null) { @@ -356,7 +355,7 @@ public override async IAsyncEnumerable GetStreamingResponseA FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); // Process all of the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken); responseMessages.AddRange(modeAndMessages.MessagesAdded); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; @@ -534,7 +533,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin if (functionCallContents.Count == 1) { FunctionInvocationResult result = await ProcessFunctionCallAsync( - messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken).ConfigureAwait(false); + messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken); IList added = CreateResponseMessages([result]); ThrowIfNoFunctionResultsAdded(added); @@ -549,13 +548,15 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin if (AllowConcurrentInvocation) { - // Schedule the invocation of every function. - // In this case we always capture exceptions because the ordering is nondeterministic + // Rather than await'ing each function before invoking the next, invoke all of them + // and then await all of them. We avoid forcibly introducing parallelism via Task.Run, + // but if a function invocation completes asynchronously, its processing can overlap + // with the processing of other the other invocation invocations. results = await Task.WhenAll( from i in Enumerable.Range(0, functionCallContents.Count) - select Task.Run(() => ProcessFunctionCallAsync( + select ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, captureExceptions: true, cancellationToken))).ConfigureAwait(false); + iteration, i, captureExceptions: true, cancellationToken)); } else { @@ -565,7 +566,7 @@ select Task.Run(() => ProcessFunctionCallAsync( { results[i] = await ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, captureCurrentIterationExceptions, cancellationToken).ConfigureAwait(false); + iteration, i, captureCurrentIterationExceptions, cancellationToken); } } @@ -663,7 +664,7 @@ private async Task ProcessFunctionCallAsync( object? result; try { - result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false); + result = await InvokeFunctionAsync(context, cancellationToken); } catch (Exception e) when (!cancellationToken.IsCancellationRequested) { @@ -763,7 +764,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul try { CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit - result = await context.Function.InvokeAsync(context.Arguments, cancellationToken).ConfigureAwait(false); + result = await context.Function.InvokeAsync(context.Arguments, cancellationToken); } catch (Exception e) { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs index 51ca5a8f6d1..b5f43f5385b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs @@ -60,7 +60,7 @@ public override async Task GetResponseAsync( try { - var response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + var response = await base.GetResponseAsync(messages, options, cancellationToken); if (_logger.IsEnabled(LogLevel.Debug)) { @@ -127,7 +127,7 @@ public override async IAsyncEnumerable GetStreamingResponseA { try { - if (!await e.MoveNextAsync().ConfigureAwait(false)) + if (!await e.MoveNextAsync()) { break; } @@ -164,7 +164,7 @@ public override async IAsyncEnumerable GetStreamingResponseA } finally { - await e.DisposeAsync().ConfigureAwait(false); + await e.DisposeAsync(); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs index df1717b4faa..c74bd3aa3c1 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -145,7 +145,7 @@ public override async Task GetResponseAsync( Exception? error = null; try { - response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + response = await base.GetResponseAsync(messages, options, cancellationToken); return response; } catch (Exception ex) @@ -183,7 +183,7 @@ public override async IAsyncEnumerable GetStreamingResponseA throw; } - var responseEnumerator = updates.ConfigureAwait(false).GetAsyncEnumerator(); + var responseEnumerator = updates.GetAsyncEnumerator(cancellationToken); List trackedUpdates = []; Exception? error = null; try diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/AnonymousDelegatingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/AnonymousDelegatingEmbeddingGenerator.cs index 0f6c696bd0d..a3a068b9c34 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/AnonymousDelegatingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/AnonymousDelegatingEmbeddingGenerator.cs @@ -39,6 +39,6 @@ public override async Task> GenerateAsync( { _ = Throw.IfNull(values); - return await _generateFunc(values, options, InnerGenerator, cancellationToken).ConfigureAwait(false); + return await _generateFunc(values, options, InnerGenerator, cancellationToken); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs index 43a983d7fd4..2c880d7a22c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs @@ -42,19 +42,19 @@ public override async Task> GenerateAsync( // In the expected common case where we can cheaply tell there's only a single value and access it, // we can avoid all the overhead of splitting the list and reassembling it. var cacheKey = GetCacheKey(valuesList[0], options); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is TEmbedding e) + if (await ReadCacheAsync(cacheKey, cancellationToken) is TEmbedding e) { return [e]; } else { - var generated = await base.GenerateAsync(valuesList, options, cancellationToken).ConfigureAwait(false); + var generated = await base.GenerateAsync(valuesList, options, cancellationToken); if (generated.Count != 1) { Throw.InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}."); } - await WriteCacheAsync(cacheKey, generated[0], cancellationToken).ConfigureAwait(false); + await WriteCacheAsync(cacheKey, generated[0], cancellationToken); return generated; } } @@ -72,7 +72,7 @@ public override async Task> GenerateAsync( // concurrent callers might trigger duplicate requests, but that's acceptable. var cacheKey = GetCacheKey(input, options); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is TEmbedding existing) + if (await ReadCacheAsync(cacheKey, cancellationToken) is TEmbedding existing) { results.Add(existing); } @@ -87,12 +87,12 @@ public override async Task> GenerateAsync( if (uncached is not null) { // Now make a single call to the wrapped generator to generate embeddings for all of the uncached inputs. - var uncachedResults = await base.GenerateAsync(uncached.Select(e => e.Input), options, cancellationToken).ConfigureAwait(false); + var uncachedResults = await base.GenerateAsync(uncached.Select(e => e.Input), options, cancellationToken); // Store the resulting embeddings into the cache individually. for (int i = 0; i < uncachedResults.Count; i++) { - await WriteCacheAsync(uncached[i].CacheKey, uncachedResults[i], cancellationToken).ConfigureAwait(false); + await WriteCacheAsync(uncached[i].CacheKey, uncachedResults[i], cancellationToken); } // Fill in the gaps with the newly generated results. diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs index 8332064f22a..7d7ef140af7 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs @@ -46,7 +46,7 @@ public override async Task> GenerateAsync( EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) { - return await base.GenerateAsync(values, Configure(options), cancellationToken).ConfigureAwait(false); + return await base.GenerateAsync(values, Configure(options), cancellationToken); } /// Creates and configures the to pass along to the inner client. diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs index d6c20ffb2f5..cd26879d040 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs @@ -57,7 +57,7 @@ public JsonSerializerOptions JsonSerializerOptions _ = Throw.IfNull(key); _jsonSerializerOptions.MakeReadOnly(); - if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson) { return JsonSerializer.Deserialize(existingJson, (JsonTypeInfo)_jsonSerializerOptions.GetTypeInfo(typeof(TEmbedding))); } @@ -73,7 +73,7 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc _jsonSerializerOptions.MakeReadOnly(); var newJson = JsonSerializer.SerializeToUtf8Bytes(value, (JsonTypeInfo)_jsonSerializerOptions.GetTypeInfo(typeof(TEmbedding))); - await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + await _storage.SetAsync(key, newJson, cancellationToken); } /// Computes a cache key for the specified values. diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs index 84d4815cb23..751a5edd443 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs index 90553ca5411..924ee362633 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -62,7 +61,7 @@ public override async Task> GenerateAsync(IEnume try { - var embeddings = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + var embeddings = await base.GenerateAsync(values, options, cancellationToken); LogCompleted(embeddings.Count); diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs index f6983408b85..14332d1253f 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs @@ -104,7 +104,7 @@ public override async Task> GenerateAsync(IEnume Exception? error = null; try { - response = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + response = await base.GenerateAsync(values, options, cancellationToken); } catch (Exception ex) { diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index 6537f3aa3ab..41550ba0451 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -303,7 +303,7 @@ private ReflectionAIFunction( } return await FunctionDescriptor.ReturnParameterMarshaller( - ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken).ConfigureAwait(false); + ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken); } finally { @@ -311,7 +311,7 @@ private ReflectionAIFunction( { if (target is IAsyncDisposable ad) { - await ad.DisposeAsync().ConfigureAwait(false); + await ad.DisposeAsync(); } else if (target is IDisposable d) { @@ -599,14 +599,14 @@ static bool IsAsyncMethod(MethodInfo method) { return async (result, cancellationToken) => { - await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); - return await marshalResult(null, null, cancellationToken).ConfigureAwait(false); + await ((Task)ThrowIfNullResult(result)); + return await marshalResult(null, null, cancellationToken); }; } return async static (result, _) => { - await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); + await ((Task)ThrowIfNullResult(result)); return null; }; } @@ -618,14 +618,14 @@ static bool IsAsyncMethod(MethodInfo method) { return async (result, cancellationToken) => { - await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); - return await marshalResult(null, null, cancellationToken).ConfigureAwait(false); + await ((ValueTask)ThrowIfNullResult(result)); + return await marshalResult(null, null, cancellationToken); }; } return async static (result, _) => { - await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); + await ((ValueTask)ThrowIfNullResult(result)); return null; }; } @@ -640,18 +640,18 @@ static bool IsAsyncMethod(MethodInfo method) { return async (taskObj, cancellationToken) => { - await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); + await ((Task)ThrowIfNullResult(taskObj)); object? result = ReflectionInvoke(taskResultGetter, taskObj, null); - return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken).ConfigureAwait(false); + return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken); }; } returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType); return async (taskObj, cancellationToken) => { - await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); + await ((Task)ThrowIfNullResult(taskObj)); object? result = ReflectionInvoke(taskResultGetter, taskObj, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken); }; } @@ -666,9 +666,9 @@ static bool IsAsyncMethod(MethodInfo method) return async (taskObj, cancellationToken) => { var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; - await task.ConfigureAwait(false); + await task; object? result = ReflectionInvoke(asTaskResultGetter, task, null); - return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken).ConfigureAwait(false); + return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken); }; } @@ -676,9 +676,9 @@ static bool IsAsyncMethod(MethodInfo method) return async (taskObj, cancellationToken) => { var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; - await task.ConfigureAwait(false); + await task; object? result = ReflectionInvoke(asTaskResultGetter, task, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken); }; } } @@ -702,7 +702,7 @@ static bool IsAsyncMethod(MethodInfo method) // Serialize asynchronously to support potential IAsyncEnumerable responses. using PooledMemoryStream stream = new(); - await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken); Utf8JsonReader reader = new(stream.GetBuffer()); return JsonElement.ParseValue(ref reader); } diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj index c851ccfb846..3b621827213 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -16,6 +16,16 @@ $(TargetFrameworks);netstandard2.0 $(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253 + + + $(NoWarn);CA2007 + true true diff --git a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/ConfigureOptionsSpeechToTextClient.cs b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/ConfigureOptionsSpeechToTextClient.cs index 85833a3c171..1601b3c5073 100644 --- a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/ConfigureOptionsSpeechToTextClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/ConfigureOptionsSpeechToTextClient.cs @@ -40,14 +40,14 @@ public ConfigureOptionsSpeechToTextClient(ISpeechToTextClient innerClient, Actio public override async Task GetTextAsync( Stream audioSpeechStream, SpeechToTextOptions? options = null, CancellationToken cancellationToken = default) { - return await base.GetTextAsync(audioSpeechStream, Configure(options), cancellationToken).ConfigureAwait(false); + return await base.GetTextAsync(audioSpeechStream, Configure(options), cancellationToken); } /// public override async IAsyncEnumerable GetStreamingTextAsync( Stream audioSpeechStream, SpeechToTextOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - await foreach (var update in base.GetStreamingTextAsync(audioSpeechStream, Configure(options), cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingTextAsync(audioSpeechStream, Configure(options), cancellationToken)) { yield return update; } diff --git a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/LoggingSpeechToTextClient.cs b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/LoggingSpeechToTextClient.cs index 4494d319dc0..6c5bf0ed929 100644 --- a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/LoggingSpeechToTextClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/LoggingSpeechToTextClient.cs @@ -63,7 +63,7 @@ public override async Task GetTextAsync( try { - var response = await base.GetTextAsync(audioSpeechStream, options, cancellationToken).ConfigureAwait(false); + var response = await base.GetTextAsync(audioSpeechStream, options, cancellationToken); if (_logger.IsEnabled(LogLevel.Debug)) { @@ -130,7 +130,7 @@ public override async IAsyncEnumerable GetStreamingT { try { - if (!await e.MoveNextAsync().ConfigureAwait(false)) + if (!await e.MoveNextAsync()) { break; } @@ -167,7 +167,7 @@ public override async IAsyncEnumerable GetStreamingT } finally { - await e.DisposeAsync().ConfigureAwait(false); + await e.DisposeAsync(); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/SpeechToTextClientBuilderSpeechToTextClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/SpeechToTextClientBuilderSpeechToTextClientExtensions.cs index 29569c55207..650282949f8 100644 --- a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/SpeechToTextClientBuilderSpeechToTextClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/SpeechToTextClientBuilderSpeechToTextClientExtensions.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; -using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 30332cb3e3c..67b2025b7de 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -122,15 +122,22 @@ public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentIn [Fact] public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync() { - using var barrier = new Barrier(2); + int remaining = 2; + var tcs = new TaskCompletionSource(); var options = new ChatOptions { Tools = [ - AIFunctionFactory.Create((string arg) => + AIFunctionFactory.Create(async (string arg) => { - barrier.SignalAndWait(); + if (Interlocked.Decrement(ref remaining) == 0) + { + tcs.SetResult(true); + } + + await tcs.Task; + return arg + arg; }, "Func"), ] @@ -867,6 +874,62 @@ public async Task FunctionInvocations_PassesServices() await InvokeAndAssertAsync(options, plan, services: expected); } + [Fact] + public async Task FunctionInvocations_InvokedOnOriginalSynchronizationContext() + { + SynchronizationContext ctx = new CustomSynchronizationContext(); + SynchronizationContext.SetSynchronizationContext(ctx); + + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg"] = "value1" }), + new FunctionCallContent("callId2", "Func1", new Dictionary { ["arg"] = "value2" }), + ]), + new ChatMessage(ChatRole.Tool, + [ + new FunctionResultContent("callId2", result: "value1"), + new FunctionResultContent("callId2", result: "value2") + ]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(async (string arg, CancellationToken cancellationToken) => + { + await Task.Delay(1, cancellationToken); + Assert.Same(ctx, SynchronizationContext.Current); + return arg; + }, "Func1")] + }; + + Func configurePipeline = builder => builder + .Use(async (messages, options, next, cancellationToken) => + { + await Task.Delay(1, cancellationToken); + await next(messages, options, cancellationToken); + }) + .UseOpenTelemetry() + .UseFunctionInvocation(configure: c => { c.AllowConcurrentInvocation = true; c.IncludeDetailedErrors = true; }); + + await InvokeAndAssertAsync(options, plan, configurePipeline: configurePipeline); + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configurePipeline); + } + + private sealed class CustomSynchronizationContext : SynchronizationContext + { + public override void Post(SendOrPostCallback d, object? state) + { + ThreadPool.QueueUserWorkItem(delegate + { + SetSynchronizationContext(this); + d(state); + }); + } + } + private static async Task> InvokeAndAssertAsync( ChatOptions options, List plan,