diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs
index fa2a1df4fbe..944ce0995f8 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
@@ -18,7 +19,7 @@ public static class EmbeddingGeneratorExtensions
/// The embedding generation options to configure the request.
/// The to monitor for cancellation requests. The default is .
/// The generated embedding for the specified .
- public static Task> GenerateAsync(
+ public static async Task GenerateAsync(
this IEmbeddingGenerator generator,
TValue value,
EmbeddingGenerationOptions? options = null,
@@ -28,6 +29,12 @@ public static Task> GenerateAsync>>([result])
};
- Assert.Same(result, (await service.GenerateAsync("hello"))[0]);
+ Assert.Same(result, await service.GenerateAsync("hello"));
}
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs
index 29502f926c6..1929869c487 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs
@@ -44,7 +44,7 @@ public virtual async Task GenerateEmbedding_CreatesEmbeddingSuccessfully()
{
SkipIfNotEnabled();
- var embeddings = await _embeddingGenerator.GenerateAsync("Using AI with .NET");
+ var embeddings = await _embeddingGenerator.GenerateAsync(["Using AI with .NET"]);
Assert.NotNull(embeddings.Usage);
Assert.NotNull(embeddings.Usage.InputTokenCount);
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs
index 2b4370222c6..9a5086a146d 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs
@@ -44,17 +44,15 @@ public async Task CachesSuccessResultsAsync()
// Make the initial request and do a quick sanity check
var result1 = await outer.GenerateAsync("abc");
- Assert.Single(result1);
- AssertEmbeddingsEqual(_expectedEmbedding, result1[0]);
+ AssertEmbeddingsEqual(_expectedEmbedding, result1);
Assert.Equal(1, innerCallCount);
// Act
var result2 = await outer.GenerateAsync("abc");
// Assert
- Assert.Single(result2);
Assert.Equal(1, innerCallCount);
- AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
+ AssertEmbeddingsEqual(_expectedEmbedding, result2);
// Act/Assert 2: Cache misses do not return cached results
await outer.GenerateAsync(["def"]);
@@ -144,13 +142,13 @@ public async Task AllowsConcurrentCallsAsync()
Assert.False(result1.IsCompleted);
Assert.False(result2.IsCompleted);
completionTcs.SetResult(true);
- AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]);
- AssertEmbeddingsEqual(_expectedEmbedding, (await result2)[0]);
+ AssertEmbeddingsEqual(_expectedEmbedding, await result1);
+ AssertEmbeddingsEqual(_expectedEmbedding, await result2);
// Act 2: Subsequent calls after completion are resolved from the cache
var result3 = await outer.GenerateAsync("abc");
Assert.Equal(2, innerCallCount);
- AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]);
+ AssertEmbeddingsEqual(_expectedEmbedding, await result1);
}
[Fact]
@@ -218,9 +216,8 @@ public async Task DoesNotCacheCanceledResultsAsync()
// Act/Assert: Second call can succeed
var result2 = await outer.GenerateAsync("abc");
- Assert.Single(result2);
Assert.Equal(2, innerCallCount);
- AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
+ AssertEmbeddingsEqual(_expectedEmbedding, result2);
}
[Fact]
@@ -254,11 +251,9 @@ public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync()
});
// Assert: Same result
- Assert.Single(result1);
- Assert.Single(result2);
Assert.Equal(1, innerCallCount);
- AssertEmbeddingsEqual(_expectedEmbedding, result1[0]);
- AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
+ AssertEmbeddingsEqual(_expectedEmbedding, result1);
+ AssertEmbeddingsEqual(_expectedEmbedding, result2);
}
[Fact]
@@ -292,11 +287,9 @@ public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync()
});
// Assert: Different results
- Assert.Single(result1);
- Assert.Single(result2);
Assert.Equal(2, innerCallCount);
- AssertEmbeddingsEqual(_expectedEmbedding, result1[0]);
- AssertEmbeddingsEqual(_expectedEmbedding, result2[0]);
+ AssertEmbeddingsEqual(_expectedEmbedding, result1);
+ AssertEmbeddingsEqual(_expectedEmbedding, result2);
}
[Fact]