Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Reflection;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Collections;

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -56,19 +54,15 @@ public abstract class AIFunction : AITool
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The result of the function's execution.</returns>
public Task<object?> InvokeAsync(
IEnumerable<KeyValuePair<string, object?>>? arguments = null,
CancellationToken cancellationToken = default)
{
arguments ??= EmptyReadOnlyDictionary<string, object?>.Instance;

return InvokeCoreAsync(arguments, cancellationToken);
}
AIFunctionArguments? arguments = null,
CancellationToken cancellationToken = default) =>
InvokeCoreAsync(arguments ?? [], cancellationToken);

/// <summary>Invokes the <see cref="AIFunction"/> and returns its result.</summary>
/// <param name="arguments">The arguments to pass to the function's invocation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
/// <returns>The result of the function's execution.</returns>
protected abstract Task<object?> InvokeCoreAsync(
IEnumerable<KeyValuePair<string, object?>> arguments,
AIFunctionArguments arguments,
CancellationToken cancellationToken);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections;
using System.Collections.Generic;

#pragma warning disable SA1111 // Closing parenthesis should be on line of last parameter
#pragma warning disable SA1112 // Closing parenthesis should be on line of opening parenthesis
#pragma warning disable SA1114 // Parameter list should follow declaration
#pragma warning disable CA1710 // Identifiers should have correct suffix

namespace Microsoft.Extensions.AI;

