diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs index 3a9c99c2e72..25bf7c6c0fa 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Reflection; using System.Text.Json.Nodes; #pragma warning disable S1067 // Expressions should not be too complex @@ -23,6 +24,17 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable public Func? TransformSchemaNode { get; init; } + /// + /// Gets a callback that is invoked for every parameter in the provided to + /// in order to determine whether it should + /// be included in the generated schema. + /// + /// + /// By default, when is , + /// all parameters are included in the generated schema. + /// + public Func? IncludeParameter { get; init; } + /// /// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums. /// @@ -44,19 +56,24 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable - public bool Equals(AIJsonSchemaCreateOptions? other) - { - return other is not null && - TransformSchemaNode == other.TransformSchemaNode && - IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas && - DisallowAdditionalProperties == other.DisallowAdditionalProperties && - IncludeSchemaKeyword == other.IncludeSchemaKeyword && - RequireAllProperties == other.RequireAllProperties; - } + public bool Equals(AIJsonSchemaCreateOptions? other) => + other is not null && + TransformSchemaNode == other.TransformSchemaNode && + IncludeParameter == other.IncludeParameter && + IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas && + DisallowAdditionalProperties == other.DisallowAdditionalProperties && + IncludeSchemaKeyword == other.IncludeSchemaKeyword && + RequireAllProperties == other.RequireAllProperties; /// public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other); /// - public override int GetHashCode() => (TransformSchemaNode, IncludeTypeInEnumSchemas, DisallowAdditionalProperties, IncludeSchemaKeyword, RequireAllProperties).GetHashCode(); + public override int GetHashCode() => + (TransformSchemaNode, + IncludeParameter, + IncludeTypeInEnumSchemas, + DisallowAdditionalProperties, + IncludeSchemaKeyword, + RequireAllProperties).GetHashCode(); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs index e8a962a0be8..52504cf239a 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Schema.cs @@ -12,7 +12,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Schema; using System.Text.Json.Serialization; -using System.Threading; using Microsoft.Shared.Diagnostics; #pragma warning disable S1121 // Assignments should not be made from within sub-expressions @@ -77,11 +76,11 @@ public static JsonElement CreateFunctionJsonSchema( Throw.ArgumentException(nameof(parameter), "Parameter is missing a name."); } - if (parameter.ParameterType == typeof(CancellationToken)) + if (inferenceOptions.IncludeParameter is { } includeParameter && + !includeParameter(parameter)) { - // CancellationToken is a special case that, by convention, we don't want to include in the schema. - // Invocations of methods that include a CancellationToken argument should also special-case CancellationToken - // to pass along what relevant token into the method's invocation. + // Skip parameters that should not be included in the schema. + // By default, all parameters are included. continue; } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AIFunctionArguments.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AIFunctionArguments.cs new file mode 100644 index 00000000000..951ea1c6c35 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AIFunctionArguments.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable CA1710 // Identifiers should have correct suffix + +namespace Microsoft.Extensions.AI; + +/// Represents arguments to be used with . +/// +/// may be invoked with arbitary +/// implementations. However, some implementations may dynamically check +/// the type of the arguments, and if it's an , use it to access +/// an that's passed in separately from the arguments enumeration. +/// +public class AIFunctionArguments : IEnumerable> +{ + /// The arguments represented by this instance. + private readonly IEnumerable> _arguments; + + /// Initializes a new instance of the class. + /// The arguments represented by this instance. + /// Options services associated with these arguments. + public AIFunctionArguments(IEnumerable>? arguments, IServiceProvider? serviceProvider = null) + { + _arguments = Throw.IfNull(arguments); + ServiceProvider = serviceProvider; + } + + /// Gets the services associated with these arguments. + public IServiceProvider? ServiceProvider { get; } + + /// + public IEnumerator> GetEnumerator() => _arguments.GetEnumerator(); + + /// + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_arguments).GetEnumerator(); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FromServicesAttribute.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FromServicesAttribute.cs new file mode 100644 index 00000000000..936f583a152 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FromServicesAttribute.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +/// Indicates that a parameter to an should be sourced from an associated . +[AttributeUsage(AttributeTargets.Parameter)] +public sealed class FromServicesAttribute : Attribute +{ + /// Initializes a new instance of the class. + public FromServicesAttribute() + { + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index a64ebf7d61d..f43523cbbfd 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -8,6 +8,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Shared.Diagnostics; @@ -58,11 +59,13 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient /// /// The underlying , or the next instance in a chain of clients. /// An to use for logging information about function invocation. - public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null) + /// An optional to use for resolving services required by the instances being invoked. + public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null, IServiceProvider? services = null) : base(innerClient) { - _logger = logger ?? NullLogger.Instance; + _logger = logger ?? (ILogger?)services?.GetService>() ?? NullLogger.Instance; _activitySource = innerClient.GetService(); + Services = services; } /// @@ -77,6 +80,9 @@ public static FunctionInvocationContext? CurrentContext protected set => _currentContext.Value = value; } + /// Gets the used for resolving services required by the instances being invoked. + public IServiceProvider? Services { get; } + /// /// Gets or sets a value indicating whether to handle exceptions that occur during function calls. /// @@ -687,8 +693,14 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul object? result = null; try { + IEnumerable>? arguments = context.CallContent.Arguments; + if (Services is not null) + { + arguments = new AIFunctionArguments(arguments, Services); + } + CurrentContext = context; - result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false); + result = await context.Function.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); } catch (Exception e) { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs index 0d2d6f8bc9b..9e4c6631e45 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs @@ -32,7 +32,7 @@ public static ChatClientBuilder UseFunctionInvocation( { loggerFactory ??= services.GetService(); - var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient))); + var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)), services); configure?.Invoke(chatClient); return chatClient; }); diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index d8be8e9f128..33f594477ea 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -14,9 +14,14 @@ using System.Text.Json.Serialization.Metadata; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Shared.Collections; using Microsoft.Shared.Diagnostics; +#pragma warning disable CA1031 // Do not catch general exception types +#pragma warning disable S2302 // "nameof" should be used +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + namespace Microsoft.Extensions.AI; /// Provides factory methods for creating commonly used implementations of . @@ -244,6 +249,32 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions) { + AIJsonSchemaCreateOptions schemaOptions = new() + { + // This needs to be kept in sync with the shape of AIJsonSchemaCreateOptions. + TransformSchemaNode = key.SchemaOptions.TransformSchemaNode, + IncludeParameter = parameterInfo => + { + // Explicitly exclude from the schema CancellationToken parameters as well + // as those annotated as [FromServices] or [FromKeyedServices]. These will be satisfied + // from sources other than arguments to InvokeAsync. + if (parameterInfo.ParameterType == typeof(CancellationToken) || + parameterInfo.GetCustomAttribute(inherit: true) is not null || + parameterInfo.GetCustomAttribute(inherit: true) is not null) + { + return false; + } + + // For all other parameters, delegate to whatever behavior is specified in the options. + // If none is specified, include the parameter. + return key.SchemaOptions.IncludeParameter?.Invoke(parameterInfo) ?? true; + }, + IncludeTypeInEnumSchemas = key.SchemaOptions.IncludeTypeInEnumSchemas, + DisallowAdditionalProperties = key.SchemaOptions.DisallowAdditionalProperties, + IncludeSchemaKeyword = key.SchemaOptions.IncludeSchemaKeyword, + RequireAllProperties = key.SchemaOptions.RequireAllProperties, + }; + // Get marshaling delegates for parameters. ParameterInfo[] parameters = key.Method.GetParameters(); ParameterMarshallers = new Func, CancellationToken, object?>[parameters.Length]; @@ -264,7 +295,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions Name, Description, serializerOptions, - key.SchemaOptions); + schemaOptions); } public string Name { get; } @@ -337,6 +368,40 @@ static bool IsAsyncMethod(MethodInfo method) cancellationToken; } + // For DI-based parameters, try to resolve from the service provider. + if (parameter.GetCustomAttribute(inherit: true) is { } fsAttr) + { + return (arguments, _) => + { + if ((arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services && + services.GetService(parameterType) is object service) + { + return service; + } + + // No service could be resolved. Return a default value if it's optional, otherwise throw. + return parameter.HasDefaultValue ? + parameter.DefaultValue : + throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' for parameter '{parameter.Name}'."); + }; + } + else if (parameter.GetCustomAttribute(inherit: true) is { } fksAttr) + { + return (arguments, _) => + { + if ((arguments as AIFunctionArguments)?.ServiceProvider is IKeyedServiceProvider services && + services.GetKeyedService(parameterType, fksAttr.Key) is object service) + { + return service; + } + + // No service could be resolved. Return a default value if it's optional, otherwise throw. + return parameter.HasDefaultValue ? + parameter.DefaultValue : + throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' with key '{fksAttr.Key}' for parameter '{parameter.Name}'."); + }; + } + // For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary. return (arguments, _) => { @@ -355,7 +420,6 @@ static bool IsAsyncMethod(MethodInfo method) object? MarshallViaJsonRoundtrip(object value) { -#pragma warning disable CA1031 // Do not catch general exception types try { string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType())); @@ -366,7 +430,6 @@ static bool IsAsyncMethod(MethodInfo method) // Eat any exceptions and fall back to the original value to force a cast exception later on. return value; } -#pragma warning restore CA1031 } } @@ -479,9 +542,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT #if NET return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition); #else -#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance; -#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken); #endif } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs index 05084c102ab..0d1a2508d02 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Utilities/AIJsonUtilitiesTests.cs @@ -86,6 +86,12 @@ public static void AIJsonSchemaCreateOptions_UsesStructuralEquality() property.SetValue(options2, transformer); break; + case null when property.PropertyType == typeof(Func): + Func includeParameter = static (parameter) => true; + property.SetValue(options1, includeParameter); + property.SetValue(options2, includeParameter); + break; + default: Assert.Fail($"Unexpected property type: {property.PropertyType}"); break;