Skip to content

Commit cd51a4e

Browse files
Ratish1ratishzastrowm
authored andcommitted
fix(tools/loader): load and register all decorated @tool functions from file path (strands-agents#742)
- Collect all DecoratedFunctionTool objects when loading a .py file and return list when multiple exist - Normalize loader results and register each AgentTool separately in registry - Add normalize_loaded_tools helper and test for multiple decorated tools --------- Co-authored-by: ratish <[email protected]> Co-authored-by: Mackenzie Zastrow <[email protected]>
1 parent b648c62 commit cd51a4e

File tree

4 files changed

+160
-58
lines changed

4 files changed

+160
-58
lines changed

src/strands/tools/loader.py

Lines changed: 73 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import logging
55
import os
66
import sys
7+
import warnings
78
from pathlib import Path
8-
from typing import cast
9+
from typing import List, cast
910

1011
from ..types.tools import AgentTool
1112
from .decorator import DecoratedFunctionTool
@@ -18,60 +19,42 @@ class ToolLoader:
1819
"""Handles loading of tools from different sources."""
1920

2021
@staticmethod
21-
def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
22-
"""Load a Python tool module.
23-
24-
Args:
25-
tool_path: Path to the Python tool file.
26-
tool_name: Name of the tool.
22+
def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]:
23+
"""Load a Python tool module and return all discovered function-based tools as a list.
2724
28-
Returns:
29-
Tool instance.
30-
31-
Raises:
32-
AttributeError: If required attributes are missing from the tool module.
33-
ImportError: If there are issues importing the tool module.
34-
TypeError: If the tool function is not callable.
35-
ValueError: If function in module is not a valid tool.
36-
Exception: For other errors during tool loading.
25+
This method always returns a list of AgentTool (possibly length 1). It is the
26+
canonical API for retrieving multiple tools from a single Python file.
3727
"""
3828
try:
39-
# Check if tool_path is in the format "package.module:function"; but keep in mind windows whose file path
40-
# could have a colon so also ensure that it's not a file
29+
# Support module:function style (e.g. package.module:function)
4130
if not os.path.exists(tool_path) and ":" in tool_path:
4231
module_path, function_name = tool_path.rsplit(":", 1)
4332
logger.debug("tool_name=<%s>, module_path=<%s> | importing tool from path", function_name, module_path)
4433

4534
try:
46-
# Import the module
4735
module = __import__(module_path, fromlist=["*"])
48-
49-
# Get the function
50-
if not hasattr(module, function_name):
51-
raise AttributeError(f"Module {module_path} has no function named {function_name}")
52-
53-
func = getattr(module, function_name)
54-
55-
if isinstance(func, DecoratedFunctionTool):
56-
logger.debug(
57-
"tool_name=<%s>, module_path=<%s> | found function-based tool", function_name, module_path
58-
)
59-
# mypy has problems converting between DecoratedFunctionTool <-> AgentTool
60-
return cast(AgentTool, func)
61-
else:
62-
raise ValueError(
63-
f"Function {function_name} in {module_path} is not a valid tool (missing @tool decorator)"
64-
)
65-
6636
except ImportError as e:
6737
raise ImportError(f"Failed to import module {module_path}: {str(e)}") from e
6838

39+
if not hasattr(module, function_name):
40+
raise AttributeError(f"Module {module_path} has no function named {function_name}")
41+
42+
func = getattr(module, function_name)
43+
if isinstance(func, DecoratedFunctionTool):
44+
logger.debug(
45+
"tool_name=<%s>, module_path=<%s> | found function-based tool", function_name, module_path
46+
)
47+
return [cast(AgentTool, func)]
48+
else:
49+
raise ValueError(
50+
f"Function {function_name} in {module_path} is not a valid tool (missing @tool decorator)"
51+
)
52+
6953
# Normal file-based tool loading
7054
abs_path = str(Path(tool_path).resolve())
71-
7255
logger.debug("tool_path=<%s> | loading python tool from path", abs_path)
7356

74-
# First load the module to get TOOL_SPEC and check for Lambda deployment
57+
# Load the module by spec
7558
spec = importlib.util.spec_from_file_location(tool_name, abs_path)
7659
if not spec:
7760
raise ImportError(f"Could not create spec for {tool_name}")
@@ -82,24 +65,26 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
8265
sys.modules[tool_name] = module
8366
spec.loader.exec_module(module)
8467

85-
# First, check for function-based tools with @tool decorator
68+
# Collect function-based tools decorated with @tool
69+
function_tools: List[AgentTool] = []
8670
for attr_name in dir(module):
8771
attr = getattr(module, attr_name)
8872
if isinstance(attr, DecoratedFunctionTool):
8973
logger.debug(
9074
"tool_name=<%s>, tool_path=<%s> | found function-based tool in path", attr_name, tool_path
9175
)
92-
# mypy has problems converting between DecoratedFunctionTool <-> AgentTool
93-
return cast(AgentTool, attr)
76+
function_tools.append(cast(AgentTool, attr))
77+
78+
if function_tools:
79+
return function_tools
9480

95-
# If no function-based tools found, fall back to traditional module-level tool
81+
# Fall back to module-level TOOL_SPEC + function
9682
tool_spec = getattr(module, "TOOL_SPEC", None)
9783
if not tool_spec:
9884
raise AttributeError(
9985
f"Tool {tool_name} missing TOOL_SPEC (neither at module level nor as a decorated function)"
10086
)
10187

102-
# Standard local tool loading
10388
tool_func_name = tool_name
10489
if not hasattr(module, tool_func_name):
10590
raise AttributeError(f"Tool {tool_name} missing function {tool_func_name}")
@@ -108,22 +93,61 @@ def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
10893
if not callable(tool_func):
10994
raise TypeError(f"Tool {tool_name} function is not callable")
11095

111-
return PythonAgentTool(tool_name, tool_spec, tool_func)
96+
return [PythonAgentTool(tool_name, tool_spec, tool_func)]
11297

11398
except Exception:
114-
logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool", tool_name, sys.path)
99+
logger.exception("tool_name=<%s>, sys_path=<%s> | failed to load python tool(s)", tool_name, sys.path)
115100
raise
116101

102+
@staticmethod
103+
def load_python_tool(tool_path: str, tool_name: str) -> AgentTool:
104+
"""DEPRECATED: Load a Python tool module and return a single AgentTool for backwards compatibility.
105+
106+
Use `load_python_tools` to retrieve all tools defined in a .py file (returns a list).
107+
This function will emit a `DeprecationWarning` and return the first discovered tool.
108+
"""
109+
warnings.warn(
110+
"ToolLoader.load_python_tool is deprecated and will be removed in Strands SDK 2.0. "
111+
"Use ToolLoader.load_python_tools(...) which always returns a list of AgentTool.",
112+
DeprecationWarning,
113+
stacklevel=2,
114+
)
115+
116+
tools = ToolLoader.load_python_tools(tool_path, tool_name)
117+
if not tools:
118+
raise RuntimeError(f"No tools found in {tool_path} for {tool_name}")
119+
return tools[0]
120+
117121
@classmethod
118122
def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool:
119-
"""Load a tool based on its file extension.
123+
"""DEPRECATED: Load a single tool based on its file extension for backwards compatibility.
124+
125+
Use `load_tools` to retrieve all tools defined in a file (returns a list).
126+
This function will emit a `DeprecationWarning` and return the first discovered tool.
127+
"""
128+
warnings.warn(
129+
"ToolLoader.load_tool is deprecated and will be removed in Strands SDK 2.0. "
130+
"Use ToolLoader.load_tools(...) which always returns a list of AgentTool.",
131+
DeprecationWarning,
132+
stacklevel=2,
133+
)
134+
135+
tools = ToolLoader.load_tools(tool_path, tool_name)
136+
if not tools:
137+
raise RuntimeError(f"No tools found in {tool_path} for {tool_name}")
138+
139+
return tools[0]
140+
141+
@classmethod
142+
def load_tools(cls, tool_path: str, tool_name: str) -> list[AgentTool]:
143+
"""Load tools from a file based on its file extension.
120144
121145
Args:
122146
tool_path: Path to the tool file.
123147
tool_name: Name of the tool.
124148
125149
Returns:
126-
Tool instance.
150+
A single Tool instance.
127151
128152
Raises:
129153
FileNotFoundError: If the tool file does not exist.
@@ -138,7 +162,7 @@ def load_tool(cls, tool_path: str, tool_name: str) -> AgentTool:
138162

139163
try:
140164
if ext == ".py":
141-
return cls.load_python_tool(abs_path, tool_name)
165+
return cls.load_python_tools(abs_path, tool_name)
142166
else:
143167
raise ValueError(f"Unsupported tool file type: {ext}")
144168
except Exception:

src/strands/tools/mcp/mcp_client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,19 +318,22 @@ def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolRes
318318
"""
319319
self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content))
320320

