1717package org .springframework .ai .model .function ;
1818
1919import java .util .function .BiFunction ;
20+ import java .util .function .Consumer ;
2021import java .util .function .Function ;
22+ import java .util .function .Supplier ;
2123
2224import com .fasterxml .jackson .annotation .JsonClassDescription ;
25+ import kotlin .jvm .functions .Function0 ;
2326import kotlin .jvm .functions .Function1 ;
2427import kotlin .jvm .functions .Function2 ;
2528
3033import org .springframework .context .annotation .Description ;
3134import org .springframework .context .support .GenericApplicationContext ;
3235import org .springframework .core .KotlinDetector ;
36+ import org .springframework .core .ParameterizedTypeReference ;
3337import org .springframework .core .ResolvableType ;
3438import org .springframework .lang .NonNull ;
3539import org .springframework .lang .Nullable ;
3842/**
3943 * A Spring {@link ApplicationContextAware} implementation that provides a way to retrieve
4044 * a {@link Function} from the Spring context and wrap it into a {@link FunctionCallback}.
41- *
45+ * <p>
4246 * The name of the function is determined by the bean name.
43- *
47+ * <p>
4448 * The description of the function is determined by the following rules:
4549 * <ul>
4650 * <li>Provided as a default description</li>
@@ -69,24 +73,28 @@ public void setApplicationContext(@NonNull ApplicationContext applicationContext
6973
7074 @ SuppressWarnings ({ "unchecked" })
7175 public FunctionCallback getFunctionCallback (@ NonNull String beanName , @ Nullable String defaultDescription ) {
72-
7376 ResolvableType functionType = TypeResolverHelper .resolveBeanType (this .applicationContext , beanName );
74- ResolvableType functionInputType = TypeResolverHelper .getFunctionArgumentType (functionType , 0 );
77+ ResolvableType functionInputType = (ResolvableType .forType (Supplier .class ).isAssignableFrom (functionType ))
78+ ? ResolvableType .forType (Void .class ) : TypeResolverHelper .getFunctionArgumentType (functionType , 0 );
79+
80+ String functionDescription = resolveFunctionDescription (beanName , defaultDescription ,
81+ functionInputType .toClass ());
82+ Object bean = this .applicationContext .getBean (beanName );
83+
84+ return buildFunctionCallback (beanName , functionType , functionInputType , functionDescription , bean );
85+ }
7586
76- Class <?> functionInputClass = functionInputType . toClass ();
87+ private String resolveFunctionDescription ( String beanName , String defaultDescription , Class <?> functionInputClass ) {
7788 String functionDescription = defaultDescription ;
7889
7990 if (!StringUtils .hasText (functionDescription )) {
80- // Look for a Description annotation on the bean
8191 Description descriptionAnnotation = this .applicationContext .findAnnotationOnBean (beanName ,
8292 Description .class );
83-
8493 if (descriptionAnnotation != null ) {
8594 functionDescription = descriptionAnnotation .value ();
8695 }
8796
8897 if (!StringUtils .hasText (functionDescription )) {
89- // Look for a JsonClassDescription annotation on the input class
9098 JsonClassDescription jsonClassDescriptionAnnotation = functionInputClass
9199 .getAnnotation (JsonClassDescription .class );
92100 if (jsonClassDescriptionAnnotation != null ) {
@@ -95,51 +103,79 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
95103 }
96104
97105 if (!StringUtils .hasText (functionDescription )) {
98- throw new IllegalStateException ("Could not determine function description."
106+ throw new IllegalStateException ("Could not determine function description. "
99107 + "Please provide a description either as a default parameter, via @Description annotation on the bean "
100108 + "or @JsonClassDescription annotation on the input class." );
101109 }
102110 }
103111
104- Object bean = this .applicationContext .getBean (beanName );
112+ return functionDescription ;
113+ }
114+
115+ private FunctionCallback buildFunctionCallback (String beanName , ResolvableType functionType ,
116+ ResolvableType functionInputType , String functionDescription , Object bean ) {
105117
106118 if (KotlinDetector .isKotlinPresent ()) {
107119 if (KotlinDelegate .isKotlinFunction (functionType .toClass ())) {
108120 return FunctionCallback .builder ()
109121 .schemaType (this .schemaType )
110122 .description (functionDescription )
111123 .function (beanName , KotlinDelegate .wrapKotlinFunction (bean ))
112- .inputType (functionInputClass )
124+ .inputType (ParameterizedTypeReference . forType ( functionInputType . getType ()) )
113125 .build ();
114126 }
115- else if (KotlinDelegate .isKotlinBiFunction (functionType .toClass ())) {
127+ if (KotlinDelegate .isKotlinBiFunction (functionType .toClass ())) {
116128 return FunctionCallback .builder ()
117129 .description (functionDescription )
118130 .schemaType (this .schemaType )
119131 .function (beanName , KotlinDelegate .wrapKotlinBiFunction (bean ))
120- .inputType (functionInputClass )
132+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
133+ .build ();
134+ }
135+ if (KotlinDelegate .isKotlinSupplier (functionType .toClass ())) {
136+ return FunctionCallback .builder ()
137+ .description (functionDescription )
138+ .schemaType (this .schemaType )
139+ .function (beanName , KotlinDelegate .wrapKotlinSupplier (bean ))
140+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
121141 .build ();
122142 }
123143 }
144+
124145 if (bean instanceof Function <?, ?> function ) {
125146 return FunctionCallback .builder ()
126147 .schemaType (this .schemaType )
127148 .description (functionDescription )
128149 .function (beanName , function )
129- .inputType (functionInputClass )
150+ .inputType (ParameterizedTypeReference . forType ( functionInputType . getType ()) )
130151 .build ();
131152 }
132- else if (bean instanceof BiFunction <?, ?, ?>) {
153+ if (bean instanceof BiFunction <?, ?, ?>) {
133154 return FunctionCallback .builder ()
134155 .description (functionDescription )
135156 .schemaType (this .schemaType )
136157 .function (beanName , (BiFunction <?, ToolContext , ?>) bean )
137- .inputType (functionInputClass )
158+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
159+ .build ();
160+ }
161+ if (bean instanceof Supplier <?> supplier ) {
162+ return FunctionCallback .builder ()
163+ .description (functionDescription )
164+ .schemaType (this .schemaType )
165+ .function (beanName , supplier )
166+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
138167 .build ();
139168 }
140- else {
141- throw new IllegalStateException ();
169+ if (bean instanceof Consumer <?> consumer ) {
170+ return FunctionCallback .builder ()
171+ .description (functionDescription )
172+ .schemaType (this .schemaType )
173+ .function (beanName , consumer )
174+ .inputType (ParameterizedTypeReference .forType (functionInputType .getType ()))
175+ .build ();
142176 }
177+
178+ throw new IllegalStateException ("Unsupported function type" );
143179 }
144180
145181 public enum SchemaType {
@@ -148,7 +184,16 @@ public enum SchemaType {
148184
149185 }
150186
151- private static class KotlinDelegate {
187+ private static final class KotlinDelegate {
188+
189+ public static boolean isKotlinSupplier (Class <?> clazz ) {
190+ return Function0 .class .isAssignableFrom (clazz );
191+ }
192+
193+ @ SuppressWarnings ("unchecked" )
194+ public static Supplier <?> wrapKotlinSupplier (Object function ) {
195+ return () -> ((Function0 <Object >) function ).invoke ();
196+ }
152197
153198 public static boolean isKotlinFunction (Class <?> clazz ) {
154199 return Function1 .class .isAssignableFrom (clazz );
0 commit comments