Skip to content

Commit c9b580a

Browse files
committed
Use AIJsonSchemaCreateOptions for excluding, and recognize FromKeyedServices
1 parent 18a349d commit c9b580a

File tree

5 files changed

+82
-65
lines changed

5 files changed

+82
-65
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Reflection;
56
using System.Text.Json.Nodes;
67

78
#pragma warning disable S1067 // Expressions should not be too complex
@@ -23,6 +24,17 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
2324
/// </summary>
2425
public Func<AIJsonSchemaCreateContext, JsonNode, JsonNode>? TransformSchemaNode { get; init; }
2526

27+
/// <summary>
28+
/// Gets a callback that is invoked for every parameter in the <see cref="MethodBase"/> provided to
29+
/// <see cref="AIJsonUtilities.CreateFunctionJsonSchema"/> in order to determine whether it should
30+
/// be included in the generated schema.
31+
/// </summary>
32+
/// <remarks>
33+
/// By default, when <see cref="IncludeParameter"/> is <see langword="null"/>,
34+
/// all parameters are included in the generated schema.
35+
/// </remarks>
36+
public Func<ParameterInfo, bool>? IncludeParameter { get; init; }
37+
2638
/// <summary>
2739
/// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums.
2840
/// </summary>
@@ -44,19 +56,24 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
4456
public bool RequireAllProperties { get; init; } = true;
4557

4658
/// <inheritdoc/>
47-
public bool Equals(AIJsonSchemaCreateOptions? other)
48-
{
49-
return other is not null &&
50-
TransformSchemaNode == other.TransformSchemaNode &&
51-
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
52-
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
53-
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
54-
RequireAllProperties == other.RequireAllProperties;
55-
}
59+
public bool Equals(AIJsonSchemaCreateOptions? other) =>
60+
other is not null &&
61+
TransformSchemaNode == other.TransformSchemaNode &&
62+
IncludeParameter == other.IncludeParameter &&
63+
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
64+
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
65+
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
66+
RequireAllProperties == other.RequireAllProperties;
5667

5768
/// <inheritdoc />
5869
public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other);
5970

6071
/// <inheritdoc />
61-
public override int GetHashCode() => (TransformSchemaNode, IncludeTypeInEnumSchemas, DisallowAdditionalProperties, IncludeSchemaKeyword, RequireAllProperties).GetHashCode();
72+
public override int GetHashCode() =>
73+
(TransformSchemaNode,
74+
IncludeParameter,
75+
IncludeTypeInEnumSchemas,
76+
DisallowAdditionalProperties,
77+
IncludeSchemaKeyword,
78+
RequireAllProperties).GetHashCode();
6279
}

src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,11 @@ public static JsonElement CreateFunctionJsonSchema(
7777
Throw.ArgumentException(nameof(parameter), "Parameter is missing a name.");
7878
}
7979

80-
if (parameter.ParameterType == typeof(CancellationToken))
80+
if (inferenceOptions.IncludeParameter is { } includeParameter &&
81+
!includeParameter(parameter))
8182
{
82-
// CancellationToken is a special case that, by convention, we don't want to include in the schema.
83-
// Invocations of methods that include a CancellationToken argument should also special-case CancellationToken
84-
// to pass along what relevant token into the method's invocation.
85-
continue;
86-
}
87-
88-
if (parameter.GetCustomAttribute<SkipJsonFunctionSchemaParameterAttribute>(inherit: true) is not null)
89-
{
90-
// Skip anything explicitly requested to not be included in the schema.
83+
// Skip parameters that should not be included in the schema.
84+
// By default, all parameters are included.
9185
continue;
9286
}
9387

src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/SkipJsonFunctionSchemaParameterAttribute.cs

Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,10 @@ namespace Microsoft.Extensions.AI;
77

88
/// <summary>Indicates that a parameter to an <see cref="AIFunction"/> should be sourced from an associated <see cref="IServiceProvider"/>.</summary>
99
[AttributeUsage(AttributeTargets.Parameter)]
10-
public sealed class FromServiceProviderAttribute : SkipJsonFunctionSchemaParameterAttribute
10+
public sealed class FromServicesAttribute : Attribute
1111
{
12-
/// <summary>Initializes a new instance of the <see cref="FromServiceProviderAttribute"/> class.</summary>
13-
/// <param name="serviceKey">Optional key to use when resolving the service.</param>
14-
public FromServiceProviderAttribute(object? serviceKey = null)
12+
/// <summary>Initializes a new instance of the <see cref="FromServicesAttribute"/> class.</summary>
13+
public FromServicesAttribute()
1514
{
16-
ServiceKey = serviceKey;
1715
}
18-
19-
/// <summary>Gets the key to use when resolving the service.</summary>
20-
public object? ServiceKey { get; }
2116
}

