diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 1b8bbea84e9..c317b88e44c 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -54,8 +54,10 @@ public final class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor { private final Scheduler scheduler; + private final boolean storeAllUserMessages; + private MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order, - Scheduler scheduler) { + Scheduler scheduler, boolean storeAllUserMessages) { Assert.notNull(chatMemory, "chatMemory cannot be null"); Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty"); Assert.notNull(scheduler, "scheduler cannot be null"); @@ -63,6 +65,7 @@ private MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversati this.defaultConversationId = defaultConversationId; this.order = order; this.scheduler = scheduler; + this.storeAllUserMessages = storeAllUserMessages; } @Override @@ -88,12 +91,19 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai // 3. Create a new request with the advised messages. ChatClientRequest processedChatClientRequest = chatClientRequest.mutate() - .prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build()) - .build(); + .prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build()) + .build(); // 4. Add the new user message to the conversation memory. - UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); - this.chatMemory.add(conversationId, userMessage); + if (this.storeAllUserMessages) { + // Store all user messages: add the new message to the existing message list + List allUserMessages = processedChatClientRequest.prompt().getUserMessages(); + this.chatMemory.add(conversationId, allUserMessages); + } else { + // Store only the latest user message + UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage(); + this.chatMemory.add(conversationId, userMessage); + } return processedChatClientRequest; } @@ -103,10 +113,10 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh List assistantMessages = new ArrayList<>(); if (chatClientResponse.chatResponse() != null) { assistantMessages = chatClientResponse.chatResponse() - .getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); + .getResults() + .stream() + .map(g -> (Message) g.getOutput()) + .toList(); } this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId), assistantMessages); @@ -121,11 +131,11 @@ public Flux adviseStream(ChatClientRequest chatClientRequest // Process the request with the before method return Mono.just(chatClientRequest) - .publishOn(scheduler) - .map(request -> this.before(request, streamAdvisorChain)) - .flatMapMany(streamAdvisorChain::nextStream) - .transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux, - response -> this.after(response, streamAdvisorChain))); + .publishOn(scheduler) + .map(request -> this.before(request, streamAdvisorChain)) + .flatMapMany(streamAdvisorChain::nextStream) + .transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux, + response -> this.after(response, streamAdvisorChain))); } public static Builder builder(ChatMemory chatMemory) { @@ -142,6 +152,8 @@ public static final class Builder { private ChatMemory chatMemory; + private boolean storeAllUserMessages = false; + private Builder(ChatMemory chatMemory) { this.chatMemory = chatMemory; } @@ -171,12 +183,22 @@ public Builder scheduler(Scheduler scheduler) { return this; } + /** + * Configure whether to store all user messages or only the latest one. + * @param storeAllUserMessages true to store all user messages, false to store only the latest + * @return the builder + */ + public Builder storeAllUserMessages(boolean storeAllUserMessages) { + this.storeAllUserMessages = storeAllUserMessages; + return this; + } + /** * Build the advisor. * @return the advisor */ public MessageChatMemoryAdvisor build() { - return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler); + return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler, this.storeAllUserMessages); } } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java index 52ec1c00a98..0119d1a987a 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java @@ -108,4 +108,53 @@ void testDefaultValues() { assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); } + @Test + void whenStoreAllUserMessagesIsTrueThenPreserveAllMessages() { + // Create a chat memory + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + // Create advisor with storeAllUserMessages set to true + MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory) + .storeAllUserMessages(true) + .build(); + + // Verify the advisor was built successfully + assertThat(advisor).isNotNull(); + } + + @Test + void whenStoreAllUserMessagesIsFalseThenStoreOnlyLatest() { + // Create a chat memory + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + // Create advisor with storeAllUserMessages set to false (default) + MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory) + .storeAllUserMessages(false) + .build(); + + // Verify the advisor was built successfully + assertThat(advisor).isNotNull(); + } + + @Test + void testDefaultStoreAllUserMessagesValue() { + // Create a chat memory + ChatMemory chatMemory = MessageWindowChatMemory.builder() + .chatMemoryRepository(new InMemoryChatMemoryRepository()) + .build(); + + // Create advisor with default values + MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build(); + + // Verify the advisor was built successfully + assertThat(advisor).isNotNull(); + // Note: We cannot directly verify the default value of storeAllUserMessages + // since it's a private field, but the construction should succeed + } + + }