Skip to content

Commit 8371308

Browse files
jozkeejeffhandley
authored andcommitted
Add ChatOptions.RawRepresentationFactory (#6319)
* Look for OpenAI.ChatCompletionOptions in top-level additional properties and stop looking for individually specific additional properties * Add RawRepresentation to ChatOptions and use it in OpenAI and AzureAIInference * Remove now unused locals * Add [JsonIgnore] and update roundtrip tests * Overwirte properties only if the underlying model don't specify it already * Clone RawRepresentation * Reflection workaround for ToolChoice not being cloned * Style changes * AI.Inference: Bring back propagation of additional properties * Don't use 0.1f, it doesn't roundtrip properly in .NET Framework * Add RawRepresentationFactory instead of object? property * Augment remarks to discourage returning shared instances * Documentation feedback * AI.Inference: keep passing TopK as AdditionalProperty if not already there
1 parent 41f31d6 commit 8371308

File tree

2 files changed

+100
-191
lines changed

2 files changed

+100
-191
lines changed

src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -281,66 +281,74 @@ private static ChatRole ToChatRole(global::Azure.AI.Inference.ChatRole role) =>
281281
finishReason == CompletionsFinishReason.ToolCalls ? ChatFinishReason.ToolCalls :
282282
new(s);
283283

284+
private ChatCompletionsOptions CreateAzureAIOptions(IEnumerable<ChatMessage> chatContents, ChatOptions? options) =>
285+
new(ToAzureAIInferenceChatMessages(chatContents))
286+
{
287+
Model = options?.ModelId ?? _metadata.DefaultModelId ??
288+
throw new InvalidOperationException("No model id was provided when either constructing the client or in the chat options.")
289+
};
290+
284291
/// <summary>Converts an extensions options instance to an AzureAI options instance.</summary>
285292
private ChatCompletionsOptions ToAzureAIOptions(IEnumerable<ChatMessage> chatContents, ChatOptions? options)
286293
{
287-
ChatCompletionsOptions result = new(ToAzureAIInferenceChatMessages(chatContents))
294+
if (options is null)
288295
{
289-
Model = options?.ModelId ?? _metadata.DefaultModelId ?? throw new InvalidOperationException("No model id was provided when either constructing the client or in the chat options.")
290-
};
296+
return CreateAzureAIOptions(chatContents, options);
297+
}
298+
299+
if (options.RawRepresentationFactory?.Invoke(this) is ChatCompletionsOptions result)
300+
{
301+
result.Messages = ToAzureAIInferenceChatMessages(chatContents).ToList();
302+
result.Model ??= options.ModelId ?? _metadata.DefaultModelId ??
303+
throw new InvalidOperationException("No model id was provided when either constructing the client or in the chat options.");
304+
}
305+
else
306+
{
307+
result = CreateAzureAIOptions(chatContents, options);
308+
}
291309

292-
if (options is not null)
310+
result.FrequencyPenalty ??= options.FrequencyPenalty;
311+
result.MaxTokens ??= options.MaxOutputTokens;
312+
result.NucleusSamplingFactor ??= options.TopP;
313+
result.PresencePenalty ??= options.PresencePenalty;
314+
result.Temperature ??= options.Temperature;
315+
result.Seed ??= options.Seed;
316+
317+
if (options.StopSequences is { Count: > 0 } stopSequences)
293318
{
294-
result.FrequencyPenalty = options.FrequencyPenalty;
295-
result.MaxTokens = options.MaxOutputTokens;
296-
result.NucleusSamplingFactor = options.TopP;
297-
result.PresencePenalty = options.PresencePenalty;
298-
result.Temperature = options.Temperature;
299-
result.Seed = options.Seed;
300-
301-
if (options.StopSequences is { Count: > 0 } stopSequences)
319+
foreach (string stopSequence in stopSequences)
302320
{
303-
foreach (string stopSequence in stopSequences)
304-
{
305-
result.StopSequences.Add(stopSequence);
306-
}
321+
result.StopSequences.Add(stopSequence);
307322
}
323+
}
324+
325+
// This property is strongly typed on ChatOptions but not on ChatCompletionsOptions.
326+
if (options.TopK is int topK && !result.AdditionalProperties.ContainsKey("top_k"))
327+
{
328+
result.AdditionalProperties["top_k"] = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(topK, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(int))));
329+
}
308330