src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,32 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu
249249

250250
private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions)
251251
{
252+
AIJsonSchemaCreateOptions schemaOptions = new()
253+
{
254+
// This needs to be kept in sync with the shape of AIJsonSchemaCreateOptions.
255+
TransformSchemaNode = key.SchemaOptions.TransformSchemaNode,
256+
IncludeParameter = parameterInfo =>
257+
{
258+
// Explicitly exclude from the schema CancellationToken parameters as well
259+
// as those annotated as [FromServices] or [FromKeyedServices]. These will be satisfied
260+
// from sources other than arguments to InvokeAsync.
261+
if (parameterInfo.ParameterType == typeof(CancellationToken) ||
262+
parameterInfo.GetCustomAttribute<FromServicesAttribute>(inherit: true) is not null ||
263+
parameterInfo.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is not null)
264+
{
265+
return false;
266+
}
267+
268+
// For all other parameters, delegate to whatever behavior is specified in the options.
269+
// If none is specified, include the parameter.
270+
return key.SchemaOptions.IncludeParameter?.Invoke(parameterInfo) ?? true;
271+
},
272+
IncludeTypeInEnumSchemas = key.SchemaOptions.IncludeTypeInEnumSchemas,
273+
DisallowAdditionalProperties = key.SchemaOptions.DisallowAdditionalProperties,
274+
IncludeSchemaKeyword = key.SchemaOptions.IncludeSchemaKeyword,
275+
RequireAllProperties = key.SchemaOptions.RequireAllProperties,
276+
};
277+
252278
// Get marshaling delegates for parameters.
253279
ParameterInfo[] parameters = key.Method.GetParameters();
254280
ParameterMarshallers = new Func<IReadOnlyDictionary<string, object?>, CancellationToken, object?>[parameters.Length];
@@ -269,7 +295,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
269295
Name,
270296
Description,
271297
serializerOptions,
272-
key.SchemaOptions);
298+
schemaOptions);
273299
}
274300

275301
public string Name { get; }
@@ -343,33 +369,36 @@ static bool IsAsyncMethod(MethodInfo method)
343369
}
344370

345371
// For DI-based parameters, try to resolve from the service provider.
346-
if (parameter.GetCustomAttribute<FromServiceProviderAttribute>(inherit: true) is FromServiceProviderAttribute fspAttr)
372+
if (parameter.GetCustomAttribute<FromServicesAttribute>(inherit: true) is { } fsAttr)
347373
{
348374
return (arguments, _) =>
349375
{
350-
if ((arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services)
376+
if ((arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services &&
377+
services.GetService(parameterType) is object service)
351378
{
352-
if (fspAttr.ServiceKey is object serviceKey)
353-
{
354-
if ((services as IKeyedServiceProvider)?.GetKeyedService(parameterType, serviceKey) is object keyedService)
355-
{
356-
return keyedService;
357-
}
358-
}
359-
else if (services.GetService(parameterType) is object service)
360-
{
361-
return service;
362-
}
379+
return service;
363380
}
364381

365-
// No service could be resolved. Does it have a default value?
366-
if (parameter.HasDefaultValue)
382+
// No service could be resolved. Return a default value if it's optional, otherwise throw.
383+
return parameter.HasDefaultValue ?
384+
parameter.DefaultValue :
385+
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' for parameter '{parameter.Name}'.");
386+
};
387+
}
388+
else if (parameter.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is { } fksAttr)
389+
{
390+
return (arguments, _) =>
391+
{
392+
if ((arguments as AIFunctionArguments)?.ServiceProvider is IKeyedServiceProvider services &&
393+
services.GetKeyedService(parameterType, fksAttr.Key) is object service)
367394
{
368-
return parameter.DefaultValue;
395+
return service;
369396
}
370397

371-
// It's a required argument, and we couldn't resolve a service. Throw.
372-
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' for parameter '{parameter.Name}'.");
398+
// No service could be resolved. Return a default value if it's optional, otherwise throw.
399+
return parameter.HasDefaultValue ?
400+
parameter.DefaultValue :
401+
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' with key '{fksAttr.Key}' for parameter '{parameter.Name}'.");
373402
};
374403
}
375404

0 commit comments

Comments
 (0)