Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public static ChatResponseFormatJson ForJsonSchema(

return ForJsonSchema(
schema,
schemaName ?? InvalidNameCharsRegex().Replace(schemaType.Name, "_"),
schemaName ?? schemaType.GetCustomAttribute<DisplayNameAttribute>()?.DisplayName ?? InvalidNameCharsRegex().Replace(schemaType.Name, "_"),
schemaDescription ?? schemaType.GetCustomAttribute<DescriptionAttribute>()?.Description);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio
/// <param name="method">The method to be represented via the created <see cref="AIFunction"/>.</param>
/// <param name="name">
/// The name to use for the <see cref="AIFunction"/>. If <see langword="null"/>, the name will be derived from
/// the name of <paramref name="method"/>.
/// any <see cref="DisplayNameAttribute"/> on <paramref name="method"/>, if available, or else from the name of <paramref name="method"/>.
/// </param>
/// <param name="description">
/// The description to use for the <see cref="AIFunction"/>. If <see langword="null"/>, a description will be derived from
Expand Down Expand Up @@ -297,7 +297,7 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac
/// </param>
/// <param name="name">
/// The name to use for the <see cref="AIFunction"/>. If <see langword="null"/>, the name will be derived from
/// the name of <paramref name="method"/>.
/// any <see cref="DisplayNameAttribute"/> on <paramref name="method"/>, if available, or else from the name of <paramref name="method"/>.
/// </param>
/// <param name="description">
/// The description to use for the <see cref="AIFunction"/>. If <see langword="null"/>, a description will be derived from
Expand Down Expand Up @@ -729,7 +729,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions

ReturnParameterMarshaller = GetReturnParameterMarshaller(key, serializerOptions, out Type? returnType);
Method = key.Method;
Name = key.Name ?? GetFunctionName(key.Method);
Name = key.Name ?? key.Method.GetCustomAttribute<DisplayNameAttribute>(inherit: true)?.DisplayName ?? GetFunctionName(key.Method);
Description = key.Description ?? key.Method.GetCustomAttribute<DescriptionAttribute>(inherit: true)?.Description ?? string.Empty;
JsonSerializerOptions = serializerOptions;
ReturnJsonSchema = returnType is null || key.ExcludeResultSchema ? null : AIJsonUtilities.CreateJsonSchema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public AIFunctionFactoryOptions()

/// <summary>Gets or sets the name to use for the function.</summary>
/// <value>
/// The name to use for the function. The default value is a name derived from the method represented by the passed <see cref="Delegate"/> or <see cref="MethodInfo"/>.
/// The name to use for the function. The default value is a name derived from the passed <see cref="Delegate"/> or <see cref="MethodInfo"/> (for example, via a <see cref="DisplayNameAttribute"/> on the method).
/// </value>
public string? Name { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public static JsonElement CreateFunctionJsonSchema(

serializerOptions ??= DefaultOptions;
inferenceOptions ??= AIJsonSchemaCreateOptions.Default;
title ??= method.Name;
title ??= method.GetCustomAttribute<DisplayNameAttribute>()?.DisplayName ?? method.Name;
description ??= method.GetCustomAttribute<DescriptionAttribute>()?.Description;

JsonObject parameterSchemas = new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,34 @@ public void ForJsonSchema_ComplexType_Succeeds(bool generic, string? name, strin
Assert.Equal(description ?? "abcd", format.SchemaDescription);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public void ForJsonSchema_DisplayNameAttribute_UsedForSchemaName(bool generic)
{
ChatResponseFormatJson format = generic ?
ChatResponseFormat.ForJsonSchema<TypeWithDisplayName>(TestJsonSerializerContext.Default.Options) :
ChatResponseFormat.ForJsonSchema(typeof(TypeWithDisplayName), TestJsonSerializerContext.Default.Options);

Assert.NotNull(format);
Assert.NotNull(format.Schema);
Assert.Equal("custom_type_name", format.SchemaName);
Assert.Equal("Type description", format.SchemaDescription);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public void ForJsonSchema_DisplayNameAttribute_CanBeOverridden(bool generic)
{
ChatResponseFormatJson format = generic ?
ChatResponseFormat.ForJsonSchema<TypeWithDisplayName>(TestJsonSerializerContext.Default.Options, schemaName: "override_name") :
ChatResponseFormat.ForJsonSchema(typeof(TypeWithDisplayName), TestJsonSerializerContext.Default.Options, schemaName: "override_name");

Assert.NotNull(format);
Assert.Equal("override_name", format.SchemaName);
}

[Description("abcd")]
public class SomeType
{
Expand All @@ -178,4 +206,11 @@ public class SomeType
[Description("hijk")]
public string? SomeString { get; set; }
}

[DisplayName("custom_type_name")]
[Description("Type description")]
public class TypeWithDisplayName
{
public int Value { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(decimal))] // Used in Content tests
[JsonSerializable(typeof(HostedMcpServerToolApprovalMode))]
[JsonSerializable(typeof(ChatResponseFormatTests.SomeType))]
[JsonSerializable(typeof(ChatResponseFormatTests.TypeWithDisplayName))]
[JsonSerializable(typeof(ResponseContinuationToken))]
internal sealed partial class TestJsonSerializerContext : JsonSerializerContext;
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,43 @@ public static void CreateFunctionJsonSchema_ReadsParameterDataAnnotationAttribut
AssertDeepEquals(expectedSchema.RootElement, func.JsonSchema);
}

[Fact]
public static void CreateFunctionJsonSchema_DisplayNameAttribute_UsedForTitle()
{
[DisplayName("custom_method_name")]
[Description("Method description")]
static void TestMethod(int x, int y)
{
// Test method for schema generation
}

var method = ((Action<int, int>)TestMethod).Method;
JsonElement schema = AIJsonUtilities.CreateFunctionJsonSchema(method);

using JsonDocument doc = JsonDocument.Parse(schema.GetRawText());
Assert.True(doc.RootElement.TryGetProperty("title", out JsonElement titleElement));
Assert.Equal("custom_method_name", titleElement.GetString());
Assert.True(doc.RootElement.TryGetProperty("description", out JsonElement descElement));
Assert.Equal("Method description", descElement.GetString());
}

[Fact]
public static void CreateFunctionJsonSchema_DisplayNameAttribute_CanBeOverridden()
{
[DisplayName("custom_method_name")]
static void TestMethod()
{
// Test method for schema generation
}

var method = ((Action)TestMethod).Method;
JsonElement schema = AIJsonUtilities.CreateFunctionJsonSchema(method, title: "override_title");

using JsonDocument doc = JsonDocument.Parse(schema.GetRawText());
Assert.True(doc.RootElement.TryGetProperty("title", out JsonElement titleElement));
Assert.Equal("override_title", titleElement.GetString());
}

[Fact]
public static void CreateJsonSchema_CanBeBoolean()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,39 @@ public void Metadata_DerivedFromLambda()
p => Assert.Equal("This is B", p.GetCustomAttribute<DescriptionAttribute>()?.Description));
}

[Fact]
public void Metadata_DisplayNameAttribute()
{
// Test DisplayNameAttribute on a delegate method
Func<string> funcWithDisplayName = [DisplayName("get_user_id")] () => "test";
AIFunction func = AIFunctionFactory.Create(funcWithDisplayName);
Assert.Equal("get_user_id", func.Name);
Assert.Empty(func.Description);

// Test DisplayNameAttribute with DescriptionAttribute
Func<string> funcWithBoth = [DisplayName("my_function")][Description("A test function")] () => "test";
func = AIFunctionFactory.Create(funcWithBoth);
Assert.Equal("my_function", func.Name);
Assert.Equal("A test function", func.Description);

// Test that explicit name parameter takes precedence over DisplayNameAttribute
func = AIFunctionFactory.Create(funcWithDisplayName, name: "explicit_name");
Assert.Equal("explicit_name", func.Name);

// Test DisplayNameAttribute with options
func = AIFunctionFactory.Create(funcWithDisplayName, new AIFunctionFactoryOptions());
Assert.Equal("get_user_id", func.Name);

// Test that options.Name takes precedence over DisplayNameAttribute
func = AIFunctionFactory.Create(funcWithDisplayName, new AIFunctionFactoryOptions { Name = "options_name" });
Assert.Equal("options_name", func.Name);

// Test function without DisplayNameAttribute falls back to method name
Func<string> funcWithoutDisplayName = () => "test";
func = AIFunctionFactory.Create(funcWithoutDisplayName);
Assert.Contains("Metadata_DisplayNameAttribute", func.Name); // Will contain the lambda method name
}

[Fact]
public void AIFunctionFactoryCreateOptions_ValuesPropagateToAIFunction()
{
Expand Down
Loading