309-
// These properties are strongly typed on ChatOptions but not on ChatCompletionsOptions.
310-
if (options.TopK is int topK)
331+
if (options.AdditionalProperties is { } props)
332+
{
333+
foreach (var prop in props)
311334
{
312-
result.AdditionalProperties["top_k"] = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(topK, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(int))));
335+
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object)));
336+
result.AdditionalProperties[prop.Key] = new BinaryData(data);
313337
}
338+
}
314339

315-
if (options.AdditionalProperties is { } props)
340+
if (options.Tools is { Count: > 0 } tools)
341+
{
342+
foreach (AITool tool in tools)
316343
{
317-
foreach (var prop in props)
344+
if (tool is AIFunction af)
318345
{
319-
switch (prop.Key)
320-
{
321-
// Propagate everything else to the ChatCompletionsOptions' AdditionalProperties.
322-
default:
323-
if (prop.Value is not null)
324-
{
325-
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object)));
326-
result.AdditionalProperties[prop.Key] = new BinaryData(data);
327-
}
328-
329-
break;
330-
}
346+
result.Tools.Add(ToAzureAIChatTool(af));
331347
}
332348
}
333349

334-
if (options.Tools is { Count: > 0 } tools)
350+
if (result.ToolChoice is null && result.Tools.Count > 0)
335351
{
336-
foreach (AITool tool in tools)
337-
{
338-
if (tool is AIFunction af)
339-
{
340-
result.Tools.Add(ToAzureAIChatTool(af));
341-
}
342-
}
343-
344352
switch (options.ToolMode)
345353
{
346354
case NoneChatToolMode:
@@ -359,7 +367,10 @@ private ChatCompletionsOptions ToAzureAIOptions(IEnumerable<ChatMessage> chatCon
359367
break;
360368
}
361369
}
370+
}
362371

372+
if (result.ResponseFormat is null)
373+
{
363374
if (options.ResponseFormat is ChatResponseFormatText)
364375
{
365376
result.ResponseFormat = ChatCompletionsResponseFormat.CreateTextFormat();

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs

Lines changed: 47 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#pragma warning disable EA0011 // Consider removing unnecessary conditional access operator (?)
1919
#pragma warning disable S1067 // Expressions should not be too complex
2020
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
21+
#pragma warning disable SA1204 // Static elements should appear before instance elements
2122

2223
namespace Microsoft.Extensions.AI;
2324

@@ -259,7 +260,6 @@ private static async IAsyncEnumerable<ChatResponseUpdate> FromOpenAIStreamingCha
259260
string? responseId = null;
260261
DateTimeOffset? createdAt = null;
261262
string? modelId = null;
262-
string? fingerprint = null;
263263

264264
// Process each update as it arrives
265265
await foreach (StreamingChatCompletionUpdate update in updates.WithCancellation(cancellationToken).ConfigureAwait(false))
@@ -270,7 +270,6 @@ private static async IAsyncEnumerable<ChatResponseUpdate> FromOpenAIStreamingCha
270270
responseId ??= update.CompletionId;
271271
createdAt ??= update.CreatedAt;
272272
modelId ??= update.Model;
273-
fingerprint ??= update.SystemFingerprint;
274273

275274
// Create the response content object.
276275
ChatResponseUpdate responseUpdate = new()
@@ -284,22 +283,6 @@ private static async IAsyncEnumerable<ChatResponseUpdate> FromOpenAIStreamingCha
284283
Role = streamedRole,
285284
};
286285

287-
// Populate it with any additional metadata from the OpenAI object.
288-
if (update.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs)
289-
{
290-
(responseUpdate.AdditionalProperties ??= [])[nameof(update.ContentTokenLogProbabilities)] = contentTokenLogProbs;
291-
}
292-
293-
if (update.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs)
294-
{
295-
(responseUpdate.AdditionalProperties ??= [])[nameof(update.RefusalTokenLogProbabilities)] = refusalTokenLogProbs;
296-
}
297-
298-
if (fingerprint is not null)
299-
{
300-
(responseUpdate.AdditionalProperties ??= [])[nameof(update.SystemFingerprint)] = fingerprint;
301-
}
302-
303286
// Transfer over content update items.
304287
if (update.ContentUpdate is { Count: > 0 })
305288
{
@@ -382,12 +365,6 @@ private static async IAsyncEnumerable<ChatResponseUpdate> FromOpenAIStreamingCha
382365
responseUpdate.Contents.Add(new ErrorContent(refusal.ToString()) { ErrorCode = "Refusal" });
383366
}
384367

385-
// Propagate additional relevant metadata.
386-
if (fingerprint is not null)
387-
{
388-
(responseUpdate.AdditionalProperties ??= [])[nameof(ChatCompletion.SystemFingerprint)] = fingerprint;
389-
}
390-
391368
yield return responseUpdate;
392369
}
393370
}
@@ -426,20 +403,7 @@ private static ChatResponse FromOpenAIChatCompletion(ChatCompletion openAIComple
426403
"mp3" or _ => "audio/mpeg",
427404
};
428405

