Skip to content

Commit 9367ef0

Browse files
authored
Awarno/reasoning tokens (#211)
1. Add total stats. 2. Add reasoning token stats (if provided). - https://platform.openai.com/docs/guides/reasoning or "reasoning_tokens" in usage, (completion_tokens_details, output_tokens_details) 3. Make stats cache-resistant — do not include stats if the response is from cache. --------- Signed-off-by: Anna Warno <[email protected]>
1 parent 69c43b0 commit 9367ef0

File tree

4 files changed

+282
-15
lines changed

4 files changed

+282
-15
lines changed

packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/reasoning_interceptor.py

Lines changed: 94 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,17 @@ def __init__(self, params: Params):
140140
"avg_original_content_words": None,
141141
"avg_updated_content_words": None,
142142
"max_reasoning_words": None,
143+
"max_original_content_words": None,
144+
"max_updated_content_words": None,
145+
"max_reasoning_tokens": None,
146+
"avg_reasoning_tokens": None,
147+
"max_updated_content_tokens": None,
148+
"avg_updated_content_tokens": None,
149+
"total_reasoning_words": 0,
150+
"total_original_content_words": 0,
151+
"total_updated_content_words": 0,
152+
"total_reasoning_tokens": 0,
153+
"total_updated_content_tokens": 0,
143154
}
144155

145156
# Initialize cache if enabled
@@ -223,6 +234,10 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
223234
reasoning_words = reasoning_info.get("reasoning_words", 0)
224235
original_words = reasoning_info.get("original_content_words", 0)
225236
updated_words = reasoning_info.get("updated_content_words", 0)
237+
reasoning_tokens = reasoning_info.get("reasoning_tokens", "unknown")
238+
updated_content_tokens = reasoning_info.get(
239+
"updated_content_tokens", "unknown"
240+
)
226241

227242
# Increment counters
228243
if reasoning_words > 0:
@@ -237,8 +252,10 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
237252
("avg_reasoning_words", reasoning_words),
238253
("avg_original_content_words", original_words),
239254
("avg_updated_content_words", updated_words),
255+
("avg_reasoning_tokens", reasoning_tokens),
256+
("avg_updated_content_tokens", updated_content_tokens),
240257
]:
241-
if value > 0:
258+
if value != "unknown":
242259
if self._reasoning_stats[stat_key] is None:
243260
self._reasoning_stats[stat_key] = value
244261
else:
@@ -253,12 +270,34 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
253270
)
254271

255272
# Update max reasoning words
256-
if reasoning_words > 0:
257-
if (
258-
self._reasoning_stats["max_reasoning_words"] is None
259-
or reasoning_words > self._reasoning_stats["max_reasoning_words"]
260-
):
261-
self._reasoning_stats["max_reasoning_words"] = reasoning_words
273+
for key in [
274+
"reasoning_words",
275+
"original_content_words",
276+
"updated_content_words",
277+
"reasoning_tokens",
278+
"updated_content_tokens",
279+
]:
280+
value = reasoning_info.get(key, None)
281+
if value is not None and value != "unknown":
282+
if (
283+
self._reasoning_stats[f"max_{key}"] is None
284+
or value > self._reasoning_stats[f"max_{key}"]
285+
):
286+
self._reasoning_stats[f"max_{key}"] = value
287+
288+
# Update total statistics
289+
if reasoning_words != "unknown":
290+
self._reasoning_stats["total_reasoning_words"] += reasoning_words
291+
if original_words != "unknown":
292+
self._reasoning_stats["total_original_content_words"] += original_words
293+
if updated_words != "unknown":
294+
self._reasoning_stats["total_updated_content_words"] += updated_words
295+
if reasoning_tokens != "unknown":
296+
self._reasoning_stats["total_reasoning_tokens"] += reasoning_tokens
297+
if updated_content_tokens != "unknown":
298+
self._reasoning_stats["total_updated_content_tokens"] += (
299+
updated_content_tokens
300+
)
262301

263302
# Log aggregated stats at specified interval
264303
if (
@@ -268,33 +307,64 @@ def _update_reasoning_stats(self, reasoning_info: dict) -> None:
268307
):
269308
self.logger.info(**self._reasoning_stats)
270309

271-
def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]:
310+
def _process_reasoning_message(
311+
self, msg: dict, usage: dict = None
312+
) -> tuple[dict, dict]:
272313
"""
273314
Process reasoning in the message and return modified message with reasoning info.
274315
275316
Args:
276317
msg: The message object containing content and potentially reasoning_content
318+
usage: Optional usage data from the response for token tracking
277319
278320
Returns:
279321
tuple: (modified_message, reasoning_info) where reasoning_info has keys:
280322
reasoning_words, original_content_words, updated_content_words, reasoning_finished, reasoning_started
281323
"""
282324
modified_msg = msg.copy()
283325
content = msg.get("content", "")
326+
updated_content_tokens = "unknown"
327+
reasoning_tokens = "unknown"
284328