/// <summary>Represents arguments to be used with <see cref="AIFunction.InvokeAsync"/>.</summary>
/// <remarks>
/// <see cref="AIFunctionArguments"/> is a dictionary of name/value pairs that are used
/// as inputs to an <see cref="AIFunction"/>. However, an instance carries additional non-nominal
/// information, such as an optional <see cref="IServiceProvider"/> that can be used by
/// an <see cref="AIFunction"/> if it needs to resolve any services from a dependency injection
/// container.
/// </remarks>
public sealed class AIFunctionArguments : IDictionary<string, object?>, IReadOnlyDictionary<string, object?>
{
/// <summary>The nominal arguments.</summary>
private readonly Dictionary<string, object?> _arguments;

/// <summary>Initializes a new instance of the <see cref="AIFunctionArguments"/> class.</summary>
public AIFunctionArguments()
{
_arguments = [];
}

/// <summary>
/// Initializes a new instance of the <see cref="AIFunctionArguments"/> class containing
/// the specified <paramref name="arguments"/>.
/// </summary>
/// <param name="arguments">The arguments represented by this instance.</param>
/// <remarks>
/// The <paramref name="arguments"/> reference will be stored if the instance is
/// already a <see cref="Dictionary{TKey, TValue}"/>, in which case all dictionary
/// operations on this instance will be routed directly to that instance. If <paramref name="arguments"/>
/// is not a dictionary, a shallow clone of its data will be used to populate this
/// instance. A <see langword="null"/> <paramref name="arguments"/> is treated as an
/// empty dictionary.
/// </remarks>
public AIFunctionArguments(IDictionary<string, object?>? arguments)
{
_arguments =
arguments is null ? [] :
arguments as Dictionary<string, object?> ??
new Dictionary<string, object?>(arguments);
}

/// <summary>Gets or sets services optionally associated with these arguments.</summary>
public IServiceProvider? Services { get; set; }

/// <inheritdoc />
public object? this[string key]
{
get => _arguments[key];
set => _arguments[key] = value;
}

/// <inheritdoc />
public ICollection<string> Keys => _arguments.Keys;

/// <inheritdoc />
public ICollection<object?> Values => _arguments.Values;

/// <inheritdoc />
public int Count => _arguments.Count;

/// <inheritdoc />
bool ICollection<KeyValuePair<string, object?>>.IsReadOnly => false;

/// <inheritdoc />
IEnumerable<string> IReadOnlyDictionary<string, object?>.Keys => Keys;

/// <inheritdoc />
IEnumerable<object?> IReadOnlyDictionary<string, object?>.Values => Values;

/// <inheritdoc />
public void Add(string key, object? value) => _arguments.Add(key, value);

/// <inheritdoc />
void ICollection<KeyValuePair<string, object?>>.Add(KeyValuePair<string, object?> item) =>
((ICollection<KeyValuePair<string, object?>>)_arguments).Add(item);

/// <inheritdoc />
public void Clear() => _arguments.Clear();

/// <inheritdoc />
bool ICollection<KeyValuePair<string, object?>>.Contains(KeyValuePair<string, object?> item) =>
((ICollection<KeyValuePair<string, object?>>)_arguments).Contains(item);

/// <inheritdoc />
public bool ContainsKey(string key) => _arguments.ContainsKey(key);

/// <inheritdoc />
public void CopyTo(KeyValuePair<string, object?>[] array, int arrayIndex) =>
((ICollection<KeyValuePair<string, object?>>)_arguments).CopyTo(array, arrayIndex);

/// <inheritdoc />
public IEnumerator<KeyValuePair<string, object?>> GetEnumerator() => _arguments.GetEnumerator();

/// <inheritdoc />
public bool Remove(string key) => _arguments.Remove(key);

/// <inheritdoc />
bool ICollection<KeyValuePair<string, object?>>.Remove(KeyValuePair<string, object?> item) =>
((ICollection<KeyValuePair<string, object?>>)_arguments).Remove(item);

/// <inheritdoc />
public bool TryGetValue(string key, out object? value) => _arguments.TryGetValue(key, out value);

/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Microsoft.Extensions.AI;
/// <summary>
/// Provides options for configuring the behavior of <see cref="AIJsonUtilities"/> JSON schema creation functionality.
/// </summary>
public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOptions>
public sealed record class AIJsonSchemaCreateOptions
{
/// <summary>
/// Gets the default options instance.
Expand Down Expand Up @@ -56,26 +56,4 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
/// Gets a value indicating whether to mark all properties as required in the schema.
/// </summary>
public bool RequireAllProperties { get; init; } = true;

/// <inheritdoc/>
public bool Equals(AIJsonSchemaCreateOptions? other) =>
other is not null &&
TransformSchemaNode == other.TransformSchemaNode &&
IncludeParameter == other.IncludeParameter &&
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
RequireAllProperties == other.RequireAllProperties;

/// <inheritdoc />
public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other);

/// <inheritdoc />
public override int GetHashCode() =>
(TransformSchemaNode,
IncludeParameter,
IncludeTypeInEnumSchemas,
DisallowAdditionalProperties,
IncludeSchemaKeyword,
RequireAllProperties).GetHashCode();
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(Embedding<float>))]
[JsonSerializable(typeof(Embedding<double>))]
[JsonSerializable(typeof(AIContent))]
[JsonSerializable(typeof(AIFunctionArguments))]
[EditorBrowsable(EditorBrowsableState.Never)] // Never use JsonContext directly, use DefaultOptions instead.
private sealed partial class JsonContext : JsonSerializerContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ private sealed class MetadataOnlyAIFunction(string name, string description, Jso
public override string Description => description;
public override JsonElement JsonSchema => schema;
public override IReadOnlyDictionary<string, object?> AdditionalProperties => additionalProps;
protected override Task<object?> InvokeCoreAsync(IEnumerable<KeyValuePair<string, object?>> arguments, CancellationToken cancellationToken) =>
protected override Task<object?> InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) =>
throw new InvalidOperationException($"The AI function '{Name}' does not support being invoked.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public static ConversationFunctionTool ToConversationFunctionTool(this AIFunctio
/// <param name="tools">The available tools.</param>
/// <param name="detailedErrors">An optional flag specifying whether to disclose detailed exception information to the model. The default value is <see langword="false"/>.</param>
/// <param name="jsonSerializerOptions">An optional <see cref="JsonSerializerOptions"/> that controls JSON handling.</param>
/// <param name="functionInvocationServices">An optional <see cref="IServiceProvider"/> to use for resolving services required by <see cref="AIFunction"/> instances being invoked.</param>
/// <param name="cancellationToken">An optional <see cref="CancellationToken"/>.</param>
/// <returns>A <see cref="Task"/> that represents the completion of processing, including invoking any asynchronous tools.</returns>
/// <exception cref="ArgumentNullException"><paramref name="session"/> is <see langword="null"/>.</exception>
Expand All @@ -63,6 +64,7 @@ public static async Task HandleToolCallsAsync(
IReadOnlyList<AIFunction> tools,
bool? detailedErrors = false,
JsonSerializerOptions? jsonSerializerOptions = null,
IServiceProvider? functionInvocationServices = null,
CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(session);
Expand All @@ -73,7 +75,7 @@ public static async Task HandleToolCallsAsync(
{
// If we need to call a tool to update the model, do so
if (!string.IsNullOrEmpty(itemFinished.FunctionName)
&& await itemFinished.GetFunctionCallOutputAsync(tools, detailedErrors, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) is { } output)
&& await itemFinished.GetFunctionCallOutputAsync(tools, detailedErrors, jsonSerializerOptions, functionInvocationServices, cancellationToken).ConfigureAwait(false) is { } output)
{
await session.AddItemAsync(output, cancellationToken).ConfigureAwait(false);
}
Expand All @@ -93,6 +95,7 @@ public static async Task HandleToolCallsAsync(
IReadOnlyList<AIFunction> tools,
bool? detailedErrors = false,
JsonSerializerOptions? jsonSerializerOptions = null,
IServiceProvider? functionInvocationServices = null,
CancellationToken cancellationToken = default)
{
if (!string.IsNullOrEmpty(update.FunctionName)
Expand All @@ -107,7 +110,7 @@ public static async Task HandleToolCallsAsync(

try
{
var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, cancellationToken).ConfigureAwait(false);
var result = await aiFunction.InvokeAsync(new(functionCallContent.Arguments) { Services = functionInvocationServices }, cancellationToken).ConfigureAwait(false);
var resultJson = JsonSerializer.Serialize(result, jsonOptions.GetTypeInfo(typeof(object)));
return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, resultJson);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,34 @@ public sealed class FunctionInvocationContext
private AIFunction _function = _nopFunction;

/// <summary>The function call content information associated with this invocation.</summary>
private FunctionCallContent _callContent = new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary<string, object?>.Instance);
private FunctionCallContent? _callContent;

/// <summary>The arguments used with the function.</summary>
private AIFunctionArguments? _arguments;

/// <summary>Initializes a new instance of the <see cref="FunctionInvocationContext"/> class.</summary>
public FunctionInvocationContext()
{
}

/// <summary>Gets or sets the AI function to be invoked.</summary>
public AIFunction Function
{
get => _function;
set => _function = Throw.IfNull(value);
}

/// <summary>Gets or sets the arguments associated with this invocation.</summary>
public AIFunctionArguments Arguments
{
get => _arguments ??= [];
set => _arguments = Throw.IfNull(value);
}

/// <summary>Gets or sets the function call content information associated with this invocation.</summary>
public FunctionCallContent CallContent
{
get => _callContent;
get => _callContent ??= new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary<string, object?>.Instance);
set => _callContent = Throw.IfNull(value);
}

Expand All @@ -48,13 +65,6 @@ public IList<ChatMessage> Messages
/// <summary>Gets or sets the chat options associated with the operation that initiated this function call request.</summary>
public ChatOptions? Options { get; set; }

/// <summary>Gets or sets the AI function to be invoked.</summary>
public AIFunction Function
{
get => _function;
set => _function = Throw.IfNull(value);
}

/// <summary>Gets or sets the number of this iteration with the underlying client.</summary>
/// <remarks>
/// The initial request to the client that passes along the chat contents provided to the <see cref="FunctionInvokingChatClient"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
/// <summary>The <see cref="FunctionInvocationContext"/> for the current function invocation.</summary>
private static readonly AsyncLocal<FunctionInvocationContext?> _currentContext = new();

/// <summary>Optional services used for function invocation.</summary>
private readonly IServiceProvider? _functionInvocationServices;

/// <summary>The logger to use for logging information about function invocation.</summary>
private readonly ILogger _logger;

Expand All @@ -62,12 +65,14 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
/// Initializes a new instance of the <see cref="FunctionInvokingChatClient"/> class.
/// </summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
/// <param name="logger">An <see cref="ILogger"/> to use for logging information about function invocation.</param>
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
/// <param name="loggerFactory">An <see cref="ILoggerFactory"/> to use for logging information about function invocation.</param>
/// <param name="functionInvocationServices">An optional <see cref="IServiceProvider"/> to use for resolving services required by the <see cref="AIFunction"/> instances being invoked.</param>
public FunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null)
: base(innerClient)
{
_logger = logger ?? NullLogger.Instance;
_logger = (ILogger?)loggerFactory?.CreateLogger<FunctionInvokingChatClient>() ?? NullLogger.Instance;
_activitySource = innerClient.GetService<ActivitySource>();
_functionInvocationServices = functionInvocationServices;
}

/// <summary>
Expand Down Expand Up @@ -601,10 +606,13 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(

FunctionInvocationContext context = new()
{
Function = function,
Arguments = new(callContent.Arguments) { Services = _functionInvocationServices },

Messages = messages,
Options = options,

CallContent = callContent,
Function = function,
Iteration = iteration,
FunctionCallIndex = functionCallIndex,
FunctionCount = callContents.Count,
Expand Down Expand Up @@ -710,7 +718,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
startingTimestamp = Stopwatch.GetTimestamp();
if (_logger.IsEnabled(LogLevel.Trace))
{
LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.CallContent.Arguments, context.Function.JsonSerializerOptions));
LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.Arguments, context.Function.JsonSerializerOptions));
}
else
{
Expand All @@ -721,8 +729,8 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
object? result = null;
try
{
CurrentContext = context;
result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit
result = await context.Function.InvokeAsync(context.Arguments, cancellationToken).ConfigureAwait(false);
}
catch (Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public static ChatClientBuilder UseFunctionInvocation(
{
loggerFactory ??= services.GetService<ILoggerFactory>();

var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)));
var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory, services);
configure?.Invoke(chatClient);
return chatClient;
});
Expand Down
Loading
Loading