Skip to content

Commit 2c6acd9

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: Support adding prefix to tool names returned by toolset
This is to address the name conflict issue of tools returned by different toolset. Mainly it's to give each toolset a namespace. We have a flag `add_tool_name_prefix` to decide whether to apply this behavior We have a `tool_name_prefix` to let client specify a custom prefix, if not set , toolset name will be used as prefix. PiperOrigin-RevId: 791983427
1 parent e0a8355 commit 2c6acd9

File tree

3 files changed

+346
-4
lines changed

3 files changed

+346
-4
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@ async def _convert_tool_union_to_tools(
113113
) -> list[BaseTool]:
114114
if isinstance(tool_union, BaseTool):
115115
return [tool_union]
116-
if isinstance(tool_union, Callable):
116+
if callable(tool_union):
117117
return [FunctionTool(func=tool_union)]
118118

119-
return await tool_union.get_tools(ctx)
119+
# At this point, tool_union must be a BaseToolset
120+
return await tool_union.get_prefixed_tools(ctx)
120121

121122

122123
class LlmAgent(BaseAgent):

src/google/adk/tools/base_toolset.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,37 @@ class BaseToolset(ABC):
5858
"""
5959

6060
def __init__(
61-
self, *, tool_filter: Optional[Union[ToolPredicate, List[str]]] = None
61+
self,
62+
*,
63+
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
64+
add_tool_name_prefix: bool = False,
65+
tool_name_prefix: Optional[str] = None,
6266
):
67+
"""Initialize the toolset.
68+
69+
Args:
70+
tool_filter: Filter to apply to tools.
71+
add_tool_name_prefix: Whether to add prefix to tool names. Defaults to False.
72+
tool_name_prefix: Custom prefix for tool names. If not provided and
73+
add_tool_name_prefix is True, uses the toolset class name (lowercased, without
74+
'toolset' suffix) as the default prefix.
75+
"""
6376
self.tool_filter = tool_filter
77+
self.add_tool_name_prefix = add_tool_name_prefix
78+
self._tool_name_prefix = tool_name_prefix
79+
80+
@property
81+
def tool_name_prefix(self) -> str:
82+
"""Get the prefix for tool names.
83+
84+
Returns:
85+
The custom prefix if provided, otherwise the toolset class name
86+
(lowercased).
87+
"""
88+
if self._tool_name_prefix is not None:
89+
return self._tool_name_prefix
90+
91+
return self.__class__.__name__.lower()
6492

6593
@abstractmethod
6694
async def get_tools(
@@ -77,6 +105,52 @@ async def get_tools(
77105
list[BaseTool]: A list of tools available under the specified context.
78106
"""
79107