285329
# Check if reasoning_content exists in the message and is not empty
286-
if (
287-
"reasoning_content" in msg
288-
and msg["reasoning_content"]
289-
and msg["reasoning_content"].strip()
290-
):
330+
if "reasoning_content" in msg:
291331
reasoning_content = msg["reasoning_content"]
292332
updated_message_content = content
293-
reasoning_started = True
333+
reasoning_started = (
334+
True
335+
if reasoning_content is not None and reasoning_content.strip() != ""
336+
else False
337+
)
294338
if content.strip() == "":
295339
reasoning_finished = False
296340
else:
297341
reasoning_finished = True
342+
if usage:
343+
# First try to get reasoning_tokens directly from usage
344+
reasoning_tokens = usage.get("reasoning_tokens", "unknown")
345+
updated_content_tokens = usage.get("content_tokens", "unknown")
346+
347+
# If not found, check in completion_tokens_details and output_tokens_details
348+
if reasoning_tokens == "unknown":
349+
for key in ["completion_tokens_details", "output_tokens_details"]:
350+
if key in usage:
351+
details = usage[key]
352+
if isinstance(details, dict):
353+
reasoning_tokens = details.get(
354+
"reasoning_tokens", "unknown"
355+
)
356+
if reasoning_tokens != "unknown":
357+
self.logger.debug(
358+
f"Found reasoning_tokens in {key}: {reasoning_tokens}"
359+
)
360+
break
361+
362+
# Log if reasoning tokens were found
363+
if reasoning_tokens != "unknown":
364+
self.logger.debug(f"Reasoning tokens extracted: {reasoning_tokens}")
365+
else:
366+
self.logger.debug("No reasoning tokens found in usage data")
367+
298368
else:
299369
reasoning_finished = False
300370
if self.start_reasoning_token is not None:
@@ -343,6 +413,8 @@ def _process_reasoning_message(self, msg: dict) -> tuple[dict, dict]:
343413
),
344414
"reasoning_finished": reasoning_finished,
345415
"reasoning_started": reasoning_started,
416+
"reasoning_tokens": reasoning_tokens,
417+
"updated_content_tokens": updated_content_tokens,
346418
}
347419

348420
return modified_msg, reasoning_info
@@ -363,6 +435,7 @@ def _strip_reasoning(self, text: str) -> str:
363435
return cleaned_content
364436

365437
def _migrate_reasoning_content(self, msg: dict):
438+
"""Migrate reasoning content to the content field with reasoning tokens."""
366439
modified_msg = msg.copy()
367440
if (
368441
"reasoning_content" in msg
@@ -404,6 +477,9 @@ def intercept_response(
404477
"Reasoning processing disabled, returning response as-is"
405478
)
406479
return resp
480+
if resp.rctx.cache_hit:
481+
self.logger.debug("Response was from cache, skipping reasoning processing")
482+
return resp
407483

408484
try:
409485
response_data = resp.r.json()
@@ -414,6 +490,9 @@ def intercept_response(
414490
choices_count=len(response_data["choices"]),
415491
)
416492

493+
# Extract usage data from response
494+
usage_data = response_data.get("usage", {})
495+
417496
for choice in response_data["choices"]:
418497
msg = choice.get("message")
419498
if (
@@ -423,7 +502,7 @@ def intercept_response(
423502
):
424503
# Get modified message and reasoning information
425504
modified_msg, reasoning_info = self._process_reasoning_message(
426-
msg
505+
msg, usage_data
427506
)
428507

429508
# Collect reasoning statistics

packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/response_stats_interceptor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,11 @@ def intercept_response(
412412
) -> AdapterResponse:
413413
"""Collect aggregated statistics from the response."""
414414
# Get status code once and reuse it
415+
if resp.rctx.cache_hit:
416+
self.logger.debug(
417+
"Response was from cache, skipping response stats collection"
418+
)
419+
return resp
415420
status_code = resp.r.status_code
416421

417422
# Update time tracking with current timestamp

packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,13 @@ def test_save_stats_to_file_creates_directory(tmp_path, nested_path):
904904
"responses_with_reasoning": 85,
905905
"avg_reasoning_words": 12.5,
906906
},
907+
{
908+
"total_responses": 5,
909+
"responses_with_reasoning": 3,
910+
"total_reasoning_words": 25,
911+
"total_original_content_words": 50,
912+
"total_updated_content_words": 30,
913+
},
907914
],
908915
)
909916
def test_reasoning_stats_access(test_stats):
@@ -928,6 +935,134 @@ def test_reasoning_stats_access(test_stats):
928935
assert stats is not interceptor._reasoning_stats
929936

930937

