Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
Expand Down Expand Up @@ -73,18 +74,24 @@ public async Task<IEnumerable<ChatMessage>> ReduceAsync(IEnumerable<ChatMessage>
{
_ = Throw.IfNull(messages);

var summarizedConversion = SummarizedConversation.FromChatMessages(messages);
if (summarizedConversion.ShouldResummarize(_targetCount, _thresholdCount))
var summarizedConversation = SummarizedConversation.FromChatMessages(messages);
var indexOfFirstMessageToKeep = summarizedConversation.FindIndexOfFirstMessageToKeep(_targetCount, _thresholdCount);
if (indexOfFirstMessageToKeep > 0)
{
summarizedConversion = await summarizedConversion.ResummarizeAsync(
_chatClient, _targetCount, SummarizationPrompt, cancellationToken);
summarizedConversation = await summarizedConversation.ResummarizeAsync(
_chatClient,
indexOfFirstMessageToKeep,
SummarizationPrompt,
cancellationToken);
}

return summarizedConversion.ToChatMessages();
return summarizedConversation.ToChatMessages();
}

/// <summary>Represents a conversation with an optional summary.</summary>
private readonly struct SummarizedConversation(string? summary, ChatMessage? systemMessage, IList<ChatMessage> unsummarizedMessages)
{
/// <summary>Creates a <see cref="SummarizedConversation"/> from a list of chat messages.</summary>
public static SummarizedConversation FromChatMessages(IEnumerable<ChatMessage> messages)
{
string? summary = null;
Expand All @@ -102,7 +109,7 @@ public static SummarizedConversation FromChatMessages(IEnumerable<ChatMessage> m
unsummarizedMessages.Clear();
summary = summaryValue;
}
else if (!message.Contents.Any(m => m is FunctionCallContent or FunctionResultContent))
else
{
unsummarizedMessages.Add(message);
}
Expand All @@ -111,31 +118,68 @@ public static SummarizedConversation FromChatMessages(IEnumerable<ChatMessage> m
return new(summary, systemMessage, unsummarizedMessages);
}

public bool ShouldResummarize(int targetCount, int thresholdCount)
=> unsummarizedMessages.Count > targetCount + thresholdCount;

public async Task<SummarizedConversation> ResummarizeAsync(
IChatClient chatClient, int targetCount, string summarizationPrompt, CancellationToken cancellationToken)
/// <summary>Performs summarization by calling the chat client and updating the conversation state.</summary>
public async ValueTask<SummarizedConversation> ResummarizeAsync(
IChatClient chatClient, int indexOfFirstMessageToKeep, string summarizationPrompt, CancellationToken cancellationToken)
{
var messagesToResummarize = unsummarizedMessages.Count - targetCount;
if (messagesToResummarize <= 0)
{
// We're at or below the target count - no need to resummarize.
return this;
}
Debug.Assert(indexOfFirstMessageToKeep > 0, "Expected positive index for first message to keep.");

var summarizerChatMessages = ToSummarizerChatMessages(messagesToResummarize, summarizationPrompt);
// Generate the summary by sending unsummarized messages to the chat client
var summarizerChatMessages = ToSummarizerChatMessages(indexOfFirstMessageToKeep, summarizationPrompt);
var response = await chatClient.GetResponseAsync(summarizerChatMessages, cancellationToken: cancellationToken);
var newSummary = response.Text;

var lastSummarizedMessage = unsummarizedMessages[messagesToResummarize - 1];
// Attach the summary metadata to the last message being summarized
// This is what allows us to build on previously-generated summaries
var lastSummarizedMessage = unsummarizedMessages[indexOfFirstMessageToKeep - 1];
var additionalProperties = lastSummarizedMessage.AdditionalProperties ??= [];
additionalProperties[SummaryKey] = newSummary;

var newUnsummarizedMessages = unsummarizedMessages.Skip(messagesToResummarize).ToList();
// Compute the new list of unsummarized messages
var newUnsummarizedMessages = unsummarizedMessages.Skip(indexOfFirstMessageToKeep).ToList();
return new SummarizedConversation(newSummary, systemMessage, newUnsummarizedMessages);
}

/// <summary>Determines the index of the first message to keep (not summarize) based on target and threshold counts.</summary>
public int FindIndexOfFirstMessageToKeep(int targetCount, int thresholdCount)
{
var earliestAllowedIndex = unsummarizedMessages.Count - thresholdCount - targetCount;
if (earliestAllowedIndex <= 0)
{
// Not enough messages to warrant summarization
return 0;
}

// Start at the ideal cut point (keeping exactly targetCount messages)
var indexOfFirstMessageToKeep = unsummarizedMessages.Count - targetCount;

// Move backward to skip over function call/result content at the boundary
// We want to keep complete function call sequences together with their responses
while (indexOfFirstMessageToKeep > 0)
{
if (!unsummarizedMessages[indexOfFirstMessageToKeep - 1].Contents.Any(IsToolRelatedContent))
{
break;
}

indexOfFirstMessageToKeep--;
}

// Search backward within the threshold window to find a User message
// If found, cut right before it to avoid orphaning user questions from responses
for (var i = indexOfFirstMessageToKeep; i >= earliestAllowedIndex; i--)
{
if (unsummarizedMessages[i].Role == ChatRole.User)
{
return i;
}
}

// No User message found within threshold - use the adjusted cut point
return indexOfFirstMessageToKeep;
}

/// <summary>Converts the summarized conversation back into a collection of chat messages.</summary>
public IEnumerable<ChatMessage> ToChatMessages()
{
if (systemMessage is not null)
Expand All @@ -154,16 +198,33 @@ public IEnumerable<ChatMessage> ToChatMessages()
}
}

private IEnumerable<ChatMessage> ToSummarizerChatMessages(int messagesToResummarize, string summarizationPrompt)
/// <summary>Returns whether the given <see cref="AIContent"/> relates to tool calling capabilities.</summary>
/// <remarks>
/// This method returns <see langword="true"/> for content types whose meaning depends on other related <see cref="AIContent"/>
/// instances in the conversation, such as function calls that require corresponding results, or other tool interactions that span
/// multiple messages. Such content should be kept together during summarization.
/// </remarks>
private static bool IsToolRelatedContent(AIContent content) => content
is FunctionCallContent
or FunctionResultContent
or UserInputRequestContent
or UserInputResponseContent;

/// <summary>Builds the list of messages to send to the chat client for summarization.</summary>
private IEnumerable<ChatMessage> ToSummarizerChatMessages(int indexOfFirstMessageToKeep, string summarizationPrompt)
{
if (summary is not null)
{
yield return new ChatMessage(ChatRole.Assistant, summary);
}

for (var i = 0; i < messagesToResummarize; i++)
for (var i = 0; i < indexOfFirstMessageToKeep; i++)
{
yield return unsummarizedMessages[i];
var message = unsummarizedMessages[i];
if (!message.Contents.Any(IsToolRelatedContent))
{
yield return message;
}
}

yield return new ChatMessage(ChatRole.System, summarizationPrompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,27 +84,145 @@ public async Task ReduceAsync_PreservesSystemMessage()
}

[Fact]
public async Task ReduceAsync_IgnoresFunctionCallsAndResults()
public async Task ReduceAsync_PreservesCompleteToolCallSequence()
{
using var chatClient = new TestChatClient();
var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 0);

// Target 2 messages, but this would split a function call sequence
var reducer = new SummarizingChatReducer(chatClient, targetCount: 2, threshold: 0);

List<ChatMessage> messages =
[
new ChatMessage(ChatRole.User, "What's the time?"),
new ChatMessage(ChatRole.Assistant, "Let me check"),
new ChatMessage(ChatRole.User, "What's the weather?"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather", new Dictionary<string, object?> { ["location"] = "Seattle" })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny, 72°F")]),
new ChatMessage(ChatRole.Assistant, "The weather in Seattle is sunny and 72°F."),
new ChatMessage(ChatRole.User, "Thanks!"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather"), new TestUserInputRequestContent("uir1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny")]),
new ChatMessage(ChatRole.User, [new TestUserInputResponseContent("uir1")]),
new ChatMessage(ChatRole.Assistant, "It's sunny"),
];

chatClient.GetResponseAsyncCallback = (msgs, _, _) =>
{
Assert.DoesNotContain(msgs, m => m.Contents.Any(c => c is FunctionCallContent or FunctionResultContent or TestUserInputRequestContent or TestUserInputResponseContent));
return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Asked about time")));
};

var result = await reducer.ReduceAsync(messages, CancellationToken.None);
var resultList = result.ToList();

// Function calls/results should be ignored, which means there aren't enough messages to generate a summary.
// Should have: summary + function call + function result + user input response + last reply
Assert.Equal(5, resultList.Count);

// Verify the complete sequence is preserved
Assert.Collection(resultList,
m => Assert.Contains("Asked about time", m.Text),
m =>
{
Assert.Contains(m.Contents, c => c is FunctionCallContent);
Assert.Contains(m.Contents, c => c is TestUserInputRequestContent);
},
m => Assert.Contains(m.Contents, c => c is FunctionResultContent),
m => Assert.Contains(m.Contents, c => c is TestUserInputResponseContent),
m => Assert.Contains("sunny", m.Text));
}

[Fact]
public async Task ReduceAsync_PreservesUserMessageWhenWithinThreshold()
{
using var chatClient = new TestChatClient();

// Target 3 messages with threshold of 2
// This allows us to keep anywhere from 3 to 5 messages
var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 2);

List<ChatMessage> messages =
[
new ChatMessage(ChatRole.User, "First question"),
new ChatMessage(ChatRole.Assistant, "First answer"),
new ChatMessage(ChatRole.User, "Second question"),
new ChatMessage(ChatRole.Assistant, "Second answer"),
new ChatMessage(ChatRole.User, "Third question"),
new ChatMessage(ChatRole.Assistant, "Third answer"),
];

chatClient.GetResponseAsyncCallback = (msgs, _, _) =>
{
var msgList = msgs.ToList();

// Should summarize messages 0-1 (First question and answer)
// The reducer should find the User message at index 2 within the threshold
Assert.Equal(3, msgList.Count); // 2 messages to summarize + system prompt
return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Summary of first exchange")));
};

var result = await reducer.ReduceAsync(messages, CancellationToken.None);
var resultList = result.ToList();
Assert.Equal(3, resultList.Count); // Function calls get removed in the summarized chat.
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionCallContent));
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionResultContent));