108+
async def get_prefixed_tools(
109+
self,
110+
readonly_context: Optional[ReadonlyContext] = None,
111+
) -> list[BaseTool]:
112+
"""Return all tools with optional prefix applied to tool names.
113+
114+
This method calls get_tools() and applies prefixing if add_tool_name_prefix is True.
115+
116+
Args:
117+
readonly_context (ReadonlyContext, optional): Context used to filter tools
118+
available to the agent. If None, all tools in the toolset are returned.
119+
120+
Returns:
121+
list[BaseTool]: A list of tools with prefixed names if add_tool_name_prefix is True.
122+
"""
123+
tools = await self.get_tools(readonly_context)
124+
125+
if not self.add_tool_name_prefix:
126+
return tools
127+
128+
prefix = self.tool_name_prefix
129+
130+
for tool in tools:
131+
132+
prefixed_name = f"{prefix}_{tool.name}"
133+
tool.name = prefixed_name
134+
135+
# Also update the function declaration name if the tool has one
136+
# Use default parameters to capture the current values in the closure
137+
def _create_prefixed_declaration(
138+
original_get_declaration=tool._get_declaration,
139+
prefixed_name=prefixed_name,
140+
):
141+
def _get_prefixed_declaration():
142+
declaration = original_get_declaration()
143+
if declaration is not None:
144+
declaration.name = prefixed_name
145+
return declaration
146+
return None
147+
148+
return _get_prefixed_declaration
149+
150+
tool._get_declaration = _create_prefixed_declaration()
151+
152+
return tools
153+
80154
@abstractmethod
81155
async def close(self) -> None:
82156
"""Performs cleanup and releases resources held by the toolset.

tests/unittests/tools/test_base_toolset.py

Lines changed: 268 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,29 @@
2323
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2424
from google.adk.tools.base_tool import BaseTool
2525
from google.adk.tools.base_toolset import BaseToolset
26+
from google.adk.tools.function_tool import FunctionTool
2627
from google.adk.tools.tool_context import ToolContext
2728
import pytest
2829

2930

31+
class _TestingTool(BaseTool):
32+
"""A test implementation of BaseTool."""
33+
34+
async def run_async(self, *, args, tool_context):
35+
return 'test result'
36+
37+
3038
class _TestingToolset(BaseToolset):
3139
"""A test implementation of BaseToolset."""
3240

41+
def __init__(self, *args, tools: Optional[list[BaseTool]] = None, **kwargs):
42+
super().__init__(*args, **kwargs)
43+
self._tools = tools or []
44+
3345
async def get_tools(
3446
self, readonly_context: Optional[ReadonlyContext] = None
3547
) -> list[BaseTool]:
36-
return []
48+
return self._tools
3749

3850
async def close(self) -> None:
3951
pass
@@ -107,3 +119,258 @@ async def process_llm_request(
107119

108120
# Verify the custom processing was applied
109121
assert llm_request.contents == ['Custom processing applied']
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_prefix_functionality_disabled_by_default():
126+
"""Test that prefix functionality is disabled by default."""
127+
tool1 = _TestingTool(name='tool1', description='Test tool 1')
128+
tool2 = _TestingTool(name='tool2', description='Test tool 2')
129+
toolset = _TestingToolset(tools=[tool1, tool2])
130+
131+
# When add_name_prefix is False (default), get_prefixed_tools should return original tools
132+
prefixed_tools = await toolset.get_prefixed_tools()
133+
134+
assert len(prefixed_tools) == 2
135+
assert prefixed_tools[0].name == 'tool1'
136+
assert prefixed_tools[1].name == 'tool2'
137+
assert not toolset.add_tool_name_prefix
138+
139+
140+
@pytest.mark.asyncio
141+
async def test_prefix_functionality_with_default_prefix():
142+
"""Test prefix functionality with default toolset name prefix."""
143+
tool1 = _TestingTool(name='tool1', description='Test tool 1')
144+
tool2 = _TestingTool(name='tool2', description='Test tool 2')
145+
toolset = _TestingToolset(tools=[tool1, tool2], add_tool_name_prefix=True)
146+
147+
# Should use '_testingtoolset' as default prefix (full class name lowercased)
148+
prefixed_tools = await toolset.get_prefixed_tools()
149+
150+
assert len(prefixed_tools) == 2
151+
assert prefixed_tools[0].name == '_testingtoolset_tool1'
152+
assert prefixed_tools[1].name == '_testingtoolset_tool2'
153+
assert toolset.tool_name_prefix == '_testingtoolset'
154+
assert toolset.add_tool_name_prefix
155+
156+
157+
@pytest.mark.asyncio
158+
async def test_prefix_functionality_with_custom_prefix():
159+
"""Test prefix functionality with custom prefix."""
160+
tool1 = _TestingTool(name='tool1', description='Test tool 1')
161+
tool2 = _TestingTool(name='tool2', description='Test tool 2')
162+
toolset = _TestingToolset(
163+
tools=[tool1, tool2], add_tool_name_prefix=True, tool_name_prefix='custom'
164+
)
165+
166+
prefixed_tools = await toolset.get_prefixed_tools()
167+
168+
assert len(prefixed_tools) == 2
169+
assert prefixed_tools[0].name == 'custom_tool1'
170+
assert prefixed_tools[1].name == 'custom_tool2'
171+
assert toolset.tool_name_prefix == 'custom'
172+
173+
174+
@pytest.mark.asyncio
175+
async def test_prefix_property_with_different_toolset_names():
176+
"""Test prefix property with different toolset class names."""
177+
178+
class BigQueryToolset(_TestingToolset):
179+
pass
180+
181+
class MyCustomClass(_TestingToolset):
182+
pass
183+
184+
# Test with 'toolset' suffix
185+
bq_toolset = BigQueryToolset(add_tool_name_prefix=True)
186+
assert bq_toolset.tool_name_prefix == 'bigquerytoolset'
187+
188+
# Test without 'toolset' suffix
189+
custom_toolset = MyCustomClass(add_tool_name_prefix=True)
190+
assert custom_toolset.tool_name_prefix == 'mycustomclass'
191+
192+
193+
@pytest.mark.asyncio
194+
async def test_prefix_property_with_explicit_prefix():
195+
"""Test prefix property when explicit prefix is provided."""
196+
toolset = _TestingToolset(
197+
add_tool_name_prefix=True, tool_name_prefix='explicit'
198+
)
199+
assert toolset.tool_name_prefix == 'explicit'
200+
201+
202+
@pytest.mark.asyncio
203+
async def test_prefix_modifies_tools_in_place():
204+
"""Test that prefixing modifies tool names in place."""
205+
original_tool = _TestingTool(
206+
name='original', description='Original description'
207+
)
208+
original_tool.is_long_running = True
209+
original_tool.custom_attribute = 'custom_value'
210+
211+
toolset = _TestingToolset(
212+
tools=[original_tool], add_tool_name_prefix=True, tool_name_prefix='test'
213+
)
214+
prefixed_tools = await toolset.get_prefixed_tools()
215+
216+
prefixed_tool = prefixed_tools[0]
217+
218+
# Name should be prefixed
219+
assert prefixed_tool.name == 'test_original'
220+
221+
# Other attributes should be preserved
222+
assert prefixed_tool.description == 'Original description'
223+
assert prefixed_tool.is_long_running == True
224+
assert prefixed_tool.custom_attribute == 'custom_value'
225+
226+
# Since we modify in place, original tool should now have prefixed name
227+
assert original_tool.name == 'test_original'
228+
assert original_tool is prefixed_tool
229+
230+
231+
@pytest.mark.asyncio
232+
async def test_get_tools_vs_get_prefixed_tools():
233+
"""Test that get_tools returns tools without prefixing."""
234+
tool1 = _TestingTool(name='test_tool1', description='Test tool 1')
235+
tool2 = _TestingTool(name='test_tool2', description='Test tool 2')
236+
toolset = _TestingToolset(
237+
tools=[tool1, tool2], add_tool_name_prefix=True, tool_name_prefix='prefix'
238+
)
239+
240+
# get_tools should return original tools (unmodified)
241+
original_tools = await toolset.get_tools()
242+
assert len(original_tools) == 2
243+
assert original_tools[0].name == 'test_tool1'
244+
assert original_tools[1].name == 'test_tool2'
245+
246+
# Now calling get_prefixed_tools should modify the tool names in place
247+
prefixed_tools = await toolset.get_prefixed_tools()
248+
assert len(prefixed_tools) == 2
249+
assert prefixed_tools[0].name == 'prefix_test_tool1'
250+
assert prefixed_tools[1].name == 'prefix_test_tool2'
251+
252+
# Since we modify in place, the original tools now have prefixed names
253+
assert original_tools[0].name == 'prefix_test_tool1'
254+
assert original_tools[1].name == 'prefix_test_tool2'
255+
256+
257+
@pytest.mark.asyncio
258+
async def test_empty_toolset_with_prefix():
259+
"""Test prefix functionality with empty toolset."""
260+
toolset = _TestingToolset(
261+
tools=[], add_tool_name_prefix=True, tool_name_prefix='test'
262+
)
263+
264+
prefixed_tools = await toolset.get_prefixed_tools()
265+
assert len(prefixed_tools) == 0
266+
267+
268+
@pytest.mark.asyncio
269+
async def test_function_declarations_are_prefixed():
270+
"""Test that function declarations have prefixed names."""
271+
272+
def test_function(param1: str, param2: int) -> str:
273+
"""A test function for checking prefixes."""
274+
return f'{param1}_{param2}'
275+
276+
function_tool = FunctionTool(test_function)
277+
toolset = _TestingToolset(
278+
tools=[function_tool],
279+
add_tool_name_prefix=True,
280+
tool_name_prefix='prefix',
281+
)
282+
283+
prefixed_tools = await toolset.get_prefixed_tools()
284+
prefixed_tool = prefixed_tools[0]
285+
286+
# Tool name should be prefixed
287+
assert prefixed_tool.name == 'prefix_test_function'
288+
289+
# Function declaration should also have prefixed name
290+
declaration = prefixed_tool._get_declaration()
291+
assert declaration is not None
292+
assert declaration.name == 'prefix_test_function'
293+
294+
# Description should remain unchanged
295+
assert 'A test function for checking prefixes.' in declaration.description
296+
297+
298+
@pytest.mark.asyncio
299+
async def test_prefixed_tools_in_llm_request():
300+
"""Test that prefixed tools are properly added to LLM request."""
301+
302+
def test_function(param: str) -> str:
303+
"""A test function."""
304+
return f'result: {param}'
305+
306+
function_tool = FunctionTool(test_function)
307+
toolset = _TestingToolset(
308+
tools=[function_tool], add_tool_name_prefix=True, tool_name_prefix='test'
309+
)
310+
311+
prefixed_tools = await toolset.get_prefixed_tools()
312+
prefixed_tool = prefixed_tools[0]
313+
314+
# Create LLM request and tool context
315+
session_service = InMemorySessionService()
316+
session = await session_service.create_session(
317+
app_name='test_app', user_id='test_user'
318+
)
319+
agent = SequentialAgent(name='test_agent')
320+
invocation_context = InvocationContext(
321+
invocation_id='test_id',
322+
agent=agent,
323+
session=session,
324+
session_service=session_service,
325+
)
326+
tool_context = ToolContext(invocation_context)
327+
llm_request = LlmRequest()
328+
329+
# Process the LLM request with the prefixed tool
330+
await prefixed_tool.process_llm_request(
331+
tool_context=tool_context, llm_request=llm_request
332+
)
333+
334+
# Verify the tool is registered with prefixed name in tools_dict
335+
assert 'test_test_function' in llm_request.tools_dict
336+
assert llm_request.tools_dict['test_test_function'] == prefixed_tool
337+
338+
# Verify the function declaration has prefixed name
339+
assert llm_request.config is not None
340+
assert llm_request.config.tools is not None
341+
assert len(llm_request.config.tools) == 1
342+
tool_config = llm_request.config.tools[0]
343+
assert len(tool_config.function_declarations) == 1
344+
func_decl = tool_config.function_declarations[0]
345+
assert func_decl.name == 'test_test_function'
346+
347+
348+
@pytest.mark.asyncio
349+
async def test_multiple_tools_have_correct_declarations():
350+
"""Test that each tool maintains its own function declaration after prefixing."""
351+
352+
def tool_one(param: str) -> str:
353+
"""Function one."""
354+
return f'one: {param}'
355+
356+
def tool_two(param: int) -> str:
357+
"""Function two."""
358+
return f'two: {param}'
359+
360+
tool1 = FunctionTool(tool_one)
361+
tool2 = FunctionTool(tool_two)
362+
toolset = _TestingToolset(
363+
tools=[tool1, tool2], add_tool_name_prefix=True, tool_name_prefix='test'
364+
)
365+
366+
prefixed_tools = await toolset.get_prefixed_tools()
367+
368+
# Verify each tool has its own correct declaration
369+
decl1 = prefixed_tools[0]._get_declaration()
370+
decl2 = prefixed_tools[1]._get_declaration()
371+
372+
assert decl1.name == 'test_tool_one'
373+
assert decl2.name == 'test_tool_two'
374+
375+
assert 'Function one.' in decl1.description
376+
assert 'Function two.' in decl2.description

0 commit comments

Comments
 (0)