diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 9d835d28e..fd395ae77 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -11,7 +11,7 @@ from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional from typing_extensions import TypedDict, cast @@ -54,7 +54,7 @@ def process_tools(self, tools: List[Any]) -> List[str]: """ tool_names = [] - for tool in tools: + def add_tool(tool: Any) -> None: # Case 1: String file path if isinstance(tool, str): # Extract tool name from path @@ -97,9 +97,16 @@ def process_tools(self, tools: List[Any]) -> List[str]: elif isinstance(tool, AgentTool): self.register_tool(tool) tool_names.append(tool.tool_name) + # Case 6: Nested iterable (list, tuple, etc.) - add each sub-tool + elif isinstance(tool, Iterable) and not isinstance(tool, (str, bytes, bytearray)): + for t in tool: + add_tool(t) else: logger.warning("tool=<%s> | unrecognized tool specification", tool) + for a_tool in tools: + add_tool(a_tool) + return tool_names def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d6471a09a..4e310dace 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -231,6 +231,25 @@ def test_agent__init__with_string_model_id(): assert agent.model.config["model_id"] == "nonsense" +def test_agent__init__nested_tools_flattening(tool_decorated, tool_module, tool_imported, tool_registry): + _ = tool_registry + # Nested structure: [tool_decorated, [tool_module, [tool_imported]]] + agent = Agent(tools=[tool_decorated, [tool_module, [tool_imported]]]) + tru_tool_names = sorted(agent.tool_names) + exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"] + assert tru_tool_names == exp_tool_names + + +def test_agent__init__deeply_nested_tools(tool_decorated, tool_module, tool_imported, tool_registry): + _ = tool_registry + # Deeply nested structure + nested_tools = [[[[tool_decorated]], [[tool_module]], tool_imported]] + agent = Agent(tools=nested_tools) + tru_tool_names = sorted(agent.tool_names) + exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"] + assert tru_tool_names == exp_tool_names + + def test_agent__call__( mock_model, system_prompt, diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index ebcba3fb1..66494c987 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -93,3 +93,30 @@ def tool_function_4(d): assert len(tools) == 2 assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools) + + +def test_process_tools_flattens_lists_and_tuples_and_sets(): + def function() -> str: + return "done" + + tool_a = tool(name="tool_a")(function) + tool_b = tool(name="tool_b")(function) + tool_c = tool(name="tool_c")(function) + tool_d = tool(name="tool_d")(function) + tool_e = tool(name="tool_e")(function) + tool_f = tool(name="tool_f")(function) + + registry = ToolRegistry() + + all_tools = [tool_a, (tool_b, tool_c), [{tool_d, tool_e}, [tool_f]]] + + tru_tool_names = sorted(registry.process_tools(all_tools)) + exp_tool_names = [ + "tool_a", + "tool_b", + "tool_c", + "tool_d", + "tool_e", + "tool_f", + ] + assert tru_tool_names == exp_tool_names