diff --git a/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/reasoning_interceptor.py b/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/reasoning_interceptor.py
index f5d9c305..47d2fc6c 100644
--- a/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/reasoning_interceptor.py
+++ b/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/reasoning_interceptor.py
@@ -140,6 +140,17 @@ def __init__(self, params: Params):
"avg_original_content_words": None,
"avg_updated_content_words": None,
"max_reasoning_words": None,
+ "max_original_content_words": None,
+ "max_updated_content_words": None,
+ "max_reasoning_tokens": None,
+ "avg_reasoning_tokens": None,
+ "max_updated_content_tokens": None,
+ "avg_updated_content_tokens": None,
+ "total_reasoning_words": 0,
+ "total_original_content_words": 0,
+ "total_updated_content_words": 0,
+ "total_reasoning_tokens": 0,
+ "total_updated_content_tokens": 0,
}
# Initialize cache if enabled
@@ -223,6 +234,10 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
reasoning_words = reasoning_info.get("reasoning_words", 0)
original_words = reasoning_info.get("original_content_words", 0)
updated_words = reasoning_info.get("updated_content_words", 0)
+ reasoning_tokens = reasoning_info.get("reasoning_tokens", "unknown")
+ updated_content_tokens = reasoning_info.get(
+ "updated_content_tokens", "unknown"
+ )
# Increment counters
if reasoning_words > 0:
@@ -237,8 +252,10 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
("avg_reasoning_words", reasoning_words),
("avg_original_content_words", original_words),
("avg_updated_content_words", updated_words),
+ ("avg_reasoning_tokens", reasoning_tokens),
+ ("avg_updated_content_tokens", updated_content_tokens),
]:
- if value > 0:
+ if value != "unknown":
if self._reasoning_stats[stat_key] is None:
self._reasoning_stats[stat_key] = value
else:
@@ -253,12 +270,34 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
)
# Update max reasoning words
- if reasoning_words > 0:
- if (
- self._reasoning_stats["max_reasoning_words"] is None
- or reasoning_words > self._reasoning_stats["max_reasoning_words"]
- ):
- self._reasoning_stats["max_reasoning_words"] = reasoning_words
+ for key in [
+ "reasoning_words",
+ "original_content_words",
+ "updated_content_words",
+ "reasoning_tokens",
+ "updated_content_tokens",
+ ]:
+ value = reasoning_info.get(key, None)
+ if value is not None and value != "unknown":
+ if (
+ self._reasoning_stats[f"max_{key}"] is None
+ or value > self._reasoning_stats[f"max_{key}"]
+ ):
+ self._reasoning_stats[f"max_{key}"] = value
+
+ # Update total statistics
+ if reasoning_words != "unknown":
+ self._reasoning_stats["total_reasoning_words"] += reasoning_words
+ if original_words != "unknown":
+ self._reasoning_stats["total_original_content_words"] += original_words
+ if updated_words != "unknown":
+ self._reasoning_stats["total_updated_content_words"] += updated_words
+ if reasoning_tokens != "unknown":
+ self._reasoning_stats["total_reasoning_tokens"] += reasoning_tokens
+ if updated_content_tokens != "unknown":
+ self._reasoning_stats["total_updated_content_tokens"] += (
+ updated_content_tokens
+ )
# Log aggregated stats at specified interval
if (
@@ -268,12 +307,15 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
):
self.logger.info(**self._reasoning_stats)
- def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]:
+ def _process_reasoning_message(
+ self, msg: dict, usage: dict = None
+ ) -> tuple[dict, dict]:
"""
Process reasoning in the message and return modified message with reasoning info.
Args:
msg: The message object containing content and potentially reasoning_content
+ usage: Optional usage data from the response for token tracking
Returns:
tuple: (modified_message, reasoning_info) where reasoning_info has keys:
@@ -281,20 +323,48 @@ def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]:
"""
modified_msg = msg.copy()
content = msg.get("content", "")
+ updated_content_tokens = "unknown"
+ reasoning_tokens = "unknown"
# Check if reasoning_content exists in the message and is not empty
- if (
- "reasoning_content" in msg
- and msg["reasoning_content"]
- and msg["reasoning_content"].strip()
- ):
+ if "reasoning_content" in msg:
reasoning_content = msg["reasoning_content"]
updated_message_content = content
- reasoning_started = True
+ reasoning_started = (
+ True
+ if reasoning_content is not None and reasoning_content.strip() != ""
+ else False
+ )
if content.strip() == "":
reasoning_finished = False
else:
reasoning_finished = True
+ if usage:
+ # First try to get reasoning_tokens directly from usage
+ reasoning_tokens = usage.get("reasoning_tokens", "unknown")
+ updated_content_tokens = usage.get("content_tokens", "unknown")
+
+ # If not found, check in completion_tokens_details and output_tokens_details
+ if reasoning_tokens == "unknown":
+ for key in ["completion_tokens_details", "output_tokens_details"]:
+ if key in usage:
+ details = usage[key]
+ if isinstance(details, dict):
+ reasoning_tokens = details.get(
+ "reasoning_tokens", "unknown"
+ )
+ if reasoning_tokens != "unknown":
+ self.logger.debug(
+ f"Found reasoning_tokens in {key}: {reasoning_tokens}"
+ )
+ break
+
+ # Log if reasoning tokens were found
+ if reasoning_tokens != "unknown":
+ self.logger.debug(f"Reasoning tokens extracted: {reasoning_tokens}")
+ else:
+ self.logger.debug("No reasoning tokens found in usage data")
+
else:
reasoning_finished = False
if self.start_reasoning_token is not None:
@@ -343,6 +413,8 @@ def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]:
),
"reasoning_finished": reasoning_finished,
"reasoning_started": reasoning_started,
+ "reasoning_tokens": reasoning_tokens,
+ "updated_content_tokens": updated_content_tokens,
}
return modified_msg, reasoning_info
@@ -363,6 +435,7 @@ def _strip_reasoning(self, text: str) -> str:
return cleaned_content
def _migrate_reasoning_content(self, msg: dict):
+ """Migrate reasoning content to the content field with reasoning tokens."""
modified_msg = msg.copy()
if (
"reasoning_content" in msg
@@ -404,6 +477,9 @@ def intercept_response(
"Reasoning processing disabled, returning response as-is"
)
return resp
+ if resp.rctx.cache_hit:
+ self.logger.debug("Response was from cache, skipping reasoning processing")
+ return resp
try:
response_data = resp.r.json()
@@ -414,6 +490,9 @@ def intercept_response(
choices_count=len(response_data["choices"]),
)
+ # Extract usage data from response
+ usage_data = response_data.get("usage", {})
+
for choice in response_data["choices"]:
msg = choice.get("message")
if (
@@ -423,7 +502,7 @@ def intercept_response(
):
# Get modified message and reasoning information
modified_msg, reasoning_info = self._process_reasoning_message(
- msg
+ msg, usage_data
)
# Collect reasoning statistics
diff --git a/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/response_stats_interceptor.py b/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/response_stats_interceptor.py
index 45fb3c39..81323b21 100644
--- a/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/response_stats_interceptor.py
+++ b/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/response_stats_interceptor.py
@@ -412,6 +412,11 @@ def intercept_response(
) -> AdapterResponse:
"""Collect aggregated statistics from the response."""
# Get status code once and reuse it
+ if resp.rctx.cache_hit:
+ self.logger.debug(
+ "Response was from cache, skipping response stats collection"
+ )
+ return resp
status_code = resp.r.status_code
# Update time tracking with current timestamp
diff --git a/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py b/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py
index ba84e2c3..3f651fd5 100644
--- a/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py
+++ b/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py
@@ -904,6 +904,13 @@ def test_save_stats_to_file_creates_directory(tmp_path, nested_path):
"responses_with_reasoning": 85,
"avg_reasoning_words": 12.5,
},
+ {
+ "total_responses": 5,
+ "responses_with_reasoning": 3,
+ "total_reasoning_words": 25,
+ "total_original_content_words": 50,
+ "total_updated_content_words": 30,
+ },
],
)
def test_reasoning_stats_access(test_stats):
@@ -928,6 +935,134 @@ def test_reasoning_stats_access(test_stats):
assert stats is not interceptor._reasoning_stats
+@pytest.mark.parametrize(
+ "cache_hit,expected_total_responses,expected_responses_with_reasoning",
+ [
+ # Cached response - should NOT be counted (skipped by interceptor)
+ (True, 0, 0),
+ # Normal response (not cached)
+ (False, 1, 1),
+ ],
+)
+def test_cached_response_reasoning_behavior(
+ tmp_path,
+ cache_hit,
+ expected_total_responses,
+ expected_responses_with_reasoning,
+):
+ """Test that cached responses are properly skipped in reasoning stats counting."""
+ interceptor = ResponseReasoningInterceptor(
+ ResponseReasoningInterceptor.Params(
+ enable_reasoning_tracking=True,
+ enable_caching=False, # Disable caching to avoid state pollution
+ )
+ )
+
+ # Create mock response with reasoning content
+ response_data = {
+ "choices": [
+ {
+ "message": {
+ "role": "assistant",
+ "content": "This is reasoning content.Final answer.",
+ "reasoning_content": "This is reasoning content.",
+ }
+ }
+ ],
+ "usage": {"reasoning_tokens": 10, "content_tokens": 20},
+ }
+
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = response_data
+ mock_response._content = json.dumps(response_data).encode()
+
+ mock_rctx = Mock()
+ mock_rctx.cache_hit = cache_hit
+
+ adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx)
+ context = AdapterGlobalContext(output_dir=str(tmp_path), url="http://localhost")
+
+ interceptor.intercept_response(adapter_response, context)
+
+ stats = interceptor._reasoning_stats
+ assert stats["total_responses"] == expected_total_responses
+ assert stats["responses_with_reasoning"] == expected_responses_with_reasoning
+
+
+@pytest.mark.parametrize(
+ "usage_format,expected_reasoning_tokens,expected_content_tokens",
+ [
+ # Format 1: reasoning_tokens and content_tokens at top level
+ ({"reasoning_tokens": 15, "content_tokens": 30}, 15, 30),
+ # Format 2: reasoning_tokens in completion_tokens_details
+ (
+ {
+ "completion_tokens_details": {"reasoning_tokens": 20},
+ "content_tokens": 40,
+ },
+ 20,
+ 40,
+ ),
+ # Format 3: reasoning_tokens in output_tokens_details
+ (
+ {
+ "output_tokens_details": {"reasoning_tokens": 25},
+ "content_tokens": 50,
+ },
+ 25,
+ 50,
+ ),
+ ],
+)
+def test_reasoning_tokens_different_formats(
+ tmp_path,
+ usage_format,
+ expected_reasoning_tokens,
+ expected_content_tokens,
+):
+ """Test reasoning interceptor handles different usage data formats for reasoning tokens."""
+ interceptor = ResponseReasoningInterceptor(
+ ResponseReasoningInterceptor.Params(
+ enable_reasoning_tracking=True,
+ enable_caching=False,
+ )
+ )
+
+ # Create mock response with reasoning content
+ response_data = {
+ "choices": [
+ {
+ "message": {
+ "role": "assistant",
+ "content": "This is reasoning content.Final answer.",
+ "reasoning_content": "This is reasoning content.",
+ }
+ }
+ ],
+ "usage": usage_format,
+ }
+
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = response_data
+ mock_response._content = json.dumps(response_data).encode()
+
+ mock_rctx = Mock()
+ mock_rctx.cache_hit = False
+
+ adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx)
+ context = AdapterGlobalContext(output_dir=str(tmp_path), url="http://localhost")
+
+ interceptor.intercept_response(adapter_response, context)
+
+ stats = interceptor._reasoning_stats
+ assert stats["total_responses"] == 1
+ assert stats["responses_with_reasoning"] == 1
+ assert stats["total_reasoning_tokens"] == expected_reasoning_tokens
+ assert stats["total_updated_content_tokens"] == expected_content_tokens
+
+
def test_load_from_cache_during_initialization(tmp_path):
"""Test that cached stats are automatically loaded during initialization."""
# Given: Create cache directory and add some cached stats
diff --git a/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_response_stats_interceptor.py b/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_response_stats_interceptor.py
index 92771e00..e3350ea1 100644
--- a/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_response_stats_interceptor.py
+++ b/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_response_stats_interceptor.py
@@ -453,6 +453,54 @@ def process_responses():
assert interceptor._stats["count"] >= 500
assert interceptor._stats["status_codes"][200] >= 500
+ @pytest.mark.parametrize(
+ "cache_hit,expected_total_responses,expected_successful_responses",
+ [
+ # Cached response - should NOT be counted (skipped by interceptor)
+ (True, 0, 0),
+ # Normal response (not cached)
+ (False, 1, 1),
+ ],
+ )
+ def test_cached_response_stats_behavior(
+ self,
+ tmp_path,
+ context,
+ cache_hit,
+ expected_total_responses,
+ expected_successful_responses,
+ ):
+ """Test that cached responses are properly skipped in stats counting."""
+ # Create a unique cache directory for each test to avoid state pollution
+ import uuid
+
+ unique_cache_dir = tmp_path / f"test_cache_{uuid.uuid4().hex[:8]}"
+
+ interceptor = ResponseStatsInterceptor(
+ ResponseStatsInterceptor.Params(
+ save_individuals=False, cache_dir=str(unique_cache_dir)
+ )
+ )
+
+ # Create a mock response
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"result": "test"}
+ mock_response._content = b'{"result": "test"}'
+
+ # Setup request context with cache hit status
+ mock_rctx = Mock()
+ mock_rctx.cache_hit = cache_hit
+
+ adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx)
+
+ # Process response
+ interceptor.intercept_response(adapter_response, context)
+
+ # Verify stats counting behavior
+ assert interceptor._stats["count"] == expected_total_responses
+ assert interceptor._stats["successful_count"] == expected_successful_responses
+
class TestResponseStatsInterceptorCache:
"""Test ResponseStatsInterceptor caching and aggregation functionality."""