diff --git a/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/reasoning_interceptor.py b/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/reasoning_interceptor.py index f5d9c305..47d2fc6c 100644 --- a/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/reasoning_interceptor.py +++ b/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/reasoning_interceptor.py @@ -140,6 +140,17 @@ def __init__(self, params: Params): "avg_original_content_words": None, "avg_updated_content_words": None, "max_reasoning_words": None, + "max_original_content_words": None, + "max_updated_content_words": None, + "max_reasoning_tokens": None, + "avg_reasoning_tokens": None, + "max_updated_content_tokens": None, + "avg_updated_content_tokens": None, + "total_reasoning_words": 0, + "total_original_content_words": 0, + "total_updated_content_words": 0, + "total_reasoning_tokens": 0, + "total_updated_content_tokens": 0, } # Initialize cache if enabled @@ -223,6 +234,10 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None: reasoning_words = reasoning_info.get("reasoning_words", 0) original_words = reasoning_info.get("original_content_words", 0) updated_words = reasoning_info.get("updated_content_words", 0) + reasoning_tokens = reasoning_info.get("reasoning_tokens", "unknown") + updated_content_tokens = reasoning_info.get( + "updated_content_tokens", "unknown" + ) # Increment counters if reasoning_words > 0: @@ -237,8 +252,10 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None: ("avg_reasoning_words", reasoning_words), ("avg_original_content_words", original_words), ("avg_updated_content_words", updated_words), + ("avg_reasoning_tokens", reasoning_tokens), + ("avg_updated_content_tokens", updated_content_tokens), ]: - if value > 0: + if value != "unknown": if self._reasoning_stats[stat_key] is None: self._reasoning_stats[stat_key] = value else: @@ -253,12 +270,34 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None: ) # Update max reasoning words - if reasoning_words > 0: - if ( - self._reasoning_stats["max_reasoning_words"] is None - or reasoning_words > self._reasoning_stats["max_reasoning_words"] - ): - self._reasoning_stats["max_reasoning_words"] = reasoning_words + for key in [ + "reasoning_words", + "original_content_words", + "updated_content_words", + "reasoning_tokens", + "updated_content_tokens", + ]: + value = reasoning_info.get(key, None) + if value is not None and value != "unknown": + if ( + self._reasoning_stats[f"max_{key}"] is None + or value > self._reasoning_stats[f"max_{key}"] + ): + self._reasoning_stats[f"max_{key}"] = value + + # Update total statistics + if reasoning_words != "unknown": + self._reasoning_stats["total_reasoning_words"] += reasoning_words + if original_words != "unknown": + self._reasoning_stats["total_original_content_words"] += original_words + if updated_words != "unknown": + self._reasoning_stats["total_updated_content_words"] += updated_words + if reasoning_tokens != "unknown": + self._reasoning_stats["total_reasoning_tokens"] += reasoning_tokens + if updated_content_tokens != "unknown": + self._reasoning_stats["total_updated_content_tokens"] += ( + updated_content_tokens + ) # Log aggregated stats at specified interval if ( @@ -268,12 +307,15 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None: ): self.logger.info(**self._reasoning_stats) - def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]: + def _process_reasoning_message( + self, msg: dict, usage: dict = None + ) -> tuple[dict, dict]: """ Process reasoning in the message and return modified message with reasoning info. Args: msg: The message object containing content and potentially reasoning_content + usage: Optional usage data from the response for token tracking Returns: tuple: (modified_message, reasoning_info) where reasoning_info has keys: @@ -281,20 +323,48 @@ def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]: """ modified_msg = msg.copy() content = msg.get("content", "") + updated_content_tokens = "unknown" + reasoning_tokens = "unknown" # Check if reasoning_content exists in the message and is not empty - if ( - "reasoning_content" in msg - and msg["reasoning_content"] - and msg["reasoning_content"].strip() - ): + if "reasoning_content" in msg: reasoning_content = msg["reasoning_content"] updated_message_content = content - reasoning_started = True + reasoning_started = ( + True + if reasoning_content is not None and reasoning_content.strip() != "" + else False + ) if content.strip() == "": reasoning_finished = False else: reasoning_finished = True + if usage: + # First try to get reasoning_tokens directly from usage + reasoning_tokens = usage.get("reasoning_tokens", "unknown") + updated_content_tokens = usage.get("content_tokens", "unknown") + + # If not found, check in completion_tokens_details and output_tokens_details + if reasoning_tokens == "unknown": + for key in ["completion_tokens_details", "output_tokens_details"]: + if key in usage: + details = usage[key] + if isinstance(details, dict): + reasoning_tokens = details.get( + "reasoning_tokens", "unknown" + ) + if reasoning_tokens != "unknown": + self.logger.debug( + f"Found reasoning_tokens in {key}: {reasoning_tokens}" + ) + break + + # Log if reasoning tokens were found + if reasoning_tokens != "unknown": + self.logger.debug(f"Reasoning tokens extracted: {reasoning_tokens}") + else: + self.logger.debug("No reasoning tokens found in usage data") + else: reasoning_finished = False if self.start_reasoning_token is not None: @@ -343,6 +413,8 @@ def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]: ), "reasoning_finished": reasoning_finished, "reasoning_started": reasoning_started, + "reasoning_tokens": reasoning_tokens, + "updated_content_tokens": updated_content_tokens, } return modified_msg, reasoning_info @@ -363,6 +435,7 @@ def _strip_reasoning(self, text: str) -> str: return cleaned_content def _migrate_reasoning_content(self, msg: dict): + """Migrate reasoning content to the content field with reasoning tokens.""" modified_msg = msg.copy() if ( "reasoning_content" in msg @@ -404,6 +477,9 @@ def intercept_response( "Reasoning processing disabled, returning response as-is" ) return resp + if resp.rctx.cache_hit: + self.logger.debug("Response was from cache, skipping reasoning processing") + return resp try: response_data = resp.r.json() @@ -414,6 +490,9 @@ def intercept_response( choices_count=len(response_data["choices"]), ) + # Extract usage data from response + usage_data = response_data.get("usage", {}) + for choice in response_data["choices"]: msg = choice.get("message") if ( @@ -423,7 +502,7 @@ def intercept_response( ): # Get modified message and reasoning information modified_msg, reasoning_info = self._process_reasoning_message( - msg + msg, usage_data ) # Collect reasoning statistics diff --git a/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/response_stats_interceptor.py b/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/response_stats_interceptor.py index 45fb3c39..81323b21 100644 --- a/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/response_stats_interceptor.py +++ b/packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/response_stats_interceptor.py @@ -412,6 +412,11 @@ def intercept_response( ) -> AdapterResponse: """Collect aggregated statistics from the response.""" # Get status code once and reuse it + if resp.rctx.cache_hit: + self.logger.debug( + "Response was from cache, skipping response stats collection" + ) + return resp status_code = resp.r.status_code # Update time tracking with current timestamp diff --git a/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py b/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py index ba84e2c3..3f651fd5 100644 --- a/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py +++ b/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py @@ -904,6 +904,13 @@ def test_save_stats_to_file_creates_directory(tmp_path, nested_path): "responses_with_reasoning": 85, "avg_reasoning_words": 12.5, }, + { + "total_responses": 5, + "responses_with_reasoning": 3, + "total_reasoning_words": 25, + "total_original_content_words": 50, + "total_updated_content_words": 30, + }, ], ) def test_reasoning_stats_access(test_stats): @@ -928,6 +935,134 @@ def test_reasoning_stats_access(test_stats): assert stats is not interceptor._reasoning_stats +@pytest.mark.parametrize( + "cache_hit,expected_total_responses,expected_responses_with_reasoning", + [ + # Cached response - should NOT be counted (skipped by interceptor) + (True, 0, 0), + # Normal response (not cached) + (False, 1, 1), + ], +) +def test_cached_response_reasoning_behavior( + tmp_path, + cache_hit, + expected_total_responses, + expected_responses_with_reasoning, +): + """Test that cached responses are properly skipped in reasoning stats counting.""" + interceptor = ResponseReasoningInterceptor( + ResponseReasoningInterceptor.Params( + enable_reasoning_tracking=True, + enable_caching=False, # Disable caching to avoid state pollution + ) + ) + + # Create mock response with reasoning content + response_data = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "This is reasoning content.Final answer.", + "reasoning_content": "This is reasoning content.", + } + } + ], + "usage": {"reasoning_tokens": 10, "content_tokens": 20}, + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = response_data + mock_response._content = json.dumps(response_data).encode() + + mock_rctx = Mock() + mock_rctx.cache_hit = cache_hit + + adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx) + context = AdapterGlobalContext(output_dir=str(tmp_path), url="http://localhost") + + interceptor.intercept_response(adapter_response, context) + + stats = interceptor._reasoning_stats + assert stats["total_responses"] == expected_total_responses + assert stats["responses_with_reasoning"] == expected_responses_with_reasoning + + +@pytest.mark.parametrize( + "usage_format,expected_reasoning_tokens,expected_content_tokens", + [ + # Format 1: reasoning_tokens and content_tokens at top level + ({"reasoning_tokens": 15, "content_tokens": 30}, 15, 30), + # Format 2: reasoning_tokens in completion_tokens_details + ( + { + "completion_tokens_details": {"reasoning_tokens": 20}, + "content_tokens": 40, + }, + 20, + 40, + ), + # Format 3: reasoning_tokens in output_tokens_details + ( + { + "output_tokens_details": {"reasoning_tokens": 25}, + "content_tokens": 50, + }, + 25, + 50, + ), + ], +) +def test_reasoning_tokens_different_formats( + tmp_path, + usage_format, + expected_reasoning_tokens, + expected_content_tokens, +): + """Test reasoning interceptor handles different usage data formats for reasoning tokens.""" + interceptor = ResponseReasoningInterceptor( + ResponseReasoningInterceptor.Params( + enable_reasoning_tracking=True, + enable_caching=False, + ) + ) + + # Create mock response with reasoning content + response_data = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "This is reasoning content.Final answer.", + "reasoning_content": "This is reasoning content.", + } + } + ], + "usage": usage_format, + } + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = response_data + mock_response._content = json.dumps(response_data).encode() + + mock_rctx = Mock() + mock_rctx.cache_hit = False + + adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx) + context = AdapterGlobalContext(output_dir=str(tmp_path), url="http://localhost") + + interceptor.intercept_response(adapter_response, context) + + stats = interceptor._reasoning_stats + assert stats["total_responses"] == 1 + assert stats["responses_with_reasoning"] == 1 + assert stats["total_reasoning_tokens"] == expected_reasoning_tokens + assert stats["total_updated_content_tokens"] == expected_content_tokens + + def test_load_from_cache_during_initialization(tmp_path): """Test that cached stats are automatically loaded during initialization.""" # Given: Create cache directory and add some cached stats diff --git a/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_response_stats_interceptor.py b/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_response_stats_interceptor.py index 92771e00..e3350ea1 100644 --- a/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_response_stats_interceptor.py +++ b/packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_response_stats_interceptor.py @@ -453,6 +453,54 @@ def process_responses(): assert interceptor._stats["count"] >= 500 assert interceptor._stats["status_codes"][200] >= 500 + @pytest.mark.parametrize( + "cache_hit,expected_total_responses,expected_successful_responses", + [ + # Cached response - should NOT be counted (skipped by interceptor) + (True, 0, 0), + # Normal response (not cached) + (False, 1, 1), + ], + ) + def test_cached_response_stats_behavior( + self, + tmp_path, + context, + cache_hit, + expected_total_responses, + expected_successful_responses, + ): + """Test that cached responses are properly skipped in stats counting.""" + # Create a unique cache directory for each test to avoid state pollution + import uuid + + unique_cache_dir = tmp_path / f"test_cache_{uuid.uuid4().hex[:8]}" + + interceptor = ResponseStatsInterceptor( + ResponseStatsInterceptor.Params( + save_individuals=False, cache_dir=str(unique_cache_dir) + ) + ) + + # Create a mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "test"} + mock_response._content = b'{"result": "test"}' + + # Setup request context with cache hit status + mock_rctx = Mock() + mock_rctx.cache_hit = cache_hit + + adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx) + + # Process response + interceptor.intercept_response(adapter_response, context) + + # Verify stats counting behavior + assert interceptor._stats["count"] == expected_total_responses + assert interceptor._stats["successful_count"] == expected_successful_responses + class TestResponseStatsInterceptorCache: """Test ResponseStatsInterceptor caching and aggregation functionality."""