diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 9e5ac9944115e..31477d293741d 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -21,7 +21,7 @@ from langgraph.constants import END, START from langgraph.graph.state import StateGraph from langgraph.runtime import Runtime # noqa: TC002 -from langgraph.types import Send +from langgraph.types import Command, Send from langgraph.typing import ContextT # noqa: TC002 from typing_extensions import NotRequired, Required, TypedDict, TypeVar @@ -56,6 +56,8 @@ from langgraph.store.base import BaseStore from langgraph.types import Checkpointer + from langchain.tools.tool_node import ToolCallHandler, ToolCallRequest + STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." ResponseT = TypeVar("ResponseT") @@ -410,6 +412,98 @@ def _handle_structured_output_error( return False, "" +def _chain_tool_call_handlers( + handlers: Sequence[ToolCallHandler], +) -> ToolCallHandler | None: + """Compose handlers into middleware stack (first = outermost). + + Args: + handlers: Handlers in middleware order. + + Returns: + Composed handler, or None if empty. + + Example: + handler = _chain_tool_call_handlers([auth, cache, retry]) + # Request flows: auth -> cache -> retry -> tool + # Response flows: tool -> retry -> cache -> auth + """ + if not handlers: + return None + + if len(handlers) == 1: + return handlers[0] + + def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandler: + """Compose two handlers where outer wraps inner.""" + + def composed( + request: ToolCallRequest, state: Any, runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + outer_gen = outer(request, state, runtime) + + # Initialize outer generator + try: + outer_request_or_result = next(outer_gen) + except StopIteration: + msg = "outer handler must yield at least once" + raise ValueError(msg) + + # Outer retry loop + while True: + # If outer yielded a ToolMessage or Command, bypass inner and yield directly + if isinstance(outer_request_or_result, (ToolMessage, Command)): + result = yield outer_request_or_result + try: + outer_request_or_result = outer_gen.send(result) + except StopIteration: + # Outer ended - final result is what we sent to it + return + continue + + inner_gen = inner(outer_request_or_result, state, runtime) + last_sent_to_inner: ToolMessage | Command | None = None + + # Initialize inner generator + try: + inner_request_or_result = next(inner_gen) + except StopIteration: + msg = "inner handler must yield at least once" + raise ValueError(msg) + + # Inner retry loop + while True: + # Yield to actual tool execution + result = yield inner_request_or_result + last_sent_to_inner = result + + # Send result to inner + try: + inner_request_or_result = inner_gen.send(result) + except StopIteration: + # Inner is done - final result from inner is last_sent_to_inner + break + + # Send inner's final result to outer + if last_sent_to_inner is None: + msg = "inner handler ended without receiving any result" + raise ValueError(msg) + try: + outer_request_or_result = outer_gen.send(last_sent_to_inner) + except StopIteration: + # Outer is done - final result is what we sent to it + return + + return composed + + # Chain all handlers: first -> second -> ... -> last + result = handlers[-1] + for handler in reversed(handlers[:-1]): + result = compose_two(handler, result) + + return result + + def create_agent( # noqa: PLR0915 model: str | BaseChatModel, tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None, @@ -537,6 +631,17 @@ def check_weather(location: str) -> str: structured_output_tools[structured_tool_info.tool.name] = structured_tool_info middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])] + # Collect middleware with on_tool_call hooks + middleware_w_on_tool_call = [ + m for m in middleware if m.__class__.on_tool_call is not AgentMiddleware.on_tool_call + ] + + # Chain all on_tool_call handlers into a single composed handler + on_tool_call_handler = None + if middleware_w_on_tool_call: + handlers = [m.on_tool_call for m in middleware_w_on_tool_call] + on_tool_call_handler = _chain_tool_call_handlers(handlers) + # Setup tools tool_node: ToolNode | None = None # Extract built-in provider tools (dict format) and regular tools (BaseTool/callables) @@ -547,7 +652,11 @@ def check_weather(location: str) -> str: available_tools = middleware_tools + regular_tools # Only create ToolNode if we have client-side tools - tool_node = ToolNode(tools=available_tools) if available_tools else None + tool_node = ( + ToolNode(tools=available_tools, on_tool_call=on_tool_call_handler) + if available_tools + else None + ) # Default tools for ModelRequest initialization # Use converted BaseTool instances from ToolNode (not raw callables) diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 6759bff2e1c46..32f6db965d416 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -21,11 +21,14 @@ if TYPE_CHECKING: from collections.abc import Awaitable + from langchain.tools.tool_node import ToolCallRequest + # needed as top level import for pydantic schema generation on AgentState -from langchain_core.messages import AIMessage, AnyMessage # noqa: TC002 +from langchain_core.messages import AIMessage, AnyMessage, ToolMessage # noqa: TC002 from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.channels.untracked_value import UntrackedValue from langgraph.graph.message import add_messages +from langgraph.types import Command # noqa: TC002 from langgraph.typing import ContextT from typing_extensions import NotRequired, Required, TypedDict, TypeVar @@ -33,7 +36,6 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.runtime import Runtime - from langgraph.types import Command from langchain.agents.structured_output import ResponseFormat @@ -261,6 +263,46 @@ async def aafter_agent( ) -> dict[str, Any] | None: """Async logic to run after the agent execution completes.""" + def on_tool_call( + self, + request: ToolCallRequest, + state: StateT, # noqa: ARG002 + runtime: Runtime[ContextT], # noqa: ARG002 + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Intercept tool execution for retries, monitoring, or modification. + + Multiple middleware compose automatically (first defined = outermost). + Exceptions propagate unless handle_tool_errors is configured on ToolNode. + + Args: + request: Tool call request with call dict and BaseTool instance. + state: Current agent state. + runtime: LangGraph runtime. + + Yields: + ToolCallRequest (execute tool), ToolMessage (cached result), + or Command (control flow). + + Receives: + ToolMessage or Command via .send() after execution. + + Example: + Modify request: + + def on_tool_call(self, request, state, runtime): + request.tool_call["args"]["value"] *= 2 + yield request + + Retry on error: + + def on_tool_call(self, request, state, runtime): + for attempt in range(3): + response = yield request + if valid(response) or attempt == 2: + return + """ + yield request + class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): """Callable with AgentState and Runtime as arguments.""" @@ -402,7 +444,7 @@ def before_model( Returns: Either an AgentMiddleware instance (if func is provided directly) or a decorator function - that can be applied to a function its wrapping. + that can be applied to a function it is wrapping. The decorated function should return: - `dict[str, Any]` - State updates to merge into the agent state @@ -812,7 +854,7 @@ def before_agent( Returns: Either an AgentMiddleware instance (if func is provided directly) or a decorator function - that can be applied to a function its wrapping. + that can be applied to a function it is wrapping. The decorated function should return: - `dict[str, Any]` - State updates to merge into the agent state diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index f0bbb1754cd62..cb9b5bc1790b5 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -5,7 +5,7 @@ Tools are functions that models can call to interact with external systems, APIs, databases, or perform computations. -The module implements several key design patterns: +The module implements design patterns for: - Parallel execution of multiple tool calls for efficiency - Robust error handling with customizable error messages - State injection for tools that need access to graph state @@ -38,8 +38,9 @@ def my_tool(x: int) -> str: import asyncio import inspect import json +from collections.abc import Callable, Generator from copy import copy, deepcopy -from dataclasses import replace +from dataclasses import dataclass, replace from types import UnionType from typing import ( TYPE_CHECKING, @@ -76,11 +77,12 @@ def my_tool(x: int) -> str: from langgraph._internal._runnable import RunnableCallable from langgraph.errors import GraphBubbleUp from langgraph.graph.message import REMOVE_ALL_MESSAGES +from langgraph.runtime import get_runtime from langgraph.types import Command, Send from pydantic import BaseModel, ValidationError if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Sequence from langchain_core.runnables import RunnableConfig from langgraph.store.base import BaseStore @@ -101,6 +103,71 @@ def my_tool(x: int) -> str: ) +@dataclass() +class ToolCallRequest: + """Tool execution request passed to on_tool_call handlers. + + Attributes: + tool_call: Tool call dict with name, args, and id from model output. + tool: BaseTool instance to be invoked. + """ + + tool_call: ToolCall + tool: BaseTool + + +ToolCallHandler = Callable[ + [ToolCallRequest, Any, Any], + Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None], +] +"""Generator-based handler for intercepting tool execution. + +Receives (request, state, runtime), yields ToolCallRequest/ToolMessage/Command, +receives results via .send(). Returns None; the last value sent becomes the result. + +Exceptions propagate unless handle_tool_errors is configured on ToolNode. + +Type Parameters: + The handler signature is ``(ToolCallRequest, Any, Any) -> Generator[...]``: + + - **First Any (state)**: Typed as ``Any`` because ToolNode supports multiple input + formats (list, dict, BaseModel, ToolCallWithContext). When used in ``create_agent``, + state will be the agent's StateT (dict with "messages" key). When used standalone, + state matches the input type passed to ToolNode. + + - **Second Any (runtime)**: Typed as ``Any`` because runtime is optional and only + available when ToolNode runs within a LangGraph graph. Will be ``None`` in + standalone usage or unit tests. + + Note: + When implementing middleware for ``create_agent``, use + ``AgentMiddleware.on_tool_call`` which has properly typed ``state: StateT`` + parameter for better type safety. + +Example: + Passthrough: + + def handler(request, state, runtime): + yield request + + Cache result: + + def handler(request, state, runtime): + if cached := get_cache(request): + yield ToolMessage(content=cached, tool_call_id=request.tool_call["id"]) + else: + yield request + + Retry with validation: + + def handler(request, state, runtime): + for attempt in range(3): + response = yield request + if valid(response) or attempt == 2: + return +""" + + class ToolCallWithContext(TypedDict): """ToolCall with additional context for graph state. @@ -125,23 +192,16 @@ class ToolCallWithContext(TypedDict): def msg_content_output(output: Any) -> str | list[dict]: - """Convert tool output to valid message content format. + """Convert tool output to ToolMessage content format. - LangChain ToolMessages accept either string content or a list of content blocks. - This function ensures tool outputs are properly formatted for message consumption - by attempting to preserve structured data when possible, falling back to JSON - serialization or string conversion. + Handles str, list[dict] (content blocks), and arbitrary objects by attempting + JSON serialization with fallback to str(). Args: - output: The raw output from a tool execution. Can be any type. + output: Tool execution output of any type. Returns: - Either a string representation of the output or a list of content blocks - if the output is already in the correct format for structured content. - - Note: - This function prioritizes backward compatibility by defaulting to JSON - serialization rather than supporting all possible message content formats. + String or list of content blocks suitable for ToolMessage.content. """ if isinstance(output, str) or ( isinstance(output, list) @@ -418,8 +478,9 @@ def __init__( | type[Exception] | tuple[type[Exception], ...] = _default_handle_tool_errors, messages_key: str = "messages", + on_tool_call: ToolCallHandler | None = None, ) -> None: - """Initialize the ToolNode with the provided tools and configuration. + """Initialize ToolNode with tools and configuration. Args: tools: Sequence of tools to make available for execution. @@ -427,6 +488,10 @@ def __init__( tags: Optional metadata tags. handle_tool_errors: Error handling configuration. messages_key: State key containing messages. + on_tool_call: Generator handler to intercept tool execution. Receives + ToolCallRequest, yields requests, messages, or Commands; receives + ToolMessage or Command via .send(). Final result is last value sent to + handler. Enables retries, caching, request modification, and control flow. """ super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False) self._tools_by_name: dict[str, BaseTool] = {} @@ -434,6 +499,7 @@ def __init__( self._tool_to_store_arg: dict[str, str | None] = {} self._handle_tool_errors = handle_tool_errors self._messages_key = messages_key + self._on_tool_call = on_tool_call for tool in tools: if not isinstance(tool, BaseTool): tool_ = create_tool(cast("type[BaseTool]", tool)) @@ -455,12 +521,23 @@ def _func( *, store: Optional[BaseStore], # noqa: UP045 ) -> Any: + try: + runtime = get_runtime() + except RuntimeError: + # Running outside of LangGraph runtime context (e.g., unit tests) + runtime = None + tool_calls, input_type = self._parse_input(input) tool_calls = [self._inject_tool_args(call, input, store) for call in tool_calls] + config_list = get_config_list(config, len(tool_calls)) input_types = [input_type] * len(tool_calls) + inputs = [input] * len(tool_calls) + runtimes = [runtime] * len(tool_calls) with get_executor_for_config(config) as executor: - outputs = [*executor.map(self._run_one, tool_calls, input_types, config_list)] + outputs = [ + *executor.map(self._run_one, tool_calls, input_types, config_list, inputs, runtimes) + ] return self._combine_tool_outputs(outputs, input_type) @@ -471,10 +548,16 @@ async def _afunc( *, store: Optional[BaseStore], # noqa: UP045 ) -> Any: + try: + runtime = get_runtime() + except RuntimeError: + # Running outside of LangGraph runtime context (e.g., unit tests) + runtime = None + tool_calls, input_type = self._parse_input(input) tool_calls = [self._inject_tool_args(call, input, store) for call in tool_calls] outputs = await asyncio.gather( - *(self._arun_one(call, input_type, config) for call in tool_calls) + *(self._arun_one(call, input_type, config, input, runtime) for call in tool_calls) ) return self._combine_tool_outputs(outputs, input_type) @@ -521,20 +604,30 @@ def _combine_tool_outputs( combined_outputs.append(parent_command) return combined_outputs - def _run_one( + def _execute_tool_sync( self, - call: ToolCall, + request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig, ) -> ToolMessage | Command: - """Run a single tool call synchronously.""" - if invalid_tool_message := self._validate_tool_call(call): - return invalid_tool_message + """Execute tool call with configured error handling. - try: - call_args = {**call, "type": "tool_call"} - tool = self.tools_by_name[call["name"]] + Args: + request: Tool execution request. + input_type: Input format. + config: Runnable configuration. + + Returns: + ToolMessage or Command. + Raises: + Exception: If tool fails and handle_tool_errors is False. + """ + call = request.tool_call + tool = request.tool + call_args = {**call, "type": "tool_call"} + + try: try: response = tool.invoke(call_args, config) except ValidationError as exc: @@ -552,6 +645,7 @@ def _run_one( except GraphBubbleUp: raise except Exception as e: + # Determine which exception types are handled handled_types: tuple[type[Exception], ...] if isinstance(self._handle_tool_errors, type) and issubclass( self._handle_tool_errors, Exception @@ -567,10 +661,11 @@ def _run_one( # default behavior is catching all exceptions handled_types = (Exception,) - # Unhandled + # Check if this error should be handled if not self._handle_tool_errors or not isinstance(e, handled_types): raise - # Handled + + # Error is handled - create error ToolMessage content = _handle_tool_error(e, flag=self._handle_tool_errors) return ToolMessage( content=content, @@ -579,28 +674,166 @@ def _run_one( status="error", ) + # Process successful response if isinstance(response, Command): - return self._validate_tool_command(response, call, input_type) + # Validate Command before returning to handler + return self._validate_tool_command(response, request.tool_call, input_type) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) return response + msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) - async def _arun_one( + def _run_one( # noqa: PLR0912 self, call: ToolCall, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig, + input: list[AnyMessage] | dict[str, Any] | BaseModel, + runtime: Any, ) -> ToolMessage | Command: - """Run a single tool call asynchronously.""" + """Execute single tool call with on_tool_call handler if configured. + + Args: + call: Tool call dict. + input_type: Input format. + config: Runnable configuration. + input: Agent state. + runtime: LangGraph runtime or None. + + Returns: + ToolMessage or Command. + """ if invalid_tool_message := self._validate_tool_call(call): return invalid_tool_message + tool = self.tools_by_name[call["name"]] + + # Create the tool request + tool_request = ToolCallRequest( + tool_call=call, + tool=tool, + ) + + if self._on_tool_call is None: + return self._execute_tool_sync(tool_request, input_type, config) + + # Extract state from ToolCallWithContext if present + state = self._extract_state(input) + + # Generator protocol: start generator, send messages, receive requests/messages + gen = self._on_tool_call(tool_request, state, runtime) + last_sent_value: ToolMessage | Command | None = None + first_yield = True + short_circuited_immediately = False + try: - call_args = {**call, "type": "tool_call"} - tool = self.tools_by_name[call["name"]] + yielded = next(gen) + except StopIteration: + # Handler ended immediately without yielding + msg = ( + "on_tool_call handler must yield at least once. " + "The final result is the last value sent to the handler." + ) + raise ValueError(msg) + + # Handler yielded - check if short-circuit (ToolMessage/Command) or normal (ToolCallRequest) + while True: + if isinstance(yielded, (ToolMessage, Command)): + # Handler yielded ToolMessage or Command + if first_yield: + # First yield is ToolMessage/Command = immediate short-circuit + short_circuited_immediately = True + elif short_circuited_immediately: + # Already short-circuited immediately, cannot yield again + msg = ( + "on_tool_call handler yielded multiple values after short-circuit. " + "After yielding ToolMessage or Command as first yield, handler must " + "end or throw." + ) + raise ValueError(msg) + # Otherwise: this is response modification after execution - allowed + + first_yield = False + + # Send it back to generator + last_sent_value = yielded + try: + yielded = gen.send(yielded) + # If generator yields again, continue the loop + except StopIteration: + # Handler ended - return the last value we sent to it + return last_sent_value + except Exception as e: + # Handler threw an exception + if not self._handle_tool_errors: + raise + # Convert to error message + content = _handle_tool_error(e, flag=self._handle_tool_errors) + return ToolMessage( + content=content, + name=tool_request.tool_call["name"], + tool_call_id=tool_request.tool_call["id"], + status="error", + ) + else: + # Normal flow: execute the tool with the request + if short_circuited_immediately: + msg = ( + "on_tool_call handler yielded ToolCallRequest after short-circuit. " + "After short-circuit, handler must end or throw." + ) + raise ValueError(msg) + + first_yield = False + + tool_message_or_command = self._execute_tool_sync(yielded, input_type, config) + + # Send result back to generator (ToolMessage or Command) + last_sent_value = tool_message_or_command + try: + yielded = gen.send(tool_message_or_command) + except StopIteration: + # Handler ended - return the last value we sent to it + return last_sent_value + except Exception as e: + # Handler threw an exception + if not self._handle_tool_errors: + raise + # Convert to error message + content = _handle_tool_error(e, flag=self._handle_tool_errors) + return ToolMessage( + content=content, + name=tool_request.tool_call["name"], + tool_call_id=tool_request.tool_call["id"], + status="error", + ) + + async def _execute_tool_async( + self, + request: ToolCallRequest, + input_type: Literal["list", "dict", "tool_calls"], + config: RunnableConfig, + ) -> ToolMessage | Command: + """Execute tool call asynchronously with configured error handling. + + Args: + request: Tool execution request. + input_type: Input format. + config: Runnable configuration. + + Returns: + ToolMessage or Command. + + Raises: + Exception: If tool fails and handle_tool_errors is False. + """ + call = request.tool_call + tool = request.tool + call_args = {**call, "type": "tool_call"} + try: try: response = await tool.ainvoke(call_args, config) except ValidationError as exc: @@ -618,6 +851,7 @@ async def _arun_one( except GraphBubbleUp: raise except Exception as e: + # Determine which exception types are handled handled_types: tuple[type[Exception], ...] if isinstance(self._handle_tool_errors, type) and issubclass( self._handle_tool_errors, Exception @@ -633,12 +867,12 @@ async def _arun_one( # default behavior is catching all exceptions handled_types = (Exception,) - # Unhandled + # Check if this error should be handled if not self._handle_tool_errors or not isinstance(e, handled_types): raise - # Handled - content = _handle_tool_error(e, flag=self._handle_tool_errors) + # Error is handled - create error ToolMessage + content = _handle_tool_error(e, flag=self._handle_tool_errors) return ToolMessage( content=content, name=call["name"], @@ -646,14 +880,144 @@ async def _arun_one( status="error", ) + # Process successful response if isinstance(response, Command): - return self._validate_tool_command(response, call, input_type) + # Validate Command before returning to handler + return self._validate_tool_command(response, request.tool_call, input_type) if isinstance(response, ToolMessage): response.content = cast("str | list", msg_content_output(response.content)) return response + msg = f"Tool {call['name']} returned unexpected type: {type(response)}" raise TypeError(msg) + async def _arun_one( # noqa: PLR0912 + self, + call: ToolCall, + input_type: Literal["list", "dict", "tool_calls"], + config: RunnableConfig, + input: list[AnyMessage] | dict[str, Any] | BaseModel, + runtime: Any, + ) -> ToolMessage | Command: + """Execute single tool call asynchronously with on_tool_call handler if configured. + + Args: + call: Tool call dict. + input_type: Input format. + config: Runnable configuration. + input: Agent state. + runtime: LangGraph runtime or None. + + Returns: + ToolMessage or Command. + """ + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + + tool = self.tools_by_name[call["name"]] + + # Create the tool request + tool_request = ToolCallRequest( + tool_call=call, + tool=tool, + ) + + if self._on_tool_call is None: + return await self._execute_tool_async(tool_request, input_type, config) + + # Extract state from ToolCallWithContext if present + state = self._extract_state(input) + + # Generator protocol: handler is sync generator, tool execution is async + gen = self._on_tool_call(tool_request, state, runtime) + last_sent_value: ToolMessage | Command | None = None + first_yield = True + short_circuited_immediately = False + + try: + yielded = next(gen) + except StopIteration: + # Handler ended immediately without yielding + msg = ( + "on_tool_call handler must yield at least once. " + "The final result is the last value sent to the handler." + ) + raise ValueError(msg) + + # Handler yielded - check if short-circuit (ToolMessage/Command) or normal (ToolCallRequest) + while True: + if isinstance(yielded, (ToolMessage, Command)): + # Handler yielded ToolMessage or Command + if first_yield: + # First yield is ToolMessage/Command = immediate short-circuit + short_circuited_immediately = True + elif short_circuited_immediately: + # Already short-circuited immediately, cannot yield again + msg = ( + "on_tool_call handler yielded multiple values after short-circuit. " + "After yielding ToolMessage or Command as first yield, handler must " + "end or throw." + ) + raise ValueError(msg) + # Otherwise: this is response modification after execution - allowed + + first_yield = False + + # Send it back to generator + last_sent_value = yielded + try: + yielded = gen.send(yielded) + # If generator yields again, continue the loop + except StopIteration: + # Handler ended - return the last value we sent to it + return last_sent_value + except Exception as e: + # Handler threw an exception + if not self._handle_tool_errors: + raise + # Convert to error message + content = _handle_tool_error(e, flag=self._handle_tool_errors) + return ToolMessage( + content=content, + name=tool_request.tool_call["name"], + tool_call_id=tool_request.tool_call["id"], + status="error", + ) + else: + # Normal flow: execute the tool with the request + if short_circuited_immediately: + msg = ( + "on_tool_call handler yielded ToolCallRequest after short-circuit. " + "After short-circuit, handler must end or throw." + ) + raise ValueError(msg) + + first_yield = False + + tool_message_or_command = await self._execute_tool_async( + yielded, input_type, config + ) + + # Send result back to generator (ToolMessage or Command) + last_sent_value = tool_message_or_command + try: + yielded = gen.send(tool_message_or_command) + except StopIteration: + # Handler ended - return the last value we sent to it + return last_sent_value + except Exception as e: + # Handler threw an exception + if not self._handle_tool_errors: + raise + # Convert to error message + content = _handle_tool_error(e, flag=self._handle_tool_errors) + return ToolMessage( + content=content, + name=tool_request.tool_call["name"], + tool_call_id=tool_request.tool_call["id"], + status="error", + ) + def _parse_input( self, input: list[AnyMessage] | dict[str, Any] | BaseModel, @@ -705,6 +1069,21 @@ def _validate_tool_call(self, call: ToolCall) -> ToolMessage | None: ) return None + def _extract_state( + self, input: list[AnyMessage] | dict[str, Any] | BaseModel + ) -> list[AnyMessage] | dict[str, Any] | BaseModel: + """Extract state from input, handling ToolCallWithContext if present. + + Args: + input: The input which may be raw state or ToolCallWithContext. + + Returns: + The actual state to pass to on_tool_call handlers. + """ + if isinstance(input, dict) and input.get("__type") == "tool_call_with_context": + return input["state"] + return input + def _inject_state( self, tool_call: ToolCall, diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_on_tool_call_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/test_on_tool_call_middleware.py new file mode 100644 index 0000000000000..2c329985c7fe2 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_on_tool_call_middleware.py @@ -0,0 +1,769 @@ +"""Unit tests for on_tool_call middleware integration.""" + +from collections.abc import Generator + +import pytest +from langchain_core.messages import HumanMessage, ToolCall, ToolMessage +from langchain_core.tools import tool +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.types import Command + +from langchain.agents.factory import create_agent +from langchain.agents.middleware.types import AgentMiddleware +from langchain.tools.tool_node import ToolCallRequest +from tests.unit_tests.agents.test_middleware_agent import FakeToolCallingModel + + +@tool +def search(query: str) -> str: + """Search for information.""" + return f"Results for: {query}" + + +@tool +def calculator(expression: str) -> str: + """Calculate an expression.""" + return f"Result: {expression}" + + +@tool +def failing_tool(input: str) -> str: + """Tool that always fails.""" + msg = f"Failed: {input}" + raise ValueError(msg) + + +def test_simple_logging_middleware() -> None: + """Test middleware that logs tool calls.""" + call_log = [] + + class LoggingMiddleware(AgentMiddleware): + """Middleware that logs tool calls.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + call_log.append(f"before_{request.tool.name}") + response = yield request + call_log.append(f"after_{request.tool.name}") + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[LoggingMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search for test")]}, + {"configurable": {"thread_id": "test"}}, + ) + + assert len(call_log) == 2 + assert call_log[0] == "before_search" + assert call_log[1] == "after_search" + assert len(result["messages"]) > 0 + + +def test_request_modification_middleware() -> None: + """Test middleware that modifies tool call arguments.""" + + class ModifyArgsMiddleware(AgentMiddleware): + """Middleware that modifies tool arguments.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + # Add prefix to query + if request.tool.name == "search": + original_query = request.tool_call["args"]["query"] + request.tool_call["args"]["query"] = f"modified: {original_query}" + response = yield request + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[ModifyArgsMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert "modified: test" in tool_messages[0].content + + +def test_response_inspection_middleware() -> None: + """Test middleware that inspects tool responses.""" + inspected_responses = [] + + class ResponseInspectionMiddleware(AgentMiddleware): + """Middleware that inspects responses.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + response = yield request + + # Record response details + if isinstance(response, ToolMessage): + inspected_responses.append( + { + "tool_name": request.tool.name, + "content": response.content, + } + ) + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[ResponseInspectionMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + # Middleware should have inspected the response + assert len(inspected_responses) == 1 + assert inspected_responses[0]["tool_name"] == "search" + + +def test_conditional_retry_middleware() -> None: + """Test middleware that retries tool calls based on response content.""" + call_count = 0 + + class ConditionalRetryMiddleware(AgentMiddleware): + """Middleware that retries based on response content.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + nonlocal call_count + max_retries = 2 + + for attempt in range(max_retries): + response = yield request + call_count += 1 + + # Check if we should retry based on content + if ( + isinstance(response, ToolMessage) + and "retry_marker" in response.content + and attempt < max_retries - 1 + ): + # Continue to retry + continue + + # Return on success or final attempt + + # Use search tool which always succeeds - we'll modify request to test retry logic + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[ConditionalRetryMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Middleware should have been called at least once + assert call_count >= 1 + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + + +def test_multiple_middleware_composition() -> None: + """Test that multiple middleware compose correctly.""" + call_log = [] + + class OuterMiddleware(AgentMiddleware): + """Outer middleware.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + call_log.append("outer_before") + response = yield request + call_log.append("outer_after") + + class InnerMiddleware(AgentMiddleware): + """Inner middleware.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + call_log.append("inner_before") + response = yield request + call_log.append("inner_after") + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + # First middleware is outermost + agent = create_agent( + model=model, + tools=[search], + middleware=[OuterMiddleware(), InnerMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify correct composition order + assert call_log == ["outer_before", "inner_before", "inner_after", "outer_after"] + assert len(result["messages"]) > 0 + + +def test_middleware_with_multiple_tool_calls() -> None: + """Test middleware handles multiple tool calls correctly.""" + call_log = [] + + class LoggingMiddleware(AgentMiddleware): + """Middleware that logs tool calls.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + call_log.append(request.tool.name) + response = yield request + + model = FakeToolCallingModel( + tool_calls=[ + [ + ToolCall(name="search", args={"query": "test1"}, id="1"), + ToolCall(name="calculator", args={"expression": "1+1"}, id="2"), + ], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search, calculator], + middleware=[LoggingMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Use tools")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Each tool call should be logged + assert "search" in call_log + assert "calculator" in call_log + assert len(call_log) == 2 + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 2 + + +def test_middleware_access_to_state() -> None: + """Test middleware can access agent state.""" + state_seen = [] + + class StateInspectionMiddleware(AgentMiddleware): + """Middleware that inspects state.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + # Record state - state could be dict or list + if state is not None: + if isinstance(state, dict) and "messages" in state: + state_seen.append(("dict", len(state["messages"]))) + elif isinstance(state, list): + state_seen.append(("list", len(state))) + else: + state_seen.append(("other", type(state).__name__)) + response = yield request + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[StateInspectionMiddleware()], + checkpointer=InMemorySaver(), + ) + + agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Middleware should have seen state (state is passed to on_tool_call) + assert len(state_seen) >= 1 + + +def test_middleware_without_on_tool_call() -> None: + """Test that middleware without on_tool_call hook works normally.""" + + class NoOpMiddleware(AgentMiddleware): + """Middleware without on_tool_call.""" + + def before_model(self, state, runtime): + """Just a dummy hook.""" + return None + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[NoOpMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Should work normally + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + + +def test_generator_composition_immediate_outer_return() -> None: + """Test composition when outer generator returns after first yield.""" + call_log = [] + + class ImmediateReturnMiddleware(AgentMiddleware): + """Outer middleware that returns after first yield.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + call_log.append("outer_yield") + # Yield once, receive response from inner + response = yield request + call_log.append("outer_got_response") + # Yield modified message to make it the final result + modified = ToolMessage( + content="Outer intercepted", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + yield modified + + class InnerMiddleware(AgentMiddleware): + """Inner middleware.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + call_log.append("inner_called") + response = yield request + call_log.append("inner_got_response") + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[ImmediateReturnMiddleware(), InnerMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Both should be called, outer intercepts the response + assert call_log == ["outer_yield", "inner_called", "inner_got_response", "outer_got_response"] + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert "Outer intercepted" in tool_messages[0].content + + +def test_generator_composition_short_circuit() -> None: + """Test composition when inner generator short-circuits after first yield.""" + call_log = [] + + class OuterMiddleware(AgentMiddleware): + """Outer middleware.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + call_log.append("outer_before") + response = yield request + call_log.append("outer_after") + # Modify response from inner + if isinstance(response, ToolMessage): + modified = ToolMessage( + content=f"outer_wrapped: {response.content}", + tool_call_id=response.tool_call_id, + name=response.name, + ) + yield modified + + class InnerShortCircuitMiddleware(AgentMiddleware): + """Inner middleware that short-circuits without calling actual tool.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + call_log.append("inner_short_circuit") + # Yield request but return custom response instead of actual tool result + _ = yield request + # Return custom result without using actual tool response + yield ToolMessage( + content="inner_short_circuit_result", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[OuterMiddleware(), InnerShortCircuitMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify order: outer_before -> inner short circuits -> outer_after + assert call_log == ["outer_before", "inner_short_circuit", "outer_after"] + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert "outer_wrapped: inner_short_circuit_result" in tool_messages[0].content + + +def test_generator_composition_nested_retries() -> None: + """Test composition when both outer and inner generators retry.""" + call_log = [] + + class OuterRetryMiddleware(AgentMiddleware): + """Outer middleware with retry logic.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + for outer_attempt in range(2): + call_log.append(f"outer_{outer_attempt}") + response = yield request + + if isinstance(response, ToolMessage) and response.content == "inner_final_failure": + # Inner failed, retry once + continue + + class InnerRetryMiddleware(AgentMiddleware): + """Inner middleware with retry logic.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + for inner_attempt in range(2): + call_log.append(f"inner_{inner_attempt}") + response = yield request + + # Check for error in tool result + if isinstance(response, ToolMessage): + if inner_attempt == 0 and "Results for:" in response.content: + # First attempt succeeded, but let's pretend it's a soft failure + # to test inner retry + continue + + # Inner exhausted retries + yield ToolMessage( + content="inner_final_failure", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[OuterRetryMiddleware(), InnerRetryMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify nested retry pattern + assert "outer_0" in call_log + assert "inner_0" in call_log + assert "inner_1" in call_log + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + + +def test_generator_composition_three_levels() -> None: + """Test composition with three middleware levels.""" + call_log = [] + + class OuterMiddleware(AgentMiddleware): + """Outermost middleware.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + call_log.append("outer_before") + response = yield request + call_log.append("outer_after") + + class MiddleMiddleware(AgentMiddleware): + """Middle middleware.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + call_log.append("middle_before") + response = yield request + call_log.append("middle_after") + + class InnerMiddleware(AgentMiddleware): + """Innermost middleware.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + call_log.append("inner_before") + response = yield request + call_log.append("inner_after") + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[OuterMiddleware(), MiddleMiddleware(), InnerMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify correct nesting order + assert call_log == [ + "outer_before", + "middle_before", + "inner_before", + "inner_after", + "middle_after", + "outer_after", + ] + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + + +def test_generator_composition_return_value_extraction() -> None: + """Test that return values are properly extracted from StopIteration.""" + final_content = [] + + class ModifyingMiddleware(AgentMiddleware): + """Middleware that modifies the final result.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest, ToolMessage, None]: + response = yield request + + # Explicitly return a modified response + if isinstance(response, ToolMessage): + modified = ToolMessage( + content=f"modified: {response.content}", + tool_call_id=response.tool_call_id, + name=response.name, + ) + final_content.append(modified.content) + yield modified + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[ModifyingMiddleware()], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + # Verify the returned value was properly extracted + assert "modified:" in tool_messages[0].content + assert len(final_content) == 1 + assert "modified:" in final_content[0] + + +def test_generator_composition_with_mixed_passthrough_and_intercepting() -> None: + """Test composition with mix of pass-through and intercepting generators.""" + call_log = [] + + class FirstPassthroughMiddleware(AgentMiddleware): + """First middleware that passes through.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + call_log.append("first_before") + response = yield request + call_log.append("first_after") + + class SecondInterceptingMiddleware(AgentMiddleware): + """Second middleware that intercepts and returns custom result.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + call_log.append("second_intercept") + # Yield request but ignore the actual result + _ = yield request + # Return custom result + yield ToolMessage( + content="intercepted_result", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + + class ThirdPassthroughMiddleware(AgentMiddleware): + """Third middleware that passes through.""" + + def on_tool_call( + self, request: ToolCallRequest, state, runtime + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + call_log.append("third_called") + response = yield request + call_log.append("third_after") + + model = FakeToolCallingModel( + tool_calls=[ + [ToolCall(name="search", args={"query": "test"}, id="1")], + [], + ] + ) + + agent = create_agent( + model=model, + tools=[search], + middleware=[ + FirstPassthroughMiddleware(), + SecondInterceptingMiddleware(), + ThirdPassthroughMiddleware(), + ], + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("Search")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # All middleware are called, second intercepts and returns custom result + assert call_log == [ + "first_before", + "second_intercept", + "third_called", + "third_after", + "first_after", + ] + + tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)] + assert len(tool_messages) == 1 + assert "intercepted_result" in tool_messages[0].content diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py new file mode 100644 index 0000000000000..adc07fbee3494 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py @@ -0,0 +1,1226 @@ +"""Unit tests for on_tool_call handler in ToolNode.""" + +from collections.abc import Generator +from typing import Any + +import pytest +from langchain_core.messages import AIMessage, ToolCall, ToolMessage +from langchain_core.tools import tool +from langgraph.types import Command + +from langchain.tools.tool_node import ( + ToolCallRequest, + ToolNode, +) + +pytestmark = pytest.mark.anyio + + +@tool +def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + +@tool +def failing_tool(a: int) -> int: + """A tool that always fails.""" + msg = f"This tool always fails (input: {a})" + raise ValueError(msg) + + +@tool +def command_tool(goto: str) -> Command: + """A tool that returns a Command.""" + return Command(goto=goto) + + +def test_passthrough_handler() -> None: + """Test a simple passthrough handler that doesn't modify anything.""" + + def passthrough_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Simple passthrough handler.""" + yield request + + tool_node = ToolNode([add], on_tool_call=passthrough_handler) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_1", + } + ], + ) + ] + } + ) + + tool_message = result["messages"][-1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "3" + assert tool_message.tool_call_id == "call_1" + assert tool_message.status != "error" + + +async def test_passthrough_handler_async() -> None: + """Test passthrough handler with async tool.""" + + def passthrough_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Simple passthrough handler.""" + yield request + + tool_node = ToolNode([add], on_tool_call=passthrough_handler) + + result = await tool_node.ainvoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 2, "b": 3}, + "id": "call_2", + } + ], + ) + ] + } + ) + + tool_message = result["messages"][-1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "5" + assert tool_message.tool_call_id == "call_2" + + +def test_modify_arguments() -> None: + """Test handler that modifies tool arguments before execution.""" + + def modify_args_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that doubles the input arguments.""" + # Modify the arguments + request.tool_call["args"]["a"] *= 2 + request.tool_call["args"]["b"] *= 2 + + yield request + + tool_node = ToolNode([add], on_tool_call=modify_args_handler) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_3", + } + ], + ) + ] + } + ) + + tool_message = result["messages"][-1] + assert isinstance(tool_message, ToolMessage) + # Original args were (1, 2), doubled to (2, 4), so result is 6 + assert tool_message.content == "6" + + +def test_handler_validation_no_return() -> None: + """Test that handler with explicit None return works (returns last sent message).""" + + def handler_with_explicit_none( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that returns None explicitly - should still work.""" + yield request + # Explicit None return - protocol uses last sent message as result + return None + + tool_node = ToolNode([add], on_tool_call=handler_with_explicit_none) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_6", + } + ], + ) + ] + } + ) + + assert isinstance(result, dict) + messages = result["messages"] + assert len(messages) == 1 + assert isinstance(messages[0], ToolMessage) + assert messages[0].content == "3" + + +def test_handler_validation_no_yield() -> None: + """Test that handler must yield at least once.""" + + def bad_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that ends immediately without yielding.""" + # End immediately without yielding anything + # Need unreachable yield to make this a generator function + if False: + yield request + return + + tool_node = ToolNode([add], on_tool_call=bad_handler) + + with pytest.raises(ValueError, match="must yield at least once"): + tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_7", + } + ], + ) + ] + } + ) + + +def test_handler_with_handle_tool_errors_true() -> None: + """Test that handle_tool_errors=True works with on_tool_call handler.""" + + def passthrough_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Simple passthrough handler.""" + message = yield request + # When handle_tool_errors=True, errors should be converted to error messages + assert isinstance(message, ToolMessage) + assert message.status == "error" + + tool_node = ToolNode([failing_tool], on_tool_call=passthrough_handler, handle_tool_errors=True) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "failing", + tool_calls=[ + { + "name": "failing_tool", + "args": {"a": 1}, + "id": "call_9", + } + ], + ) + ] + } + ) + + tool_message = result["messages"][-1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.status == "error" + + +def test_multiple_tool_calls_with_handler() -> None: + """Test handler with multiple tool calls in one message.""" + call_count = 0 + + def counting_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that counts calls.""" + nonlocal call_count + call_count += 1 + yield request + + tool_node = ToolNode([add], on_tool_call=counting_handler) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding multiple", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_10", + }, + { + "name": "add", + "args": {"a": 3, "b": 4}, + "id": "call_11", + }, + { + "name": "add", + "args": {"a": 5, "b": 6}, + "id": "call_12", + }, + ], + ) + ] + } + ) + + # Handler should be called once for each tool call + assert call_count == 3 + + # Verify all results + messages = result["messages"] + assert len(messages) == 3 + assert all(isinstance(m, ToolMessage) for m in messages) + assert messages[0].content == "3" + assert messages[1].content == "7" + assert messages[2].content == "11" + + +def test_tool_call_request_dataclass() -> None: + """Test ToolCallRequest dataclass.""" + tool_call: ToolCall = {"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1"} + + request = ToolCallRequest(tool_call=tool_call, tool=add) + + assert request.tool_call == tool_call + assert request.tool == add + assert request.tool_call["name"] == "add" + + +async def test_handler_with_async_execution() -> None: + """Test handler works correctly with async tool execution.""" + + @tool + async def async_add(a: int, b: int) -> int: + """Async add two numbers.""" + return a + b + + def modifying_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that modifies arguments.""" + # Add 10 to both arguments + request.tool_call["args"]["a"] += 10 + request.tool_call["args"]["b"] += 10 + yield request + + tool_node = ToolNode([async_add], on_tool_call=modifying_handler) + + result = await tool_node.ainvoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "async_add", + "args": {"a": 1, "b": 2}, + "id": "call_13", + } + ], + ) + ] + } + ) + + tool_message = result["messages"][-1] + assert isinstance(tool_message, ToolMessage) + # Original: 1 + 2 = 3, with modifications: 11 + 12 = 23 + assert tool_message.content == "23" + + +def test_short_circuit_with_tool_message() -> None: + """Test handler that yields ToolMessage to short-circuit tool execution.""" + + def short_circuit_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that returns cached result without executing tool.""" + # Yield a ToolMessage directly instead of a ToolCallRequest + cached_result = ToolMessage( + content="cached_result", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + message = yield cached_result + # Message should be our cached message sent back + assert message == cached_result + + tool_node = ToolNode([add], on_tool_call=short_circuit_handler) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_16", + } + ], + ) + ] + } + ) + + tool_message = result["messages"][-1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "cached_result" + assert tool_message.tool_call_id == "call_16" + assert tool_message.name == "add" + + +async def test_short_circuit_with_tool_message_async() -> None: + """Test async handler that yields ToolMessage to short-circuit tool execution.""" + + def short_circuit_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that returns cached result without executing tool.""" + cached_result = ToolMessage( + content="async_cached_result", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + yield cached_result + + tool_node = ToolNode([add], on_tool_call=short_circuit_handler) + + result = await tool_node.ainvoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 2, "b": 3}, + "id": "call_17", + } + ], + ) + ] + } + ) + + tool_message = result["messages"][-1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "async_cached_result" + assert tool_message.tool_call_id == "call_17" + + +def test_conditional_short_circuit() -> None: + """Test handler that conditionally short-circuits based on request.""" + call_count = {"count": 0} + + def conditional_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that caches even numbers, executes odd.""" + call_count["count"] += 1 + a = request.tool_call["args"]["a"] + + if a % 2 == 0: + # Even: use cached result + cached = ToolMessage( + content=f"cached_{a}", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + yield cached + else: + # Odd: execute normally + yield request + + tool_node = ToolNode([add], on_tool_call=conditional_handler) + + # Test with even number (should be cached) + result1 = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 2, "b": 3}, + "id": "call_18", + } + ], + ) + ] + } + ) + + tool_message1 = result1["messages"][-1] + assert tool_message1.content == "cached_2" + + # Test with odd number (should execute) + result2 = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 3, "b": 4}, + "id": "call_19", + } + ], + ) + ] + } + ) + + tool_message2 = result2["messages"][-1] + assert tool_message2.content == "7" # Actual execution: 3 + 4 + + +def test_direct_return_tool_message() -> None: + """Test handler that returns ToolMessage directly without yielding.""" + + def direct_return_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that returns ToolMessage directly.""" + # Return ToolMessage directly + # Note: We still need this to be a generator, so we use return (not yield) + # The generator protocol will catch the StopIteration with the return value + if False: + yield # Makes this a generator function + yield ToolMessage( + content="direct_return", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + + tool_node = ToolNode([add], on_tool_call=direct_return_handler) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_21", + } + ], + ) + ] + } + ) + + tool_message = result["messages"][-1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "direct_return" + assert tool_message.tool_call_id == "call_21" + assert tool_message.name == "add" + + +async def test_direct_return_tool_message_async() -> None: + """Test async handler that returns ToolMessage directly without yielding.""" + + def direct_return_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that returns ToolMessage directly.""" + if False: + yield # Makes this a generator function + yield ToolMessage( + content="async_direct_return", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + + tool_node = ToolNode([add], on_tool_call=direct_return_handler) + + result = await tool_node.ainvoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 2, "b": 3}, + "id": "call_22", + } + ], + ) + ] + } + ) + + tool_message = result["messages"][-1] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "async_direct_return" + assert tool_message.tool_call_id == "call_22" + + +def test_conditional_direct_return() -> None: + """Test handler that conditionally returns ToolMessage directly or executes tool.""" + + def conditional_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that returns cached or executes based on condition.""" + a = request.tool_call["args"]["a"] + + if a == 0: + # Return ToolMessage directly for zero + if False: + yield # Makes this a generator + yield ToolMessage( + content="zero_cached", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + else: + # Execute tool normally + yield request + + tool_node = ToolNode([add], on_tool_call=conditional_handler) + + # Test with zero (should return directly) + result1 = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 0, "b": 5}, + "id": "call_23", + } + ], + ) + ] + } + ) + + tool_message1 = result1["messages"][-1] + assert tool_message1.content == "zero_cached" + + # Test with non-zero (should execute) + result2 = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 3, "b": 4}, + "id": "call_24", + } + ], + ) + ] + } + ) + + tool_message2 = result2["messages"][-1] + assert tool_message2.content == "7" # Actual execution: 3 + 4 + + +def test_handler_can_throw_exception() -> None: + """Test that a handler can throw an exception to signal error.""" + + def throwing_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that throws an exception after receiving response.""" + response = yield request + # Check response and throw if invalid + if isinstance(response, ToolMessage): + msg = "Handler rejected the response" + raise ValueError(msg) # noqa: TRY004 + + tool_node = ToolNode([add], on_tool_call=throwing_handler, handle_tool_errors=True) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_exc_1", + } + ], + ) + ] + } + ) + + # Should get error message due to handle_tool_errors=True + messages = result["messages"] + assert len(messages) == 1 + assert isinstance(messages[0], ToolMessage) + assert messages[0].status == "error" + assert "Handler rejected the response" in messages[0].content + + +def test_handler_throw_without_handle_errors() -> None: + """Test that exception propagates when handle_tool_errors=False.""" + + def throwing_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that throws an exception.""" + yield request + msg = "Handler error" + raise ValueError(msg) + + tool_node = ToolNode([add], on_tool_call=throwing_handler, handle_tool_errors=False) + + with pytest.raises(ValueError, match="Handler error"): + tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_exc_2", + } + ], + ) + ] + } + ) + + +def test_retry_middleware_with_exception() -> None: + """Test retry middleware pattern that throws after exhausting retries.""" + attempt_count = {"count": 0} + + def retry_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that retries up to 3 times, then throws.""" + max_retries = 3 + + for attempt in range(max_retries): + attempt_count["count"] += 1 + response = yield request + + # Simulate checking for retriable errors + # In real use case, would check response.status or content + if isinstance(response, ToolMessage) and attempt < max_retries - 1: + # Could retry based on some condition + # For this test, just succeed immediately + break + + # If we exhausted retries, could throw + # For this test, we succeed on first try + + tool_node = ToolNode([add], on_tool_call=retry_handler) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_exc_3", + } + ], + ) + ] + } + ) + + # Should succeed after 1 attempt + assert attempt_count["count"] == 1 + messages = result["messages"] + assert len(messages) == 1 + assert isinstance(messages[0], ToolMessage) + assert messages[0].content == "3" + + +async def test_async_handler_can_throw_exception() -> None: + """Test that async execution also supports exception throwing.""" + + def throwing_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that throws an exception after receiving response.""" + response = yield request + if isinstance(response, ToolMessage): + msg = "Async handler rejected the response" + raise ValueError(msg) # noqa: TRY004 + + tool_node = ToolNode([add], on_tool_call=throwing_handler, handle_tool_errors=True) + + result = await tool_node.ainvoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_exc_4", + } + ], + ) + ] + } + ) + + # Should get error message due to handle_tool_errors=True + messages = result["messages"] + assert len(messages) == 1 + assert isinstance(messages[0], ToolMessage) + assert messages[0].status == "error" + assert "Async handler rejected the response" in messages[0].content + + +def test_handler_cannot_yield_multiple_tool_messages() -> None: + """Test that yielding multiple ToolMessages is rejected.""" + + def multi_message_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that incorrectly yields multiple ToolMessages.""" + # First short-circuit + yield ToolMessage("first", tool_call_id=request.tool_call["id"], name="add") + # Second short-circuit - should fail + yield ToolMessage("second", tool_call_id=request.tool_call["id"], name="add") + + tool_node = ToolNode([add], on_tool_call=multi_message_handler) + + with pytest.raises( + ValueError, + match="on_tool_call handler yielded multiple values after short-circuit", + ): + tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_multi_1", + } + ], + ) + ] + } + ) + + +def test_handler_cannot_yield_request_after_tool_message() -> None: + """Test that yielding ToolCallRequest after ToolMessage is rejected.""" + + def confused_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that incorrectly switches from short-circuit to execution.""" + # First short-circuit with cached result + yield ToolMessage("cached", tool_call_id=request.tool_call["id"], name="add") + # Then try to execute - should fail + yield request + + tool_node = ToolNode([add], on_tool_call=confused_handler) + + with pytest.raises( + ValueError, + match="on_tool_call handler yielded ToolCallRequest after short-circuit", + ): + tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_confused_1", + } + ], + ) + ] + } + ) + + +def test_handler_can_short_circuit_with_command() -> None: + """Test that handler can short-circuit by yielding Command.""" + + def command_handler( + _request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that short-circuits with Command.""" + # Short-circuit with Command instead of executing tool + yield Command(goto="end") + + tool_node = ToolNode([add], on_tool_call=command_handler) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_cmd_1", + } + ], + ) + ] + } + ) + + # Should get Command in result list + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], Command) + assert result[0].goto == "end" + + +def test_handler_cannot_yield_multiple_commands() -> None: + """Test that yielding multiple Commands is rejected.""" + + def multi_command_handler( + _request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that incorrectly yields multiple Commands.""" + # First short-circuit + yield Command(goto="step1") + # Second short-circuit - should fail + yield Command(goto="step2") + + tool_node = ToolNode([add], on_tool_call=multi_command_handler) + + with pytest.raises( + ValueError, + match="on_tool_call handler yielded multiple values after short-circuit", + ): + tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_multicmd_1", + } + ], + ) + ] + } + ) + + +def test_handler_cannot_yield_request_after_command() -> None: + """Test that yielding ToolCallRequest after Command is rejected.""" + + def command_then_request_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that incorrectly yields request after Command.""" + # First short-circuit with Command + yield Command(goto="somewhere") + # Then try to execute - should fail + yield request + + tool_node = ToolNode([add], on_tool_call=command_then_request_handler) + + with pytest.raises( + ValueError, + match="on_tool_call handler yielded ToolCallRequest after short-circuit", + ): + tool_node.invoke( + { + "messages": [ + AIMessage( + "adding", + tool_calls=[ + { + "name": "add", + "args": {"a": 1, "b": 2}, + "id": "call_cmdreq_1", + } + ], + ) + ] + } + ) + + +def test_tool_returning_command_sent_to_handler() -> None: + """Test that when tool returns Command, it's sent to handler.""" + received_commands = [] + + def command_inspector_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that inspects Command returned by tool.""" + result = yield request + # Should receive Command from tool + if isinstance(result, Command): + received_commands.append(result) + # Can end here, returning the Command + + tool_node = ToolNode([command_tool], on_tool_call=command_inspector_handler) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "navigating", + tool_calls=[ + { + "name": "command_tool", + "args": {"goto": "next_step"}, + "id": "call_cmdtool_1", + } + ], + ) + ] + } + ) + + # Handler should have received the Command + assert len(received_commands) == 1 + assert received_commands[0].goto == "next_step" + + # Final result should be the Command in result list + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], Command) + assert result[0].goto == "next_step" + + +def test_handler_can_modify_command_from_tool() -> None: + """Test that handler can inspect and modify Command from tool.""" + + def command_modifier_handler( + request: ToolCallRequest, _state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that modifies Command returned by tool.""" + result = yield request + # Modify the Command + if isinstance(result, Command): + modified_cmd = Command(goto=f"modified_{result.goto}") + yield modified_cmd + # Otherwise pass through + + tool_node = ToolNode([command_tool], on_tool_call=command_modifier_handler) + + result = tool_node.invoke( + { + "messages": [ + AIMessage( + "navigating", + tool_calls=[ + { + "name": "command_tool", + "args": {"goto": "original"}, + "id": "call_cmdmod_1", + } + ], + ) + ] + } + ) + + # Final result should be the modified Command in result list + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], Command) + assert result[0].goto == "modified_original" + + +def test_state_extraction_with_dict_input() -> None: + """Test that state is correctly passed when input is a dict.""" + state_seen = [] + + def state_inspector_handler( + request: ToolCallRequest, state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that records the state it receives.""" + state_seen.append(state) + yield request + + tool_node = ToolNode([add], on_tool_call=state_inspector_handler) + + input_state = { + "messages": [ + AIMessage( + "test", + tool_calls=[{"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1"}], + ) + ], + "other_field": "value", + } + + tool_node.invoke(input_state) + + # State should be the dict we passed in + assert len(state_seen) == 1 + assert state_seen[0] == input_state + assert isinstance(state_seen[0], dict) + assert "messages" in state_seen[0] + assert "other_field" in state_seen[0] + assert "__type" not in state_seen[0] + + +def test_state_extraction_with_list_input() -> None: + """Test that state is correctly passed when input is a list.""" + state_seen = [] + + def state_inspector_handler( + request: ToolCallRequest, state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that records the state it receives.""" + state_seen.append(state) + yield request + + tool_node = ToolNode([add], on_tool_call=state_inspector_handler) + + input_state = [ + AIMessage( + "test", + tool_calls=[{"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1"}], + ) + ] + + tool_node.invoke(input_state) + + # State should be the list we passed in + assert len(state_seen) == 1 + assert state_seen[0] == input_state + assert isinstance(state_seen[0], list) + + +def test_state_extraction_with_tool_call_with_context() -> None: + """Test that state is correctly extracted from ToolCallWithContext. + + This tests the scenario where ToolNode is invoked via the Send API in + create_agent, which wraps the tool call with additional context including + the graph state. + """ + state_seen = [] + + def state_inspector_handler( + request: ToolCallRequest, state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that records the state it receives.""" + state_seen.append(state) + yield request + + tool_node = ToolNode([add], on_tool_call=state_inspector_handler) + + # Simulate ToolCallWithContext as used by create_agent with Send API + actual_state = { + "messages": [AIMessage("test")], + "thread_model_call_count": 1, + "run_model_call_count": 1, + "custom_field": "custom_value", + } + + tool_call_with_context = { + "__type": "tool_call_with_context", + "tool_call": {"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1", "type": "tool_call"}, + "state": actual_state, + } + + tool_node.invoke(tool_call_with_context) + + # State should be the extracted state from ToolCallWithContext, not the wrapper + assert len(state_seen) == 1 + assert state_seen[0] == actual_state + assert isinstance(state_seen[0], dict) + assert "messages" in state_seen[0] + assert "thread_model_call_count" in state_seen[0] + assert "custom_field" in state_seen[0] + # Most importantly, __type should NOT be in the extracted state + assert "__type" not in state_seen[0] + # And tool_call should not be in the state + assert "tool_call" not in state_seen[0] + + +async def test_state_extraction_with_tool_call_with_context_async() -> None: + """Test that state is correctly extracted from ToolCallWithContext in async mode.""" + state_seen = [] + + def state_inspector_handler( + request: ToolCallRequest, state: Any, _runtime: Any + ) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]: + """Handler that records the state it receives.""" + state_seen.append(state) + yield request + + tool_node = ToolNode([add], on_tool_call=state_inspector_handler) + + # Simulate ToolCallWithContext as used by create_agent with Send API + actual_state = { + "messages": [AIMessage("test")], + "thread_model_call_count": 1, + "run_model_call_count": 1, + } + + tool_call_with_context = { + "__type": "tool_call_with_context", + "tool_call": {"name": "add", "args": {"a": 1, "b": 2}, "id": "call_1", "type": "tool_call"}, + "state": actual_state, + } + + await tool_node.ainvoke(tool_call_with_context) + + # State should be the extracted state from ToolCallWithContext + assert len(state_seen) == 1 + assert state_seen[0] == actual_state + assert "__type" not in state_seen[0] + assert "tool_call" not in state_seen[0]