Skip to content

Commit b11c2f2

Browse files
authored
fix(ai): correct size calculation, rename internal property for message truncation & add test (#4949)
1 parent 814cd5a commit b11c2f2

File tree

4 files changed

+95
-43
lines changed

4 files changed

+95
-43
lines changed

sentry_sdk/ai/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ def truncate_and_annotate_messages(
139139

140140
truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes)
141141
if removed_count > 0:
142-
scope._gen_ai_messages_truncated[span.span_id] = len(messages) - len(
143-
truncated_messages
144-
)
142+
scope._gen_ai_original_message_count[span.span_id] = len(messages)
145143

146144
return truncated_messages

sentry_sdk/client.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -598,23 +598,20 @@ def _prepare_event(
598598
if event_scrubber:
599599
event_scrubber.scrub_event(event)
600600

601-
if scope is not None and scope._gen_ai_messages_truncated:
601+
if scope is not None and scope._gen_ai_original_message_count:
602602
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue
603603
if isinstance(spans, list):
604604
for span in spans:
605605
span_id = span.get("span_id", None)
606606
span_data = span.get("data", {})
607607
if (
608608
span_id
609-
and span_id in scope._gen_ai_messages_truncated
609+
and span_id in scope._gen_ai_original_message_count
610610
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
611611
):
612612
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
613613
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
614-
{
615-
"len": scope._gen_ai_messages_truncated[span_id]
616-
+ len(span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES])
617-
},
614+
{"len": scope._gen_ai_original_message_count[span_id]},
618615
)
619616
if previous_total_spans is not None:
620617
event["spans"] = AnnotatedValue(

sentry_sdk/scope.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class Scope:
188188
"_extras",
189189
"_breadcrumbs",
190190
"_n_breadcrumbs_truncated",
191-
"_gen_ai_messages_truncated",
191+
"_gen_ai_original_message_count",
192192
"_event_processors",
193193
"_error_processors",
194194
"_should_capture",
@@ -214,7 +214,7 @@ def __init__(self, ty=None, client=None):
214214
self._name = None # type: Optional[str]
215215
self._propagation_context = None # type: Optional[PropagationContext]
216216
self._n_breadcrumbs_truncated = 0 # type: int
217-
self._gen_ai_messages_truncated = {} # type: Dict[str, int]
217+
self._gen_ai_original_message_count = {} # type: Dict[str, int]
218218

219219
self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient
220220

@@ -249,7 +249,7 @@ def __copy__(self):
249249

250250
rv._breadcrumbs = copy(self._breadcrumbs)
251251
rv._n_breadcrumbs_truncated = self._n_breadcrumbs_truncated
252-
rv._gen_ai_messages_truncated = self._gen_ai_messages_truncated.copy()
252+
rv._gen_ai_original_message_count = self._gen_ai_original_message_count.copy()
253253
rv._event_processors = self._event_processors.copy()
254254
rv._error_processors = self._error_processors.copy()
255255
rv._propagation_context = self._propagation_context
@@ -1586,8 +1586,10 @@ def update_from_scope(self, scope):
15861586
self._n_breadcrumbs_truncated = (
15871587
self._n_breadcrumbs_truncated + scope._n_breadcrumbs_truncated
15881588
)
1589-
if scope._gen_ai_messages_truncated:
1590-
self._gen_ai_messages_truncated.update(scope._gen_ai_messages_truncated)
1589+
if scope._gen_ai_original_message_count:
1590+
self._gen_ai_original_message_count.update(
1591+
scope._gen_ai_original_message_count
1592+
)
15911593
if scope._span:
15921594
self._span = scope._span
15931595
if scope._attachments:

tests/test_ai_monitoring.py

Lines changed: 84 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import uuid
23

34
import pytest
45

@@ -210,32 +211,32 @@ def large_messages():
210211
class TestTruncateMessagesBySize:
211212
def test_no_truncation_needed(self, sample_messages):
212213
"""Test that messages under the limit are not truncated"""
213-
result, removed_count = truncate_messages_by_size(
214+
result, truncation_index = truncate_messages_by_size(
214215
sample_messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES
215216
)
216217
assert len(result) == len(sample_messages)
217218
assert result == sample_messages
218-
assert removed_count == 0
219+
assert truncation_index == 0
219220

220221
def test_truncation_removes_oldest_first(self, large_messages):
221222
"""Test that oldest messages are removed first during truncation"""
222223
small_limit = 3000
223-
result, removed_count = truncate_messages_by_size(
224+
result, truncation_index = truncate_messages_by_size(
224225
large_messages, max_bytes=small_limit
225226
)
226227
assert len(result) < len(large_messages)
227228

228229
if result:
229230
assert result[-1] == large_messages[-1]
230-
assert removed_count == len(large_messages) - len(result)
231+
assert truncation_index == len(large_messages) - len(result)
231232

232233
def test_empty_messages_list(self):
233234
"""Test handling of empty messages list"""
234-
result, removed_count = truncate_messages_by_size(
235+
result, truncation_index = truncate_messages_by_size(
235236
[], max_bytes=MAX_GEN_AI_MESSAGE_BYTES // 500
236237
)
237238
assert result == []
238-
assert removed_count == 0
239+
assert truncation_index == 0
239240

240241
def test_find_truncation_index(
241242
self,
@@ -290,7 +291,7 @@ def set_data(self, key, value):
290291

291292
class MockScope:
292293
def __init__(self):
293-
self._gen_ai_messages_truncated = {}
294+
self._gen_ai_original_message_count = {}
294295

295296
span = MockSpan()
296297
scope = MockScope()
@@ -300,7 +301,7 @@ def __init__(self):
300301
assert not isinstance(result, AnnotatedValue)
301302
assert len(result) == len(sample_messages)
302303
assert result == sample_messages
303-
assert span.span_id not in scope._gen_ai_messages_truncated
304+
assert span.span_id not in scope._gen_ai_original_message_count
304305

305306
def test_truncation_sets_metadata_on_scope(self, large_messages):
306307
class MockSpan:
@@ -313,9 +314,9 @@ def set_data(self, key, value):
313314

314315
class MockScope:
315316
def __init__(self):
316-
self._gen_ai_messages_truncated = {}
317+
self._gen_ai_original_message_count = {}
317318

318-
small_limit = 1000
319+
small_limit = 3000
319320
span = MockSpan()
320321
scope = MockScope()
321322
original_count = len(large_messages)
@@ -326,10 +327,9 @@ def __init__(self):
326327
assert isinstance(result, list)
327328
assert not isinstance(result, AnnotatedValue)
328329
assert len(result) < len(large_messages)
329-
n_removed = original_count - len(result)
330-
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
330+
assert scope._gen_ai_original_message_count[span.span_id] == original_count
331331

332-
def test_scope_tracks_removed_messages(self, large_messages):
332+
def test_scope_tracks_original_message_count(self, large_messages):
333333
class MockSpan:
334334
def __init__(self):
335335
self.span_id = "test_span_id"
@@ -340,9 +340,9 @@ def set_data(self, key, value):
340340

341341
class MockScope:
342342
def __init__(self):
343-
self._gen_ai_messages_truncated = {}
343+
self._gen_ai_original_message_count = {}
344344

345-
small_limit = 1000
345+
small_limit = 3000
346346
original_count = len(large_messages)
347347
span = MockSpan()
348348
scope = MockScope()
@@ -351,9 +351,8 @@ def __init__(self):
351351
large_messages, span, scope, max_bytes=small_limit
352352
)
353353

354-
n_removed = original_count - len(result)
355-
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
356-
assert len(result) + n_removed == original_count
354+
assert scope._gen_ai_original_message_count[span.span_id] == original_count
355+
assert len(result) == 1
357356

358357
def test_empty_messages_returns_none(self):
359358
class MockSpan:
@@ -366,7 +365,7 @@ def set_data(self, key, value):
366365

367366
class MockScope:
368367
def __init__(self):
369-
self._gen_ai_messages_truncated = {}
368+
self._gen_ai_original_message_count = {}
370369

371370
span = MockSpan()
372371
scope = MockScope()
@@ -387,7 +386,7 @@ def set_data(self, key, value):
387386

388387
class MockScope:
389388
def __init__(self):
390-
self._gen_ai_messages_truncated = {}
389+
self._gen_ai_original_message_count = {}
391390

392391
small_limit = 3000
393392
span = MockSpan()
@@ -416,7 +415,7 @@ def set_data(self, key, value):
416415

417416
class MockScope:
418417
def __init__(self):
419-
self._gen_ai_messages_truncated = {}
418+
self._gen_ai_original_message_count = {}
420419

421420
small_limit = 3000
422421
span = MockSpan()
@@ -430,33 +429,89 @@ def __init__(self):
430429
span.set_data(SPANDATA.GEN_AI_REQUEST_MESSAGES, truncated_messages)
431430

432431
# Verify metadata was set on scope
433-
assert span.span_id in scope._gen_ai_messages_truncated
434-
assert scope._gen_ai_messages_truncated[span.span_id] > 0
432+
assert span.span_id in scope._gen_ai_original_message_count
433+
assert scope._gen_ai_original_message_count[span.span_id] > 0
435434

436435
# Simulate what client.py does
437436
event = {"spans": [{"span_id": span.span_id, "data": span.data.copy()}]}
438437

439-
# Mimic client.py logic - using scope to get the removed count
438+
# Mimic client.py logic - using scope to get the original length
440439
for event_span in event["spans"]:
441440
span_id = event_span.get("span_id")
442441
span_data = event_span.get("data", {})
443442
if (
444443
span_id
445-
and span_id in scope._gen_ai_messages_truncated
444+
and span_id in scope._gen_ai_original_message_count
446445
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
447446
):
448447
messages = span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES]
449-
n_removed = scope._gen_ai_messages_truncated[span_id]
450-
n_remaining = len(messages) if isinstance(messages, list) else 0
451-
original_count_calculated = n_removed + n_remaining
448+
n_original_count = scope._gen_ai_original_message_count[span_id]
452449

453450
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
454451
safe_serialize(messages),
455-
{"len": original_count_calculated},
452+
{"len": n_original_count},
456453
)
457454

458455
# Verify the annotation happened
459456
messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
460457
assert isinstance(messages_value, AnnotatedValue)
461458
assert messages_value.metadata["len"] == original_count
462459
assert isinstance(messages_value.value, str)
460+
461+
def test_annotated_value_shows_correct_original_length(self, large_messages):
462+
"""Test that the annotated value correctly shows the original message count before truncation"""
463+
from sentry_sdk.consts import SPANDATA
464+
465+
class MockSpan:
466+
def __init__(self):
467+
self.span_id = "test_span_456"
468+
self.data = {}
469+
470+
def set_data(self, key, value):
471+
self.data[key] = value
472+
473+
class MockScope:
474+
def __init__(self):
475+
self._gen_ai_original_message_count = {}
476+
477+
small_limit = 3000
478+
span = MockSpan()
479+
scope = MockScope()
480+
original_message_count = len(large_messages)
481+
482+
truncated_messages = truncate_and_annotate_messages(
483+
large_messages, span, scope, max_bytes=small_limit
484+
)
485+
486+
assert len(truncated_messages) < original_message_count
487+
488+
assert span.span_id in scope._gen_ai_original_message_count
489+
stored_original_length = scope._gen_ai_original_message_count[span.span_id]
490+
assert stored_original_length == original_message_count
491+
492+
event = {
493+
"spans": [
494+
{
495+
"span_id": span.span_id,
496+
"data": {SPANDATA.GEN_AI_REQUEST_MESSAGES: truncated_messages},
497+
}
498+
]
499+
}
500+
501+
for event_span in event["spans"]:
502+
span_id = event_span.get("span_id")
503+
span_data = event_span.get("data", {})
504+
if (
505+
span_id
506+
and span_id in scope._gen_ai_original_message_count
507+
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
508+
):
509+
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
510+
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
511+
{"len": scope._gen_ai_original_message_count[span_id]},
512+
)
513+
514+
messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
515+
assert isinstance(messages_value, AnnotatedValue)
516+
assert messages_value.metadata["len"] == stored_original_length
517+
assert len(messages_value.value) == len(truncated_messages)

0 commit comments

Comments
 (0)