@@ -70,48 +70,73 @@ def _capture_exception(exc):
7070 sentry_sdk .capture_event (event , hint = hint )
7171
7272
73- def _calculate_chat_completion_usage (
73+ def _get_usage (usage , names ):
74+ # type: (Any, List[str]) -> int
75+ for name in names :
76+ if hasattr (usage , name ) and isinstance (getattr (usage , name ), int ):
77+ return getattr (usage , name )
78+ return 0
79+
80+
81+ def _calculate_token_usage (
7482 messages , response , span , streaming_message_responses , count_tokens
7583):
7684 # type: (Iterable[ChatCompletionMessageParam], Any, Span, Optional[List[str]], Callable[..., Any]) -> None
77- completion_tokens = 0 # type: Optional[int]
78- prompt_tokens = 0 # type: Optional[int]
85+ input_tokens = 0 # type: Optional[int]
86+ input_tokens_cached = 0 # type: Optional[int]
87+ output_tokens = 0 # type: Optional[int]
88+ output_tokens_reasoning = 0 # type: Optional[int]
7989 total_tokens = 0 # type: Optional[int]
90+
8091 if hasattr (response , "usage" ):
81- if hasattr (response .usage , "completion_tokens" ) and isinstance (
82- response .usage .completion_tokens , int
83- ):
84- completion_tokens = response .usage .completion_tokens
85- if hasattr (response .usage , "prompt_tokens" ) and isinstance (
86- response .usage .prompt_tokens , int
87- ):
88- prompt_tokens = response .usage .prompt_tokens
89- if hasattr (response .usage , "total_tokens" ) and isinstance (
90- response .usage .total_tokens , int
91- ):
92- total_tokens = response .usage .total_tokens
92+ input_tokens = _get_usage (response .usage , ["input_tokens" , "prompt_tokens" ])
93+ if hasattr (response .usage , "input_tokens_details" ):
94+ input_tokens_cached = _get_usage (
95+ response .usage .input_tokens_details , ["cached_tokens" ]
96+ )
9397
94- if prompt_tokens == 0 :
98+ output_tokens = _get_usage (
99+ response .usage , ["output_tokens" , "completion_tokens" ]
100+ )
101+ if hasattr (response .usage , "output_tokens_details" ):
102+ output_tokens_reasoning = _get_usage (
103+ response .usage .output_tokens_details , ["reasoning_tokens" ]
104+ )
105+
106+ total_tokens = _get_usage (response .usage , ["total_tokens" ])
107+
108+ # Manually count tokens
109+ # TODO: when implementing responses API, check for responses API
110+ if input_tokens == 0 :
95111 for message in messages :
96112 if "content" in message :
97- prompt_tokens += count_tokens (message ["content" ])
113+ input_tokens += count_tokens (message ["content" ])
98114
99- if completion_tokens == 0 :
115+ # TODO: when implementing responses API, check for responses API
116+ if output_tokens == 0 :
100117 if streaming_message_responses is not None :
101118 for message in streaming_message_responses :
102- completion_tokens += count_tokens (message )
119+ output_tokens += count_tokens (message )
103120 elif hasattr (response , "choices" ):
104121 for choice in response .choices :
105122 if hasattr (choice , "message" ):
106- completion_tokens += count_tokens (choice .message )
107-
108- if prompt_tokens == 0 :
109- prompt_tokens = None
110- if completion_tokens == 0 :
111- completion_tokens = None
112- if total_tokens == 0 :
113- total_tokens = None
114- record_token_usage (span , prompt_tokens , completion_tokens , total_tokens )
123+ output_tokens += count_tokens (choice .message )
124+
125+ # Do not set token data if it is 0
126+ input_tokens = input_tokens or None
127+ input_tokens_cached = input_tokens_cached or None
128+ output_tokens = output_tokens or None
129+ output_tokens_reasoning = output_tokens_reasoning or None
130+ total_tokens = total_tokens or None
131+
132+ record_token_usage (
133+ span ,
134+ input_tokens = input_tokens ,
135+ input_tokens_cached = input_tokens_cached ,
136+ output_tokens = output_tokens ,
137+ output_tokens_reasoning = output_tokens_reasoning ,
138+ total_tokens = total_tokens ,
139+ )
115140
116141
117142def _new_chat_completion_common (f , * args , ** kwargs ):
@@ -158,9 +183,7 @@ def _new_chat_completion_common(f, *args, **kwargs):
158183 SPANDATA .AI_RESPONSES ,
159184 list (map (lambda x : x .message , res .choices )),
160185 )
161- _calculate_chat_completion_usage (
162- messages , res , span , None , integration .count_tokens
163- )
186+ _calculate_token_usage (messages , res , span , None , integration .count_tokens )
164187 span .__exit__ (None , None , None )
165188 elif hasattr (res , "_iterator" ):
166189 data_buf : list [list [str ]] = [] # one for each choice
@@ -191,7 +214,7 @@ def new_iterator():
191214 set_data_normalized (
192215 span , SPANDATA .AI_RESPONSES , all_responses
193216 )
194- _calculate_chat_completion_usage (
217+ _calculate_token_usage (
195218 messages ,
196219 res ,
197220 span ,
@@ -224,7 +247,7 @@ async def new_iterator_async():
224247 set_data_normalized (
225248 span , SPANDATA .AI_RESPONSES , all_responses
226249 )
227- _calculate_chat_completion_usage (
250+ _calculate_token_usage (
228251 messages ,
229252 res ,
230253 span ,
@@ -341,22 +364,26 @@ def _new_embeddings_create_common(f, *args, **kwargs):
341364
342365 response = yield f , args , kwargs
343366
344- prompt_tokens = 0
367+ input_tokens = 0
345368 total_tokens = 0
346369 if hasattr (response , "usage" ):
347370 if hasattr (response .usage , "prompt_tokens" ) and isinstance (
348371 response .usage .prompt_tokens , int
349372 ):
350- prompt_tokens = response .usage .prompt_tokens
373+ input_tokens = response .usage .prompt_tokens
351374 if hasattr (response .usage , "total_tokens" ) and isinstance (
352375 response .usage .total_tokens , int
353376 ):
354377 total_tokens = response .usage .total_tokens
355378
356- if prompt_tokens == 0 :
357- prompt_tokens = integration .count_tokens (kwargs ["input" ] or "" )
379+ if input_tokens == 0 :
380+ input_tokens = integration .count_tokens (kwargs ["input" ] or "" )
358381
359- record_token_usage (span , prompt_tokens , None , total_tokens or prompt_tokens )
382+ record_token_usage (
383+ span ,
384+ input_tokens = input_tokens ,
385+ total_tokens = total_tokens or input_tokens ,
386+ )
360387
361388 return response
362389
0 commit comments