Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 111 additions & 2 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
from langgraph.runtime import Runtime # noqa: TC002
from langgraph.types import Send
from langgraph.types import Command, Send
from langgraph.typing import ContextT # noqa: TC002
from typing_extensions import NotRequired, Required, TypedDict, TypeVar

Expand Down Expand Up @@ -56,6 +56,8 @@
from langgraph.store.base import BaseStore
from langgraph.types import Checkpointer

from langchain.tools.tool_node import ToolCallHandler, ToolCallRequest

STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."

ResponseT = TypeVar("ResponseT")
Expand Down Expand Up @@ -410,6 +412,98 @@ def _handle_structured_output_error(
return False, ""


def _chain_tool_call_handlers(
handlers: Sequence[ToolCallHandler],
) -> ToolCallHandler | None:
"""Compose handlers into middleware stack (first = outermost).

Args:
handlers: Handlers in middleware order.

Returns:
Composed handler, or None if empty.

Example:
handler = _chain_tool_call_handlers([auth, cache, retry])
# Request flows: auth -> cache -> retry -> tool
# Response flows: tool -> retry -> cache -> auth
"""
if not handlers:
return None

if len(handlers) == 1:
return handlers[0]

def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandler:
"""Compose two handlers where outer wraps inner."""

def composed(
request: ToolCallRequest, state: Any, runtime: Any
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
outer_gen = outer(request, state, runtime)

# Initialize outer generator
try:
outer_request_or_result = next(outer_gen)
except StopIteration:
msg = "outer handler must yield at least once"
raise ValueError(msg)

# Outer retry loop
while True:
# If outer yielded a ToolMessage or Command, bypass inner and yield directly
if isinstance(outer_request_or_result, (ToolMessage, Command)):
result = yield outer_request_or_result
try:
outer_request_or_result = outer_gen.send(result)
except StopIteration:
# Outer ended - final result is what we sent to it
return
continue

inner_gen = inner(outer_request_or_result, state, runtime)
last_sent_to_inner: ToolMessage | Command | None = None

# Initialize inner generator
try:
inner_request_or_result = next(inner_gen)
except StopIteration:
msg = "inner handler must yield at least once"
raise ValueError(msg)

# Inner retry loop
while True:
# Yield to actual tool execution
result = yield inner_request_or_result
last_sent_to_inner = result

# Send result to inner
try:
inner_request_or_result = inner_gen.send(result)
except StopIteration:
# Inner is done - final result from inner is last_sent_to_inner
break

# Send inner's final result to outer
if last_sent_to_inner is None:
msg = "inner handler ended without receiving any result"
raise ValueError(msg)
try:
outer_request_or_result = outer_gen.send(last_sent_to_inner)
except StopIteration:
# Outer is done - final result is what we sent to it
return

return composed

# Chain all handlers: first -> second -> ... -> last
result = handlers[-1]
for handler in reversed(handlers[:-1]):
result = compose_two(handler, result)

return result


def create_agent( # noqa: PLR0915
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
Expand Down Expand Up @@ -537,6 +631,17 @@ def check_weather(location: str) -> str:
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]

# Collect middleware with on_tool_call hooks
middleware_w_on_tool_call = [
m for m in middleware if m.__class__.on_tool_call is not AgentMiddleware.on_tool_call
]

# Chain all on_tool_call handlers into a single composed handler
on_tool_call_handler = None
if middleware_w_on_tool_call:
handlers = [m.on_tool_call for m in middleware_w_on_tool_call]
on_tool_call_handler = _chain_tool_call_handlers(handlers)

# Setup tools
tool_node: ToolNode | None = None
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
Expand All @@ -547,7 +652,11 @@ def check_weather(location: str) -> str:
available_tools = middleware_tools + regular_tools

# Only create ToolNode if we have client-side tools
tool_node = ToolNode(tools=available_tools) if available_tools else None
tool_node = (
ToolNode(tools=available_tools, on_tool_call=on_tool_call_handler)
if available_tools
else None
)

# Default tools for ModelRequest initialization
# Use converted BaseTool instances from ToolNode (not raw callables)
Expand Down
50 changes: 46 additions & 4 deletions libs/langchain_v1/langchain/agents/middleware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@
if TYPE_CHECKING:
from collections.abc import Awaitable

from langchain.tools.tool_node import ToolCallRequest

# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AIMessage, AnyMessage # noqa: TC002
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.graph.message import add_messages
from langgraph.types import Command # noqa: TC002
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar

if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from langgraph.types import Command

from langchain.agents.structured_output import ResponseFormat

Expand Down Expand Up @@ -261,6 +263,46 @@ async def aafter_agent(
) -> dict[str, Any] | None:
"""Async logic to run after the agent execution completes."""

def on_tool_call(
self,
request: ToolCallRequest,
state: StateT, # noqa: ARG002
runtime: Runtime[ContextT], # noqa: ARG002
) -> Generator[ToolCallRequest | ToolMessage | Command, ToolMessage | Command, None]:
"""Intercept tool execution for retries, monitoring, or modification.

Multiple middleware compose automatically (first defined = outermost).
Exceptions propagate unless handle_tool_errors is configured on ToolNode.

Args:
request: Tool call request with call dict and BaseTool instance.
state: Current agent state.
runtime: LangGraph runtime.

Yields:
ToolCallRequest (execute tool), ToolMessage (cached result),
or Command (control flow).

Receives:
ToolMessage or Command via .send() after execution.

Example:
Modify request:

def on_tool_call(self, request, state, runtime):
request.tool_call["args"]["value"] *= 2
yield request

Retry on error:

def on_tool_call(self, request, state, runtime):
for attempt in range(3):
response = yield request
if valid(response) or attempt == 2:
return
"""
yield request


class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with AgentState and Runtime as arguments."""
Expand Down Expand Up @@ -402,7 +444,7 @@ def before_model(

Returns:
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
that can be applied to a function its wrapping.
that can be applied to a function it is wrapping.

The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
Expand Down Expand Up @@ -812,7 +854,7 @@ def before_agent(

Returns:
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
that can be applied to a function its wrapping.
that can be applied to a function it is wrapping.

The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
Expand Down
Loading