// Should have: summary + 4 kept messages (from "Second question" onward)
Assert.Equal(5, resultList.Count);

// Verify the summary is first
Assert.Contains("Summary", resultList[0].Text);

// Verify we kept the User message at index 2 and everything after
Assert.Collection(resultList.Skip(1),
m => Assert.Contains("Second question", m.Text),
m => Assert.Contains("Second answer", m.Text),
m => Assert.Contains("Third question", m.Text),
m => Assert.Contains("Third answer", m.Text));
}

[Fact]
public async Task ReduceAsync_ExcludesToolCallsFromSummarizedPortion()
{
using var chatClient = new TestChatClient();

// Target 3 messages - this will cause function calls in older messages to be summarized (excluded)
// while function calls in recent messages are kept
var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 0);

List<ChatMessage> messages =
[
new ChatMessage(ChatRole.User, "What's the weather in Seattle?"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather", new Dictionary<string, object?> { ["location"] = "Seattle" }), new TestUserInputRequestContent("uir2")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny, 72°F")]),
new ChatMessage(ChatRole.User, [new TestUserInputResponseContent("uir2")]),
new ChatMessage(ChatRole.Assistant, "It's sunny and 72°F in Seattle."),
new ChatMessage(ChatRole.User, "What about New York?"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call2", "get_weather", new Dictionary<string, object?> { ["location"] = "New York" })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call2", "Rainy, 65°F")]),
new ChatMessage(ChatRole.Assistant, "It's rainy and 65°F in New York."),
];

chatClient.GetResponseAsyncCallback = (msgs, _, _) =>
{
var msgList = msgs.ToList();

Assert.Equal(4, msgList.Count); // 3 non-function messages + system prompt
Assert.DoesNotContain(msgList, m => m.Contents.Any(c => c is FunctionCallContent or FunctionResultContent or TestUserInputRequestContent or TestUserInputResponseContent));
Assert.Contains(msgList, m => m.Text.Contains("What's the weather in Seattle?"));
Assert.Contains(msgList, m => m.Text.Contains("sunny and 72°F in Seattle"));
Assert.Contains(msgList, m => m.Text.Contains("What about New York?"));
Assert.Contains(msgList, m => m.Role == ChatRole.System);

return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "User asked about weather in Seattle and New York.")));
};

var result = await reducer.ReduceAsync(messages, CancellationToken.None);
var resultList = result.ToList();

// Should have: summary + 3 kept messages (the last 3 messages with function calls)
Assert.Equal(4, resultList.Count);

Assert.Contains("User asked about weather", resultList[0].Text);
Assert.Contains(resultList, m => m.Contents.Any(c => c is FunctionCallContent fc && fc.CallId == "call2"));
Assert.Contains(resultList, m => m.Contents.Any(c => c is FunctionResultContent fr && fr.CallId == "call2"));
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionCallContent fc && fc.CallId == "call1"));
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionResultContent fr && fr.CallId == "call1"));
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is TestUserInputRequestContent));
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is TestUserInputResponseContent));
Assert.DoesNotContain(resultList, m => m.Text.Contains("sunny and 72°F in Seattle"));
}

[Theory]
Expand All @@ -121,7 +239,7 @@ public async Task ReduceAsync_RespectsTargetAndThresholdCounts(int targetCount,
var messages = new List<ChatMessage>();
for (int i = 0; i < messageCount; i++)
{
messages.Add(new ChatMessage(i % 2 == 0 ? ChatRole.User : ChatRole.Assistant, $"Message {i}"));
messages.Add(new ChatMessage(ChatRole.Assistant, $"Message {i}"));
}

var summarizationCalled = false;
Expand Down Expand Up @@ -266,4 +384,20 @@ need frequent exercise. The user then asked about whether they're good around ki
m => Assert.StartsWith("Golden retrievers get along", m.Text, StringComparison.Ordinal),
m => Assert.StartsWith("Do they make good lap dogs", m.Text, StringComparison.Ordinal));
}

private sealed class TestUserInputRequestContent : UserInputRequestContent
{
public TestUserInputRequestContent(string id)
: base(id)
{
}
}

private sealed class TestUserInputResponseContent : UserInputResponseContent
{
public TestUserInputResponseContent(string id)
: base(id)
{
}
}
}
Loading