Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 125 additions & 2 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@
from langchain.tools import ToolNode

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Callable, Generator, Sequence

from langchain_core.runnables import Runnable
from langgraph.cache.base import BaseCache
from langgraph.graph.state import CompiledStateGraph
from langgraph.store.base import BaseStore
from langgraph.types import Checkpointer

from langchain.tools.tool_node import ToolCallHandler, ToolCallRequest, ToolCallResponse

STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."

ResponseT = TypeVar("ResponseT")
Expand Down Expand Up @@ -192,6 +194,112 @@ def _handle_structured_output_error(
return False, ""


def _chain_tool_call_handlers(
handlers: Sequence[ToolCallHandler],
) -> ToolCallHandler | None:
"""Compose multiple tool call handlers into a single middleware stack.

Args:
handlers: Handlers in middleware order (first = outermost layer).

Returns:
Single composed handler, or None if handlers is empty.

Example:
Auth middleware (outer) wraps rate limit (inner):

def auth(req, state, runtime):
resp = yield req
if "unauthorized" in str(resp.exception):
refresh_token()
resp = yield req # Retry
return resp

def rate_limit(req, state, runtime):
for attempt in range(3):
resp = yield req
if "rate limit" not in str(resp.exception):
return resp
time.sleep(2**attempt)
return resp

handler = _chain_tool_call_handlers([auth, rate_limit])
# Request: auth -> rate_limit -> tool
# Response: tool -> rate_limit -> auth
"""
if not handlers:
return None

if len(handlers) == 1:
return handlers[0]

def _extract_return_value(stop_iteration: StopIteration) -> ToolCallResponse:
"""Extract return value from StopIteration, raising if None."""
if stop_iteration.value is None:
msg = "on_tool_call handler must explicitly return a ToolCallResponse"
raise ValueError(msg)
return stop_iteration.value

def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandler:
"""Compose two handlers where outer wraps inner."""

def composed(
request: ToolCallRequest, state: Any, runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
outer_gen = outer(request, state, runtime)

# Initialize outer generator
try:
outer_request = next(outer_gen)
except StopIteration as e:
return _extract_return_value(e)

# Outer retry loop
while True:
inner_gen = inner(outer_request, state, runtime)

# Initialize inner generator
try:
inner_request = next(inner_gen)
except StopIteration as e:
# Inner returned immediately - send to outer
inner_response = _extract_return_value(e)
try:
outer_request = outer_gen.send(inner_response)
except StopIteration as e2:
return _extract_return_value(e2)
continue

# Inner retry loop
while True:
# Yield to actual tool execution
tool_response = yield inner_request

# Send response to inner
try:
inner_request = inner_gen.send(tool_response)
except StopIteration as e:
# Inner is done - send final response to outer
inner_response = _extract_return_value(e)
break

# Send inner's final response to outer
try:
outer_request = outer_gen.send(inner_response)
except StopIteration as e:
# Outer is done
return _extract_return_value(e)

return composed

# Chain all handlers: first -> second -> ... -> last
result = handlers[-1]
for handler in reversed(handlers[:-1]):
result = compose_two(handler, result)

return result


def create_agent( # noqa: PLR0915
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
Expand Down Expand Up @@ -319,6 +427,17 @@ def check_weather(location: str) -> str:
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]

# Collect middleware with on_tool_call hooks
middleware_w_on_tool_call = [
m for m in middleware if m.__class__.on_tool_call is not AgentMiddleware.on_tool_call
]

# Chain all on_tool_call handlers into a single composed handler
on_tool_call_handler = None
if middleware_w_on_tool_call:
handlers = [m.on_tool_call for m in middleware_w_on_tool_call]
on_tool_call_handler = _chain_tool_call_handlers(handlers)

# Setup tools
tool_node: ToolNode | None = None
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
Expand All @@ -329,7 +448,11 @@ def check_weather(location: str) -> str:
available_tools = middleware_tools + regular_tools

# Only create ToolNode if we have client-side tools
tool_node = ToolNode(tools=available_tools) if available_tools else None
tool_node = (
ToolNode(tools=available_tools, on_tool_call=on_tool_call_handler)
if available_tools
else None
)

# Default tools for ModelRequest initialization
# Use converted BaseTool instances from ToolNode (not raw callables)
Expand Down
46 changes: 45 additions & 1 deletion libs/langchain_v1/langchain/agents/middleware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from collections.abc import Callable
from collections.abc import Callable, Generator
from dataclasses import dataclass, field
from inspect import iscoroutinefunction
from typing import (
Expand All @@ -21,6 +21,8 @@
if TYPE_CHECKING:
from collections.abc import Awaitable

from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse

# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AnyMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
Expand Down Expand Up @@ -236,6 +238,48 @@ async def aafter_agent(
) -> dict[str, Any] | None:
"""Async logic to run after the agent execution completes."""

def on_tool_call(
self,
request: ToolCallRequest,
state: StateT, # noqa: ARG002
runtime: Runtime[ContextT], # noqa: ARG002
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah no async support right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not right now. It would be post v1 thing. And you can get really far without async logic with this interceptor pattern -- you need async if the actual interceptor needs to do IO (e.g., access redis)

"""Intercept tool execution for retries, monitoring, or request modification.

Generator protocol for fine-grained control over tool execution. Multiple
middleware with on_tool_call compose automatically: first defined = outermost.

Args:
request: Tool execution request with tool call dict and BaseTool instance.
state: Current agent state.
runtime: LangGraph runtime.

Yields:
ToolCallRequest to execute (may be modified from input).

Receives:
ToolCallResponse via .send() after execution.

Returns:
ToolCallResponse with action="return" and result, or action="raise"
and exception.

Example:
Retry on rate limit errors:

def on_tool_call(self, request, state, runtime):
for attempt in range(3):
response = yield request
if response.action == "return":
return response
if "rate limit" in str(response.exception):
time.sleep(2**attempt)
continue
return response
"""
response = yield request
return response


class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with AgentState and Runtime as arguments."""
Expand Down
Loading