Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Text.Json.Schema;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
Expand Down Expand Up @@ -167,7 +168,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
// the LLM backend is meant to do whatever's needed to explain the schema to the LLM.
options.ResponseFormat = ChatResponseFormat.ForJsonSchema(
schema,
schemaName: typeof(T).Name,
schemaName: SanitizeMetadataName(typeof(T).Name),
schemaDescription: typeof(T).GetCustomAttribute<DescriptionAttribute>()?.Description);
}
else
Expand Down Expand Up @@ -224,4 +225,21 @@ private static JsonSerializerOptions GetOrCreateDefaultJsonSerializerOptions()
[JsonSerializable(typeof(JsonNode))]
[JsonSourceGenerationOptions(WriteIndented = true)]
private sealed partial class JsonNodeContext : JsonSerializerContext;

/// <summary>
/// Remove characters from type name that are valid in metadata but shouldn't be used in a schema name.
/// Removes arrays and generic type parameters, and replaces invalid characters with underscores.
/// </summary>
private static string SanitizeMetadataName(string typeName) =>
InvalidNameCharsRegex().Replace(typeName, "_");

/// <summary>Regex that flags any character other than ASCII digits or letters or the underscore.</summary>
#if NET
[GeneratedRegex("[^0-9A-Za-z_]")]
private static partial Regex InvalidNameCharsRegex();
#else
private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex;
private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled);
#endif

}
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,40 @@ public async Task CanUseNativeStructuredOutput()
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
}

[Fact]
public async Task CanUseNativeStructuredOutputWithSanitizedTypeName()
{
var expectedResult = new Data<Animal> { Value = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger } };
var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))]);

using var client = new TestChatClient
{
CompleteAsyncCallback = (messages, options, cancellationToken) =>
{
var responseFormat = Assert.IsType<ChatResponseFormatJson>(options!.ResponseFormat);

Assert.Matches("^[a-zA-Z0-9_-]+$", responseFormat.SchemaName);

return Task.FromResult(expectedCompletion);
},
};

var chatHistory = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var response = await client.CompleteAsync<Data<Animal>>(chatHistory, useNativeJsonSchema: true);

// The completion contains the deserialized result and other completion properties
Assert.Equal(1, response.Result!.Value!.Id);
Assert.Equal("Tigger", response.Result.Value.FullName);
Assert.Equal(Species.Tiger, response.Result.Value.Species);

// TryGetResult returns the same value
Assert.True(response.TryGetResult(out var tryGetResultOutput));
Assert.Same(response.Result, tryGetResultOutput);

// History remains unmutated
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
}

[Fact]
public async Task CanSpecifyCustomJsonSerializationOptions()
{
Expand Down Expand Up @@ -247,6 +281,11 @@ private class Animal
public Species Species { get; set; }
}

private class Data<T>
{
public T? Value { get; set; }
}

private enum Species
{
Bear,
Expand Down
Loading