diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index a519af483b2..4b72db7d579 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -617,16 +617,16 @@ private List buildMessages(Prompt prompt, CacheEligibilityReso List contentBlocks = new ArrayList<>(); String content = message.getText(); // For conversation history caching, apply cache control to the - // message immediately before the last user message. - boolean isPenultimateUserMessage = (lastUserIndex > 0) && (i == lastUserIndex - 1); + // last user message to cache the entire conversation up to that point. + boolean isLastUserMessage = (lastUserIndex >= 0) && (i == lastUserIndex); ContentBlock contentBlock = new ContentBlock(content); - if (isPenultimateUserMessage && cacheEligibilityResolver.isCachingEnabled()) { - // Combine text from all user messages except the last one (current - // question) - // as the basis for cache eligibility checks - String combinedUserMessagesText = combineEligibleUserMessagesText(allMessages, lastUserIndex); + if (isLastUserMessage && cacheEligibilityResolver.isCachingEnabled()) { + // Combine text from all messages (user, assistant, tool) up to and + // including the last user message as the basis for cache eligibility + // checks + String combinedMessagesText = combineEligibleMessagesText(allMessages, lastUserIndex); contentBlocks.add(cacheAwareContentBlock(contentBlock, messageType, cacheEligibilityResolver, - combinedUserMessagesText)); + combinedMessagesText)); } else { contentBlocks.add(contentBlock); @@ -675,19 +675,21 @@ else if (messageType == MessageType.TOOL) { return result; } - private String combineEligibleUserMessagesText(List userMessages, int lastUserIndex) { - List userMessagesForEligibility = new ArrayList<>(); + private String combineEligibleMessagesText(List allMessages, int lastUserIndex) { // Only 20 content blocks are considered by anthropic, so limit the number of - // message content to consider - int startIndex = Math.max(0, lastUserIndex - 20); - for (int i = startIndex; i < lastUserIndex; i++) { - Message message = userMessages.get(i); - if (message.getMessageType() == MessageType.USER) { - userMessagesForEligibility.add(message); + // message content to consider. We include all message types (user, assistant, + // tool) + // up to and including the last user message for aggregate eligibility checking. + int startIndex = Math.max(0, lastUserIndex - 19); + int endIndex = Math.min(allMessages.size(), lastUserIndex + 1); + StringBuilder sb = new StringBuilder(); + for (int i = startIndex; i < endIndex; i++) { + Message message = allMessages.get(i); + String text = message.getText(); + if (StringUtils.hasText(text)) { + sb.append(text); } } - StringBuilder sb = new StringBuilder(); - userMessagesForEligibility.stream().map(Message::getText).filter(StringUtils::hasText).forEach(sb::append); return sb.toString(); } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicPromptCachingIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicPromptCachingIT.java index 3ed96e5a17d..0adf989c284 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicPromptCachingIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicPromptCachingIT.java @@ -280,17 +280,20 @@ void shouldRespectMinLengthForSystemCaching() { @Test void shouldRespectMinLengthForUserHistoryCaching() { - // Two-user-message prompt; only the first (history tail) is eligible. + // Two-user-message prompt; aggregate length check applies String userMessage = loadPrompt("system-only-cache-prompt.txt"); - List messages = List.of(new UserMessage(userMessage), - new UserMessage("Please answer this question succinctly")); + String secondUserMessage = "Please answer this question succinctly"; + List messages = List.of(new UserMessage(userMessage), new UserMessage(secondUserMessage)); + + // Calculate combined length of both messages for aggregate checking + int combinedLength = userMessage.length() + secondUserMessage.length(); - // Set USER min length high so caching should not apply + // Set USER min length higher than combined length so caching should not apply AnthropicChatOptions noCacheOptions = AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_SONNET_4_0.getValue()) .cacheOptions(AnthropicCacheOptions.builder() .strategy(AnthropicCacheStrategy.CONVERSATION_HISTORY) - .messageTypeMinContentLength(MessageType.USER, userMessage.length() + 1) + .messageTypeMinContentLength(MessageType.USER, combinedLength + 1) .build()) .maxTokens(80) .temperature(0.2) @@ -303,12 +306,12 @@ void shouldRespectMinLengthForUserHistoryCaching() { assertThat(noCacheUsage.cacheCreationInputTokens()).isEqualTo(0); assertThat(noCacheUsage.cacheReadInputTokens()).isEqualTo(0); - // Now allow caching by lowering the USER min length + // Now allow caching by lowering the USER min length below combined length AnthropicChatOptions cacheOptions = AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_SONNET_4_0.getValue()) .cacheOptions(AnthropicCacheOptions.builder() .strategy(AnthropicCacheStrategy.CONVERSATION_HISTORY) - .messageTypeMinContentLength(MessageType.USER, userMessage.length() - 1) + .messageTypeMinContentLength(MessageType.USER, combinedLength - 1) .build()) .maxTokens(80) .temperature(0.2) @@ -319,20 +322,20 @@ void shouldRespectMinLengthForUserHistoryCaching() { AnthropicApi.Usage cacheUsage = getAnthropicUsage(cacheResponse); assertThat(cacheUsage).isNotNull(); assertThat(cacheUsage.cacheCreationInputTokens()) - .as("Expect some cache creation tokens when USER history tail is cached") + .as("Expect some cache creation tokens when aggregate content meets min length") .isGreaterThan(0); } @Test - void shouldRespectAllButLastUserMessageForUserHistoryCaching() { - // Three-user-message prompt; only the first (history tail) is eligible. + void shouldApplyCacheControlToLastUserMessageForConversationHistory() { + // Three-user-message prompt; the last user message will have cache_control. String userMessage = loadPrompt("system-only-cache-prompt.txt"); List messages = List.of(new UserMessage(userMessage), new UserMessage("Additional content to exceed min length"), new UserMessage("Please answer this question succinctly")); - // The combined length of the first two USER messages exceeds the min length, - // so caching should apply + // The combined length of all three USER messages (including the last) exceeds + // the min length, so caching should apply AnthropicChatOptions cacheOptions = AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_SONNET_4_0.getValue()) .cacheOptions(AnthropicCacheOptions.builder() @@ -450,4 +453,163 @@ void shouldHandleMultipleCacheStrategiesInSession() { } } + @Test + void shouldDemonstrateIncrementalCachingAcrossMultipleTurns() { + // This test demonstrates how caching grows incrementally with each turn + // NOTE: Anthropic requires 1024+ tokens for caching to activate + // We use a large system message to ensure we cross this threshold + + // Large system prompt to ensure we exceed 1024 token minimum for caching + String largeSystemPrompt = loadPrompt("system-only-cache-prompt.txt"); + + AnthropicChatOptions options = AnthropicChatOptions.builder() + .model(AnthropicApi.ChatModel.CLAUDE_SONNET_4_0.getValue()) + .cacheOptions(AnthropicCacheOptions.builder() + .strategy(AnthropicCacheStrategy.CONVERSATION_HISTORY) + // Disable min content length since we're using aggregate check + .messageTypeMinContentLength(MessageType.USER, 0) + .build()) + .maxTokens(200) + .temperature(0.3) + .build(); + + List conversationHistory = new ArrayList<>(); + // Add system message to provide enough tokens for caching + conversationHistory.add(new SystemMessage(largeSystemPrompt)); + + // Turn 1: Initial question + logger.info("\n=== TURN 1: Initial Question ==="); + conversationHistory.add(new UserMessage("What is quantum computing? Please explain the basics.")); + + ChatResponse turn1 = this.chatModel.call(new Prompt(conversationHistory, options)); + assertThat(turn1).isNotNull(); + String assistant1Response = turn1.getResult().getOutput().getText(); + conversationHistory.add(turn1.getResult().getOutput()); + + AnthropicApi.Usage usage1 = getAnthropicUsage(turn1); + assertThat(usage1).isNotNull(); + logger.info("Turn 1 - User: '{}'", conversationHistory.get(0).getText().substring(0, 50) + "..."); + logger.info("Turn 1 - Assistant: '{}'", + assistant1Response.substring(0, Math.min(100, assistant1Response.length())) + "..."); + logger.info("Turn 1 - Input tokens: {}", usage1.inputTokens()); + logger.info("Turn 1 - Cache creation tokens: {}", usage1.cacheCreationInputTokens()); + logger.info("Turn 1 - Cache read tokens: {}", usage1.cacheReadInputTokens()); + + // Note: First turn may not create cache if total tokens < 1024 (Anthropic's + // minimum) + // We'll track whether caching starts in turn 1 or later + boolean cachingStarted = usage1.cacheCreationInputTokens() > 0; + logger.info("Turn 1 - Caching started: {}", cachingStarted); + assertThat(usage1.cacheReadInputTokens()).as("Turn 1 should not read cache (no previous cache)").isEqualTo(0); + + // Turn 2: Follow-up question + logger.info("\n=== TURN 2: Follow-up Question ==="); + conversationHistory.add(new UserMessage("How does quantum entanglement work in this context?")); + + ChatResponse turn2 = this.chatModel.call(new Prompt(conversationHistory, options)); + assertThat(turn2).isNotNull(); + String assistant2Response = turn2.getResult().getOutput().getText(); + conversationHistory.add(turn2.getResult().getOutput()); + + AnthropicApi.Usage usage2 = getAnthropicUsage(turn2); + assertThat(usage2).isNotNull(); + logger.info("Turn 2 - User: '{}'", conversationHistory.get(2).getText()); + logger.info("Turn 2 - Assistant: '{}'", + assistant2Response.substring(0, Math.min(100, assistant2Response.length())) + "..."); + logger.info("Turn 2 - Input tokens: {}", usage2.inputTokens()); + logger.info("Turn 2 - Cache creation tokens: {}", usage2.cacheCreationInputTokens()); + logger.info("Turn 2 - Cache read tokens: {}", usage2.cacheReadInputTokens()); + + // Second turn: If caching started in turn 1, we should see cache reads + // Otherwise, caching might start here if we've accumulated enough tokens + if (cachingStarted) { + assertThat(usage2.cacheReadInputTokens()).as("Turn 2 should read cache from Turn 1").isGreaterThan(0); + } + // Update caching status + cachingStarted = cachingStarted || usage2.cacheCreationInputTokens() > 0; + + // Turn 3: Another follow-up + logger.info("\n=== TURN 3: Deeper Question ==="); + conversationHistory + .add(new UserMessage("Can you give me a practical example of quantum computing application?")); + + ChatResponse turn3 = this.chatModel.call(new Prompt(conversationHistory, options)); + assertThat(turn3).isNotNull(); + String assistant3Response = turn3.getResult().getOutput().getText(); + conversationHistory.add(turn3.getResult().getOutput()); + + AnthropicApi.Usage usage3 = getAnthropicUsage(turn3); + assertThat(usage3).isNotNull(); + logger.info("Turn 3 - User: '{}'", conversationHistory.get(4).getText()); + logger.info("Turn 3 - Assistant: '{}'", + assistant3Response.substring(0, Math.min(100, assistant3Response.length())) + "..."); + logger.info("Turn 3 - Input tokens: {}", usage3.inputTokens()); + logger.info("Turn 3 - Cache creation tokens: {}", usage3.cacheCreationInputTokens()); + logger.info("Turn 3 - Cache read tokens: {}", usage3.cacheReadInputTokens()); + + // Third turn: Should read cache if caching has started + if (cachingStarted) { + assertThat(usage3.cacheReadInputTokens()).as("Turn 3 should read cache if caching has started") + .isGreaterThan(0); + } + // Update caching status + cachingStarted = cachingStarted || usage3.cacheCreationInputTokens() > 0; + + // Turn 4: Final question + logger.info("\n=== TURN 4: Final Question ==="); + conversationHistory.add(new UserMessage("What are the limitations of current quantum computers?")); + + ChatResponse turn4 = this.chatModel.call(new Prompt(conversationHistory, options)); + assertThat(turn4).isNotNull(); + String assistant4Response = turn4.getResult().getOutput().getText(); + conversationHistory.add(turn4.getResult().getOutput()); + + AnthropicApi.Usage usage4 = getAnthropicUsage(turn4); + assertThat(usage4).isNotNull(); + logger.info("Turn 4 - User: '{}'", conversationHistory.get(6).getText()); + logger.info("Turn 4 - Assistant: '{}'", + assistant4Response.substring(0, Math.min(100, assistant4Response.length())) + "..."); + logger.info("Turn 4 - Input tokens: {}", usage4.inputTokens()); + logger.info("Turn 4 - Cache creation tokens: {}", usage4.cacheCreationInputTokens()); + logger.info("Turn 4 - Cache read tokens: {}", usage4.cacheReadInputTokens()); + + // Fourth turn: By now we should definitely have caching working + assertThat(cachingStarted).as("Caching should have started by turn 4").isTrue(); + if (cachingStarted) { + assertThat(usage4.cacheReadInputTokens()).as("Turn 4 should read cache").isGreaterThan(0); + } + + // Summary logging + logger.info("\n=== CACHING SUMMARY ==="); + logger.info("Turn 1 - Created: {}, Read: {}", usage1.cacheCreationInputTokens(), usage1.cacheReadInputTokens()); + logger.info("Turn 2 - Created: {}, Read: {}", usage2.cacheCreationInputTokens(), usage2.cacheReadInputTokens()); + logger.info("Turn 3 - Created: {}, Read: {}", usage3.cacheCreationInputTokens(), usage3.cacheReadInputTokens()); + logger.info("Turn 4 - Created: {}, Read: {}", usage4.cacheCreationInputTokens(), usage4.cacheReadInputTokens()); + + // Demonstrate incremental growth pattern + logger.info("\n=== CACHE GROWTH PATTERN ==="); + logger.info("Cache read tokens grew from {} → {} → {} → {}", usage1.cacheReadInputTokens(), + usage2.cacheReadInputTokens(), usage3.cacheReadInputTokens(), usage4.cacheReadInputTokens()); + logger.info("This demonstrates incremental prefix caching: each turn builds on the previous cache"); + + // Verify that once caching starts, cache reads continue to grow + List cacheReads = List.of(usage1.cacheReadInputTokens(), usage2.cacheReadInputTokens(), + usage3.cacheReadInputTokens(), usage4.cacheReadInputTokens()); + int firstNonZeroIndex = -1; + for (int i = 0; i < cacheReads.size(); i++) { + if (cacheReads.get(i) > 0) { + firstNonZeroIndex = i; + break; + } + } + if (firstNonZeroIndex >= 0 && firstNonZeroIndex < cacheReads.size() - 1) { + // Verify each subsequent turn has cache reads >= previous + for (int i = firstNonZeroIndex + 1; i < cacheReads.size(); i++) { + assertThat(cacheReads.get(i)) + .as("Cache reads should grow or stay same once caching starts (turn %d vs turn %d)", i + 1, i) + .isGreaterThanOrEqualTo(cacheReads.get(i - 1)); + } + } + } + } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicPromptCachingMockTest.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicPromptCachingMockTest.java index a914a243085..1f250604644 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicPromptCachingMockTest.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicPromptCachingMockTest.java @@ -331,11 +331,11 @@ void testConversationHistoryCacheStrategy() throws Exception { assertThat(messagesArray.isArray()).isTrue(); assertThat(messagesArray.size()).isGreaterThan(1); - // Verify the second-to-last message has cache control (conversation history) - if (messagesArray.size() >= 2) { - JsonNode secondToLastMessage = messagesArray.get(messagesArray.size() - 2); - assertThat(secondToLastMessage.has("content")).isTrue(); - JsonNode contentArray = secondToLastMessage.get("content"); + // Verify the last message has cache control (conversation history) + if (messagesArray.size() >= 1) { + JsonNode lastMessage = messagesArray.get(messagesArray.size() - 1); + assertThat(lastMessage.has("content")).isTrue(); + JsonNode contentArray = lastMessage.get("content"); if (contentArray.isArray() && contentArray.size() > 0) { JsonNode lastContentBlock = contentArray.get(contentArray.size() - 1); assertThat(lastContentBlock.has("cache_control")).isTrue(); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc index 06bb3eb32ef..d0261db78d7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc @@ -215,9 +215,9 @@ Different models have different minimum token thresholds for cache effectiveness Spring AI provides strategic cache placement through the `AnthropicCacheStrategy` enum: * `NONE`: Disables prompt caching completely -* `SYSTEM_ONLY`: Caches only the system message content +* `SYSTEM_ONLY`: Caches only the system message content * `SYSTEM_AND_TOOLS`: Caches system message and the last tool definition -* `CONVERSATION_HISTORY`: Caches conversation history in chat memory scenarios +* `CONVERSATION_HISTORY`: Caches the entire conversation history by placing cache breakpoints on tools (if present), system message, and the last user message. This enables incremental prefix caching for multi-turn conversations This strategic approach ensures optimal cache breakpoint placement while staying within Anthropic's 4-breakpoint limit. @@ -274,7 +274,7 @@ ChatResponse response = chatModel.call( [source,java] ---- -// Cache conversation history with ChatClient and memory (latest user question is not cached) +// Cache conversation history with ChatClient and memory (cache breakpoint on last user message) ChatClient chatClient = ChatClient.builder(chatModel) .defaultSystem("You are a personalized career counselor...") .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory) @@ -622,6 +622,9 @@ Even small changes will require a new cache entry. The prompt caching implementation in Spring AI follows these key design principles: 1. **Strategic Cache Placement**: Cache breakpoints are automatically placed at optimal locations based on the chosen strategy, ensuring compliance with Anthropic's 4-breakpoint limit. + - `CONVERSATION_HISTORY` places cache breakpoints on: tools (if present), system message, and the last user message + - This enables Anthropic's prefix matching to incrementally cache the growing conversation history + - Each turn builds on the previous cached prefix, maximizing cache reuse 2. **Provider Portability**: Cache configuration is done through `AnthropicChatOptions` rather than individual messages, preserving compatibility when switching between different AI providers. @@ -629,6 +632,8 @@ The prompt caching implementation in Spring AI follows these key design principl 4. **Automatic Content Ordering**: The implementation ensures proper on-the-wire ordering of JSON content blocks and cache controls according to Anthropic's API requirements. +5. **Aggregate Eligibility Checking**: For `CONVERSATION_HISTORY`, the implementation considers all message types (user, assistant, tool) within the last ~20 content blocks when determining if the combined content meets the minimum token threshold for caching. + === Future Enhancements The current cache strategies are designed to handle **90% of common use cases** effectively. For applications requiring more granular control, future enhancements may include: