Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ def __init__(self, params: Params):
"avg_original_content_words": None,
"avg_updated_content_words": None,
"max_reasoning_words": None,
"max_original_content_words": None,
"max_updated_content_words": None,
"max_reasoning_tokens": None,
"avg_reasoning_tokens": None,
"max_updated_content_tokens": None,
"avg_updated_content_tokens": None,
"total_reasoning_words": 0,
"total_original_content_words": 0,
"total_updated_content_words": 0,
"total_reasoning_tokens": 0,
"total_updated_content_tokens": 0,
}

# Initialize cache if enabled
Expand Down Expand Up @@ -223,6 +234,10 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
reasoning_words = reasoning_info.get("reasoning_words", 0)
original_words = reasoning_info.get("original_content_words", 0)
updated_words = reasoning_info.get("updated_content_words", 0)
reasoning_tokens = reasoning_info.get("reasoning_tokens", "unknown")
updated_content_tokens = reasoning_info.get(
"updated_content_tokens", "unknown"
)

# Increment counters
if reasoning_words > 0:
Expand All @@ -237,8 +252,10 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
("avg_reasoning_words", reasoning_words),
("avg_original_content_words", original_words),
("avg_updated_content_words", updated_words),
("avg_reasoning_tokens", reasoning_tokens),
("avg_updated_content_tokens", updated_content_tokens),
]:
if value > 0:
if value != "unknown":
if self._reasoning_stats[stat_key] is None:
self._reasoning_stats[stat_key] = value
else:
Expand All @@ -253,12 +270,34 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
)

# Update max reasoning words
if reasoning_words > 0:
if (
self._reasoning_stats["max_reasoning_words"] is None
or reasoning_words > self._reasoning_stats["max_reasoning_words"]
):
self._reasoning_stats["max_reasoning_words"] = reasoning_words
for key in [
"reasoning_words",
"original_content_words",
"updated_content_words",
"reasoning_tokens",
"updated_content_tokens",
]:
value = reasoning_info.get(key, None)
if value is not None and value != "unknown":
if (
self._reasoning_stats[f"max_{key}"] is None
or value > self._reasoning_stats[f"max_{key}"]
):
self._reasoning_stats[f"max_{key}"] = value

# Update total statistics
if reasoning_words != "unknown":
self._reasoning_stats["total_reasoning_words"] += reasoning_words
if original_words != "unknown":
self._reasoning_stats["total_original_content_words"] += original_words
if updated_words != "unknown":
self._reasoning_stats["total_updated_content_words"] += updated_words
if reasoning_tokens != "unknown":
self._reasoning_stats["total_reasoning_tokens"] += reasoning_tokens
if updated_content_tokens != "unknown":
self._reasoning_stats["total_updated_content_tokens"] += (
updated_content_tokens
)

# Log aggregated stats at specified interval
if (
Expand All @@ -268,33 +307,64 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
):
self.logger.info(**self._reasoning_stats)

def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]:
def _process_reasoning_message(
self, msg: dict, usage: dict = None
) -> tuple[dict, dict]:
"""
Process reasoning in the message and return modified message with reasoning info.

Args:
msg: The message object containing content and potentially reasoning_content
usage: Optional usage data from the response for token tracking

Returns:
tuple: (modified_message, reasoning_info) where reasoning_info has keys:
reasoning_words, original_content_words, updated_content_words, reasoning_finished, reasoning_started
"""
modified_msg = msg.copy()
content = msg.get("content", "")
updated_content_tokens = "unknown"
reasoning_tokens = "unknown"

# Check if reasoning_content exists in the message and is not empty
if (
"reasoning_content" in msg
and msg["reasoning_content"]
and msg["reasoning_content"].strip()
):
if "reasoning_content" in msg:
reasoning_content = msg["reasoning_content"]
updated_message_content = content
reasoning_started = True
reasoning_started = (
True
if reasoning_content is not None and reasoning_content.strip() != ""
else False
)
if content.strip() == "":
reasoning_finished = False
else:
reasoning_finished = True
if usage:
# First try to get reasoning_tokens directly from usage
reasoning_tokens = usage.get("reasoning_tokens", "unknown")
updated_content_tokens = usage.get("content_tokens", "unknown")

# If not found, check in completion_tokens_details and output_tokens_details
if reasoning_tokens == "unknown":
for key in ["completion_tokens_details", "output_tokens_details"]:
if key in usage:
details = usage[key]
if isinstance(details, dict):
reasoning_tokens = details.get(
"reasoning_tokens", "unknown"
)
if reasoning_tokens != "unknown":
self.logger.debug(
f"Found reasoning_tokens in {key}: {reasoning_tokens}"
)
break

# Log if reasoning tokens were found
if reasoning_tokens != "unknown":
self.logger.debug(f"Reasoning tokens extracted: {reasoning_tokens}")
else:
self.logger.debug("No reasoning tokens found in usage data")

else:
reasoning_finished = False
if self.start_reasoning_token is not None:
Expand Down Expand Up @@ -343,6 +413,8 @@ def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]:
),
"reasoning_finished": reasoning_finished,
"reasoning_started": reasoning_started,
"reasoning_tokens": reasoning_tokens,
"updated_content_tokens": updated_content_tokens,
}

return modified_msg, reasoning_info
Expand All @@ -363,6 +435,7 @@ def _strip_reasoning(self, text: str) -> str:
return cleaned_content

def _migrate_reasoning_content(self, msg: dict):
"""Migrate reasoning content to the content field with reasoning tokens."""
modified_msg = msg.copy()
if (
"reasoning_content" in msg
Expand Down Expand Up @@ -404,6 +477,9 @@ def intercept_response(
"Reasoning processing disabled, returning response as-is"
)
return resp
if resp.rctx.cache_hit:
self.logger.debug("Response was from cache, skipping reasoning processing")
return resp

try:
response_data = resp.r.json()
Expand All @@ -414,6 +490,9 @@ def intercept_response(
choices_count=len(response_data["choices"]),
)

# Extract usage data from response
usage_data = response_data.get("usage", {})

for choice in response_data["choices"]:
msg = choice.get("message")
if (
Expand All @@ -423,7 +502,7 @@ def intercept_response(
):
# Get modified message and reasoning information
modified_msg, reasoning_info = self._process_reasoning_message(
msg
msg, usage_data
)

# Collect reasoning statistics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,11 @@ def intercept_response(
) -> AdapterResponse:
"""Collect aggregated statistics from the response."""
# Get status code once and reuse it
if resp.rctx.cache_hit:
self.logger.debug(
"Response was from cache, skipping response stats collection"
)
return resp
status_code = resp.r.status_code

# Update time tracking with current timestamp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,13 @@ def test_save_stats_to_file_creates_directory(tmp_path, nested_path):
"responses_with_reasoning": 85,
"avg_reasoning_words": 12.5,
},
{
"total_responses": 5,
"responses_with_reasoning": 3,
"total_reasoning_words": 25,
"total_original_content_words": 50,
"total_updated_content_words": 30,
},
],
)
def test_reasoning_stats_access(test_stats):
Expand All @@ -928,6 +935,134 @@ def test_reasoning_stats_access(test_stats):
assert stats is not interceptor._reasoning_stats


@pytest.mark.parametrize(
"cache_hit,expected_total_responses,expected_responses_with_reasoning",
[
# Cached response - should NOT be counted (skipped by interceptor)
(True, 0, 0),
# Normal response (not cached)
(False, 1, 1),
],
)
def test_cached_response_reasoning_behavior(
tmp_path,
cache_hit,
expected_total_responses,
expected_responses_with_reasoning,
):
"""Test that cached responses are properly skipped in reasoning stats counting."""
interceptor = ResponseReasoningInterceptor(
ResponseReasoningInterceptor.Params(
enable_reasoning_tracking=True,
enable_caching=False, # Disable caching to avoid state pollution
)
)

# Create mock response with reasoning content
response_data = {
"choices": [
{
"message": {
"role": "assistant",
"content": "<think>This is reasoning content.</think>Final answer.",
"reasoning_content": "This is reasoning content.",
}
}
],
"usage": {"reasoning_tokens": 10, "content_tokens": 20},
}

mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = response_data
mock_response._content = json.dumps(response_data).encode()

mock_rctx = Mock()
mock_rctx.cache_hit = cache_hit

adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx)
context = AdapterGlobalContext(output_dir=str(tmp_path), url="http://localhost")

interceptor.intercept_response(adapter_response, context)

stats = interceptor._reasoning_stats
assert stats["total_responses"] == expected_total_responses
assert stats["responses_with_reasoning"] == expected_responses_with_reasoning


@pytest.mark.parametrize(
"usage_format,expected_reasoning_tokens,expected_content_tokens",
[
# Format 1: reasoning_tokens and content_tokens at top level
({"reasoning_tokens": 15, "content_tokens": 30}, 15, 30),
# Format 2: reasoning_tokens in completion_tokens_details
(
{
"completion_tokens_details": {"reasoning_tokens": 20},
"content_tokens": 40,
},
20,
40,
),
# Format 3: reasoning_tokens in output_tokens_details
(
{
"output_tokens_details": {"reasoning_tokens": 25},
"content_tokens": 50,
},
25,
50,
),
],
)
def test_reasoning_tokens_different_formats(
tmp_path,
usage_format,
expected_reasoning_tokens,
expected_content_tokens,
):
"""Test reasoning interceptor handles different usage data formats for reasoning tokens."""
interceptor = ResponseReasoningInterceptor(
ResponseReasoningInterceptor.Params(
enable_reasoning_tracking=True,
enable_caching=False,
)
)

# Create mock response with reasoning content
response_data = {
"choices": [
{
"message": {
"role": "assistant",
"content": "<think>This is reasoning content.</think>Final answer.",
"reasoning_content": "This is reasoning content.",
}
}
],
"usage": usage_format,
}

mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = response_data
mock_response._content = json.dumps(response_data).encode()

mock_rctx = Mock()
mock_rctx.cache_hit = False

adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx)
context = AdapterGlobalContext(output_dir=str(tmp_path), url="http://localhost")

interceptor.intercept_response(adapter_response, context)

stats = interceptor._reasoning_stats
assert stats["total_responses"] == 1
assert stats["responses_with_reasoning"] == 1
assert stats["total_reasoning_tokens"] == expected_reasoning_tokens
assert stats["total_updated_content_tokens"] == expected_content_tokens


def test_load_from_cache_during_initialization(tmp_path):
"""Test that cached stats are automatically loaded during initialization."""
# Given: Create cache directory and add some cached stats
Expand Down
Loading
Loading