429-
var dc = new DataContent(audio.AudioBytes.ToMemory(), mimeType)
430-
{
431-
AdditionalProperties = new() { [nameof(audio.ExpiresAt)] = audio.ExpiresAt },
432-
};
433-
434-
if (audio.Id is string id)
435-
{
436-
dc.AdditionalProperties[nameof(audio.Id)] = id;
437-
}
438-
439-
if (audio.Transcript is string transcript)
440-
{
441-
dc.AdditionalProperties[nameof(audio.Transcript)] = transcript;
442-
}
406+
var dc = new DataContent(audio.AudioBytes.ToMemory(), mimeType);
443407

444408
returnMessage.Contents.Add(dc);
445409
}
@@ -480,140 +444,74 @@ private static ChatResponse FromOpenAIChatCompletion(ChatCompletion openAIComple
480444
response.Usage = FromOpenAIUsage(tokenUsage);
481445
}
482446

483-
if (openAICompletion.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs)
484-
{
485-
(response.AdditionalProperties ??= [])[nameof(openAICompletion.ContentTokenLogProbabilities)] = contentTokenLogProbs;
486-
}
487-
488-
if (openAICompletion.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs)
489-
{
490-
(response.AdditionalProperties ??= [])[nameof(openAICompletion.RefusalTokenLogProbabilities)] = refusalTokenLogProbs;
491-
}
492-
493-
if (openAICompletion.SystemFingerprint is string systemFingerprint)
494-
{
495-
(response.AdditionalProperties ??= [])[nameof(openAICompletion.SystemFingerprint)] = systemFingerprint;
496-
}
497-
498447
return response;
499448
}
500449

