Skip to content

Commit 10a46c2

Browse files
Xi YanXi Yan
authored andcommitted
fix lint
1 parent 8752eda commit 10a46c2

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

src/strands/models/llamaapi.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..types.content import ContentBlock, Messages
1717
from ..types.exceptions import ModelThrottledException
1818
from ..types.models import Model
19-
from ..types.streaming import StreamEvent
19+
from ..types.streaming import StreamEvent, Usage
2020
from ..types.tools import ToolResult, ToolSpec, ToolUse
2121

2222
logger = logging.getLogger(__name__)
@@ -38,11 +38,11 @@ class LlamaConfig(TypedDict, total=False):
3838
"""
3939

4040
model_id: str
41-
repetition_penalty: float | None = None
42-
temperature: float | None = None
43-
top_p: float | None = None
44-
max_completion_tokens: int | None = None
45-
top_k: int | None = None
41+
repetition_penalty: Optional[float]
42+
temperature: Optional[float]
43+
top_p: Optional[float]
44+
max_completion_tokens: Optional[int]
45+
top_k: Optional[int]
4646

4747
def __init__(
4848
self,
@@ -169,12 +169,15 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s
169169
if "toolResult" in content
170170
]
171171

172+
new_formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = ""
172173
if message["role"] == "assistant":
173-
formatted_contents = formatted_contents[0] if len(formatted_contents) > 0 else ""
174+
new_formatted_contents = formatted_contents[0] if formatted_contents else ""
175+
else:
176+
new_formatted_contents = formatted_contents
174177

175178
formatted_message = {
176179
"role": message["role"],
177-
"content": formatted_contents if len(formatted_contents) > 0 else "",
180+
"content": new_formatted_contents if len(new_formatted_contents) > 0 else "",
178181
**({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
179182
}
180183
formatted_messages.append(formatted_message)
@@ -282,9 +285,14 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
282285
elif metrics.metric == "num_total_tokens":
283286
usage["totalTokens"] = metrics.value
284287

288+
usage_type = Usage(
289+
inputTokens=usage["inputTokens"],
290+
outputTokens=usage["outputTokens"],
291+
totalTokens=usage["totalTokens"],
292+
)
285293
return {
286294
"metadata": {
287-
"usage": usage,
295+
"usage": usage_type,
288296
"metrics": {
289297
"latencyMs": 0, # TODO
290298
},
@@ -315,7 +323,7 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
315323
yield {"chunk_type": "message_start"}
316324

317325
stop_reason = None
318-
tool_calls: dict[int, list[Any]] = {}
326+
tool_calls: dict[Any, list[Any]] = {}
319327
curr_tool_call_id = None
320328

321329
metrics_event = None
@@ -328,7 +336,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
328336
if chunk.event.delta.type == "tool_call":
329337
if chunk.event.delta.id:
330338
curr_tool_call_id = chunk.event.delta.id
331-
tool_calls.setdefault(curr_tool_call_id, []).append(chunk.event.delta)
339+
340+
if curr_tool_call_id not in tool_calls:
341+
tool_calls[curr_tool_call_id] = []
342+
tool_calls[curr_tool_call_id].append(chunk.event.delta)
332343
elif chunk.event.event_type == "metrics":
333344
metrics_event = chunk.event.metrics
334345
else:

0 commit comments

Comments
 (0)