Skip to content

Commit 18a349d

Browse files
committed
Address feedback
1 parent abdcc83 commit 18a349d

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,6 @@ static bool IsAsyncMethod(MethodInfo method)
333333
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
334334
Type parameterType = parameter.ParameterType;
335335
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
336-
FromServiceProviderAttribute? fspAttr = parameter.GetCustomAttribute<FromServiceProviderAttribute>(inherit: true);
337336

338337
// For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync.
339338
if (parameterType == typeof(CancellationToken))
@@ -343,28 +342,40 @@ static bool IsAsyncMethod(MethodInfo method)
343342
cancellationToken;
344343
}
345344

346-
// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
347-
return (arguments, _) =>
345+
// For DI-based parameters, try to resolve from the service provider.
346+
if (parameter.GetCustomAttribute<FromServiceProviderAttribute>(inherit: true) is FromServiceProviderAttribute fspAttr)
348347
{
349-
// If the parameter is [FromServiceProvider], try to satisfy it from the service provider
350-
// provided via arguments.
351-
if (fspAttr is not null &&
352-
(arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services)
348+
return (arguments, _) =>
353349
{
354-
if (fspAttr.ServiceKey is object serviceKey)
350+
if ((arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services)
355351
{
356-
if (services is IKeyedServiceProvider ksp &&
357-
ksp.GetKeyedService(parameterType, serviceKey) is object keyedService)
352+
if (fspAttr.ServiceKey is object serviceKey)
353+
{
354+
if ((services as IKeyedServiceProvider)?.GetKeyedService(parameterType, serviceKey) is object keyedService)
355+
{
356+
return keyedService;
357+
}
358+
}
359+
else if (services.GetService(parameterType) is object service)
358360
{
359-
return keyedService;
361+
return service;
360362
}
361363
}
362-
else if (services.GetService(parameterType) is object service)
364+
365+
// No service could be resolved. Does it have a default value?
366+
if (parameter.HasDefaultValue)
363367
{
364-
return service;
368+
return parameter.DefaultValue;
365369
}
366-
}
367370

371+
// It's a required argument, and we couldn't resolve a service. Throw.
372+
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' for parameter '{parameter.Name}'.");
373+
};
374+
}
375+
376+
// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
377+
return (arguments, _) =>
378+
{
368379
// If the parameter has an argument specified in the dictionary, return that argument.
369380
if (arguments.TryGetValue(parameter.Name, out object? value))
370381
{

0 commit comments

Comments
 (0)