Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,18 @@ public final class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor {

private final Scheduler scheduler;

private final boolean storeAllUserMessages;

private MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order,
Scheduler scheduler) {
Scheduler scheduler, boolean storeAllUserMessages) {
Assert.notNull(chatMemory, "chatMemory cannot be null");
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
Assert.notNull(scheduler, "scheduler cannot be null");
this.chatMemory = chatMemory;
this.defaultConversationId = defaultConversationId;
this.order = order;
this.scheduler = scheduler;
this.storeAllUserMessages = storeAllUserMessages;
}

@Override
Expand All @@ -88,12 +91,19 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai

// 3. Create a new request with the advised messages.
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
.prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build())
.build();
.prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build())
.build();

// 4. Add the new user message to the conversation memory.
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
this.chatMemory.add(conversationId, userMessage);
if (this.storeAllUserMessages) {
// Store all user messages: add the new message to the existing message list
List allUserMessages = processedChatClientRequest.prompt().getUserMessages();
this.chatMemory.add(conversationId, allUserMessages);
} else {
// Store only the latest user message
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
this.chatMemory.add(conversationId, userMessage);
}

return processedChatClientRequest;
}
Expand All @@ -103,10 +113,10 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
List<Message> assistantMessages = new ArrayList<>();
if (chatClientResponse.chatResponse() != null) {
assistantMessages = chatClientResponse.chatResponse()
.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();
.getResults()
.stream()
.map(g -> (Message) g.getOutput())
.toList();
}
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
assistantMessages);
Expand All @@ -121,11 +131,11 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest

// Process the request with the before method
return Mono.just(chatClientRequest)
.publishOn(scheduler)
.map(request -> this.before(request, streamAdvisorChain))
.flatMapMany(streamAdvisorChain::nextStream)
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
response -> this.after(response, streamAdvisorChain)));
.publishOn(scheduler)
.map(request -> this.before(request, streamAdvisorChain))
.flatMapMany(streamAdvisorChain::nextStream)
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
response -> this.after(response, streamAdvisorChain)));
}

public static Builder builder(ChatMemory chatMemory) {
Expand All @@ -142,6 +152,8 @@ public static final class Builder {

private ChatMemory chatMemory;

private boolean storeAllUserMessages = false;

private Builder(ChatMemory chatMemory) {
this.chatMemory = chatMemory;
}
Expand Down Expand Up @@ -171,12 +183,22 @@ public Builder scheduler(Scheduler scheduler) {
return this;
}

/**
* Configure whether to store all user messages or only the latest one.
* @param storeAllUserMessages true to store all user messages, false to store only the latest
* @return the builder
*/
public Builder storeAllUserMessages(boolean storeAllUserMessages) {
this.storeAllUserMessages = storeAllUserMessages;
return this;
}

/**
* Build the advisor.
* @return the advisor
*/
public MessageChatMemoryAdvisor build() {
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler);
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler, this.storeAllUserMessages);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,53 @@ void testDefaultValues() {
assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
}

@Test
void whenStoreAllUserMessagesIsTrueThenPreserveAllMessages() {
// Create a chat memory
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

// Create advisor with storeAllUserMessages set to true
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory)
.storeAllUserMessages(true)
.build();

// Verify the advisor was built successfully
assertThat(advisor).isNotNull();
}

@Test
void whenStoreAllUserMessagesIsFalseThenStoreOnlyLatest() {
// Create a chat memory
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

// Create advisor with storeAllUserMessages set to false (default)
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory)
.storeAllUserMessages(false)
.build();

// Verify the advisor was built successfully
assertThat(advisor).isNotNull();
}

@Test
void testDefaultStoreAllUserMessagesValue() {
// Create a chat memory
ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();

// Create advisor with default values
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();

// Verify the advisor was built successfully
assertThat(advisor).isNotNull();
// Note: We cannot directly verify the default value of storeAllUserMessages
// since it's a private field, but the construction should succeed
}


}