diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientMetadata.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientMetadata.cs index ace4dead9e3..027817ddff8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientMetadata.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientMetadata.cs @@ -14,10 +14,10 @@ public class ChatClientMetadata /// appropriate name defined in the OpenTelemetry Semantic Conventions for Generative AI systems. /// /// The URL for accessing the chat provider, if applicable. - /// The ID of the chat model used, if applicable. - public ChatClientMetadata(string? providerName = null, Uri? providerUri = null, string? modelId = null) + /// The ID of the chat model used by default, if applicable. + public ChatClientMetadata(string? providerName = null, Uri? providerUri = null, string? defaultModelId = null) { - ModelId = modelId; + DefaultModelId = defaultModelId; ProviderName = providerName; ProviderUri = providerUri; } @@ -32,10 +32,10 @@ public ChatClientMetadata(string? providerName = null, Uri? providerUri = null, /// Gets the URL for accessing the chat provider. public Uri? ProviderUri { get; } - /// Gets the ID of the model used by this chat provider. + /// Gets the ID of the default model used by this chat client. /// - /// This value can be null if either the name is unknown or there are multiple possible models associated with this instance. + /// This value can be null if no default model is set on the corresponding . /// An individual request may override this value via . /// - public string? ModelId { get; } + public string? DefaultModelId { get; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorMetadata.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorMetadata.cs index a3f5181648b..5f5f9d8c5c2 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorMetadata.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorMetadata.cs @@ -9,20 +9,19 @@ namespace Microsoft.Extensions.AI; public class EmbeddingGeneratorMetadata { /// Initializes a new instance of the class. - /// /// The name of the embedding generation provider, if applicable. Where possible, this should map to the /// appropriate name defined in the OpenTelemetry Semantic Conventions for Generative AI systems. /// /// The URL for accessing the embedding generation provider, if applicable. - /// The ID of the embedding generation model used, if applicable. - /// The number of dimensions in vectors produced by this generator, if applicable. - public EmbeddingGeneratorMetadata(string? providerName = null, Uri? providerUri = null, string? modelId = null, int? dimensions = null) + /// The ID of the default embedding generation model used, if applicable. + /// The number of dimensions in vectors produced by the default model, if applicable. + public EmbeddingGeneratorMetadata(string? providerName = null, Uri? providerUri = null, string? defaultModelId = null, int? defaultModelDimensions = null) { - ModelId = modelId; + DefaultModelId = defaultModelId; ProviderName = providerName; ProviderUri = providerUri; - Dimensions = dimensions; + DefaultModelDimensions = defaultModelDimensions; } /// Gets the name of the embedding generation provider. @@ -35,17 +34,17 @@ public EmbeddingGeneratorMetadata(string? providerName = null, Uri? providerUri /// Gets the URL for accessing the embedding generation provider. public Uri? ProviderUri { get; } - /// Gets the ID of the model used by this embedding generation provider. + /// Gets the ID of the default model used by this embedding generator. /// - /// This value can be null if either the name is unknown or there are multiple possible models associated with this instance. + /// This value can be null if no default model is set on the corresponding embedding generator. /// An individual request may override this value via . /// - public string? ModelId { get; } + public string? DefaultModelId { get; } - /// Gets the number of dimensions in the embeddings produced by this instance. + /// Gets the number of dimensions in the embeddings produced by the default model. /// - /// This value can be null if either the number of dimensions is unknown or there are multiple possible lengths associated with this instance. + /// This value can be null if either the number of dimensions is unknown or there are multiple possible lengths associated with this model. /// An individual request may override this value via . /// - public int? Dimensions { get; } + public int? DefaultModelDimensions { get; } } diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs index 721af6e0d1c..c4ccafe098e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -59,7 +59,7 @@ public AzureAIInferenceChatClient(ChatCompletionsClient chatCompletionsClient, s var providerUrl = typeof(ChatCompletionsClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) ?.GetValue(chatCompletionsClient) as Uri; - _metadata = new("az.ai.inference", providerUrl, modelId); + _metadata = new ChatClientMetadata("az.ai.inference", providerUrl, modelId); } /// Gets or sets to use for any serialization activities related to tool call arguments and results. @@ -288,7 +288,7 @@ private ChatCompletionsOptions ToAzureAIOptions(IEnumerable chatCon { ChatCompletionsOptions result = new(ToAzureAIInferenceChatMessages(chatContents)) { - Model = options?.ModelId ?? _metadata.ModelId ?? throw new InvalidOperationException("No model id was provided when either constructing the client or in the chat options.") + Model = options?.ModelId ?? _metadata.DefaultModelId ?? throw new InvalidOperationException("No model id was provided when either constructing the client or in the chat options.") }; if (options is not null) diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs index 5cadc200869..b3fadc7bf54 100644 --- a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs @@ -69,7 +69,7 @@ public AzureAIInferenceEmbeddingGenerator( var providerUrl = typeof(EmbeddingsClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) ?.GetValue(embeddingsClient) as Uri; - _metadata = new("az.ai.inference", providerUrl, modelId, dimensions); + _metadata = new EmbeddingGeneratorMetadata("az.ai.inference", providerUrl, modelId, dimensions); } /// @@ -167,7 +167,7 @@ private EmbeddingsOptions ToAzureAIOptions(IEnumerable inputs, Embedding EmbeddingsOptions result = new(inputs) { Dimensions = options?.Dimensions ?? _dimensions, - Model = options?.ModelId ?? _metadata.ModelId, + Model = options?.ModelId ?? _metadata.DefaultModelId, EncodingFormat = format, }; diff --git a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ReportingConfiguration.cs b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ReportingConfiguration.cs index ba8e0361c6e..a2e3341c40e 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ReportingConfiguration.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ReportingConfiguration.cs @@ -229,7 +229,7 @@ private static IEnumerable GetCachingKeysForChatClient(IChatClient chatC yield return providerUri.AbsoluteUri; } - string? modelId = metadata?.ModelId; + string? modelId = metadata?.DefaultModelId; if (!string.IsNullOrWhiteSpace(modelId)) { yield return modelId!; diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs index 62e8c383f4e..f42f1e1edfb 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -69,7 +69,7 @@ public OllamaChatClient(Uri endpoint, string? modelId = null, HttpClient? httpCl _apiChatEndpoint = new Uri(endpoint, "api/chat"); _httpClient = httpClient ?? OllamaUtilities.SharedClient; - _metadata = new("ollama", endpoint, modelId); + _metadata = new ChatClientMetadata("ollama", endpoint, modelId); } /// Gets or sets to use for any serialization activities related to tool call arguments and results. @@ -111,7 +111,7 @@ public async Task GetResponseAsync( { CreatedAt = DateTimeOffset.TryParse(response.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, FinishReason = ToFinishReason(response), - ModelId = response.Model ?? options?.ModelId ?? _metadata.ModelId, + ModelId = response.Model ?? options?.ModelId ?? _metadata.DefaultModelId, ResponseId = responseId, Usage = ParseOllamaChatResponseUsage(response), }; @@ -158,7 +158,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( continue; } - string? modelId = chunk.Model ?? _metadata.ModelId; + string? modelId = chunk.Model ?? _metadata.DefaultModelId; ChatResponseUpdate update = new() { @@ -306,7 +306,7 @@ private OllamaChatRequest ToOllamaChatRequest(IEnumerable messages, { Format = ToOllamaChatResponseFormat(options?.ResponseFormat), Messages = messages.SelectMany(ToOllamaChatRequestMessages).ToArray(), - Model = options?.ModelId ?? _metadata.ModelId ?? string.Empty, + Model = options?.ModelId ?? _metadata.DefaultModelId ?? string.Empty, Stream = stream, Tools = options?.ToolMode is not NoneChatToolMode && options?.Tools is { Count: > 0 } tools ? tools.OfType().Select(ToOllamaTool) : null, }; diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs index 0b63491ddc2..0b0d4d3b344 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs @@ -89,7 +89,7 @@ public async Task>> GenerateAsync( // Create request. string[] inputs = values.ToArray(); - string? requestModel = options?.ModelId ?? _metadata.ModelId; + string? requestModel = options?.ModelId ?? _metadata.DefaultModelId; var request = new OllamaEmbeddingRequest { Model = requestModel ?? string.Empty, diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs index 7852a87c2e1..c85b202f7e8 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -56,7 +56,7 @@ public OpenAIChatClient(OpenAIClient openAIClient, string modelId) Uri providerUrl = typeof(OpenAIClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) ?.GetValue(openAIClient) as Uri ?? DefaultOpenAIEndpoint; - _metadata = new("openai", providerUrl, modelId); + _metadata = new ChatClientMetadata("openai", providerUrl, modelId); } /// Initializes a new instance of the class for the specified . diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index 59bd70eefc6..7ad8ea1d279 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -31,10 +31,10 @@ public static class ChatClientStructuredOutputExtensions /// The . /// The chat content to send. /// The chat options to configure the request. - /// + /// /// Optionally specifies whether to set a JSON schema on the . /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. - /// If not specified, the underlying provider's default will be used. + /// If not specified, the default value is . /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. @@ -43,18 +43,19 @@ public static Task> GetResponseAsync( this IChatClient chatClient, IEnumerable messages, ChatOptions? options = null, - bool? useNativeJsonSchema = null, + bool? useJsonSchema = null, CancellationToken cancellationToken = default) => - GetResponseAsync(chatClient, messages, AIJsonUtilities.DefaultOptions, options, useNativeJsonSchema, cancellationToken); + GetResponseAsync(chatClient, messages, AIJsonUtilities.DefaultOptions, options, useJsonSchema, cancellationToken); /// Sends a user chat text message, requesting a response matching the type . /// The . /// The text content for the chat message to send. /// The chat options to configure the request. - /// + /// /// Optionally specifies whether to set a JSON schema on the . /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. - /// If not specified, the underlying provider's default will be used. + /// If not specified, the default value is determined by the implementation. + /// If a specific value is required, it must be specified by the caller. /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. @@ -63,18 +64,18 @@ public static Task> GetResponseAsync( this IChatClient chatClient, string chatMessage, ChatOptions? options = null, - bool? useNativeJsonSchema = null, + bool? useJsonSchema = null, CancellationToken cancellationToken = default) => - GetResponseAsync(chatClient, new ChatMessage(ChatRole.User, chatMessage), options, useNativeJsonSchema, cancellationToken); + GetResponseAsync(chatClient, new ChatMessage(ChatRole.User, chatMessage), options, useJsonSchema, cancellationToken); /// Sends a chat message, requesting a response matching the type . /// The . /// The chat message to send. /// The chat options to configure the request. - /// + /// /// Optionally specifies whether to set a JSON schema on the . /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. - /// If not specified, the underlying provider's default will be used. + /// If not specified, the default value is . /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. @@ -83,19 +84,19 @@ public static Task> GetResponseAsync( this IChatClient chatClient, ChatMessage chatMessage, ChatOptions? options = null, - bool? useNativeJsonSchema = null, + bool? useJsonSchema = null, CancellationToken cancellationToken = default) => - GetResponseAsync(chatClient, [chatMessage], options, useNativeJsonSchema, cancellationToken); + GetResponseAsync(chatClient, [chatMessage], options, useJsonSchema, cancellationToken); /// Sends a user chat text message, requesting a response matching the type . /// The . /// The text content for the chat message to send. /// The JSON serialization options to use. /// The chat options to configure the request. - /// + /// /// Optionally specifies whether to set a JSON schema on the . /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. - /// If not specified, the underlying provider's default will be used. + /// If not specified, the default value is . /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. @@ -105,19 +106,19 @@ public static Task> GetResponseAsync( string chatMessage, JsonSerializerOptions serializerOptions, ChatOptions? options = null, - bool? useNativeJsonSchema = null, + bool? useJsonSchema = null, CancellationToken cancellationToken = default) => - GetResponseAsync(chatClient, new ChatMessage(ChatRole.User, chatMessage), serializerOptions, options, useNativeJsonSchema, cancellationToken); + GetResponseAsync(chatClient, new ChatMessage(ChatRole.User, chatMessage), serializerOptions, options, useJsonSchema, cancellationToken); /// Sends a chat message, requesting a response matching the type . /// The . /// The chat message to send. /// The JSON serialization options to use. /// The chat options to configure the request. - /// + /// /// Optionally specifies whether to set a JSON schema on the . /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. - /// If not specified, the underlying provider's default will be used. + /// If not specified, the default value is . /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. @@ -127,19 +128,19 @@ public static Task> GetResponseAsync( ChatMessage chatMessage, JsonSerializerOptions serializerOptions, ChatOptions? options = null, - bool? useNativeJsonSchema = null, + bool? useJsonSchema = null, CancellationToken cancellationToken = default) => - GetResponseAsync(chatClient, [chatMessage], serializerOptions, options, useNativeJsonSchema, cancellationToken); + GetResponseAsync(chatClient, [chatMessage], serializerOptions, options, useJsonSchema, cancellationToken); /// Sends chat messages, requesting a response matching the type . /// The . /// The chat content to send. /// The JSON serialization options to use. /// The chat options to configure the request. - /// + /// /// Optionally specifies whether to set a JSON schema on the . /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. - /// If not specified, the underlying provider's default will be used. + /// If not specified, the default value is . /// /// The to monitor for cancellation requests. The default is . /// The response messages generated by the client. @@ -152,7 +153,7 @@ public static async Task> GetResponseAsync( IEnumerable messages, JsonSerializerOptions serializerOptions, ChatOptions? options = null, - bool? useNativeJsonSchema = null, + bool? useJsonSchema = null, CancellationToken cancellationToken = default) { _ = Throw.IfNull(chatClient); @@ -186,16 +187,17 @@ public static async Task> GetResponseAsync( { "type", "object" }, { "properties", new JsonObject { { "data", JsonElementToJsonNode(schemaElement) } } }, { "additionalProperties", false }, + { "required", new JsonArray("data") }, }, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonObject))); } ChatMessage? promptAugmentation = null; options = options is not null ? options.Clone() : new(); - // Currently there's no way for the inner IChatClient to specify whether structured output - // is supported, so we always default to false. In the future, some mechanism of declaring - // capabilities may be added (e.g., on ChatClientMetadata). - if (useNativeJsonSchema.GetValueOrDefault(false)) + // We default to assuming that models support JSON schema because developers will normally use + // GetResponseAsync only with models that do. If the model doesn't support JSON schema, it may + // throw or it may ignore the schema. In these cases developers should pass useJsonSchema: false. + if (useJsonSchema.GetValueOrDefault(true)) { // When using native structured output, we don't add any additional prompt, because // the LLM backend is meant to do whatever's needed to explain the schema to the LLM. @@ -208,7 +210,7 @@ public static async Task> GetResponseAsync( { options.ResponseFormat = ChatResponseFormat.Json; - // When not using native structured output, augment the chat messages with a schema prompt + // When not using native JSON schema, augment the chat messages with a schema prompt promptAugmentation = new ChatMessage(ChatRole.User, $$""" Respond with a JSON value conforming to the following schema: ``` diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs index d9bdb479a44..5dc215e151c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -37,7 +37,7 @@ public sealed partial class OpenTelemetryChatClient : DelegatingChatClient private readonly Histogram _tokenUsageHistogram; private readonly Histogram _operationDurationHistogram; - private readonly string? _modelId; + private readonly string? _defaultModelId; private readonly string? _system; private readonly string? _serverAddress; private readonly int _serverPort; @@ -57,7 +57,7 @@ public OpenTelemetryChatClient(IChatClient innerClient, ILogger? logger = null, if (innerClient!.GetService() is ChatClientMetadata metadata) { - _modelId = metadata.ModelId; + _defaultModelId = metadata.DefaultModelId; _system = metadata.ProviderName; _serverAddress = metadata.ProviderUri?.GetLeftPart(UriPartial.Path); _serverPort = metadata.ProviderUri?.Port ?? 0; @@ -129,7 +129,7 @@ public override async Task GetResponseAsync( using Activity? activity = CreateAndConfigureActivity(options); Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; - string? requestModelId = options?.ModelId ?? _modelId; + string? requestModelId = options?.ModelId ?? _defaultModelId; LogChatMessages(messages); @@ -160,7 +160,7 @@ public override async IAsyncEnumerable GetStreamingResponseA using Activity? activity = CreateAndConfigureActivity(options); Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; - string? requestModelId = options?.ModelId ?? _modelId; + string? requestModelId = options?.ModelId ?? _defaultModelId; LogChatMessages(messages); @@ -217,7 +217,7 @@ public override async IAsyncEnumerable GetStreamingResponseA Activity? activity = null; if (_activitySource.HasListeners()) { - string? modelId = options?.ModelId ?? _modelId; + string? modelId = options?.ModelId ?? _defaultModelId; activity = _activitySource.StartActivity( string.IsNullOrWhiteSpace(modelId) ? OpenTelemetryConsts.GenAI.Chat : $"{OpenTelemetryConsts.GenAI.Chat} {modelId}", diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs index cfcd032bbd2..cf896a472c1 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs @@ -31,11 +31,11 @@ public sealed class OpenTelemetryEmbeddingGenerator : Delega private readonly Histogram _operationDurationHistogram; private readonly string? _system; - private readonly string? _modelId; + private readonly string? _defaultModelId; + private readonly int? _defaultModelDimensions; private readonly string? _modelProvider; private readonly string? _endpointAddress; private readonly int _endpointPort; - private readonly int? _dimensions; /// /// Initializes a new instance of the class. @@ -53,11 +53,11 @@ public OpenTelemetryEmbeddingGenerator(IEmbeddingGenerator i if (innerGenerator!.GetService() is EmbeddingGeneratorMetadata metadata) { _system = metadata.ProviderName; - _modelId = metadata.ModelId; + _defaultModelId = metadata.DefaultModelId; + _defaultModelDimensions = metadata.DefaultModelDimensions; _modelProvider = metadata.ProviderName; _endpointAddress = metadata.ProviderUri?.GetLeftPart(UriPartial.Path); _endpointPort = metadata.ProviderUri?.Port ?? 0; - _dimensions = metadata.Dimensions; } string name = string.IsNullOrEmpty(sourceName) ? OpenTelemetryConsts.DefaultSourceName : sourceName!; @@ -89,7 +89,7 @@ public override async Task> GenerateAsync(IEnume using Activity? activity = CreateAndConfigureActivity(options); Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; - string? requestModelId = options?.ModelId ?? _modelId; + string? requestModelId = options?.ModelId ?? _defaultModelId; GeneratedEmbeddings? response = null; Exception? error = null; @@ -128,7 +128,7 @@ protected override void Dispose(bool disposing) Activity? activity = null; if (_activitySource.HasListeners()) { - string? modelId = options?.ModelId ?? _modelId; + string? modelId = options?.ModelId ?? _defaultModelId; activity = _activitySource.StartActivity( string.IsNullOrWhiteSpace(modelId) ? OpenTelemetryConsts.GenAI.Embeddings : $"{OpenTelemetryConsts.GenAI.Embeddings} {modelId}", @@ -149,9 +149,9 @@ protected override void Dispose(bool disposing) .AddTag(OpenTelemetryConsts.Server.Port, _endpointPort); } - if (_dimensions is int dimensions) + if ((options?.Dimensions ?? _defaultModelDimensions) is int dimensionsValue) { - _ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.EmbeddingDimensions, dimensions); + _ = activity.AddTag(OpenTelemetryConsts.GenAI.Request.EmbeddingDimensions, dimensionsValue); } if (options is not null && diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs index 43e24e61f8e..68c61bfc32d 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs @@ -11,19 +11,19 @@ public class ChatClientMetadataTests [Fact] public void Constructor_NullValues_AllowedAndRoundtrip() { - ChatClientMetadata metadata = new(null, null, null); - Assert.Null(metadata.ProviderName); - Assert.Null(metadata.ProviderUri); - Assert.Null(metadata.ModelId); + ChatClientMetadata providerMetadata = new(null, null, null); + Assert.Null(providerMetadata.ProviderName); + Assert.Null(providerMetadata.ProviderUri); + Assert.Null(providerMetadata.DefaultModelId); } [Fact] public void Constructor_Value_Roundtrips() { var uri = new Uri("https://example.com"); - ChatClientMetadata metadata = new("providerName", uri, "theModel"); - Assert.Equal("providerName", metadata.ProviderName); - Assert.Same(uri, metadata.ProviderUri); - Assert.Equal("theModel", metadata.ModelId); + ChatClientMetadata providerMetadata = new("providerName", uri, "theModel"); + Assert.Equal("providerName", providerMetadata.ProviderName); + Assert.Same(uri, providerMetadata.ProviderUri); + Assert.Equal("theModel", providerMetadata.DefaultModelId); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs index b3cd0d59abb..7905951c3f1 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs @@ -11,21 +11,21 @@ public class EmbeddingGeneratorMetadataTests [Fact] public void Constructor_NullValues_AllowedAndRoundtrip() { - EmbeddingGeneratorMetadata metadata = new(null, null, null, null); - Assert.Null(metadata.ProviderName); - Assert.Null(metadata.ProviderUri); - Assert.Null(metadata.ModelId); - Assert.Null(metadata.Dimensions); + EmbeddingGeneratorMetadata providerMetadata = new(null, null, null); + Assert.Null(providerMetadata.ProviderName); + Assert.Null(providerMetadata.ProviderUri); + Assert.Null(providerMetadata.DefaultModelId); + Assert.Null(providerMetadata.DefaultModelDimensions); } [Fact] public void Constructor_Value_Roundtrips() { var uri = new Uri("https://example.com"); - EmbeddingGeneratorMetadata metadata = new("providerName", uri, "theModel", 42); + EmbeddingGeneratorMetadata metadata = new EmbeddingGeneratorMetadata("providerName", uri, "theModel", 42); Assert.Equal("providerName", metadata.ProviderName); Assert.Same(uri, metadata.ProviderUri); - Assert.Equal("theModel", metadata.ModelId); - Assert.Equal(42, metadata.Dimensions); + Assert.Equal("theModel", metadata.DefaultModelId); + Assert.Equal(42, metadata.DefaultModelDimensions); } } diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs index 834d1092bdc..92918ce1c51 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -66,7 +66,7 @@ public void AsChatClient_ProducesExpectedMetadata() var metadata = chatClient.GetService(); Assert.Equal("az.ai.inference", metadata?.ProviderName); Assert.Equal(endpoint, metadata?.ProviderUri); - Assert.Equal(model, metadata?.ModelId); + Assert.Equal(model, metadata?.DefaultModelId); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceEmbeddingGeneratorTests.cs index c092bc87ced..aacffe591b8 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceEmbeddingGeneratorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceEmbeddingGeneratorTests.cs @@ -52,7 +52,7 @@ public void AsEmbeddingGenerator_OpenAIClient_ProducesExpectedMetadata() var metadata = embeddingGenerator.GetService(); Assert.Equal("az.ai.inference", metadata?.ProviderName); Assert.Equal(endpoint, metadata?.ProviderUri); - Assert.Equal(model, metadata?.ModelId); + Assert.Equal(model, metadata?.DefaultModelId); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs index 236a8428733..df84f2c9d9b 100644 --- a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs @@ -2,8 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Threading.Tasks; using Azure; using Azure.AI.Inference; +using Azure.Core; +using Azure.Core.Pipeline; namespace Microsoft.Extensions.AI; @@ -21,12 +24,49 @@ internal static class IntegrationTestHelpers /// Gets a to use for testing, or null if the associated tests should be disabled. public static ChatCompletionsClient? GetChatCompletionsClient() => _apiKey is string apiKey ? - new ChatCompletionsClient(new Uri(_endpoint), new AzureKeyCredential(apiKey)) : + new ChatCompletionsClient(new Uri(_endpoint), new AzureKeyCredential(apiKey), CreateOptions()) : null; /// Gets an to use for testing, or null if the associated tests should be disabled. public static EmbeddingsClient? GetEmbeddingsClient() => _apiKey is string apiKey ? - new EmbeddingsClient(new Uri(_endpoint), new AzureKeyCredential(apiKey)) : + new EmbeddingsClient(new Uri(_endpoint), new AzureKeyCredential(apiKey), CreateOptions()) : null; + + private static AzureAIInferenceClientOptions CreateOptions() + { + var result = new AzureAIInferenceClientOptions(); + + // The API vesion set here corresponds to the value used by AzureOpenAIClientOptions + // if the AZURE_OPENAI_GA flag is set during its compilation. This API version is the + // minimum required for structured output with JSON schema. + result.AddPolicy(new OverrideApiVersionPolicy("2024-08-01-preview"), HttpPipelinePosition.PerCall); + + return result; + } + + // From https://github.com/Azure/azure-sdk-for-net/issues/48405#issuecomment-2704360548 + private class OverrideApiVersionPolicy : HttpPipelinePolicy + { + private string ApiVersion { get; } + + public OverrideApiVersionPolicy(string apiVersion) + { + ApiVersion = apiVersion; + } + + public override void Process(HttpMessage message, ReadOnlyMemory pipeline) + { + message.Request.Uri.Query = $"?api-version={ApiVersion}"; + ProcessNext(message, pipeline); + } + + public override ValueTask ProcessAsync(HttpMessage message, ReadOnlyMemory pipeline) + { + message.Request.Uri.Query = $"?api-version={ApiVersion}"; + var task = ProcessNextAsync(message, pipeline); + + return task; + } + } } diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index a5533c0ae63..57b7a224b98 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -731,7 +731,7 @@ public virtual async Task GetResponseAsync_StructuredOutputBool_False() var response = await _chatClient.GetResponseAsync(""" Jimbo Smith is a 35-year-old software developer from Cardiff, Wales. - Can we be sure that he is a medical doctor? + Reply true if the previous statement indicates that he is a medical doctor, otherwise false. """); Assert.False(response.Result); @@ -781,30 +781,31 @@ public virtual async Task GetResponseAsync_StructuredOutput_WithFunctions() } [ConditionalFact] - public virtual async Task GetResponseAsync_StructuredOutput_Native() + public virtual async Task GetResponseAsync_StructuredOutput_NonNative() { SkipIfNotEnabled(); - var capturedCalls = new List>(); + var capturedOptions = new List(); var captureOutputChatClient = _chatClient.AsBuilder() .Use((messages, options, nextAsync, cancellationToken) => { - capturedCalls.Add([.. messages]); + capturedOptions.Add(options); return nextAsync(messages, options, cancellationToken); }) .Build(); var response = await captureOutputChatClient.GetResponseAsync(""" - Supply a JSON object to represent Jimbo Smith from Cardiff. - """, useNativeJsonSchema: true); + Supply an object to represent Jimbo Smith from Cardiff. + """, useJsonSchema: false); Assert.Equal("Jimbo Smith", response.Result.FullName); Assert.Contains("Cardiff", response.Result.HomeTown); - // Verify it used *native* structured output, i.e., no prompt augmentation - Assert.All( - Assert.Single(capturedCalls), - message => Assert.DoesNotContain("schema", message.Text)); + // Verify it used *non-native* structured output, i.e., response format Json with no schema + var responseFormat = Assert.IsType(Assert.Single(capturedOptions)!.ResponseFormat); + Assert.Null(responseFormat.Schema); + Assert.Null(responseFormat.SchemaName); + Assert.Null(responseFormat.SchemaDescription); } private class Person @@ -819,7 +820,7 @@ private class Person private enum JobType { - Surgeon, + Wombat, PopStar, Programmer, Unknown, diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs index bc5aaac88a7..b2a0d7a2c94 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs @@ -50,7 +50,7 @@ public virtual async Task GenerateEmbedding_CreatesEmbeddingSuccessfully() Assert.NotNull(embeddings.Usage.InputTokenCount); Assert.NotNull(embeddings.Usage.TotalTokenCount); Assert.Single(embeddings); - Assert.Equal(_embeddingGenerator.GetService()?.ModelId, embeddings[0].ModelId); + Assert.Equal(_embeddingGenerator.GetService()?.DefaultModelId, embeddings[0].ModelId); Assert.NotEmpty(embeddings[0].Vector.ToArray()); } @@ -71,7 +71,7 @@ public virtual async Task GenerateEmbeddings_CreatesEmbeddingsSuccessfully() Assert.NotNull(embeddings.Usage.TotalTokenCount); Assert.All(embeddings, embedding => { - Assert.Equal(_embeddingGenerator.GetService()?.ModelId, embedding.ModelId); + Assert.Equal(_embeddingGenerator.GetService()?.DefaultModelId, embedding.ModelId); Assert.NotEmpty(embedding.Vector.ToArray()); }); } diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs index 502979d336f..1f92fea6bc6 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs @@ -72,9 +72,10 @@ public void AsChatClient_ProducesExpectedMetadata() using IChatClient chatClient = new OllamaChatClient(endpoint, model); var metadata = chatClient.GetService(); - Assert.Equal("ollama", metadata?.ProviderName); - Assert.Equal(endpoint, metadata?.ProviderUri); - Assert.Equal(model, metadata?.ModelId); + Assert.NotNull(metadata); + Assert.Equal("ollama", metadata.ProviderName); + Assert.Equal(endpoint, metadata.ProviderUri); + Assert.Equal(model, metadata.DefaultModelId); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs index bc6d5500bd9..9ccdd79197c 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs @@ -53,7 +53,7 @@ public void AsEmbeddingGenerator_ProducesExpectedMetadata() var metadata = generator.GetService(); Assert.Equal("ollama", metadata?.ProviderName); Assert.Equal(endpoint, metadata?.ProviderUri); - Assert.Equal(model, metadata?.ModelId); + Assert.Equal(model, metadata?.DefaultModelId); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs index 55946590b4e..fe9a5c019f1 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs @@ -75,13 +75,13 @@ public void AsChatClient_OpenAIClient_ProducesExpectedMetadata(bool useAzureOpen var metadata = chatClient.GetService(); Assert.Equal("openai", metadata?.ProviderName); Assert.Equal(endpoint, metadata?.ProviderUri); - Assert.Equal(model, metadata?.ModelId); + Assert.Equal(model, metadata?.DefaultModelId); chatClient = client.GetChatClient(model).AsChatClient(); metadata = chatClient.GetService(); Assert.Equal("openai", metadata?.ProviderName); Assert.Equal(endpoint, metadata?.ProviderUri); - Assert.Equal(model, metadata?.ModelId); + Assert.Equal(model, metadata?.DefaultModelId); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs index 3ceba5a9d00..72864576529 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs @@ -58,12 +58,12 @@ public void AsEmbeddingGenerator_OpenAIClient_ProducesExpectedMetadata(bool useA var metadata = embeddingGenerator.GetService(); Assert.Equal("openai", metadata?.ProviderName); Assert.Equal(endpoint, metadata?.ProviderUri); - Assert.Equal(model, metadata?.ModelId); + Assert.Equal(model, metadata?.DefaultModelId); embeddingGenerator = client.GetEmbeddingClient(model).AsEmbeddingGenerator(); Assert.Equal("openai", metadata?.ProviderName); Assert.Equal(endpoint, metadata?.ProviderUri); - Assert.Equal(model, metadata?.ModelId); + Assert.Equal(model, metadata?.DefaultModelId); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIResponseClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIResponseClientTests.cs index 5f52f0da7de..5ca7fc7272f 100644 --- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIResponseClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIResponseClientTests.cs @@ -48,7 +48,7 @@ public void AsChatClient_ProducesExpectedMetadata(bool useAzureOpenAI) var metadata = chatClient.GetService(); Assert.Equal("openai", metadata?.ProviderName); Assert.Equal(endpoint, metadata?.ProviderUri); - Assert.Equal(model, metadata?.ModelId); + Assert.Equal(model, metadata?.DefaultModelId); } [Fact] diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs index 0a499ab644d..4477c3cdb26 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs @@ -4,9 +4,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; -using System.Linq; using System.Text.Json; -using System.Text.RegularExpressions; using System.Threading.Tasks; using Xunit; @@ -15,7 +13,89 @@ namespace Microsoft.Extensions.AI; public class ChatClientStructuredOutputExtensionsTests { [Fact] - public async Task SuccessUsage() + public async Task SuccessUsage_Default() + { + var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; + var expectedResponse = new ChatResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))) + { + ResponseId = "test", + CreatedAt = DateTimeOffset.UtcNow, + ModelId = "someModel", + RawRepresentation = new object(), + Usage = new(), + }; + + using var client = new TestChatClient + { + GetResponseAsyncCallback = (messages, options, cancellationToken) => + { + var responseFormat = Assert.IsType(options!.ResponseFormat); + Assert.Equal(""" + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "description": "Some test description", + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "fullName": { + "type": [ + "string", + "null" + ] + }, + "species": { + "type": "string", + "enum": [ + "Bear", + "Tiger", + "Walrus" + ] + } + }, + "additionalProperties": false, + "required": [ + "id", + "fullName", + "species" + ] + } + """, responseFormat.Schema.ToString()); + Assert.Equal(nameof(Animal), responseFormat.SchemaName); + Assert.Equal("Some test description", responseFormat.SchemaDescription); + + // The inner client receives the prompt with no augmentation + Assert.Collection(messages, + message => Assert.Equal("Hello", message.Text)); + + return Task.FromResult(expectedResponse); + }, + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.GetResponseAsync(chatHistory); + + // The response contains the deserialized result and other response properties + Assert.Equal(1, response.Result.Id); + Assert.Equal("Tigger", response.Result.FullName); + Assert.Equal(Species.Tiger, response.Result.Species); + Assert.Equal(expectedResponse.ResponseId, response.ResponseId); + Assert.Equal(expectedResponse.CreatedAt, response.CreatedAt); + Assert.Equal(expectedResponse.ModelId, response.ModelId); + Assert.Same(expectedResponse.RawRepresentation, response.RawRepresentation); + Assert.Same(expectedResponse.Usage, response.Usage); + + // TryGetResult returns the same value + Assert.True(response.TryGetResult(out var tryGetResultOutput)); + Assert.Same(response.Result, tryGetResultOutput); + + // Doesn't mutate history (or at least, reverts any changes) + Assert.Equal("Hello", Assert.Single(chatHistory).Text); + } + + [Fact] + public async Task SuccessUsage_NoJsonSchema() { var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; var expectedResponse = new ChatResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))) @@ -55,7 +135,7 @@ public async Task SuccessUsage() }; var chatHistory = new List { new(ChatRole.User, "Hello") }; - var response = await client.GetResponseAsync(chatHistory); + var response = await client.GetResponseAsync(chatHistory, useJsonSchema: false); // The response contains the deserialized result and other response properties Assert.Equal(1, response.Result.Id); @@ -85,8 +165,7 @@ public async Task WrapsNonObjectValuesInDataProperty() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { - var suppliedSchemaMatch = Regex.Match(messages.Last().Text!, "```(.*?)```", RegexOptions.Singleline); - Assert.True(suppliedSchemaMatch.Success); + var responseFormat = Assert.IsType(options!.ResponseFormat); Assert.Equal(""" { "$schema": "https://json-schema.org/draft/2020-12/schema", @@ -97,9 +176,12 @@ public async Task WrapsNonObjectValuesInDataProperty() "type": "integer" } }, - "additionalProperties": false + "additionalProperties": false, + "required": [ + "data" + ] } - """, suppliedSchemaMatch.Groups[1].Value.Trim()); + """, responseFormat.Schema.ToString()); return Task.FromResult(expectedResponse); }, }; @@ -165,50 +247,6 @@ public async Task FailureUsage_NoJsonInResponse() Assert.Null(tryGetResult); } - [Fact] - public async Task CanUseNativeStructuredOutput() - { - var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; - var expectedResponse = new ChatResponse(new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))); - - using var client = new TestChatClient - { - GetResponseAsyncCallback = (messages, options, cancellationToken) => - { - var responseFormat = Assert.IsType(options!.ResponseFormat); - Assert.Equal(nameof(Animal), responseFormat.SchemaName); - Assert.Equal("Some test description", responseFormat.SchemaDescription); - - var responseFormatJsonSchema = JsonSerializer.Serialize(responseFormat.Schema, TestJsonSerializerContext.Default.JsonElement); - Assert.Contains("https://json-schema.org/draft/2020-12/schema", responseFormatJsonSchema); - foreach (Species v in Enum.GetValues(typeof(Species))) - { - Assert.Contains(v.ToString(), responseFormatJsonSchema); // All enum values are described as strings - } - - // The chat history isn't mutated any further, since native structured output is used instead of a prompt - Assert.Equal("Hello", Assert.Single(messages).Text); - - return Task.FromResult(expectedResponse); - }, - }; - - var chatHistory = new List { new(ChatRole.User, "Hello") }; - var response = await client.GetResponseAsync(chatHistory, useNativeJsonSchema: true); - - // The response contains the deserialized result and other response properties - Assert.Equal(1, response.Result.Id); - Assert.Equal("Tigger", response.Result.FullName); - Assert.Equal(Species.Tiger, response.Result.Species); - - // TryGetResult returns the same value - Assert.True(response.TryGetResult(out var tryGetResultOutput)); - Assert.Same(response.Result, tryGetResultOutput); - - // History remains unmutated - Assert.Equal("Hello", Assert.Single(chatHistory).Text); - } - [Fact] public async Task CanUseNativeStructuredOutputWithSanitizedTypeName() { @@ -228,7 +266,7 @@ public async Task CanUseNativeStructuredOutputWithSanitizedTypeName() }; var chatHistory = new List { new(ChatRole.User, "Hello") }; - var response = await client.GetResponseAsync>(chatHistory, useNativeJsonSchema: true); + var response = await client.GetResponseAsync>(chatHistory); // The response contains the deserialized result and other response properties Assert.Equal(1, response.Result!.Value!.Id); @@ -256,7 +294,7 @@ public async Task CanUseNativeStructuredOutputWithArray() }; var chatHistory = new List { new(ChatRole.User, "Hello") }; - var response = await client.GetResponseAsync(chatHistory, useNativeJsonSchema: true); + var response = await client.GetResponseAsync(chatHistory); // The response contains the deserialized result and other response properties Assert.Single(response.Result!); @@ -285,17 +323,37 @@ public async Task CanSpecifyCustomJsonSerializationOptions() { GetResponseAsyncCallback = (messages, options, cancellationToken) => { - Assert.Collection(messages, - message => Assert.Equal("Hello", message.Text), - message => + // In the schema below, note that: + // - The property is named full_name, because we specified SnakeCaseLower + // - The species value is an integer instead of a string, because we didn't use enum-to-string conversion + var responseFormat = Assert.IsType(options!.ResponseFormat); + Assert.Equal(""" { - Assert.Equal(ChatRole.User, message.Role); - Assert.Contains("Respond with a JSON value", message.Text); - Assert.Contains("https://json-schema.org/draft/2020-12/schema", message.Text); - Assert.DoesNotContain(nameof(Animal.FullName), message.Text); // The JSO uses snake_case - Assert.Contains("full_name", message.Text); // The JSO uses snake_case - Assert.DoesNotContain(nameof(Species.Tiger), message.Text); // The JSO doesn't use enum-to-string conversion - }); + "$schema": "https://json-schema.org/draft/2020-12/schema", + "description": "Some test description", + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "full_name": { + "type": [ + "string", + "null" + ] + }, + "species": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "id", + "full_name", + "species" + ] + } + """, responseFormat.Schema.ToString()); return Task.FromResult(expectedResponse); }, diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/OpenTelemetryEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/OpenTelemetryEmbeddingGeneratorTests.cs index b1d7221e552..847bf49be06 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/OpenTelemetryEmbeddingGeneratorTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/OpenTelemetryEmbeddingGeneratorTests.cs @@ -15,8 +15,10 @@ namespace Microsoft.Extensions.AI; public class OpenTelemetryEmbeddingGeneratorTests { - [Fact] - public async Task ExpectedInformationLogged_Async() + [Theory] + [InlineData(null)] + [InlineData("replacementmodel")] + public async Task ExpectedInformationLogged_Async(string? perRequestModelId) { var sourceName = Guid.NewGuid().ToString(); var activities = new List(); @@ -48,7 +50,7 @@ public async Task ExpectedInformationLogged_Async() }; }, GetServiceCallback = (serviceType, serviceKey) => - serviceType == typeof(EmbeddingGeneratorMetadata) ? new EmbeddingGeneratorMetadata("testservice", new Uri("http://localhost:12345/something"), "amazingmodel", 384) : + serviceType == typeof(EmbeddingGeneratorMetadata) ? new EmbeddingGeneratorMetadata("testservice", new Uri("http://localhost:12345/something"), "defaultmodel", 1234) : null, }; @@ -59,7 +61,7 @@ public async Task ExpectedInformationLogged_Async() var options = new EmbeddingGenerationOptions { - ModelId = "replacementmodel", + ModelId = perRequestModelId, AdditionalProperties = new() { ["service_tier"] = "value1", @@ -70,6 +72,7 @@ public async Task ExpectedInformationLogged_Async() await generator.GenerateEmbeddingVectorAsync("hello", options); var activity = Assert.Single(activities); + var expectedModelName = perRequestModelId ?? "defaultmodel"; Assert.NotNull(activity.Id); Assert.NotEmpty(activity.Id); @@ -77,10 +80,11 @@ public async Task ExpectedInformationLogged_Async() Assert.Equal("http://localhost:12345/something", activity.GetTagItem("server.address")); Assert.Equal(12345, (int)activity.GetTagItem("server.port")!); - Assert.Equal("embeddings replacementmodel", activity.DisplayName); + Assert.Equal($"embeddings {expectedModelName}", activity.DisplayName); Assert.Equal("testservice", activity.GetTagItem("gen_ai.system")); - Assert.Equal("replacementmodel", activity.GetTagItem("gen_ai.request.model")); + Assert.Equal(expectedModelName, activity.GetTagItem("gen_ai.request.model")); + Assert.Equal(1234, activity.GetTagItem("gen_ai.request.embedding.dimensions")); Assert.Equal("value1", activity.GetTagItem("gen_ai.testservice.request.service_tier")); Assert.Equal("value2", activity.GetTagItem("gen_ai.testservice.request.something_else"));