1616from ..types .content import ContentBlock , Messages
1717from ..types .exceptions import ModelThrottledException
1818from ..types .models import Model
19- from ..types .streaming import StreamEvent
19+ from ..types .streaming import StreamEvent , Usage
2020from ..types .tools import ToolResult , ToolSpec , ToolUse
2121
2222logger = logging .getLogger (__name__ )
@@ -38,11 +38,11 @@ class LlamaConfig(TypedDict, total=False):
3838 """
3939
4040 model_id : str
41- repetition_penalty : float | None = None
42- temperature : float | None = None
43- top_p : float | None = None
44- max_completion_tokens : int | None = None
45- top_k : int | None = None
41+ repetition_penalty : Optional [ float ]
42+ temperature : Optional [ float ]
43+ top_p : Optional [ float ]
44+ max_completion_tokens : Optional [ int ]
45+ top_k : Optional [ int ]
4646
4747 def __init__ (
4848 self ,
@@ -169,12 +169,15 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s
169169 if "toolResult" in content
170170 ]
171171
172+ new_formatted_contents : list [dict [str , Any ]] | dict [str , Any ] | str = ""
172173 if message ["role" ] == "assistant" :
173- formatted_contents = formatted_contents [0 ] if len (formatted_contents ) > 0 else ""
174+ new_formatted_contents = formatted_contents [0 ] if formatted_contents else ""
175+ else :
176+ new_formatted_contents = formatted_contents
174177
175178 formatted_message = {
176179 "role" : message ["role" ],
177- "content" : formatted_contents if len (formatted_contents ) > 0 else "" ,
180+ "content" : new_formatted_contents if len (new_formatted_contents ) > 0 else "" ,
178181 ** ({"tool_calls" : formatted_tool_calls } if formatted_tool_calls else {}),
179182 }
180183 formatted_messages .append (formatted_message )
@@ -282,9 +285,14 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
282285 elif metrics .metric == "num_total_tokens" :
283286 usage ["totalTokens" ] = metrics .value
284287
288+ usage_type = Usage (
289+ inputTokens = usage ["inputTokens" ],
290+ outputTokens = usage ["outputTokens" ],
291+ totalTokens = usage ["totalTokens" ],
292+ )
285293 return {
286294 "metadata" : {
287- "usage" : usage ,
295+ "usage" : usage_type ,
288296 "metrics" : {
289297 "latencyMs" : 0 , # TODO
290298 },
@@ -315,7 +323,7 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
315323 yield {"chunk_type" : "message_start" }
316324
317325 stop_reason = None
318- tool_calls : dict [int , list [Any ]] = {}
326+ tool_calls : dict [Any , list [Any ]] = {}
319327 curr_tool_call_id = None
320328
321329 metrics_event = None
@@ -328,7 +336,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
328336 if chunk .event .delta .type == "tool_call" :
329337 if chunk .event .delta .id :
330338 curr_tool_call_id = chunk .event .delta .id
331- tool_calls .setdefault (curr_tool_call_id , []).append (chunk .event .delta )
339+
340+ if curr_tool_call_id not in tool_calls :
341+ tool_calls [curr_tool_call_id ] = []
342+ tool_calls [curr_tool_call_id ].append (chunk .event .delta )
332343 elif chunk .event .event_type == "metrics" :
333344 metrics_event = chunk .event .metrics
334345 else :
0 commit comments