Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Reflection;
using System.Text.Json.Nodes;

#pragma warning disable S1067 // Expressions should not be too complex
Expand All @@ -23,6 +24,17 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
/// </summary>
public Func<AIJsonSchemaCreateContext, JsonNode, JsonNode>? TransformSchemaNode { get; init; }

/// <summary>
/// Gets a callback that is invoked for every parameter in the <see cref="MethodBase"/> provided to
/// <see cref="AIJsonUtilities.CreateFunctionJsonSchema"/> in order to determine whether it should
/// be included in the generated schema.
/// </summary>
/// <remarks>
/// By default, when <see cref="IncludeParameter"/> is <see langword="null"/>,
/// all parameters are included in the generated schema.
/// </remarks>
public Func<ParameterInfo, bool>? IncludeParameter { get; init; }

/// <summary>
/// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums.
/// </summary>
Expand All @@ -44,19 +56,24 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
public bool RequireAllProperties { get; init; } = true;

/// <inheritdoc/>
public bool Equals(AIJsonSchemaCreateOptions? other)
{
return other is not null &&
TransformSchemaNode == other.TransformSchemaNode &&
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
RequireAllProperties == other.RequireAllProperties;
}
public bool Equals(AIJsonSchemaCreateOptions? other) =>
other is not null &&
TransformSchemaNode == other.TransformSchemaNode &&
IncludeParameter == other.IncludeParameter &&
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
RequireAllProperties == other.RequireAllProperties;

/// <inheritdoc />
public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other);

/// <inheritdoc />
public override int GetHashCode() => (TransformSchemaNode, IncludeTypeInEnumSchemas, DisallowAdditionalProperties, IncludeSchemaKeyword, RequireAllProperties).GetHashCode();
public override int GetHashCode() =>
(TransformSchemaNode,
IncludeParameter,
IncludeTypeInEnumSchemas,
DisallowAdditionalProperties,
IncludeSchemaKeyword,
RequireAllProperties).GetHashCode();
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
using System.Text.Json.Nodes;
using System.Text.Json.Schema;
using System.Text.Json.Serialization;
using System.Threading;
using Microsoft.Shared.Diagnostics;

#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
Expand Down Expand Up @@ -77,11 +76,11 @@ public static JsonElement CreateFunctionJsonSchema(
Throw.ArgumentException(nameof(parameter), "Parameter is missing a name.");
}