938+
@pytest.mark.parametrize(
939+
"cache_hit,expected_total_responses,expected_responses_with_reasoning",
940+
[
941+
# Cached response - should NOT be counted (skipped by interceptor)
942+
(True, 0, 0),
943+
# Normal response (not cached)
944+
(False, 1, 1),
945+
],
946+
)
947+
def test_cached_response_reasoning_behavior(
948+
tmp_path,
949+
cache_hit,
950+
expected_total_responses,
951+
expected_responses_with_reasoning,
952+
):
953+
"""Test that cached responses are properly skipped in reasoning stats counting."""
954+
interceptor = ResponseReasoningInterceptor(
955+
ResponseReasoningInterceptor.Params(
956+
enable_reasoning_tracking=True,
957+
enable_caching=False, # Disable caching to avoid state pollution
958+
)
959+
)
960+
961+
# Create mock response with reasoning content
962+
response_data = {
963+
"choices": [
964+
{
965+
"message": {
966+
"role": "assistant",
967+
"content": "<think>This is reasoning content.</think>Final answer.",
968+
"reasoning_content": "This is reasoning content.",
969+
}
970+
}
971+
],
972+
"usage": {"reasoning_tokens": 10, "content_tokens": 20},
973+
}
974+
975+
mock_response = Mock()
976+
mock_response.status_code = 200
977+
mock_response.json.return_value = response_data
978+
mock_response._content = json.dumps(response_data).encode()
979+
980+
mock_rctx = Mock()
981+
mock_rctx.cache_hit = cache_hit
982+
983+
adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx)
984+
context = AdapterGlobalContext(output_dir=str(tmp_path), url="http://localhost")
985+
986+
interceptor.intercept_response(adapter_response, context)
987+
988+
stats = interceptor._reasoning_stats
989+
assert stats["total_responses"] == expected_total_responses
990+
assert stats["responses_with_reasoning"] == expected_responses_with_reasoning
991+
992+
993+
@pytest.mark.parametrize(
994+
"usage_format,expected_reasoning_tokens,expected_content_tokens",
995+
[
996+
# Format 1: reasoning_tokens and content_tokens at top level
997+
({"reasoning_tokens": 15, "content_tokens": 30}, 15, 30),
998+
# Format 2: reasoning_tokens in completion_tokens_details
999+
(
1000+
{
1001+
"completion_tokens_details": {"reasoning_tokens": 20},
1002+
"content_tokens": 40,
1003+
},
1004+
20,
1005+
40,
1006+
),
1007+
# Format 3: reasoning_tokens in output_tokens_details
1008+
(
1009+
{
1010+
"output_tokens_details": {"reasoning_tokens": 25},
1011+
"content_tokens": 50,
1012+
},
1013+
25,
1014+
50,
1015+
),
1016+
],
1017+
)
1018+
def test_reasoning_tokens_different_formats(
1019+
tmp_path,
1020+
usage_format,
1021+
expected_reasoning_tokens,
1022+
expected_content_tokens,
1023+
):
1024+
"""Test reasoning interceptor handles different usage data formats for reasoning tokens."""
1025+
interceptor = ResponseReasoningInterceptor(
1026+
ResponseReasoningInterceptor.Params(
1027+
enable_reasoning_tracking=True,
1028+
enable_caching=False,
1029+
)
1030+
)
1031+
1032+
# Create mock response with reasoning content
1033+
response_data = {
1034+
"choices": [
1035+
{
1036+
"message": {
1037+
"role": "assistant",
1038+
"content": "<think>This is reasoning content.</think>Final answer.",
1039+
"reasoning_content": "This is reasoning content.",
1040+
}
1041+
}
1042+
],
1043+
"usage": usage_format,
1044+
}
1045+
1046+
mock_response = Mock()
1047+
mock_response.status_code = 200
1048+
mock_response.json.return_value = response_data
1049+
mock_response._content = json.dumps(response_data).encode()
1050+
1051+
mock_rctx = Mock()
1052+
mock_rctx.cache_hit = False
1053+
1054+
adapter_response = AdapterResponse(r=mock_response, rctx=mock_rctx)
1055+
context = AdapterGlobalContext(output_dir=str(tmp_path), url="http://localhost")
1056+
1057+
interceptor.intercept_response(adapter_response, context)
1058+
1059+
stats = interceptor._reasoning_stats
1060+
assert stats["total_responses"] == 1
1061+
assert stats["responses_with_reasoning"] == 1
1062+
assert stats["total_reasoning_tokens"] == expected_reasoning_tokens
1063+
assert stats["total_updated_content_tokens"] == expected_content_tokens
1064+
1065+
9311066
def test_load_from_cache_during_initialization(tmp_path):
9321067
"""Test that cached stats are automatically loaded during initialization."""
9331068
# Given: Create cache directory and add some cached stats

0 commit comments

Comments
 (0)