diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs
index 667a956a2f7..55517a7e2b0 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs
@@ -1,12 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
-using System.Collections.Generic;
using System.Reflection;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
-using Microsoft.Shared.Collections;
namespace Microsoft.Extensions.AI;
@@ -56,19 +54,15 @@ public abstract class AIFunction : AITool
/// The to monitor for cancellation requests. The default is .
/// The result of the function's execution.
public Task InvokeAsync(
- IEnumerable>? arguments = null,
- CancellationToken cancellationToken = default)
- {
- arguments ??= EmptyReadOnlyDictionary.Instance;
-
- return InvokeCoreAsync(arguments, cancellationToken);
- }
+ AIFunctionArguments? arguments = null,
+ CancellationToken cancellationToken = default) =>
+ InvokeCoreAsync(arguments ?? [], cancellationToken);
/// Invokes the and returns its result.
/// The arguments to pass to the function's invocation.
/// The to monitor for cancellation requests.
/// The result of the function's execution.
protected abstract Task InvokeCoreAsync(
- IEnumerable> arguments,
+ AIFunctionArguments arguments,
CancellationToken cancellationToken);
}
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionArguments.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionArguments.cs
new file mode 100644
index 00000000000..ac77bde4d21
--- /dev/null
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionArguments.cs
@@ -0,0 +1,119 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+
+#pragma warning disable SA1111 // Closing parenthesis should be on line of last parameter
+#pragma warning disable SA1112 // Closing parenthesis should be on line of opening parenthesis
+#pragma warning disable SA1114 // Parameter list should follow declaration
+#pragma warning disable CA1710 // Identifiers should have correct suffix
+
+namespace Microsoft.Extensions.AI;
+
+/// Represents arguments to be used with .
+///
+/// is a dictionary of name/value pairs that are used
+/// as inputs to an . However, an instance carries additional non-nominal
+/// information, such as an optional that can be used by
+/// an if it needs to resolve any services from a dependency injection
+/// container.
+///
+public sealed class AIFunctionArguments : IDictionary, IReadOnlyDictionary
+{
+ /// The nominal arguments.
+ private readonly Dictionary _arguments;
+
+ /// Initializes a new instance of the class.
+ public AIFunctionArguments()
+ {
+ _arguments = [];
+ }
+
+ ///
+ /// Initializes a new instance of the class containing
+ /// the specified .
+ ///
+ /// The arguments represented by this instance.
+ ///
+ /// The reference will be stored if the instance is
+ /// already a , in which case all dictionary
+ /// operations on this instance will be routed directly to that instance. If
+ /// is not a dictionary, a shallow clone of its data will be used to populate this
+ /// instance. A is treated as an
+ /// empty dictionary.
+ ///
+ public AIFunctionArguments(IDictionary? arguments)
+ {
+ _arguments =
+ arguments is null ? [] :
+ arguments as Dictionary ??
+ new Dictionary(arguments);
+ }
+
+ /// Gets or sets services optionally associated with these arguments.
+ public IServiceProvider? Services { get; set; }
+
+ ///
+ public object? this[string key]
+ {
+ get => _arguments[key];
+ set => _arguments[key] = value;
+ }
+
+ ///
+ public ICollection Keys => _arguments.Keys;
+
+ ///
+ public ICollection Values => _arguments.Values;
+
+ ///
+ public int Count => _arguments.Count;
+
+ ///
+ bool ICollection>.IsReadOnly => false;
+
+ ///
+ IEnumerable IReadOnlyDictionary.Keys => Keys;
+
+ ///
+ IEnumerable IReadOnlyDictionary.Values => Values;
+
+ ///
+ public void Add(string key, object? value) => _arguments.Add(key, value);
+
+ ///
+ void ICollection>.Add(KeyValuePair item) =>
+ ((ICollection>)_arguments).Add(item);
+
+ ///
+ public void Clear() => _arguments.Clear();
+
+ ///
+ bool ICollection>.Contains(KeyValuePair item) =>
+ ((ICollection>)_arguments).Contains(item);
+
+ ///
+ public bool ContainsKey(string key) => _arguments.ContainsKey(key);
+
+ ///
+ public void CopyTo(KeyValuePair[] array, int arrayIndex) =>
+ ((ICollection>)_arguments).CopyTo(array, arrayIndex);
+
+ ///
+ public IEnumerator> GetEnumerator() => _arguments.GetEnumerator();
+
+ ///
+ public bool Remove(string key) => _arguments.Remove(key);
+
+ ///
+ bool ICollection>.Remove(KeyValuePair item) =>
+ ((ICollection>)_arguments).Remove(item);
+
+ ///
+ public bool TryGetValue(string key, out object? value) => _arguments.TryGetValue(key, out value);
+
+ ///
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+}
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs
index 53e233d0979..69ff3144582 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonSchemaCreateOptions.cs
@@ -13,7 +13,7 @@ namespace Microsoft.Extensions.AI;
///
/// Provides options for configuring the behavior of JSON schema creation functionality.
///
-public sealed class AIJsonSchemaCreateOptions : IEquatable
+public sealed record class AIJsonSchemaCreateOptions
{
///
/// Gets the default options instance.
@@ -56,26 +56,4 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable
public bool RequireAllProperties { get; init; } = true;
-
- ///
- public bool Equals(AIJsonSchemaCreateOptions? other) =>
- other is not null &&
- TransformSchemaNode == other.TransformSchemaNode &&
- IncludeParameter == other.IncludeParameter &&
- IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
- DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
- IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
- RequireAllProperties == other.RequireAllProperties;
-
- ///
- public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other);
-
- ///
- public override int GetHashCode() =>
- (TransformSchemaNode,
- IncludeParameter,
- IncludeTypeInEnumSchemas,
- DisallowAdditionalProperties,
- IncludeSchemaKeyword,
- RequireAllProperties).GetHashCode();
}
diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs
index f7d1c4bf036..c85d7791cb6 100644
--- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Utilities/AIJsonUtilities.Defaults.cs
@@ -108,6 +108,7 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(Embedding))]
[JsonSerializable(typeof(Embedding))]
[JsonSerializable(typeof(AIContent))]
+ [JsonSerializable(typeof(AIFunctionArguments))]
[EditorBrowsable(EditorBrowsableState.Never)] // Never use JsonContext directly, use DefaultOptions instead.
private sealed partial class JsonContext : JsonSerializerContext;
}
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs
index 9edcac55c5e..1e27f3a081f 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs
@@ -437,7 +437,7 @@ private sealed class MetadataOnlyAIFunction(string name, string description, Jso
public override string Description => description;
public override JsonElement JsonSchema => schema;
public override IReadOnlyDictionary AdditionalProperties => additionalProps;
- protected override Task InvokeCoreAsync(IEnumerable> arguments, CancellationToken cancellationToken) =>
+ protected override Task InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) =>
throw new InvalidOperationException($"The AI function '{Name}' does not support being invoked.");
}
diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs
index d74505e64f8..7d7c12087b4 100644
--- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIRealtimeExtensions.cs
@@ -52,6 +52,7 @@ public static ConversationFunctionTool ToConversationFunctionTool(this AIFunctio
/// The available tools.
/// An optional flag specifying whether to disclose detailed exception information to the model. The default value is .
/// An optional that controls JSON handling.
+ /// An optional to use for resolving services required by instances being invoked.
/// An optional .
/// A that represents the completion of processing, including invoking any asynchronous tools.
/// is .
@@ -63,6 +64,7 @@ public static async Task HandleToolCallsAsync(
IReadOnlyList tools,
bool? detailedErrors = false,
JsonSerializerOptions? jsonSerializerOptions = null,
+ IServiceProvider? functionInvocationServices = null,
CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(session);
@@ -73,7 +75,7 @@ public static async Task HandleToolCallsAsync(
{
// If we need to call a tool to update the model, do so
if (!string.IsNullOrEmpty(itemFinished.FunctionName)
- && await itemFinished.GetFunctionCallOutputAsync(tools, detailedErrors, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) is { } output)
+ && await itemFinished.GetFunctionCallOutputAsync(tools, detailedErrors, jsonSerializerOptions, functionInvocationServices, cancellationToken).ConfigureAwait(false) is { } output)
{
await session.AddItemAsync(output, cancellationToken).ConfigureAwait(false);
}
@@ -93,6 +95,7 @@ public static async Task HandleToolCallsAsync(
IReadOnlyList tools,
bool? detailedErrors = false,
JsonSerializerOptions? jsonSerializerOptions = null,
+ IServiceProvider? functionInvocationServices = null,
CancellationToken cancellationToken = default)
{
if (!string.IsNullOrEmpty(update.FunctionName)
@@ -107,7 +110,7 @@ public static async Task HandleToolCallsAsync(
try
{
- var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, cancellationToken).ConfigureAwait(false);
+ var result = await aiFunction.InvokeAsync(new(functionCallContent.Arguments) { Services = functionInvocationServices }, cancellationToken).ConfigureAwait(false);
var resultJson = JsonSerializer.Serialize(result, jsonOptions.GetTypeInfo(typeof(object)));
return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, resultJson);
}
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs
index 8dca904ccd0..81998ab52ef 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs
@@ -24,17 +24,34 @@ public sealed class FunctionInvocationContext
private AIFunction _function = _nopFunction;
/// The function call content information associated with this invocation.
- private FunctionCallContent _callContent = new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary.Instance);
+ private FunctionCallContent? _callContent;
+
+ /// The arguments used with the function.
+ private AIFunctionArguments? _arguments;
/// Initializes a new instance of the class.
public FunctionInvocationContext()
{
}
+ /// Gets or sets the AI function to be invoked.
+ public AIFunction Function
+ {
+ get => _function;
+ set => _function = Throw.IfNull(value);
+ }
+
+ /// Gets or sets the arguments associated with this invocation.
+ public AIFunctionArguments Arguments
+ {
+ get => _arguments ??= [];
+ set => _arguments = Throw.IfNull(value);
+ }
+
/// Gets or sets the function call content information associated with this invocation.
public FunctionCallContent CallContent
{
- get => _callContent;
+ get => _callContent ??= new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary.Instance);
set => _callContent = Throw.IfNull(value);
}
@@ -48,13 +65,6 @@ public IList Messages
/// Gets or sets the chat options associated with the operation that initiated this function call request.
public ChatOptions? Options { get; set; }
- /// Gets or sets the AI function to be invoked.
- public AIFunction Function
- {
- get => _function;
- set => _function = Throw.IfNull(value);
- }
-
/// Gets or sets the number of this iteration with the underlying client.
///
/// The initial request to the client that passes along the chat contents provided to the
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
index cf0d25b3f17..ccc00264344 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
@@ -48,6 +48,9 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
/// The for the current function invocation.
private static readonly AsyncLocal _currentContext = new();
+ /// Optional services used for function invocation.
+ private readonly IServiceProvider? _functionInvocationServices;
+
/// The logger to use for logging information about function invocation.
private readonly ILogger _logger;
@@ -62,12 +65,14 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
/// Initializes a new instance of the class.
///
/// The underlying , or the next instance in a chain of clients.
- /// An to use for logging information about function invocation.
- public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
+ /// An to use for logging information about function invocation.
+ /// An optional to use for resolving services required by the instances being invoked.
+ public FunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null)
: base(innerClient)
{
- _logger = logger ?? NullLogger.Instance;
+ _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
_activitySource = innerClient.GetService();
+ _functionInvocationServices = functionInvocationServices;
}
///
@@ -601,10 +606,13 @@ private async Task ProcessFunctionCallAsync(
FunctionInvocationContext context = new()
{
+ Function = function,
+ Arguments = new(callContent.Arguments) { Services = _functionInvocationServices },
+
Messages = messages,
Options = options,
+
CallContent = callContent,
- Function = function,
Iteration = iteration,
FunctionCallIndex = functionCallIndex,
FunctionCount = callContents.Count,
@@ -710,7 +718,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
startingTimestamp = Stopwatch.GetTimestamp();
if (_logger.IsEnabled(LogLevel.Trace))
{
- LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.CallContent.Arguments, context.Function.JsonSerializerOptions));
+ LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.Arguments, context.Function.JsonSerializerOptions));
}
else
{
@@ -721,8 +729,8 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
object? result = null;
try
{
- CurrentContext = context;
- result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
+ CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit
+ result = await context.Function.InvokeAsync(context.Arguments, cancellationToken).ConfigureAwait(false);
}
catch (Exception e)
{
diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs
index f2a60718ea9..a4c0c2b589d 100644
--- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs
@@ -33,7 +33,7 @@ public static ChatClientBuilder UseFunctionInvocation(
{
loggerFactory ??= services.GetService();
- var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)));
+ var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory, services);
configure?.Invoke(chatClient);
return chatClient;
});
diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
index f81ee89fb6d..b6f7e5335a4 100644
--- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
+++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
@@ -6,7 +6,9 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
+#if !NET
using System.Linq;
+#endif
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text.Json;
@@ -17,6 +19,9 @@
using Microsoft.Shared.Collections;
using Microsoft.Shared.Diagnostics;
+#pragma warning disable CA1031 // Do not catch general exception types
+#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
+
namespace Microsoft.Extensions.AI;
/// Provides factory methods for creating commonly used implementations of .
@@ -188,25 +193,17 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor,
public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method;
public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema;
public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions;
+
protected override Task InvokeCoreAsync(
- IEnumerable>? arguments,
+ AIFunctionArguments arguments,
CancellationToken cancellationToken)
{
var paramMarshallers = FunctionDescriptor.ParameterMarshallers;
object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : [];
- IReadOnlyDictionary argDict =
- arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary.Instance :
- arguments as IReadOnlyDictionary ??
- arguments.
-#if NET8_0_OR_GREATER
- ToDictionary();
-#else
- ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
-#endif
for (int i = 0; i < args.Length; i++)
{
- args[i] = paramMarshallers[i](argDict, cancellationToken);
+ args[i] = paramMarshallers[i](arguments, cancellationToken);
}
return FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken);
@@ -248,9 +245,36 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu
private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions)
{
+ // Augment the schema options to exclude AIFunctionArguments/IServiceProvider from the schema,
+ // as it'll be satisfied from AIFunctionArguments.
+ static bool IncludeNonAIFunctionArgumentParameter(ParameterInfo parameterInfo) =>
+ parameterInfo.ParameterType != typeof(AIFunctionArguments) &&
+ parameterInfo.ParameterType != typeof(IServiceProvider);
+
+ AIJsonSchemaCreateOptions schemaOptions;
+ if (key.SchemaOptions.IncludeParameter is not null)
+ {
+ // There's an existing filter, so delegate to it after filtering out IServiceProvider.
+ var existingIncludeParameter = key.SchemaOptions.IncludeParameter;
+ schemaOptions = key.SchemaOptions with
+ {
+ IncludeParameter = parameterInfo =>
+ IncludeNonAIFunctionArgumentParameter(parameterInfo) &&
+ existingIncludeParameter(parameterInfo),
+ };
+ }
+ else
+ {
+ // There's no existing parameter filter, so only exclude IServiceProvider.
+ schemaOptions = key.SchemaOptions with
+ {
+ IncludeParameter = IncludeNonAIFunctionArgumentParameter,
+ };
+ }
+
// Get marshaling delegates for parameters.
ParameterInfo[] parameters = key.Method.GetParameters();
- ParameterMarshallers = new Func, CancellationToken, object?>[parameters.Length];
+ ParameterMarshallers = new Func[parameters.Length];
for (int i = 0; i < parameters.Length; i++)
{
ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i]);
@@ -268,7 +292,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
Name,
Description,
serializerOptions,
- key.SchemaOptions);
+ schemaOptions);
}
public string Name { get; }
@@ -276,7 +300,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
public MethodInfo Method { get; }
public JsonSerializerOptions JsonSerializerOptions { get; }
public JsonElement JsonSchema { get; }
- public Func, CancellationToken, object?>[] ParameterMarshallers { get; }
+ public Func[] ParameterMarshallers { get; }
public Func> ReturnParameterMarshaller { get; }
public ReflectionAIFunction? CachedDefaultInstance { get; set; }
@@ -320,7 +344,7 @@ static bool IsAsyncMethod(MethodInfo method)
///
/// Gets a delegate for handling the marshaling of a parameter.
///
- private static Func, CancellationToken, object?> GetParameterMarshaller(
+ private static Func GetParameterMarshaller(
JsonSerializerOptions serializerOptions,
ParameterInfo parameter)
{
@@ -341,6 +365,28 @@ static bool IsAsyncMethod(MethodInfo method)
cancellationToken;
}
+ // For AIFunctionArgument parameters, we always bind to the arguments passed directly to InvokeAsync.
+ if (parameterType == typeof(AIFunctionArguments))
+ {
+ return static (arguments, _) => arguments;
+ }
+
+ // For IServiceProvider parameters, we always bind to the services passed directly to InvokeAsync via AIFunctionArguments.
+ // However, those Services are not required, so we throw if they're not available and are required.
+ if (parameterType == typeof(IServiceProvider))
+ {
+ return (arguments, _) =>
+ {
+ IServiceProvider? services = arguments.Services;
+ if (services is null && !parameter.HasDefaultValue)
+ {
+ Throw.ArgumentException(nameof(arguments), $"An {nameof(IServiceProvider)} was not provided for the {parameter.Name} parameter.");
+ }
+
+ return services;
+ };
+ }
+
// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
return (arguments, _) =>
{
@@ -359,7 +405,6 @@ static bool IsAsyncMethod(MethodInfo method)
object? MarshallViaJsonRoundtrip(object value)
{
-#pragma warning disable CA1031 // Do not catch general exception types
try
{
string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType()));
@@ -370,7 +415,6 @@ static bool IsAsyncMethod(MethodInfo method)
// Eat any exceptions and fall back to the original value to force a cast exception later on.
return value;
}
-#pragma warning restore CA1031
}
}
@@ -482,9 +526,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT
#if NET
return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition);
#else
-#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
-#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken);
#endif
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs
index 103bc884022..be52d069936 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs
@@ -96,7 +96,7 @@ public void ItShouldBeSerializableAndDeserializableWithException()
[Fact]
public async Task AIFunctionFactory_ObjectValues_Converted()
{
- Dictionary arguments = new()
+ AIFunctionArguments arguments = new()
{
["a"] = new DayOfWeek[] { DayOfWeek.Monday, DayOfWeek.Tuesday, DayOfWeek.Wednesday },
["b"] = 123.4M,
@@ -116,7 +116,7 @@ public async Task AIFunctionFactory_ObjectValues_Converted()
[Fact]
public async Task AIFunctionFactory_JsonElementValues_ValuesDeserialized()
{
- Dictionary arguments = JsonSerializer.Deserialize>("""
+ AIFunctionArguments arguments = JsonSerializer.Deserialize("""
{
"a": ["Monday", "Tuesday", "Wednesday"],
"b": 123.4,
@@ -164,7 +164,7 @@ public async Task AIFunctionFactory_JsonDocumentValues_ValuesDeserialized()
""", TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value);
AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, serializerOptions: TestJsonSerializerContext.Default.Options);
- var result = await function.InvokeAsync(arguments);
+ var result = await function.InvokeAsync(new(arguments));
AssertExtensions.EqualFunctionCallResults(123.4, result);
}
@@ -185,14 +185,14 @@ public async Task AIFunctionFactory_JsonNodeValues_ValuesDeserialized()
""", TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value);
AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, serializerOptions: TestJsonSerializerContext.Default.Options);
- var result = await function.InvokeAsync(arguments);
+ var result = await function.InvokeAsync(new(arguments));
AssertExtensions.EqualFunctionCallResults(123.4, result);
}
[Fact]
public async Task TypelessAIFunction_JsonDocumentValues_AcceptsArguments()
{
- var arguments = JsonSerializer.Deserialize>("""
+ AIFunctionArguments arguments = new(JsonSerializer.Deserialize>("""
{
"a": "string",
"b": 123.4,
@@ -201,7 +201,7 @@ public async Task TypelessAIFunction_JsonDocumentValues_AcceptsArguments()
"e": ["Monday", "Tuesday", "Wednesday"],
"f": null
}
- """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value);
+ """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value));
var result = await NetTypelessAIFunction.Instance.InvokeAsync(arguments);
Assert.Same(result, arguments);
@@ -210,7 +210,7 @@ public async Task TypelessAIFunction_JsonDocumentValues_AcceptsArguments()
[Fact]
public async Task TypelessAIFunction_JsonElementValues_AcceptsArguments()
{
- Dictionary arguments = JsonSerializer.Deserialize>("""
+ AIFunctionArguments arguments = new(JsonSerializer.Deserialize>("""
{
"a": "string",
"b": 123.4,
@@ -219,7 +219,7 @@ public async Task TypelessAIFunction_JsonElementValues_AcceptsArguments()
"e": ["Monday", "Tuesday", "Wednesday"],
"f": null
}
- """, TestJsonSerializerContext.Default.Options)!;
+ """, TestJsonSerializerContext.Default.Options)!);
var result = await NetTypelessAIFunction.Instance.InvokeAsync(arguments);
Assert.Same(result, arguments);
@@ -228,7 +228,7 @@ public async Task TypelessAIFunction_JsonElementValues_AcceptsArguments()
[Fact]
public async Task TypelessAIFunction_JsonNodeValues_AcceptsArguments()
{
- var arguments = JsonSerializer.Deserialize>("""
+ AIFunctionArguments arguments = new(JsonSerializer.Deserialize>("""
{
"a": "string",
"b": 123.4,
@@ -237,7 +237,7 @@ public async Task TypelessAIFunction_JsonNodeValues_AcceptsArguments()
"e": ["Monday", "Tuesday", "Wednesday"],
"f": null
}
- """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value);
+ """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value));
var result = await NetTypelessAIFunction.Instance.InvokeAsync(arguments);
Assert.Same(result, arguments);
@@ -251,7 +251,7 @@ private sealed class NetTypelessAIFunction : AIFunction
public override string Name => "NetTypeless";
public override string Description => "AIFunction with parameters that lack .NET types";
- protected override Task InvokeCoreAsync(IEnumerable>? arguments, CancellationToken cancellationToken) =>
+ protected override Task InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) =>
Task.FromResult(arguments);
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionArgumentsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionArgumentsTests.cs
new file mode 100644
index 00000000000..64e8163b4ad
--- /dev/null
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionArgumentsTests.cs
@@ -0,0 +1,171 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.Extensions.DependencyInjection;
+using Xunit;
+
+namespace Microsoft.Extensions.AI;
+
+public class AIFunctionArgumentsTests
+{
+ [Fact]
+ public void NullArg_RoundtripsAsEmpty()
+ {
+ var args = new AIFunctionArguments();
+ Assert.Null(args.Services);
+ Assert.Empty(args);
+
+ args.Add("key", "value");
+ Assert.Single(args);
+ }
+
+ [Fact]
+ public void EmptyArg_RoundtripsAsEmpty()
+ {
+ var args = new AIFunctionArguments(new Dictionary());
+ Assert.Null(args.Services);
+ Assert.Empty(args);
+
+ args.Add("key", "value");
+ Assert.Single(args);
+ }
+
+ [Fact]
+ public void NonEmptyArg_RoundtripsAsEmpty()
+ {
+ var args = new AIFunctionArguments(new Dictionary
+ {
+ ["key"] = "value"
+ });
+ Assert.Null(args.Services);
+ Assert.Single(args);
+ }
+
+ [Fact]
+ public void Services_Roundtrips()
+ {
+ ServiceCollection sc = new();
+ IServiceProvider sp = sc.BuildServiceProvider();
+
+ var args = new AIFunctionArguments
+ {
+ Services = sp
+ };
+
+ Assert.Same(sp, args.Services);
+ Assert.Empty(args);
+
+ args.Add("key", "value");
+ Assert.Single(args);
+ }
+
+ [Fact]
+ public void IReadOnlyDictionary_ImplementsInterface()
+ {
+ ServiceCollection sc = new();
+ IServiceProvider sp = sc.BuildServiceProvider();
+
+ IReadOnlyDictionary args = new AIFunctionArguments(new Dictionary
+ {
+ ["key1"] = "value1",
+ ["key2"] = "value2",
+ });
+
+ Assert.Equal(2, args.Count);
+
+ Assert.True(args.ContainsKey("key1"));
+ Assert.True(args.ContainsKey("key2"));
+ Assert.False(args.ContainsKey("KEY1"));
+
+ Assert.Equal(["key1", "key2"], args.Keys);
+ Assert.Equal(["value1", "value2"], args.Values);
+
+ Assert.Equal("value1", args["key1"]);
+ Assert.Equal("value2", args["key2"]);
+
+ Assert.Equal(new[] { "key1", "key2" }, args.Keys);
+ Assert.Equal(new[] { "value1", "value2" }, args.Values);
+
+ Assert.True(args.TryGetValue("key1", out var value));
+ Assert.Equal("value1", value);
+ Assert.False(args.TryGetValue("key3", out value));
+ Assert.Null(value);
+
+ Assert.Equal([
+ new KeyValuePair("key1", "value1"),
+ new KeyValuePair("key2", "value2"),
+ ], args.ToArray());
+ }
+
+ [Fact]
+ public void IDictionary_ImplementsInterface()
+ {
+ ServiceCollection sc = new();
+ IServiceProvider sp = sc.BuildServiceProvider();
+
+ IDictionary args = new AIFunctionArguments(new Dictionary
+ {
+ ["key1"] = "value1",
+ ["key2"] = "value2",
+ });
+
+ Assert.Equal(2, args.Count);
+ Assert.False(args.IsReadOnly);
+
+ Assert.True(args.ContainsKey("key1"));
+ Assert.True(args.ContainsKey("key2"));
+ Assert.False(args.ContainsKey("KEY1"));
+
+ Assert.Equal("value1", args["key1"]);
+ Assert.Equal("value2", args["key2"]);
+
+ Assert.Equal(new[] { "key1", "key2" }, args.Keys);
+ Assert.Equal(new[] { "value1", "value2" }, args.Values);
+
+ Assert.True(args.TryGetValue("key1", out var value));
+ Assert.Equal("value1", value);
+ Assert.False(args.TryGetValue("key3", out value));
+ Assert.Null(value);
+
+ Assert.Equal([
+ new KeyValuePair("key1", "value1"),
+ new KeyValuePair("key2", "value2"),
+ ], args.ToArray());
+
+ args.Add("key3", "value3");
+ Assert.Equal(3, args.Count);
+ Assert.True(args.ContainsKey("key3"));
+ Assert.Equal("value3", args["key3"]);
+
+ args["key4"] = "value4";
+ Assert.Equal(4, args.Count);
+ Assert.True(args.ContainsKey("key4"));
+ Assert.Equal("value4", args["key4"]);
+
+ args.Remove("key1");
+ Assert.Equal(3, args.Count);
+ Assert.False(args.ContainsKey("key1"));
+ Assert.Equal("value2", args["key2"]);
+ Assert.Equal("value3", args["key3"]);
+ Assert.Equal("value4", args["key4"]);
+
+ args.Clear();
+ Assert.Empty(args);
+
+ args.Add(new KeyValuePair("key1", "value1"));
+ Assert.Single(args);
+ Assert.True(args.ContainsKey("key1"));
+ Assert.Equal("value1", args["key1"]);
+
+ args.Add(new KeyValuePair("key2", "value2"));
+ Assert.Equal(2, args.Count);
+ Assert.True(args.ContainsKey("key2"));
+ Assert.Equal("value2", args["key2"]);
+
+ Assert.Equal(["key1", "key2"], args.Keys);
+ Assert.Equal(["value1", "value2"], args.Values);
+ }
+}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs
index 1ced6ae3185..f084dda4367 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs
@@ -1,7 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
-using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
@@ -10,22 +9,6 @@ namespace Microsoft.Extensions.AI;
public class AIFunctionTests
{
- [Fact]
- public async Task InvokeAsync_UsesDefaultEmptyCollectionForNullArgsAsync()
- {
- DerivedAIFunction f = new();
-
- using CancellationTokenSource cts = new();
- var result1 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!;
-
- Assert.NotNull(result1.Item1);
- Assert.Empty(result1.Item1);
- Assert.Equal(cts.Token, result1.Item2);
-
- var result2 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!;
- Assert.Same(result1.Item1, result2.Item1);
- }
-
[Fact]
public void ToString_ReturnsName()
{
@@ -38,7 +21,7 @@ private sealed class DerivedAIFunction : AIFunction
public override string Name => "name";
public override string Description => "";
- protected override Task InvokeCoreAsync(IEnumerable> arguments, CancellationToken cancellationToken)
+ protected override Task InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken)
{
Assert.NotNull(arguments);
return Task.FromResult((arguments, cancellationToken));
diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs
index 4af54d6cfd9..0362be74821 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs
@@ -19,6 +19,7 @@ namespace Microsoft.Extensions.AI;
[JsonSerializable(typeof(ChatOptions))]
[JsonSerializable(typeof(EmbeddingGenerationOptions))]
[JsonSerializable(typeof(Dictionary))]
+[JsonSerializable(typeof(AIFunctionArguments))]
[JsonSerializable(typeof(int[]))] // Used in ChatMessageContentTests
[JsonSerializable(typeof(Embedding))] // Used in EmbeddingTests
[JsonSerializable(typeof(Dictionary))] // Used in Content tests
diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs
index 3cc42ff0473..977c7608917 100644
--- a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAISerializationTests.cs
@@ -382,7 +382,7 @@ public static async Task RequestDeserialization_ToolCall()
Assert.Equal("The person whose age is being requested", (string)parameterSchema["description"]!);
Assert.Equal("string", (string)parameterSchema["type"]!);
- Dictionary functionArgs = new() { ["personName"] = "John" };
+ AIFunctionArguments functionArgs = new() { ["personName"] = "John" };
var ex = await Assert.ThrowsAsync(() => function.InvokeAsync(functionArgs));
Assert.Contains("does not support being invoked.", ex.Message);
}
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
index 8d069034e15..99fb2a4fdf0 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs
@@ -287,7 +287,7 @@ public async Task FunctionInvocationsLogged(LogLevel level)
};
Func configure = b =>
- b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>()));
+ b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService()));
await InvokeAsync(services => InvokeAndAssertAsync(options, plan, configurePipeline: configure, services: services));
@@ -606,6 +606,32 @@ public async Task PropagatesResponseChatThreadIdToOptions()
Assert.Equal("done!", (await service.GetStreamingResponseAsync("hey", options).ToChatResponseAsync()).ToString());
}
+ [Fact]
+ public async Task FunctionInvocations_PassesServices()
+ {
+ List plan =
+ [
+ new ChatMessage(ChatRole.User, "hello"),
+ new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]),
+ new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
+ new ChatMessage(ChatRole.Assistant, "world"),
+ ];
+
+ ServiceCollection c = new();
+ IServiceProvider expected = c.BuildServiceProvider();
+
+ var options = new ChatOptions
+ {
+ Tools = [AIFunctionFactory.Create((IServiceProvider actual) =>
+ {
+ Assert.Same(expected, actual);
+ return "Result 1";
+ }, "Func1")]
+ };
+
+ await InvokeAndAssertAsync(options, plan, services: expected);
+ }
+
private static async Task> InvokeAndAssertAsync(
ChatOptions options,
List plan,
diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs
index dc104ea6be6..9373d66a804 100644
--- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs
+++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs
@@ -8,8 +8,11 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.Extensions.DependencyInjection;
using Xunit;
+#pragma warning disable S107 // Methods should not have too many parameters
+
namespace Microsoft.Extensions.AI;
public class AIFunctionFactoryTest
@@ -30,13 +33,13 @@ public async Task Parameters_MappedByName_Async()
AIFunction func;
func = AIFunctionFactory.Create((string a) => a + " " + a);
- AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync([new KeyValuePair("a", "test")]));
+ AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync(new() { ["a"] = "test" }));
func = AIFunctionFactory.Create((string a, string b) => b + " " + a);
- AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync([new KeyValuePair("b", "hello"), new KeyValuePair("a", "world")]));
+ AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync(new() { ["b"] = "hello", ["a"] = "world" }));
func = AIFunctionFactory.Create((int a, long b) => a + b);
- AssertExtensions.EqualFunctionCallResults(3L, await func.InvokeAsync([new KeyValuePair("a", 1), new KeyValuePair("b", 2L)]));
+ AssertExtensions.EqualFunctionCallResults(3L, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
}
[Fact]
@@ -44,7 +47,7 @@ public async Task Parameters_DefaultValuesAreUsedButOverridable_Async()
{
AIFunction func = AIFunctionFactory.Create((string a = "test") => a + " " + a);
AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync());
- AssertExtensions.EqualFunctionCallResults("hello hello", await func.InvokeAsync([new KeyValuePair("a", "hello")]));
+ AssertExtensions.EqualFunctionCallResults("hello hello", await func.InvokeAsync(new() { ["a"] = "hello" }));
}
[Fact]
@@ -90,23 +93,23 @@ public async Task Returns_AsyncReturnTypesSupported_Async()
AIFunction func;
func = AIFunctionFactory.Create(Task (string a) => Task.FromResult(a + " " + a));
- AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync([new KeyValuePair("a", "test")]));
+ AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync(new() { ["a"] = "test" }));
func = AIFunctionFactory.Create(ValueTask (string a, string b) => new ValueTask(b + " " + a));
- AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync([new KeyValuePair("b", "hello"), new KeyValuePair("a", "world")]));
+ AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync(new() { ["b"] = "hello", ["a"] = "world" }));
long result = 0;
func = AIFunctionFactory.Create(async Task (int a, long b) => { result = a + b; await Task.Yield(); });
- AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync([new KeyValuePair("a", 1), new KeyValuePair("b", 2L)]));
+ AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
Assert.Equal(3, result);
result = 0;
func = AIFunctionFactory.Create(async ValueTask (int a, long b) => { result = a + b; await Task.Yield(); });
- AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync([new KeyValuePair("a", 1), new KeyValuePair("b", 2L)]));
+ AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new() { ["a"] = 1, ["b"] = 2L }));
Assert.Equal(3, result);
func = AIFunctionFactory.Create((int count) => SimpleIAsyncEnumerable(count));
- AssertExtensions.EqualFunctionCallResults(new int[] { 0, 1, 2, 3, 4 }, await func.InvokeAsync([new("count", 5)]));
+ AssertExtensions.EqualFunctionCallResults(new int[] { 0, 1, 2, 3, 4 }, await func.InvokeAsync(new() { ["count"] = 5 }));
static async IAsyncEnumerable SimpleIAsyncEnumerable(int count)
{
@@ -208,7 +211,7 @@ public async Task AIFunctionFactoryOptions_SupportsSkippingParameters()
Assert.DoesNotContain("firstParameter", func.JsonSchema.ToString());
Assert.Contains("secondParameter", func.JsonSchema.ToString());
- JsonElement? result = (JsonElement?)await func.InvokeAsync(new Dictionary
+ JsonElement? result = (JsonElement?)await func.InvokeAsync(new()
{
["firstParameter"] = "test",
["secondParameter"] = 42
@@ -216,4 +219,75 @@ public async Task AIFunctionFactoryOptions_SupportsSkippingParameters()
Assert.NotNull(result);
Assert.Contains("test42", result.ToString());
}
+
+ [Fact]
+ public async Task AIFunctionArguments_SatisfiesParameters()
+ {
+ ServiceCollection sc = new();
+ IServiceProvider sp = sc.BuildServiceProvider();
+
+ AIFunctionArguments arguments = new() { ["myInteger"] = 42 };
+
+ AIFunction func = AIFunctionFactory.Create((
+ int myInteger,
+ IServiceProvider services1,
+ IServiceProvider services2,
+ AIFunctionArguments arguments1,
+ AIFunctionArguments arguments2,
+ IServiceProvider? services3,
+ AIFunctionArguments? arguments3,
+ IServiceProvider? services4 = null,
+ AIFunctionArguments? arguments4 = null) =>
+ {
+ Assert.Same(sp, services1);
+ Assert.Same(sp, services2);
+ Assert.Same(sp, services3);
+ Assert.Same(sp, services4);
+
+ Assert.Same(arguments, arguments1);
+ Assert.Same(arguments, arguments2);
+ Assert.Same(arguments, arguments3);
+ Assert.Same(arguments, arguments4);
+
+ return myInteger;
+ });
+
+ Assert.Contains("myInteger", func.JsonSchema.ToString());
+ Assert.DoesNotContain("services", func.JsonSchema.ToString());
+ Assert.DoesNotContain("arguments", func.JsonSchema.ToString());
+
+ await Assert.ThrowsAsync("arguments", () => func.InvokeAsync(arguments));
+
+ arguments.Services = sp;
+ var result = await func.InvokeAsync(arguments);
+
+ Assert.Contains("42", result?.ToString());
+ }
+
+ [Fact]
+ public async Task AIFunctionArguments_MissingServicesMayBeOptional()
+ {
+ ServiceCollection sc = new();
+ IServiceProvider sp = sc.BuildServiceProvider();
+
+ AIFunction func = AIFunctionFactory.Create((
+ int? myInteger = null,
+ AIFunctionArguments? arguments = null,
+ IServiceProvider? services = null) =>
+ {
+ Assert.NotNull(arguments);
+ Assert.Null(services);
+ return myInteger;
+ });
+
+ Assert.Contains("myInteger", func.JsonSchema.ToString());
+ Assert.DoesNotContain("services", func.JsonSchema.ToString());
+ Assert.DoesNotContain("arguments", func.JsonSchema.ToString());
+
+ var result = await func.InvokeAsync(new() { ["myInteger"] = 42 });
+ Assert.Contains("42", result?.ToString());
+
+ result = await func.InvokeAsync();
+ Assert.Equal("", result?.ToString());
+ }
}