diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 13758f1a751..27695a42112 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -273,6 +273,10 @@ Prompt buildRequestPrompt(Prompt prompt) { if (prompt.getOptions() instanceof BedrockChatOptions bedrockChatOptions) { runtimeOptions = bedrockChatOptions.copy(); } + else if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + BedrockChatOptions.class); + } else { runtimeOptions = from(prompt.getOptions()); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java index bd3e07c1d77..2b2361cba03 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java @@ -46,6 +46,7 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -279,6 +280,29 @@ void functionCallTest() { assertThat(generation.getOutput().getText()).contains("30", "10", "15"); } + @Test + void functionCallTestWithToolCallingOptions() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = ToolCallingChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location. Return in 36°C format") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + Generation generation = response.getResult(); + assertThat(generation.getOutput().getText()).contains("30", "10", "15"); + } + @Test void streamFunctionCallTest() {