321-
mapped_content = [
322-
mapped_content
321+
# Build a typed list of ToolResultContent. Use a clearer local name to avoid shadowing
322+
# and annotate the result for mypy so it knows the intended element type.
323+
mapped_contents: list[ToolResultContent] = [
324+
mc
323325
for content in call_tool_result.content
324-
if (mapped_content := self._map_mcp_content_to_tool_result_content(content)) is not None
326+
if (mc := self._map_mcp_content_to_tool_result_content(content)) is not None
325327
]
326328

327329
status: ToolResultStatus = "error" if call_tool_result.isError else "success"
328330
self._log_debug_with_thread("tool execution completed with status: %s", status)
329331
result = MCPToolResult(
330332
status=status,
331333
toolUseId=tool_use_id,
332-
content=mapped_content,
334+
content=mapped_contents,
333335
)
336+
334337
if call_tool_result.structuredContent:
335338
result["structuredContent"] = call_tool_result.structuredContent
336339

src/strands/tools/registry.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,11 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None:
127127
if not os.path.exists(tool_path):
128128
raise FileNotFoundError(f"Tool file not found: {tool_path}")
129129

130-
loaded_tool = ToolLoader.load_tool(tool_path, tool_name)
131-
loaded_tool.mark_dynamic()
132-
133-
# Because we're explicitly registering the tool we don't need an allowlist
134-
self.register_tool(loaded_tool)
130+
loaded_tools = ToolLoader.load_tools(tool_path, tool_name)
131+
for t in loaded_tools:
132+
t.mark_dynamic()
133+
# Because we're explicitly registering the tool we don't need an allowlist
134+
self.register_tool(t)
135135
except Exception as e:
136136
exception_str = str(e)
137137
logger.exception("tool_name=<%s> | failed to load tool", tool_name)

tests/strands/tools/test_loader.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,78 @@ def no_spec():
235235
def test_load_tool_no_spec(tool_path):
236236
with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"):
237237
ToolLoader.load_tool(tool_path, "no_spec")
238+
239+
with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"):
240+
ToolLoader.load_tools(tool_path, "no_spec")
241+
242+
with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"):
243+
ToolLoader.load_python_tool(tool_path, "no_spec")
244+
245+
with pytest.raises(AttributeError, match="Tool no_spec missing TOOL_SPEC"):
246+
ToolLoader.load_python_tools(tool_path, "no_spec")
247+
248+
249+
@pytest.mark.parametrize(
250+
"tool_path",
251+
[
252+
textwrap.dedent(
253+
"""
254+
import strands
255+
256+
@strands.tools.tool
257+
def alpha():
258+
return "alpha"
259+
260+
@strands.tools.tool
261+
def bravo():
262+
return "bravo"
263+
"""
264+
)
265+
],
266+
indirect=True,
267+
)
268+
def test_load_python_tool_path_multiple_function_based(tool_path):
269+
# load_python_tools, load_tools returns a list when multiple decorated tools are present
270+
loaded_python_tools = ToolLoader.load_python_tools(tool_path, "alpha")
271+
272+
assert isinstance(loaded_python_tools, list)
273+
assert len(loaded_python_tools) == 2
274+
assert all(isinstance(t, DecoratedFunctionTool) for t in loaded_python_tools)
275+
names = {t.tool_name for t in loaded_python_tools}
276+
assert names == {"alpha", "bravo"}
277+
278+
loaded_tools = ToolLoader.load_tools(tool_path, "alpha")
279+
280+
assert isinstance(loaded_tools, list)
281+
assert len(loaded_tools) == 2
282+
assert all(isinstance(t, DecoratedFunctionTool) for t in loaded_tools)
283+
names = {t.tool_name for t in loaded_tools}
284+
assert names == {"alpha", "bravo"}
285+
286+
287+
@pytest.mark.parametrize(
288+
"tool_path",
289+
[
290+
textwrap.dedent(
291+
"""
292+
import strands
293+
294+
@strands.tools.tool
295+
def alpha():
296+
return "alpha"
297+
298+
@strands.tools.tool
299+
def bravo():
300+
return "bravo"
301+
"""
302+
)
303+
],
304+
indirect=True,
305+
)
306+
def test_load_tool_path_returns_single_tool(tool_path):
307+
# loaded_python_tool and loaded_tool returns single item
308+
loaded_python_tool = ToolLoader.load_python_tool(tool_path, "alpha")
309+
loaded_tool = ToolLoader.load_tool(tool_path, "alpha")
310+
311+
assert loaded_python_tool.tool_name == "alpha"
312+
assert loaded_tool.tool_name == "alpha"

0 commit comments

Comments
 (0)