if (parameter.ParameterType == typeof(CancellationToken))
if (inferenceOptions.IncludeParameter is { } includeParameter &&
!includeParameter(parameter))
{
// CancellationToken is a special case that, by convention, we don't want to include in the schema.
// Invocations of methods that include a CancellationToken argument should also special-case CancellationToken
// to pass along what relevant token into the method's invocation.
// Skip parameters that should not be included in the schema.
// By default, all parameters are included.
continue;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections;
using System.Collections.Generic;
using Microsoft.Shared.Diagnostics;

#pragma warning disable CA1710 // Identifiers should have correct suffix

namespace Microsoft.Extensions.AI;

/// <summary>Represents arguments to be used with <see cref="AIFunction.InvokeAsync"/>.</summary>
/// <remarks>
/// <see cref="AIFunction.InvokeAsync"/> may be invoked with arbitary <see cref="IEnumerable{T}"/>
/// implementations. However, some <see cref="AIFunction"/> implementations may dynamically check
/// the type of the arguments, and if it's an <see cref="AIFunctionArguments"/>, use it to access
/// an <see cref="IServiceProvider"/> that's passed in separately from the arguments enumeration.
/// </remarks>
public class AIFunctionArguments : IEnumerable<KeyValuePair<string, object?>>
{
/// <summary>The arguments represented by this instance.</summary>
private readonly IEnumerable<KeyValuePair<string, object?>> _arguments;

/// <summary>Initializes a new instance of the <see cref="AIFunctionArguments"/> class.</summary>
/// <param name="arguments">The arguments represented by this instance.</param>
/// <param name="serviceProvider">Options services associated with these arguments.</param>
public AIFunctionArguments(IEnumerable<KeyValuePair<string, object?>>? arguments, IServiceProvider? serviceProvider = null)
{
_arguments = Throw.IfNull(arguments);
ServiceProvider = serviceProvider;
}

/// <summary>Gets the services associated with these arguments.</summary>
public IServiceProvider? ServiceProvider { get; }

/// <inheritdoc />
public IEnumerator<KeyValuePair<string, object?>> GetEnumerator() => _arguments.GetEnumerator();

/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_arguments).GetEnumerator();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;

namespace Microsoft.Extensions.AI;

/// <summary>Indicates that a parameter to an <see cref="AIFunction"/> should be sourced from an associated <see cref="IServiceProvider"/>.</summary>
[AttributeUsage(AttributeTargets.Parameter)]
public sealed class FromServicesAttribute : Attribute
{
/// <summary>Initializes a new instance of the <see cref="FromServicesAttribute"/> class.</summary>
public FromServicesAttribute()
{
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's a challenge with .NET in general, but we do already have a type with this name and of course the same meaning. This could lead to some pretty awkward debugging for people who believe they are using [FromServices] correctly but it's not working (either breaking MVC code or breaking MEAI code).

Possible solutions:

  • Use a different name here
  • Try to push MVC's FromServicesAttribute up into System.ComponentModel or something
  • Use a different pattern here

If we wanted a different pattern here, we could embrace the whole "AIFunctionCallContext" parameter notion. It could deal with services and cancellation tokens in one, and would make it easy for us to add any other per-call context in the future without having to special-case any other parameter types. I know we considered that before, but maybe adding further per-call context starts to tip the balance another way.

What do you think? Is this enough of a problem to warrant a different approach?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now read the other comments on this PR and see there's already been talk of this. Looks like one blocker is that we would need AIFF etc to move from M.E.AI.Abstractions to M.E.AI. But is there a plan to do that regardless?

If we could avoid creating a duplicate-named attribute that certainly sounds preferable to me. I'm surprised this didn't seem to bother @halter73 so maybe I'm overblowing the issue. What do you think, @halter73?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is there a plan to do that regardless?

No, that would require M.E.AI.Abstractions depending on DI, which I want to avoid.

Copy link
Member Author

@stephentoub stephentoub Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name

I'd started with a different name and was urged by Stephen and others to use the same one. Happy to go back to a different one if folks can agree.

Try to push MVC's FromServicesAttribute up into System.ComponentModel or something

It'd be really weird to depend on something in the Microsoft.AspNetCore.Mvc name space, even if we could figure out how to move it into a different assembly on an appropriate time line.

Just to confirm, you're factoring into your concerns that it's in an MVC namespace?

we could embrace the whole "AIFunctionCallContext" parameter

Got a lot of push back on that during API review, especially around CT... we won't be putting the CT into such a thing even if we add it back, it's too hard to discover and doesn't match ..NET conventions... we had multiple bugs filled spot not supporting CT even when we did via that mechanism But I'm also not clear how that helps, unless you're suggesting it would just carry an IServiceProvider and the ap code would query directly? We could also do that just by special-casing IServiceProvider if we wanted to.

Copy link
Member

@SteveSandersonMS SteveSandersonMS Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, you're factoring into your concerns that it's in an MVC namespace?

Yes, I do know that. It means we're relying on the developer understanding that, in a typical ASP.NET Core application, there are two different [FromServices] attributes and that they have to get the right one in each case, with the compiler/IDE not helping identify the correct one (presumably, the IDE will offer to auto-add both namespaces and leaves the developer to select which one). Picking the wrong one means it just won't work at runtime.

But I'm also not clear how that helps, unless you're suggesting it would just carry an IServiceProvider and the ap code would query directly?

Yes, that's what I meant.

We could also do that just by special-casing IServiceProvider if we wanted to.

That's a perfectly fine solution as long as we consider this a relatively niche case.

If we thought that, say, 50%+ of AIFunctions would need to obtain DI services directly (and couldn't acquire them from context, e.g., by being instance methods on some type that already has access to the desired services, which I think is also common), then it would make sense to try to make it more first-class.

One option is, for now, to support IServiceProvider parameters and thus bypass the other issues. It would still leave open the option to add other ways to resolve params-as-services in the future.

Altogether this isn't something that worries me deeply. If you and others are keen to have [FromServices] and don't think the duplicate type name will be an issue, that's OK.

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Shared.Diagnostics;
Expand Down Expand Up @@ -58,11 +59,13 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
/// </summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
/// <param name="logger">An <see cref="ILogger"/> to use for logging information about function invocation.</param>
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
/// <param name="services">An optional <see cref="IServiceProvider"/> to use for resolving services required by the <see cref="AIFunction"/> instances being invoked.</param>
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null, IServiceProvider? services = null)
: base(innerClient)
{
_logger = logger ?? NullLogger.Instance;
_logger = logger ?? (ILogger?)services?.GetService<ILogger<FunctionInvokingChatClient>>() ?? NullLogger.Instance;
_activitySource = innerClient.GetService<ActivitySource>();
Services = services;
}

/// <summary>
Expand All @@ -77,6 +80,9 @@ public static FunctionInvocationContext? CurrentContext
protected set => _currentContext.Value = value;
}

/// <summary>Gets the <see cref="IServiceProvider"/> used for resolving services required by the <see cref="AIFunction"/> instances being invoked.</summary>
public IServiceProvider? Services { get; }

/// <summary>
/// Gets or sets a value indicating whether to handle exceptions that occur during function calls.
/// </summary>
Expand Down Expand Up @@ -687,8 +693,14 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
object? result = null;
try
{
IEnumerable<KeyValuePair<string, object?>>? arguments = context.CallContent.Arguments;
if (Services is not null)
{
arguments = new AIFunctionArguments(arguments, Services);
}

CurrentContext = context;
result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
result = await context.Function.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false);
}
catch (Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public static ChatClientBuilder UseFunctionInvocation(
{
loggerFactory ??= services.GetService<ILoggerFactory>();

var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)));
var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)), services);
configure?.Invoke(chatClient);
return chatClient;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Shared.Collections;
using Microsoft.Shared.Diagnostics;

