diff --git a/src/google/adk/flows/llm_flows/agent_transfer.py b/src/google/adk/flows/llm_flows/agent_transfer.py index 86128706f..3a8f7e757 100644 --- a/src/google/adk/flows/llm_flows/agent_transfer.py +++ b/src/google/adk/flows/llm_flows/agent_transfer.py @@ -24,9 +24,8 @@ from ...agents.invocation_context import InvocationContext from ...events.event import Event from ...models.llm_request import LlmRequest -from ...tools.function_tool import FunctionTool from ...tools.tool_context import ToolContext -from ...tools.transfer_to_agent_tool import transfer_to_agent +from ...tools.transfer_to_agent_tool import TransferToAgentTool from ._base_llm_processor import BaseLlmRequestProcessor if typing.TYPE_CHECKING: @@ -50,15 +49,18 @@ async def run_async( if not transfer_targets: return + transfer_to_agent = TransferToAgentTool( + agent_names=[agent.name for agent in transfer_targets] + ) + llm_request.append_instructions([ _build_target_agents_instructions( - invocation_context.agent, transfer_targets + transfer_to_agent.name, invocation_context.agent, transfer_targets ) ]) - transfer_to_agent_tool = FunctionTool(func=transfer_to_agent) tool_context = ToolContext(invocation_context) - await transfer_to_agent_tool.process_llm_request( + await transfer_to_agent.process_llm_request( tool_context=tool_context, llm_request=llm_request ) @@ -80,7 +82,7 @@ def _build_target_agents_info(target_agent: BaseAgent) -> str: def _build_target_agents_instructions( - agent: LlmAgent, target_agents: list[BaseAgent] + tool_name: str, agent: LlmAgent, target_agents: list[BaseAgent] ) -> str: si = f""" You have a list of other agents to transfer to: @@ -93,7 +95,7 @@ def _build_target_agents_instructions( can answer it. If another agent is better for answering the question according to its -description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the +description, call `{tool_name}` function to transfer the question to that agent. When transferring, do not generate any text other than the function call. """ @@ -107,9 +109,6 @@ def _build_target_agents_instructions( return si -_TRANSFER_TO_AGENT_FUNCTION_NAME = transfer_to_agent.__name__ - - def _get_transfer_targets(agent: LlmAgent) -> list[BaseAgent]: from ...agents.llm_agent import LlmAgent diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py index bb26d4941..aa526cbac 100644 --- a/src/google/adk/tools/__init__.py +++ b/src/google/adk/tools/__init__.py @@ -29,7 +29,7 @@ from .long_running_tool import LongRunningFunctionTool from .preload_memory_tool import preload_memory_tool as preload_memory from .tool_context import ToolContext -from .transfer_to_agent_tool import transfer_to_agent +from .transfer_to_agent_tool import TransferToAgentTool from .url_context_tool import url_context from .vertex_ai_search_tool import VertexAiSearchTool @@ -51,7 +51,7 @@ 'LongRunningFunctionTool', 'preload_memory', 'ToolContext', - 'transfer_to_agent', + 'TransferToAgentTool', ] diff --git a/src/google/adk/tools/transfer_to_agent_tool.py b/src/google/adk/tools/transfer_to_agent_tool.py index 99ee234b3..5c98bfd59 100644 --- a/src/google/adk/tools/transfer_to_agent_tool.py +++ b/src/google/adk/tools/transfer_to_agent_tool.py @@ -14,6 +14,12 @@ from __future__ import annotations +from typing import Optional + +from google.genai import types +from typing_extensions import override + +from .function_tool import FunctionTool from .tool_context import ToolContext @@ -27,3 +33,24 @@ def transfer_to_agent(agent_name: str, tool_context: ToolContext) -> None: agent_name: the agent name to transfer to. """ tool_context.actions.transfer_to_agent = agent_name + + +class TransferToAgentTool(FunctionTool): + """A specialized FunctionTool for agent transfer.""" + + def __init__( + self, + agent_names: list[str], + ): + super().__init__(func=transfer_to_agent) + self._agent_names = agent_names + + @override + def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + """Add enum constraint to the agent_name.""" + function_decl = super()._get_declaration() + if function_decl and function_decl.parameters: + agent_name_schema = function_decl.parameters.properties.get("agent_name") + if agent_name_schema: + agent_name_schema.enum = self._agent_names + return function_decl