Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/strands/tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from importlib import import_module, util
from os.path import expanduser
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterable, List, Optional

from typing_extensions import TypedDict, cast

Expand Down Expand Up @@ -54,7 +54,7 @@ def process_tools(self, tools: List[Any]) -> List[str]:
"""
tool_names = []

for tool in tools:
def add_tool(tool: Any) -> None:
# Case 1: String file path
if isinstance(tool, str):
# Extract tool name from path
Expand Down Expand Up @@ -97,9 +97,16 @@ def process_tools(self, tools: List[Any]) -> List[str]:
elif isinstance(tool, AgentTool):
self.register_tool(tool)
tool_names.append(tool.tool_name)
# Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool
elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)):
for t in tool:
add_tool(t)
else:
logger.warning("tool=<%s> | unrecognized tool specification", tool)

for a_tool in tools:
add_tool(a_tool)

return tool_names

def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None:
Expand Down
19 changes: 19 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,25 @@ def test_agent__init__with_string_model_id():
assert agent.model.config["model_id"] == "nonsense"


def test_agent__init__nested_tools_flattening(tool_decorated, tool_module, tool_imported, tool_registry):
_ = tool_registry
# Nested structure: [tool_decorated, [tool_module, [tool_imported]]]
agent = Agent(tools=[tool_decorated, [tool_module, [tool_imported]]])
tru_tool_names = sorted(agent.tool_names)
exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]
assert tru_tool_names == exp_tool_names


def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_imported, tool_registry):
_ = tool_registry
# Deeply nested structure
nested_tools = [[[[tool_decorated]], [[tool_module]], tool_imported]]
agent = Agent(tools=nested_tools)
tru_tool_names = sorted(agent.tool_names)
exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]
assert tru_tool_names == exp_tool_names


def test_agent__call__(
mock_model,
system_prompt,
Expand Down
27 changes: 27 additions & 0 deletions tests/strands/tools/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,30 @@ def tool_function_4(d):

assert len(tools) == 2
assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools)


def test_process_tools_flattens_lists_and_tuples_and_sets():
def function() -> str:
return "done"

tool_a = tool(name="tool_a")(function)
tool_b = tool(name="tool_b")(function)
tool_c = tool(name="tool_c")(function)
tool_d = tool(name="tool_d")(function)
tool_e = tool(name="tool_e")(function)
tool_f = tool(name="tool_f")(function)

registry = ToolRegistry()

all_tools = [tool_a, (tool_b, tool_c), [{tool_d, tool_e}, [tool_f]]]

tru_tool_names = sorted(registry.process_tools(all_tools))
exp_tool_names = [
"tool_a",
"tool_b",
"tool_c",
"tool_d",
"tool_e",
"tool_f",
]
assert tru_tool_names == exp_tool_names
Loading