#pragma warning disable CA1031 // Do not catch general exception types
#pragma warning disable S2302 // "nameof" should be used
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields

namespace Microsoft.Extensions.AI;

/// <summary>Provides factory methods for creating commonly used implementations of <see cref="AIFunction"/>.</summary>
Expand Down Expand Up @@ -244,6 +249,32 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu

private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions)
{
AIJsonSchemaCreateOptions schemaOptions = new()
{
// This needs to be kept in sync with the shape of AIJsonSchemaCreateOptions.
TransformSchemaNode = key.SchemaOptions.TransformSchemaNode,
IncludeParameter = parameterInfo =>
{
// Explicitly exclude from the schema CancellationToken parameters as well
// as those annotated as [FromServices] or [FromKeyedServices]. These will be satisfied
// from sources other than arguments to InvokeAsync.
if (parameterInfo.ParameterType == typeof(CancellationToken) ||
parameterInfo.GetCustomAttribute<FromServicesAttribute>(inherit: true) is not null ||
parameterInfo.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is not null)
{
return false;
}

// For all other parameters, delegate to whatever behavior is specified in the options.
// If none is specified, include the parameter.
return key.SchemaOptions.IncludeParameter?.Invoke(parameterInfo) ?? true;
},
IncludeTypeInEnumSchemas = key.SchemaOptions.IncludeTypeInEnumSchemas,
DisallowAdditionalProperties = key.SchemaOptions.DisallowAdditionalProperties,
IncludeSchemaKeyword = key.SchemaOptions.IncludeSchemaKeyword,
RequireAllProperties = key.SchemaOptions.RequireAllProperties,
};

// Get marshaling delegates for parameters.
ParameterInfo[] parameters = key.Method.GetParameters();
ParameterMarshallers = new Func<IReadOnlyDictionary<string, object?>, CancellationToken, object?>[parameters.Length];
Expand All @@ -264,7 +295,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
Name,
Description,
serializerOptions,
key.SchemaOptions);
schemaOptions);
}

public string Name { get; }
Expand Down Expand Up @@ -337,6 +368,40 @@ static bool IsAsyncMethod(MethodInfo method)
cancellationToken;
}

// For DI-based parameters, try to resolve from the service provider.
if (parameter.GetCustomAttribute<FromServicesAttribute>(inherit: true) is { } fsAttr)
{
return (arguments, _) =>
{
if ((arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services &&
services.GetService(parameterType) is object service)
{
return service;
}

// No service could be resolved. Return a default value if it's optional, otherwise throw.
return parameter.HasDefaultValue ?
parameter.DefaultValue :
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' for parameter '{parameter.Name}'.");
};
}
else if (parameter.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is { } fksAttr)
{
return (arguments, _) =>
{
if ((arguments as AIFunctionArguments)?.ServiceProvider is IKeyedServiceProvider services &&
services.GetKeyedService(parameterType, fksAttr.Key) is object service)
{
return service;
}

// No service could be resolved. Return a default value if it's optional, otherwise throw.
return parameter.HasDefaultValue ?
parameter.DefaultValue :
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' with key '{fksAttr.Key}' for parameter '{parameter.Name}'.");
};
}

// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
return (arguments, _) =>
{
Expand All @@ -355,7 +420,6 @@ static bool IsAsyncMethod(MethodInfo method)

object? MarshallViaJsonRoundtrip(object value)
{
#pragma warning disable CA1031 // Do not catch general exception types
try
{
string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType()));
Expand All @@ -366,7 +430,6 @@ static bool IsAsyncMethod(MethodInfo method)
// Eat any exceptions and fall back to the original value to force a cast exception later on.
return value;
}
#pragma warning restore CA1031
}
}

Expand Down Expand Up @@ -479,9 +542,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT
#if NET
return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition);
#else
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken);
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ public static void AIJsonSchemaCreateOptions_UsesStructuralEquality()
property.SetValue(options2, transformer);
break;

case null when property.PropertyType == typeof(Func<ParameterInfo, bool>):
Func<ParameterInfo, bool> includeParameter = static (parameter) => true;
property.SetValue(options1, includeParameter);
property.SetValue(options2, includeParameter);
break;

default:
Assert.Fail($"Unexpected property type: {property.PropertyType}");
break;
Expand Down
Loading