501450
/// <summary>Converts an extensions options instance to an OpenAI options instance.</summary>
502-
private static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)
451+
private ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)
503452
{
504-
ChatCompletionOptions result = new();
453+
if (options is null)
454+
{
455+
return new ChatCompletionOptions();
456+
}
505457

506-
if (options is not null)
458+
if (options.RawRepresentationFactory?.Invoke(this) is not ChatCompletionOptions result)
507459
{
508-
result.FrequencyPenalty = options.FrequencyPenalty;
509-
result.MaxOutputTokenCount = options.MaxOutputTokens;
510-
result.TopP = options.TopP;
511-
result.PresencePenalty = options.PresencePenalty;
512-
result.Temperature = options.Temperature;
513-
result.AllowParallelToolCalls = options.AllowMultipleToolCalls;
460+
result = new ChatCompletionOptions();
461+
}
462+
463+
result.FrequencyPenalty ??= options.FrequencyPenalty;
464+
result.MaxOutputTokenCount ??= options.MaxOutputTokens;
465+
result.TopP ??= options.TopP;
466+
result.PresencePenalty ??= options.PresencePenalty;
467+
result.Temperature ??= options.Temperature;
468+
result.AllowParallelToolCalls ??= options.AllowMultipleToolCalls;
514469
#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates.
515-
result.Seed = options.Seed;
470+
result.Seed ??= options.Seed;
516471
#pragma warning restore OPENAI001
517472

518-
if (options.StopSequences is { Count: > 0 } stopSequences)
473+
if (options.StopSequences is { Count: > 0 } stopSequences)
474+
{
475+
foreach (string stopSequence in stopSequences)
519476
{
520-
foreach (string stopSequence in stopSequences)
521-
{
522-
result.StopSequences.Add(stopSequence);
523-
}
477+
result.StopSequences.Add(stopSequence);
524478
}
479+
}
525480

526-
if (options.AdditionalProperties is { Count: > 0 } additionalProperties)
481+
if (options.Tools is { Count: > 0 } tools)
482+
{
483+
foreach (AITool tool in tools)
527484
{
528-
if (additionalProperties.TryGetValue(nameof(result.AudioOptions), out ChatAudioOptions? audioOptions))
485+
if (tool is AIFunction af)
529486
{
530-
result.AudioOptions = audioOptions;
531-
}
532-
533-
if (additionalProperties.TryGetValue(nameof(result.EndUserId), out string? endUserId))
534-
{
535-
result.EndUserId = endUserId;
536-
}
537-
538-
if (additionalProperties.TryGetValue(nameof(result.IncludeLogProbabilities), out bool includeLogProbabilities))
539-
{
540-
result.IncludeLogProbabilities = includeLogProbabilities;
541-
}
542-
543-
if (additionalProperties.TryGetValue(nameof(result.LogitBiases), out IDictionary<int, int>? logitBiases))
544-
{
545-
foreach (KeyValuePair<int, int> kvp in logitBiases!)
546-
{
547-
result.LogitBiases[kvp.Key] = kvp.Value;
548-
}
549-
}
550-
551-
if (additionalProperties.TryGetValue(nameof(result.Metadata), out IDictionary<string, string>? metadata))
552-
{
553-
foreach (KeyValuePair<string, string> kvp in metadata)
554-
{
555-
result.Metadata[kvp.Key] = kvp.Value;
556-
}
557-
}
558-
559-
if (additionalProperties.TryGetValue(nameof(result.OutputPrediction), out ChatOutputPrediction? outputPrediction))
560-
{
561-
result.OutputPrediction = outputPrediction;
562-
}
563-
564-
if (additionalProperties.TryGetValue(nameof(result.ReasoningEffortLevel), out ChatReasoningEffortLevel reasoningEffortLevel))
565-
{
566-
result.ReasoningEffortLevel = reasoningEffortLevel;
567-
}
568-
569-
if (additionalProperties.TryGetValue(nameof(result.ResponseModalities), out ChatResponseModalities responseModalities))
570-
{
571-
result.ResponseModalities = responseModalities;
572-
}
573-
574-
if (additionalProperties.TryGetValue(nameof(result.StoredOutputEnabled), out bool storeOutputEnabled))
575-
{
576-
result.StoredOutputEnabled = storeOutputEnabled;
577-
}
578-
579-
if (additionalProperties.TryGetValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt))
580-
{
581-
result.TopLogProbabilityCount = topLogProbabilityCountInt;
487+
result.Tools.Add(ToOpenAIChatTool(af));
582488
}
583489
}
584490

585-
if (options.Tools is { Count: > 0 } tools)
491+
if (result.ToolChoice is null && result.Tools.Count > 0)
586492
{
587-
foreach (AITool tool in tools)
588-
{
589-
if (tool is AIFunction af)
590-
{
591-
result.Tools.Add(ToOpenAIChatTool(af));
592-
}
593-
}
594-
595-
if (result.Tools.Count > 0)
493+
switch (options.ToolMode)
596494
{
597-
switch (options.ToolMode)
598-
{
599-
case NoneChatToolMode:
600-
result.ToolChoice = ChatToolChoice.CreateNoneChoice();
601-
break;
602-
603-
case AutoChatToolMode:
604-
case null:
605-
result.ToolChoice = ChatToolChoice.CreateAutoChoice();
606-
break;
607-
608-
case RequiredChatToolMode required:
609-
result.ToolChoice = required.RequiredFunctionName is null ?
610-
ChatToolChoice.CreateRequiredChoice() :
611-
ChatToolChoice.CreateFunctionChoice(required.RequiredFunctionName);
612-
break;
613-
}
495+
case NoneChatToolMode:
496+
result.ToolChoice = ChatToolChoice.CreateNoneChoice();
497+
break;
498+
499+
case AutoChatToolMode:
500+
case null:
501+
result.ToolChoice = ChatToolChoice.CreateAutoChoice();
502+
break;
503+
504+
case RequiredChatToolMode required:
505+
result.ToolChoice = required.RequiredFunctionName is null ?
506+
ChatToolChoice.CreateRequiredChoice() :
507+
ChatToolChoice.CreateFunctionChoice(required.RequiredFunctionName);
508+
break;
614509
}
615510
}
511+
}
616512

513+
if (result.ResponseFormat is null)
514+
{
617515
if (options.ResponseFormat is ChatResponseFormatText)
618516
{
619517
result.ResponseFormat = OpenAI.Chat.ChatResponseFormat.CreateTextFormat();

0 commit comments

Comments
 (0)