|
23 | 23 | from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
24 | 24 | from google.adk.tools.base_tool import BaseTool
|
25 | 25 | from google.adk.tools.base_toolset import BaseToolset
|
| 26 | +from google.adk.tools.function_tool import FunctionTool |
26 | 27 | from google.adk.tools.tool_context import ToolContext
|
27 | 28 | import pytest
|
28 | 29 |
|
29 | 30 |
|
| 31 | +class _TestingTool(BaseTool): |
| 32 | + """A test implementation of BaseTool.""" |
| 33 | + |
| 34 | + async def run_async(self, *, args, tool_context): |
| 35 | + return 'test result' |
| 36 | + |
| 37 | + |
30 | 38 | class _TestingToolset(BaseToolset):
|
31 | 39 | """A test implementation of BaseToolset."""
|
32 | 40 |
|
| 41 | + def __init__(self, *args, tools: Optional[list[BaseTool]] = None, **kwargs): |
| 42 | + super().__init__(*args, **kwargs) |
| 43 | + self._tools = tools or [] |
| 44 | + |
33 | 45 | async def get_tools(
|
34 | 46 | self, readonly_context: Optional[ReadonlyContext] = None
|
35 | 47 | ) -> list[BaseTool]:
|
36 |
| - return [] |
| 48 | + return self._tools |
37 | 49 |
|
38 | 50 | async def close(self) -> None:
|
39 | 51 | pass
|
@@ -107,3 +119,258 @@ async def process_llm_request(
|
107 | 119 |
|
108 | 120 | # Verify the custom processing was applied
|
109 | 121 | assert llm_request.contents == ['Custom processing applied']
|
| 122 | + |
| 123 | + |
| 124 | +@pytest.mark.asyncio |
| 125 | +async def test_prefix_functionality_disabled_by_default(): |
| 126 | + """Test that prefix functionality is disabled by default.""" |
| 127 | + tool1 = _TestingTool(name='tool1', description='Test tool 1') |
| 128 | + tool2 = _TestingTool(name='tool2', description='Test tool 2') |
| 129 | + toolset = _TestingToolset(tools=[tool1, tool2]) |
| 130 | + |
| 131 | + # When add_name_prefix is False (default), get_prefixed_tools should return original tools |
| 132 | + prefixed_tools = await toolset.get_prefixed_tools() |
| 133 | + |
| 134 | + assert len(prefixed_tools) == 2 |
| 135 | + assert prefixed_tools[0].name == 'tool1' |
| 136 | + assert prefixed_tools[1].name == 'tool2' |
| 137 | + assert not toolset.add_tool_name_prefix |
| 138 | + |
| 139 | + |
| 140 | +@pytest.mark.asyncio |
| 141 | +async def test_prefix_functionality_with_default_prefix(): |
| 142 | + """Test prefix functionality with default toolset name prefix.""" |
| 143 | + tool1 = _TestingTool(name='tool1', description='Test tool 1') |
| 144 | + tool2 = _TestingTool(name='tool2', description='Test tool 2') |
| 145 | + toolset = _TestingToolset(tools=[tool1, tool2], add_tool_name_prefix=True) |
| 146 | + |
| 147 | + # Should use '_testingtoolset' as default prefix (full class name lowercased) |
| 148 | + prefixed_tools = await toolset.get_prefixed_tools() |
| 149 | + |
| 150 | + assert len(prefixed_tools) == 2 |
| 151 | + assert prefixed_tools[0].name == '_testingtoolset_tool1' |
| 152 | + assert prefixed_tools[1].name == '_testingtoolset_tool2' |
| 153 | + assert toolset.tool_name_prefix == '_testingtoolset' |
| 154 | + assert toolset.add_tool_name_prefix |
| 155 | + |
| 156 | + |
| 157 | +@pytest.mark.asyncio |
| 158 | +async def test_prefix_functionality_with_custom_prefix(): |
| 159 | + """Test prefix functionality with custom prefix.""" |
| 160 | + tool1 = _TestingTool(name='tool1', description='Test tool 1') |
| 161 | + tool2 = _TestingTool(name='tool2', description='Test tool 2') |
| 162 | + toolset = _TestingToolset( |
| 163 | + tools=[tool1, tool2], add_tool_name_prefix=True, tool_name_prefix='custom' |
| 164 | + ) |
| 165 | + |
| 166 | + prefixed_tools = await toolset.get_prefixed_tools() |
| 167 | + |
| 168 | + assert len(prefixed_tools) == 2 |
| 169 | + assert prefixed_tools[0].name == 'custom_tool1' |
| 170 | + assert prefixed_tools[1].name == 'custom_tool2' |
| 171 | + assert toolset.tool_name_prefix == 'custom' |
| 172 | + |
| 173 | + |
| 174 | +@pytest.mark.asyncio |
| 175 | +async def test_prefix_property_with_different_toolset_names(): |
| 176 | + """Test prefix property with different toolset class names.""" |
| 177 | + |
| 178 | + class BigQueryToolset(_TestingToolset): |
| 179 | + pass |
| 180 | + |
| 181 | + class MyCustomClass(_TestingToolset): |
| 182 | + pass |
| 183 | + |
| 184 | + # Test with 'toolset' suffix |
| 185 | + bq_toolset = BigQueryToolset(add_tool_name_prefix=True) |
| 186 | + assert bq_toolset.tool_name_prefix == 'bigquerytoolset' |
| 187 | + |
| 188 | + # Test without 'toolset' suffix |
| 189 | + custom_toolset = MyCustomClass(add_tool_name_prefix=True) |
| 190 | + assert custom_toolset.tool_name_prefix == 'mycustomclass' |
| 191 | + |
| 192 | + |
| 193 | +@pytest.mark.asyncio |
| 194 | +async def test_prefix_property_with_explicit_prefix(): |
| 195 | + """Test prefix property when explicit prefix is provided.""" |
| 196 | + toolset = _TestingToolset( |
| 197 | + add_tool_name_prefix=True, tool_name_prefix='explicit' |
| 198 | + ) |
| 199 | + assert toolset.tool_name_prefix == 'explicit' |
| 200 | + |
| 201 | + |
| 202 | +@pytest.mark.asyncio |
| 203 | +async def test_prefix_modifies_tools_in_place(): |
| 204 | + """Test that prefixing modifies tool names in place.""" |
| 205 | + original_tool = _TestingTool( |
| 206 | + name='original', description='Original description' |
| 207 | + ) |
| 208 | + original_tool.is_long_running = True |
| 209 | + original_tool.custom_attribute = 'custom_value' |
| 210 | + |
| 211 | + toolset = _TestingToolset( |
| 212 | + tools=[original_tool], add_tool_name_prefix=True, tool_name_prefix='test' |
| 213 | + ) |
| 214 | + prefixed_tools = await toolset.get_prefixed_tools() |
| 215 | + |
| 216 | + prefixed_tool = prefixed_tools[0] |
| 217 | + |
| 218 | + # Name should be prefixed |
| 219 | + assert prefixed_tool.name == 'test_original' |
| 220 | + |
| 221 | + # Other attributes should be preserved |
| 222 | + assert prefixed_tool.description == 'Original description' |
| 223 | + assert prefixed_tool.is_long_running == True |
| 224 | + assert prefixed_tool.custom_attribute == 'custom_value' |
| 225 | + |
| 226 | + # Since we modify in place, original tool should now have prefixed name |
| 227 | + assert original_tool.name == 'test_original' |
| 228 | + assert original_tool is prefixed_tool |
| 229 | + |
| 230 | + |
| 231 | +@pytest.mark.asyncio |
| 232 | +async def test_get_tools_vs_get_prefixed_tools(): |
| 233 | + """Test that get_tools returns tools without prefixing.""" |
| 234 | + tool1 = _TestingTool(name='test_tool1', description='Test tool 1') |
| 235 | + tool2 = _TestingTool(name='test_tool2', description='Test tool 2') |
| 236 | + toolset = _TestingToolset( |
| 237 | + tools=[tool1, tool2], add_tool_name_prefix=True, tool_name_prefix='prefix' |
| 238 | + ) |
| 239 | + |
| 240 | + # get_tools should return original tools (unmodified) |
| 241 | + original_tools = await toolset.get_tools() |
| 242 | + assert len(original_tools) == 2 |
| 243 | + assert original_tools[0].name == 'test_tool1' |
| 244 | + assert original_tools[1].name == 'test_tool2' |
| 245 | + |
| 246 | + # Now calling get_prefixed_tools should modify the tool names in place |
| 247 | + prefixed_tools = await toolset.get_prefixed_tools() |
| 248 | + assert len(prefixed_tools) == 2 |
| 249 | + assert prefixed_tools[0].name == 'prefix_test_tool1' |
| 250 | + assert prefixed_tools[1].name == 'prefix_test_tool2' |
| 251 | + |
| 252 | + # Since we modify in place, the original tools now have prefixed names |
| 253 | + assert original_tools[0].name == 'prefix_test_tool1' |
| 254 | + assert original_tools[1].name == 'prefix_test_tool2' |
| 255 | + |
| 256 | + |
| 257 | +@pytest.mark.asyncio |
| 258 | +async def test_empty_toolset_with_prefix(): |
| 259 | + """Test prefix functionality with empty toolset.""" |
| 260 | + toolset = _TestingToolset( |
| 261 | + tools=[], add_tool_name_prefix=True, tool_name_prefix='test' |
| 262 | + ) |
| 263 | + |
| 264 | + prefixed_tools = await toolset.get_prefixed_tools() |
| 265 | + assert len(prefixed_tools) == 0 |
| 266 | + |
| 267 | + |
| 268 | +@pytest.mark.asyncio |
| 269 | +async def test_function_declarations_are_prefixed(): |
| 270 | + """Test that function declarations have prefixed names.""" |
| 271 | + |
| 272 | + def test_function(param1: str, param2: int) -> str: |
| 273 | + """A test function for checking prefixes.""" |
| 274 | + return f'{param1}_{param2}' |
| 275 | + |
| 276 | + function_tool = FunctionTool(test_function) |
| 277 | + toolset = _TestingToolset( |
| 278 | + tools=[function_tool], |
| 279 | + add_tool_name_prefix=True, |
| 280 | + tool_name_prefix='prefix', |
| 281 | + ) |
| 282 | + |
| 283 | + prefixed_tools = await toolset.get_prefixed_tools() |
| 284 | + prefixed_tool = prefixed_tools[0] |
| 285 | + |
| 286 | + # Tool name should be prefixed |
| 287 | + assert prefixed_tool.name == 'prefix_test_function' |
| 288 | + |
| 289 | + # Function declaration should also have prefixed name |
| 290 | + declaration = prefixed_tool._get_declaration() |
| 291 | + assert declaration is not None |
| 292 | + assert declaration.name == 'prefix_test_function' |
| 293 | + |
| 294 | + # Description should remain unchanged |
| 295 | + assert 'A test function for checking prefixes.' in declaration.description |
| 296 | + |
| 297 | + |
| 298 | +@pytest.mark.asyncio |
| 299 | +async def test_prefixed_tools_in_llm_request(): |
| 300 | + """Test that prefixed tools are properly added to LLM request.""" |
| 301 | + |
| 302 | + def test_function(param: str) -> str: |
| 303 | + """A test function.""" |
| 304 | + return f'result: {param}' |
| 305 | + |
| 306 | + function_tool = FunctionTool(test_function) |
| 307 | + toolset = _TestingToolset( |
| 308 | + tools=[function_tool], add_tool_name_prefix=True, tool_name_prefix='test' |
| 309 | + ) |
| 310 | + |
| 311 | + prefixed_tools = await toolset.get_prefixed_tools() |
| 312 | + prefixed_tool = prefixed_tools[0] |
| 313 | + |
| 314 | + # Create LLM request and tool context |
| 315 | + session_service = InMemorySessionService() |
| 316 | + session = await session_service.create_session( |
| 317 | + app_name='test_app', user_id='test_user' |
| 318 | + ) |
| 319 | + agent = SequentialAgent(name='test_agent') |
| 320 | + invocation_context = InvocationContext( |
| 321 | + invocation_id='test_id', |
| 322 | + agent=agent, |
| 323 | + session=session, |
| 324 | + session_service=session_service, |
| 325 | + ) |
| 326 | + tool_context = ToolContext(invocation_context) |
| 327 | + llm_request = LlmRequest() |
| 328 | + |
| 329 | + # Process the LLM request with the prefixed tool |
| 330 | + await prefixed_tool.process_llm_request( |
| 331 | + tool_context=tool_context, llm_request=llm_request |
| 332 | + ) |
| 333 | + |
| 334 | + # Verify the tool is registered with prefixed name in tools_dict |
| 335 | + assert 'test_test_function' in llm_request.tools_dict |
| 336 | + assert llm_request.tools_dict['test_test_function'] == prefixed_tool |
| 337 | + |
| 338 | + # Verify the function declaration has prefixed name |
| 339 | + assert llm_request.config is not None |
| 340 | + assert llm_request.config.tools is not None |
| 341 | + assert len(llm_request.config.tools) == 1 |
| 342 | + tool_config = llm_request.config.tools[0] |
| 343 | + assert len(tool_config.function_declarations) == 1 |
| 344 | + func_decl = tool_config.function_declarations[0] |
| 345 | + assert func_decl.name == 'test_test_function' |
| 346 | + |
| 347 | + |
| 348 | +@pytest.mark.asyncio |
| 349 | +async def test_multiple_tools_have_correct_declarations(): |
| 350 | + """Test that each tool maintains its own function declaration after prefixing.""" |
| 351 | + |
| 352 | + def tool_one(param: str) -> str: |
| 353 | + """Function one.""" |
| 354 | + return f'one: {param}' |
| 355 | + |
| 356 | + def tool_two(param: int) -> str: |
| 357 | + """Function two.""" |
| 358 | + return f'two: {param}' |
| 359 | + |
| 360 | + tool1 = FunctionTool(tool_one) |
| 361 | + tool2 = FunctionTool(tool_two) |
| 362 | + toolset = _TestingToolset( |
| 363 | + tools=[tool1, tool2], add_tool_name_prefix=True, tool_name_prefix='test' |
| 364 | + ) |
| 365 | + |
| 366 | + prefixed_tools = await toolset.get_prefixed_tools() |
| 367 | + |
| 368 | + # Verify each tool has its own correct declaration |
| 369 | + decl1 = prefixed_tools[0]._get_declaration() |
| 370 | + decl2 = prefixed_tools[1]._get_declaration() |
| 371 | + |
| 372 | + assert decl1.name == 'test_tool_one' |
| 373 | + assert decl2.name == 'test_tool_two' |
| 374 | + |
| 375 | + assert 'Function one.' in decl1.description |
| 376 | + assert 'Function two.' in decl2.description |
0 commit comments