diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8607a2601..f963f14e7 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -55,6 +55,7 @@ from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException +from ..types.interrupt import InterruptResponseContent from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -62,6 +63,7 @@ ConversationManager, SlidingWindowConversationManager, ) +from .interrupt import InterruptState from .state import AgentState logger = logging.getLogger(__name__) @@ -143,6 +145,9 @@ def caller( Raises: AttributeError: If the tool doesn't exist. """ + if self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") + normalized_name = self._find_normalized_tool_name(name) # Create unique tool ID and set up the tool request @@ -338,6 +343,8 @@ def __init__( self.hooks = HookRegistry() + self._interrupt_state = InterruptState() + # Initialize session management functionality self._session_manager = session_manager if self._session_manager: @@ -491,6 +498,9 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu Raises: ValueError: If no conversation history or prompt is provided. """ + if self._interrupt_state.activated: + raise RuntimeError("cannot call structured output during interrupt") + self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT @@ -573,6 +583,8 @@ async def stream_async( yield event["data"] ``` """ + self._resume_interrupt(prompt) + merged_state = {} if kwargs: warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) @@ -614,6 +626,38 @@ async def stream_async( self._end_agent_trace_span(error=e) raise + def _resume_interrupt(self, prompt: AgentInput) -> None: + """Configure the interrupt state if resuming from an interrupt event. + + Args: + prompt: User responses if resuming from interrupt. + + Raises: + TypeError: If in interrupt state but user did not provide responses. + """ + if not self._interrupt_state.activated: + return + + if not isinstance(prompt, list): + raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") + + invalid_types = [ + content_type for content in prompt for content_type in content if content_type != "interruptResponse" + ] + if invalid_types: + raise TypeError( + f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" + ) + + for content in cast(list[InterruptResponseContent], prompt): + interrupt_id = content["interruptResponse"]["interruptId"] + interrupt_response = content["interruptResponse"]["response"] + + if interrupt_id not in self._interrupt_state.interrupts: + raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") + + self._interrupt_state.interrupts[interrupt_id].response = interrupt_response + async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -689,6 +733,9 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A yield event def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: + if self._interrupt_state.activated: + return [] + messages: Messages | None = None if prompt is not None: if isinstance(prompt, str): diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index f3758c8d2..eb9bc4dd9 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -4,8 +4,9 @@ """ from dataclasses import dataclass -from typing import Any +from typing import Any, Sequence +from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics from ..types.content import Message from ..types.streaming import StopReason @@ -20,12 +21,14 @@ class AgentResult: message: The last message generated by the agent. metrics: Performance metrics collected during processing. state: Additional state information from the event loop. + interrupts: List of interrupts if raised by user. """ stop_reason: StopReason message: Message metrics: EventLoopMetrics state: Any + interrupts: Sequence[Interrupt] | None = None def __str__(self) -> str: """Get the agent's last message as a string. diff --git a/src/strands/agent/interrupt.py b/src/strands/agent/interrupt.py new file mode 100644 index 000000000..3cec1541b --- /dev/null +++ b/src/strands/agent/interrupt.py @@ -0,0 +1,59 @@ +"""Track the state of interrupt events raised by the user for human-in-the-loop workflows.""" + +from dataclasses import asdict, dataclass, field +from typing import Any + +from ..interrupt import Interrupt + + +@dataclass +class InterruptState: + """Track the state of interrupt events raised by the user. + + Note, interrupt state is cleared after resuming. + + Attributes: + interrupts: Interrupts raised by the user. + context: Additional context associated with an interrupt event. + activated: True if agent is in an interrupt state, False otherwise. + """ + + interrupts: dict[str, Interrupt] = field(default_factory=dict) + context: dict[str, Any] = field(default_factory=dict) + activated: bool = False + + def activate(self, context: dict[str, Any] | None = None) -> None: + """Activate the interrupt state. + + Args: + context: Context associated with the interrupt event. + """ + self.context = context or {} + self.activated = True + + def deactivate(self) -> None: + """Deacitvate the interrupt state. + + Interrupts and context are cleared. + """ + self.interrupts = {} + self.context = {} + self.activated = False + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InterruptState": + """Initiailize interrupt state from serialized interrupt state. + + Interrupt state can be serialized with the `to_dict` method. + """ + return cls( + interrupts={ + interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() + }, + context=data["context"], + activated=data["activated"], + ) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index feb6ac339..7a9c60c3b 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -27,6 +27,7 @@ ModelStopReason, StartEvent, StartEventLoopEvent, + ToolInterruptEvent, ToolResultMessageEvent, TypedEvent, ) @@ -106,13 +107,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) invocation_state["event_loop_cycle_span"] = cycle_span - model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls. + if agent._interrupt_state.activated: + stop_reason: StopReason = "tool_use" + message = agent._interrupt_state.context["tool_use_message"] - stop_reason, message, *_ = model_event["stop"] - yield ModelMessageEvent(message=message) + else: + model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) try: if stop_reason == "max_tokens": @@ -142,6 +149,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> cycle_span=cycle_span, cycle_start_time=cycle_start_time, invocation_state=invocation_state, + tracer=tracer, ) async for tool_event in tool_events: yield tool_event @@ -345,6 +353,7 @@ async def _handle_tool_execution( cycle_span: Any, cycle_start_time: float, invocation_state: dict[str, Any], + tracer: Tracer, ) -> AsyncGenerator[TypedEvent, None]: """Handles the execution of tools requested by the model during an event loop cycle. @@ -356,6 +365,7 @@ async def _handle_tool_execution( cycle_span: Span object for tracing the cycle (type may vary). cycle_start_time: Start time of the current cycle. invocation_state: Additional keyword arguments, including request state. + tracer: Tracer instance for span management. Yields: Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple @@ -375,15 +385,45 @@ async def _handle_tool_execution( yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) return + if agent._interrupt_state.activated: + tool_results.extend(agent._interrupt_state.context["tool_results"]) + + # Filter to only the interrupted tools when resuming from interrupt (tool uses without results) + tool_use_ids = {tool_result["toolUseId"] for tool_result in tool_results} + tool_uses = [tool_use for tool_use in tool_uses if tool_use["toolUseId"] not in tool_use_ids] + + interrupts = [] tool_events = agent.tool_executor._execute( agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state ) async for tool_event in tool_events: + if isinstance(tool_event, ToolInterruptEvent): + interrupts.extend(tool_event["tool_interrupt_event"]["interrupts"]) + yield tool_event # Store parent cycle ID for the next cycle invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] + if interrupts: + # Session state stored on AfterInvocationEvent. + agent._interrupt_state.activate(context={"tool_use_message": message, "tool_results": tool_results}) + + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + yield EventLoopStopEvent( + "interrupt", + message, + agent.event_loop_metrics, + invocation_state["request_state"], + interrupts, + ) + if cycle_span: + tracer.end_event_loop_cycle_span(span=cycle_span, message=message) + + return + + agent._interrupt_state.deactivate() + tool_result_message: Message = { "role": "user", "content": [{"toolResult": result} for result in tool_results], @@ -394,7 +434,6 @@ async def _handle_tool_execution( yield ToolResultMessageEvent(message=tool_result_message) if cycle_span: - tracer = get_tracer() tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) if invocation_state["request_state"].get("stop_event_loop", False): diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8f611e4e2..de07002c5 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -3,10 +3,14 @@ This module defines the events that are emitted as Agents run through the lifecycle of a request. """ +import uuid from dataclasses import dataclass from typing import Any, Optional +from typing_extensions import override + from ..types.content import Message +from ..types.interrupt import InterruptHookEvent from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse from .registry import HookEvent @@ -84,7 +88,7 @@ class MessageAddedEvent(HookEvent): @dataclass -class BeforeToolCallEvent(HookEvent): +class BeforeToolCallEvent(HookEvent, InterruptHookEvent): """Event triggered before a tool is invoked. This event is fired just before the agent executes a tool, allowing hook @@ -110,6 +114,18 @@ class BeforeToolCallEvent(HookEvent): def _can_write(self, name: str) -> bool: return name in ["cancel_tool", "selected_tool", "tool_use"] + @override + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + + Returns: + Interrupt id. + """ + return f"v1:{self.tool_use['toolUseId']}:{uuid.uuid5(uuid.NAMESPACE_OID, name)}" + @dataclass class AfterToolCallEvent(HookEvent): diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index b8e7f82ab..1cfd5c63e 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -10,6 +10,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from ..interrupt import Interrupt, InterruptException + if TYPE_CHECKING: from ..agent import Agent @@ -184,7 +186,7 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) - def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: + def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]: """Invoke all registered callbacks for the given event. This method finds all callbacks registered for the event's type and @@ -192,11 +194,16 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: callbacks are invoked in reverse registration order. Any exceptions raised by callback functions will propagate to the caller. + Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows. + Args: event: The event to dispatch to registered callbacks. Returns: - The event dispatched to registered callbacks. + The event dispatched to registered callbacks and any interrupts raised by the user. + + Raises: + ValueError: If interrupt name is used more than once. Example: ```python @@ -204,10 +211,22 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: registry.invoke_callbacks(event) ``` """ + interrupts: dict[str, Interrupt] = {} + for callback in self.get_callbacks_for(event): - callback(event) + try: + callback(event) + except InterruptException as exception: + interrupt = exception.interrupt + if interrupt.name in interrupts: + raise ValueError( + f"interrupt_name=<{interrupt.name}> | interrupt name used more than once" + ) from exception + + # Each callback is allowed to raise their own interrupt. + interrupts[interrupt.name] = interrupt - return event + return event, list(interrupts.values()) def has_callbacks(self) -> bool: """Check if the registry has any registered callbacks. diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py new file mode 100644 index 000000000..f0ed52389 --- /dev/null +++ b/src/strands/interrupt.py @@ -0,0 +1,33 @@ +"""Human-in-the-loop interrupt system for agent workflows.""" + +from dataclasses import asdict, dataclass +from typing import Any + + +@dataclass +class Interrupt: + """Represents an interrupt that can pause agent execution for human-in-the-loop workflows. + + Attributes: + id: Unique identifier. + name: User defined name. + reason: User provided reason for raising the interrupt. + response: Human response provided when resuming the agent after an interrupt. + """ + + id: str + name: str + reason: Any = None + response: Any = None + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + +class InterruptException(Exception): + """Exception raised when human input is required.""" + + def __init__(self, interrupt: Interrupt) -> None: + """Set the interrupt.""" + self.interrupt = interrupt diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 75058b251..e5075de93 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -132,6 +132,8 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: ) agent.state = AgentState(session_agent.state) + session_agent.initialize_internal_state(agent) + # Restore the conversation manager to its previous state, and get the optional prepend messages prepend_messages = agent.conversation_manager.restore_from_session(session_agent.conversation_manager_state) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 6c1bd4eb4..a4f43b149 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -14,7 +14,7 @@ from ...hooks import AfterToolCallEvent, BeforeToolCallEvent from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer, serialize -from ...types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent, TypedEvent +from ...types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse @@ -43,6 +43,7 @@ async def _stream( - Before/after hook execution - Tracing and metrics collection - Error handling and recovery + - Interrupt handling for human-in-the-loop workflows Args: agent: The agent for which the tool is being executed. @@ -80,7 +81,7 @@ async def _stream( } ) - before_event = agent.hooks.invoke_callbacks( + before_event, interrupts = agent.hooks.invoke_callbacks( BeforeToolCallEvent( agent=agent, selected_tool=tool_func, @@ -89,6 +90,10 @@ async def _stream( ) ) + if interrupts: + yield ToolInterruptEvent(tool_use, interrupts) + return + if before_event.cancel_tool: cancel_message = ( before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" @@ -100,7 +105,7 @@ async def _stream( "status": "error", "content": [{"text": cancel_message}], } - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, tool_use=tool_use, @@ -138,7 +143,7 @@ async def _stream( "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -169,7 +174,7 @@ async def _stream( result = cast(ToolResult, event) - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -189,7 +194,7 @@ async def _stream( "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event = agent.hooks.invoke_callbacks( + after_event, _ = agent.hooks.invoke_callbacks( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -238,6 +243,10 @@ async def _stream_with_trace( async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs): yield event + if isinstance(event, ToolInterruptEvent): + tracer.end_tool_call_span(tool_call_span, tool_result=None) + return + result_event = cast(ToolResultEvent, event) result = result_event.tool_result diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 60e5c7fa7..adbd5a5d3 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -5,7 +5,7 @@ from typing_extensions import override from ...telemetry.metrics import Trace -from ...types._events import TypedEvent +from ...types._events import ToolInterruptEvent, TypedEvent from ...types.tools import ToolResult, ToolUse from ._executor import ToolExecutor @@ -28,6 +28,8 @@ async def _execute( ) -> AsyncGenerator[TypedEvent, None]: """Execute tools sequentially. + Breaks early if an interrupt is raised by the user. + Args: agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. @@ -39,9 +41,17 @@ async def _execute( Yields: Events from the tool execution stream. """ + interrupted = False + for tool_use in tool_uses: events = ToolExecutor._stream_with_trace( agent, tool_use, tool_results, cycle_trace, cycle_span, invocation_state ) async for event in events: + if isinstance(event, ToolInterruptEvent): + interrupted = True + yield event + + if interrupted: + break diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index e20bf658a..13d4a98f9 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,10 +5,11 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Sequence, cast from typing_extensions import override +from ..interrupt import Interrupt from ..telemetry import EventLoopMetrics from .citations import Citation from .content import Message @@ -220,6 +221,7 @@ def __init__( message: Message, metrics: "EventLoopMetrics", request_state: Any, + interrupts: Sequence[Interrupt] | None = None, ) -> None: """Initialize with the final execution results. @@ -228,8 +230,9 @@ def __init__( message: Final message from the model metrics: Execution metrics and performance data request_state: Final state of the agent execution + interrupts: Interrupts raised by user during agent execution. """ - super().__init__({"stop": (stop_reason, message, metrics, request_state)}) + super().__init__({"stop": (stop_reason, message, metrics, request_state, interrupts)}) @property @override @@ -313,12 +316,30 @@ def __init__(self, tool_use: ToolUse, message: str) -> None: @property def tool_use_id(self) -> str: """The id of the tool cancelled.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancelled_event")).get("tool_use")).get("toolUseId")) + return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId")) @property def message(self) -> str: """The tool cancellation message.""" - return cast(str, self["message"]) + return cast(str, self["tool_cancel_event"]["message"]) + + +class ToolInterruptEvent(TypedEvent): + """Event emitted when a tool is interrupted.""" + + def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None: + """Set interrupt in the event payload.""" + super().__init__({"tool_interrupt_event": {"tool_use": tool_use, "interrupts": interrupts}}) + + @property + def tool_use_id(self) -> str: + """The id of the tool interrupted.""" + return cast(str, cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId")) + + @property + def interrupts(self) -> list[Interrupt]: + """The interrupt instances.""" + return cast(list[Interrupt], self["tool_interrupt_event"]["interrupts"]) class ModelMessageEvent(TypedEvent): diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py index 151c88f89..a2a4c7dce 100644 --- a/src/strands/types/agent.py +++ b/src/strands/types/agent.py @@ -6,5 +6,6 @@ from typing import TypeAlias from .content import ContentBlock, Messages +from .interrupt import InterruptResponse -AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None +AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponse] | Messages | None diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index f184f5e59..2a7ad344e 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -40,6 +40,7 @@ class Metrics(TypedDict, total=False): "content_filtered", "end_turn", "guardrail_intervened", + "interrupt", "max_tokens", "stop_sequence", "tool_use", @@ -49,6 +50,7 @@ class Metrics(TypedDict, total=False): - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened +- "interrupt": Agent was interrupted for human input - "max_tokens": Maximum token limit reached - "stop_sequence": Stop sequence encountered - "tool_use": Model requested to use a tool diff --git a/src/strands/types/interrupt.py b/src/strands/types/interrupt.py new file mode 100644 index 000000000..4e9584a70 --- /dev/null +++ b/src/strands/types/interrupt.py @@ -0,0 +1,181 @@ +"""Interrupt related type definitions for human-in-the-loop workflows. + +Interrupt Flow: + ┌─────────────────┐ + │ Agent Invoke │ + └────────┬────────┘ + │ + ▼ + ┌─────────────────┐ + │ Hook Calls │ + | on Event | + └────────┬────────┘ + │ + ▼ + ┌─────────────────┐ No ┌─────────────────┐ + │ Interrupts │ ────────► │ Continue │ + │ Raised? │ │ Execution │ + └────────┬────────┘ └─────────────────┘ + │ Yes + ▼ + ┌─────────────────┐ + │ Stop Event Loop │◄───────────────────┐ + └────────┬────────┘ | + │ | + ▼ | + ┌─────────────────┐ | + │ Return | | + | Interrupts │ | + └────────┬────────┘ | + │ | + ▼ | + ┌─────────────────┐ | + │ Agent Invoke │ | + │ with Responses │ | + └────────┬────────┘ | + │ | + ▼ | + ┌─────────────────┐ | + │ Hook Calls │ | + | on Event | | + | with Responses | | + └────────┬────────┘ | + │ | + ▼ | + ┌─────────────────┐ Yes ┌────────┴────────┐ + │ New Interrupts │ ────────► │ Store State │ + │ Raised? │ │ │ + └────────┬────────┘ └─────────────────┘ + │ No + ▼ + ┌─────────────────┐ + │ Continue │ + │ Execution │ + └─────────────────┘ + +Example: + ``` + from typing import Any + + from strands import Agent, tool + from strands.hooks import BeforeToolCallEvent, HookProvider, HookRegistry + + + @tool + def delete_tool(key: str) -> bool: + print("DELETE_TOOL | deleting") + return True + + + class ToolInterruptHook(HookProvider): + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + registry.add_callback(BeforeToolCallEvent, self.approve) + + def approve(self, event: BeforeToolCallEvent) -> None: + if event.tool_use["name"] != "delete_tool": + return + + approval = event.interrupt("for_delete_tool", reason="APPROVAL") + if approval != "A": + event.cancel_tool = "approval was not granted" + + agent = Agent( + hooks=[ToolInterruptHook()], + tools=[delete_tool], + system_prompt="You delete objects given their keys.", + callback_handler=None, + ) + result = agent(f"delete object with key 'X'") + + if result.stop_reason == "interrupt": + responses = [] + for interrupt in result.interrupts: + if interrupt.name == "for_delete_tool": + responses.append({"interruptResponse": {"interruptId": interrupt.id, "response": "A"}) + + result = agent(responses) + + ... + ``` + +Details: + + - User raises interrupt on their hook event by calling `event.interrupt()`. + - User can raise one interrupt per hook callback. + - Interrupts stop the agent event loop. + - Interrupts are returned to the user in AgentResult. + - User resumes by invoking agent with interrupt responses. + - Second call to `event.interrupt()` returns user response. + - Process repeats if user raises additional interrupts. + - Interrupts are session managed in-between return and user response. +""" + +from typing import TYPE_CHECKING, Any, Protocol, TypedDict + +from ..interrupt import Interrupt, InterruptException + +if TYPE_CHECKING: + from ..agent import Agent + + +class InterruptHookEvent(Protocol): + """Interface that adds interrupt support to hook events.""" + + agent: "Agent" + + def interrupt(self, name: str, reason: Any = None, response: Any = None) -> Any: + """Trigger the interrupt with a reason. + + Args: name: User defined name for the interrupt. + Must be unique across hook callbacks. + reason: User provided reason for the interrupt. + response: Preemptive response from user if available. + + Returns: + The response from a human user when resuming from an interrupt state. + + Raises: + InterruptException: If human input is required. + """ + id = self._interrupt_id(name) + state = self.agent._interrupt_state + + interrupt_ = state.interrupts.setdefault(id, Interrupt(id, name, reason, response)) + if interrupt_.response: + return interrupt_.response + + raise InterruptException(interrupt_) + + def _interrupt_id(self, name: str) -> str: + """Unique id for the interrupt. + + Args: + name: User defined name for the interrupt. + reason: User provided reason for the interrupt. + + Returns: + Interrupt id. + """ + ... + + +class InterruptResponse(TypedDict): + """User response to an interrupt. + + Attributes: + interruptId: Unique identifier for the interrupt. + response: User response to the interrupt. + """ + + interruptId: str + response: Any + + +class InterruptResponseContent(TypedDict): + """Content block containing a user response to an interrupt. + + Attributes: + interruptResponse: User response to an interrupt event. + """ + + interruptResponse: InterruptResponse diff --git a/src/strands/types/session.py b/src/strands/types/session.py index e51816f74..926480f2c 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -5,8 +5,9 @@ from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional +from ..agent.interrupt import InterruptState from .content import Message if TYPE_CHECKING: @@ -104,11 +105,20 @@ def to_dict(self) -> dict[str, Any]: @dataclass class SessionAgent: - """Agent that belongs to a Session.""" + """Agent that belongs to a Session. + + Attributes: + agent_id: Unique id for the agent. + state: User managed state. + conversation_manager_state: State for conversation management. + created_at: Created at time. + updated_at: Updated at time. + """ agent_id: str - state: Dict[str, Any] - conversation_manager_state: Dict[str, Any] + state: dict[str, Any] + conversation_manager_state: dict[str, Any] + _internal_state: dict[str, Any] = field(default_factory=dict) # Strands managed state created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -121,6 +131,9 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent": agent_id=agent.agent_id, conversation_manager_state=agent.conversation_manager.get_state(), state=agent.state.get(), + _internal_state={ + "interrupt_state": agent._interrupt_state.to_dict(), + }, ) @classmethod @@ -132,6 +145,11 @@ def to_dict(self) -> dict[str, Any]: """Convert the SessionAgent to a dictionary representation.""" return asdict(self) + def initialize_internal_state(self, agent: "Agent") -> None: + """Initialize internal state of agent.""" + if "interrupt_state" in self._internal_state: + agent._interrupt_state = InterruptState.from_dict(self._internal_state["interrupt_state"]) + @dataclass class Session: diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 200584115..ae2d8c7b5 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -17,6 +17,8 @@ from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler +from strands.hooks import BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize @@ -1933,3 +1935,129 @@ async def check_invocation_state(**kwargs): agent("hello!", invocation_state={"my": "state"}) assert len(captured_warnings) == 0 + + +def test_agent__call__resume_interrupt(mock_model, tool_decorated, agenerator): + tool_use_message = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_decorated", + "input": {"random_string": "test input"}, + } + }, + ], + } + agent = Agent( + messages=[tool_use_message], + model=mock_model, + tools=[tool_decorated], + ) + + interrupt = Interrupt( + id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + agent._interrupt_state.activate(context={"tool_use_message": tool_use_message, "tool_results": []}) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + mock_model.mock_stream.return_value = agenerator( + [ + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "resumed"}}}, + {"contentBlockStop": {}}, + ] + ) + + prompt = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test response", + } + } + ] + agent(prompt) + + tru_result_message = agent.messages[-2] + exp_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "test input"}], + }, + }, + ], + } + assert tru_result_message == exp_result_message + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": False, + "context": {}, + "interrupts": {}, + } + assert tru_state == exp_state + + +def test_agent__call__resume_interrupt_invalid_prompt(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"prompt_type= \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + agent("invalid") + + +def test_agent__call__resume_interrupt_invalid_content(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"content_types=<\['text'\]> \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + agent([{"text": "invalid"}]) + + +def test_agent__call__resume_interrupt_invalid_id(): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"interrupt_id= \| no interrupt found" + with pytest.raises(KeyError, match=exp_message): + agent([{"interruptResponse": {"interruptId": "invalid", "response": None}}]) + + +def test_agent_structured_output_interrupt(user): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"cannot call structured output during interrupt" + with pytest.raises(RuntimeError, match=exp_message): + agent.structured_output(type(user), "invalid") + + +def test_agent_tool_caller_interrupt(user): + agent = Agent() + agent._interrupt_state.activated = True + + exp_message = r"cannot directly call tool during interrupt" + with pytest.raises(RuntimeError, match=exp_message): + agent.tool.test_tool() diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 6c5625e0b..32266c3eb 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -124,7 +124,10 @@ def test_agent_tool_call(agent, hook_provider, agent_tool): assert length == 6 assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, @@ -170,7 +173,10 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, @@ -231,7 +237,10 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[1]) assert next(events) == BeforeToolCallEvent( - agent=agent, selected_tool=agent_tool, tool_use=tool_use, invocation_state=ANY + agent=agent, + selected_tool=agent_tool, + tool_use=tool_use, + invocation_state=ANY, ) assert next(events) == AfterToolCallEvent( agent=agent, diff --git a/tests/strands/agent/test_interrupt.py b/tests/strands/agent/test_interrupt.py new file mode 100644 index 000000000..e248c29a6 --- /dev/null +++ b/tests/strands/agent/test_interrupt.py @@ -0,0 +1,61 @@ +import pytest + +from strands.agent.interrupt import InterruptState +from strands.interrupt import Interrupt + + +@pytest.fixture +def interrupt(): + return Interrupt(id="test_id", name="test_name", reason="test reason") + + +def test_interrupt_activate(): + interrupt_state = InterruptState() + + interrupt_state.activate(context={"test": "context"}) + + assert interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {"test": "context"} + assert tru_context == exp_context + + +def test_interrupt_deactivate(): + interrupt_state = InterruptState(context={"test": "context"}, activated=True) + + interrupt_state.deactivate() + + assert not interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {} + assert tru_context == exp_context + + +def test_interrupt_state_to_dict(interrupt): + interrupt_state = InterruptState(interrupts={"test_id": interrupt}, context={"test": "context"}, activated=True) + + tru_data = interrupt_state.to_dict() + exp_data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + assert tru_data == exp_data + + +def test_interrupt_state_from_dict(): + data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + + tru_state = InterruptState.from_dict(data) + exp_state = InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + assert tru_state == exp_state diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 2b71f3502..89ef477fa 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -6,12 +6,15 @@ import strands import strands.telemetry +from strands.agent.interrupt import InterruptState from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, + BeforeToolCallEvent, HookRegistry, MessageAddedEvent, ) +from strands.interrupt import Interrupt from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry @@ -138,6 +141,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry mock.tool_executor = tool_executor + mock._interrupt_state = InterruptState() return mock @@ -169,7 +173,7 @@ async def test_event_loop_cycle_text_response( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -201,7 +205,7 @@ async def test_event_loop_cycle_text_response_throttling( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -239,7 +243,7 @@ async def test_event_loop_cycle_exponential_backoff( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] # Verify the final response assert tru_stop_reason == "end_turn" @@ -330,7 +334,7 @@ async def test_event_loop_cycle_tool_result( invocation_state={}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "end_turn" exp_message = {"role": "assistant", "content": [{"text": "test text"}]} @@ -445,7 +449,7 @@ async def test_event_loop_cycle_stop( invocation_state={"request_state": {"stop_event_loop": True}}, ) events = await alist(stream) - tru_stop_reason, tru_message, _, tru_request_state = events[-1]["stop"] + tru_stop_reason, tru_message, _, tru_request_state, _ = events[-1]["stop"] exp_stop_reason = "tool_use" exp_message = { @@ -747,7 +751,7 @@ async def test_request_state_initialization(alist): invocation_state={}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + _, _, _, tru_request_state, _ = events[-1]["stop"] # Verify request_state was initialized to empty dict assert tru_request_state == {} @@ -759,7 +763,7 @@ async def test_request_state_initialization(alist): invocation_state={"request_state": initial_request_state}, ) events = await alist(stream) - _, _, _, tru_request_state = events[-1]["stop"] + _, _, _, tru_request_state, _ = events[-1]["stop"] # Verify existing request_state was preserved assert tru_request_state == initial_request_state @@ -862,3 +866,147 @@ async def test_event_loop_cycle_exception_model_hooks(mock_sleep, agent, model, assert next(events) == MessageAddedEvent( agent=agent, message={"content": [{"text": "test text"}], "role": "assistant"} ) + + +@pytest.mark.asyncio +async def test_event_loop_cycle_interrupt(agent, model, tool_stream, agenerator, alist): + def interrupt_callback(event): + event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + model.stream.side_effect = [agenerator(tool_stream)] + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + tru_stop_reason, _, _, _, tru_interrupts = events[-1]["stop"] + exp_stop_reason = "interrupt" + exp_interrupts = [ + Interrupt( + id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ), + ] + + assert tru_stop_reason == exp_stop_reason and tru_interrupts == exp_interrupts + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": True, + "context": { + "tool_results": [], + "tool_use_message": { + "content": [ + { + "toolUse": { + "input": {"random_string": "abcdEfghI123"}, + "name": "tool_for_testing", + "toolUseId": "t1", + }, + }, + ], + "role": "assistant", + }, + }, + "interrupts": { + "v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9": { + "id": "v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + "name": "test_name", + "reason": "test reason", + "response": None, + }, + }, + } + assert tru_state == exp_state + + +@pytest.mark.asyncio +async def test_event_loop_cycle_interrupt_resume(agent, model, tool, tool_times_2, agenerator, alist): + interrupt = Interrupt( + id="v1:t1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + + tool_use_message = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "t1", + "name": "tool_for_testing", + "input": {"random_string": "test input"}, + } + }, + { + "toolUse": { + "toolUseId": "t2", + "name": "tool_times_2", + "input": {}, + } + }, + ], + } + tool_results = [ + { + "toolUseId": "t2", + "status": "success", + "content": [{"text": "t2 result"}], + }, + ] + + agent._interrupt_state.activate(context={"tool_use_message": tool_use_message, "tool_results": tool_results}) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + model.stream.side_effect = [agenerator([{"contentBlockStop": {}}])] + + stream = strands.event_loop.event_loop.event_loop_cycle(agent, invocation_state={}) + events = await alist(stream) + + tru_stop_reason, _, _, _, _ = events[-1]["stop"] + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + tru_result_message = agent.messages[-2] + exp_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t2", + "status": "success", + "content": [{"text": "t2 result"}], + }, + }, + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "test input"}], + }, + }, + ], + } + assert tru_result_message == exp_result_message + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response + + tru_state = agent._interrupt_state.to_dict() + exp_state = { + "activated": False, + "context": {}, + "interrupts": {}, + } + assert tru_state == exp_state diff --git a/tests/strands/hooks/__init__.py b/tests/strands/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py new file mode 100644 index 000000000..807011869 --- /dev/null +++ b/tests/strands/hooks/test_registry.py @@ -0,0 +1,73 @@ +import unittest.mock + +import pytest + +from strands.agent.interrupt import InterruptState +from strands.hooks import BeforeToolCallEvent, HookRegistry +from strands.interrupt import Interrupt + + +@pytest.fixture +def registry(): + return HookRegistry() + + +@pytest.fixture +def agent(): + instance = unittest.mock.Mock() + instance._interrupt_state = InterruptState() + return instance + + +def test_hook_registry_invoke_callbacks_interrupt(registry, agent): + event = BeforeToolCallEvent( + agent=agent, + selected_tool=None, + tool_use={"toolUseId": "test_tool_id", "name": "test_tool_name", "input": {}}, + invocation_state={}, + ) + + callback1 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name_1", "test reason 1")) + callback2 = unittest.mock.Mock() + callback3 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name_2", "test reason 2")) + + registry.add_callback(BeforeToolCallEvent, callback1) + registry.add_callback(BeforeToolCallEvent, callback2) + registry.add_callback(BeforeToolCallEvent, callback3) + + _, tru_interrupts = registry.invoke_callbacks(event) + exp_interrupts = [ + Interrupt( + id="v1:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee", + name="test_name_1", + reason="test reason 1", + ), + Interrupt( + id="v1:test_tool_id:0f5a8068-d1ba-5a48-bf67-c9d33786d8d4", + name="test_name_2", + reason="test reason 2", + ), + ] + assert tru_interrupts == exp_interrupts + + callback1.assert_called_once_with(event) + callback2.assert_called_once_with(event) + callback3.assert_called_once_with(event) + + +def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent): + event = BeforeToolCallEvent( + agent=agent, + selected_tool=None, + tool_use={"toolUseId": "test_tool_id", "name": "test_tool_name", "input": {}}, + invocation_state={}, + ) + + callback1 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name", "test reason 1")) + callback2 = unittest.mock.Mock(side_effect=lambda event: event.interrupt("test_name", "test reason 2")) + + registry.add_callback(BeforeToolCallEvent, callback1) + registry.add_callback(BeforeToolCallEvent, callback2) + + with pytest.raises(ValueError, match="interrupt_name= | interrupt name used more than once"): + registry.invoke_callbacks(event) diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 2c25fcc38..923b13daa 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -5,6 +5,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.agent.interrupt import InterruptState from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException @@ -95,6 +96,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): agent_id="existing-agent", state={"key": "value"}, conversation_manager_state=SlidingWindowConversationManager().get_state(), + _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, ) session_manager.session_repository.create_agent("test-session", session_agent) @@ -116,6 +118,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert len(agent.messages) == 1 assert agent.messages[0]["role"] == "user" assert agent.messages[0]["content"][0]["text"] == "Hello" + assert agent._interrupt_state == InterruptState(interrupts={}, context={"test": "init"}, activated=False) def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py new file mode 100644 index 000000000..8ce972103 --- /dev/null +++ b/tests/strands/test_interrupt.py @@ -0,0 +1,24 @@ +import pytest + +from strands.interrupt import Interrupt + + +@pytest.fixture +def interrupt(): + return Interrupt( + id="test_id:test_name", + name="test_name", + reason={"reason": "test"}, + response={"response": "test"}, + ) + + +def test_interrupt_to_dict(interrupt): + tru_dict = interrupt.to_dict() + exp_dict = { + "id": "test_id:test_name", + "name": "test_name", + "reason": {"reason": "test"}, + "response": {"response": "test"}, + } + assert tru_dict == exp_dict diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index be90226f6..fa8ce10af 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,6 +4,7 @@ import pytest import strands +from strands.agent.interrupt import InterruptState from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry from strands.tools.registry import ToolRegistry @@ -92,6 +93,7 @@ def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry + mock_agent._interrupt_state = InterruptState() return mock_agent diff --git a/tests/strands/tools/executors/test_concurrent.py b/tests/strands/tools/executors/test_concurrent.py index f7fc64b25..7264c8e58 100644 --- a/tests/strands/tools/executors/test_concurrent.py +++ b/tests/strands/tools/executors/test_concurrent.py @@ -1,8 +1,9 @@ import pytest +from strands.hooks import BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.tools.executors import ConcurrentToolExecutor -from strands.types._events import ToolResultEvent -from strands.types.tools import ToolUse +from strands.types._events import ToolInterruptEvent, ToolResultEvent @pytest.fixture @@ -14,7 +15,7 @@ def executor(): async def test_concurrent_executor_execute( executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist ): - tool_uses: list[ToolUse] = [ + tool_uses = [ {"name": "weather_tool", "toolUseId": "1", "input": {}}, {"name": "temperature_tool", "toolUseId": "2", "input": {}}, ] @@ -30,3 +31,38 @@ async def test_concurrent_executor_execute( tru_results = sorted(tool_results, key=lambda result: result.get("toolUseId")) exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_concurrent_executor_interrupt( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + interrupt = Interrupt( + id="v1:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + def interrupt_callback(event): + if event.tool_use["name"] == "weather_tool": + event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + tool_uses = [ + {"name": "weather_tool", "toolUseId": "test_tool_id_1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "test_tool_id_2", "input": {}}, + ] + + stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = sorted(await alist(stream), key=lambda event: event.tool_use_id) + exp_events = [ + ToolInterruptEvent(tool_uses[0], [interrupt]), + ToolResultEvent({"toolUseId": "test_tool_id_2", "status": "success", "content": [{"text": "75F"}]}), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[1].tool_result] + assert tru_results == exp_results diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 81be34969..fd15c9747 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -5,9 +5,10 @@ import strands from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.telemetry.metrics import Trace from strands.tools.executors._executor import ToolExecutor -from strands.types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import ToolUse @@ -36,6 +37,7 @@ async def test_executor_stream_yields_result( executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist ): tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) tru_events = await alist(stream) @@ -337,3 +339,71 @@ async def test_executor_stream_no_span_attributes_when_no_tool_spec( # Verify set_attribute was not called since tool_spec is None mock_span.set_attribute.assert_not_called() + + +@pytest.mark.asyncio +async def test_executor_stream_interrupt(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "weather_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + def interrupt_callback(event): + event.interrupt("test_name", reason="test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_use, [interrupt])] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [] + assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_executor_stream_interrupt_resume(executor, agent, tool_results, invocation_state, alist): + tool_use = {"name": "weather_tool", "toolUseId": "test_tool_id", "input": {}} + + interrupt = Interrupt( + id="v1:test_tool_id:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + response="test response", + ) + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + interrupt_response = {} + + def interrupt_callback(event): + interrupt_response["response"] = event.interrupt("test_name", reason="test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolResultEvent( + { + "toolUseId": "test_tool_id", + "status": "success", + "content": [{"text": "sunny"}], + }, + ), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results + + tru_response = interrupt_response["response"] + exp_response = "test response" + assert tru_response == exp_response diff --git a/tests/strands/tools/executors/test_sequential.py b/tests/strands/tools/executors/test_sequential.py index 37e098142..c1db3cd55 100644 --- a/tests/strands/tools/executors/test_sequential.py +++ b/tests/strands/tools/executors/test_sequential.py @@ -1,7 +1,9 @@ import pytest +from strands.hooks import BeforeToolCallEvent +from strands.interrupt import Interrupt from strands.tools.executors import SequentialToolExecutor -from strands.types._events import ToolResultEvent +from strands.types._events import ToolInterruptEvent, ToolResultEvent @pytest.fixture @@ -29,3 +31,34 @@ async def test_sequential_executor_execute( tru_results = tool_results exp_results = [exp_events[0].tool_result, exp_events[1].tool_result] assert tru_results == exp_results + + +@pytest.mark.asyncio +async def test_sequential_executor_interrupt( + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist +): + interrupt = Interrupt( + id="v1:test_tool_id_1:78714d6c-613c-5cf4-bf25-7037569941f9", + name="test_name", + reason="test reason", + ) + + def interrupt_callback(event): + event.interrupt("test_name", "test reason") + + agent.hooks.add_callback(BeforeToolCallEvent, interrupt_callback) + + tool_uses = [ + {"name": "weather_tool", "toolUseId": "test_tool_id_1", "input": {}}, + {"name": "temperature_tool", "toolUseId": "test_tool_id_2", "input": {}}, + ] + + stream = executor._execute(agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state) + + tru_events = await alist(stream) + exp_events = [ToolInterruptEvent(tool_uses[0], [interrupt])] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [] + assert tru_results == exp_results diff --git a/tests/strands/types/__init__.py b/tests/strands/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/types/test_interrupt.py b/tests/strands/types/test_interrupt.py new file mode 100644 index 000000000..3b970a00a --- /dev/null +++ b/tests/strands/types/test_interrupt.py @@ -0,0 +1,80 @@ +import unittest.mock + +import pytest + +from strands.agent.interrupt import InterruptState +from strands.interrupt import Interrupt, InterruptException +from strands.types.interrupt import InterruptHookEvent + + +@pytest.fixture +def interrupt(): + return Interrupt( + id="test_id:test_name", + name="test_name", + reason={"reason": "test"}, + response={"response": "test"}, + ) + + +@pytest.fixture +def agent(): + instance = unittest.mock.Mock() + instance._interrupt_state = InterruptState() + return instance + + +@pytest.fixture +def interrupt_hook_event(agent): + class Event(InterruptHookEvent): + def __init__(self): + self.agent = agent + + def _interrupt_id(self, name): + return f"test_id:{name}" + + return Event() + + +def test_interrupt_hook_event_interrupt(interrupt_hook_event): + with pytest.raises(InterruptException) as exception: + interrupt_hook_event.interrupt("custom_test_name", "custom test reason") + + tru_interrupt = exception.value.interrupt + exp_interrupt = Interrupt( + id="test_id:custom_test_name", + name="custom_test_name", + reason="custom test reason", + ) + assert tru_interrupt == exp_interrupt + + +def test_interrupt_hook_event_interrupt_state(agent, interrupt_hook_event): + with pytest.raises(InterruptException): + interrupt_hook_event.interrupt("custom_test_name", "custom test reason") + + exp_interrupt = Interrupt( + id="test_id:custom_test_name", + name="custom_test_name", + reason="custom test reason", + ) + assert exp_interrupt.id in agent._interrupt_state.interrupts + + tru_interrupt = agent._interrupt_state.interrupts[exp_interrupt.id] + assert tru_interrupt == exp_interrupt + + +def test_interrupt_hook_event_interrupt_response(interrupt, agent, interrupt_hook_event): + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + tru_response = interrupt_hook_event.interrupt("test_name") + exp_response = {"response": "test"} + assert tru_response == exp_response + + +def test_interrupt_hook_event_interrupt_response_empty(interrupt, agent, interrupt_hook_event): + interrupt.response = None + agent._interrupt_state.interrupts[interrupt.id] = interrupt + + with pytest.raises(InterruptException): + interrupt_hook_event.interrupt("test_name") diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py index c39615c32..26d4062e4 100644 --- a/tests/strands/types/test_session.py +++ b/tests/strands/types/test_session.py @@ -1,7 +1,10 @@ import json +import unittest.mock from uuid import uuid4 from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager +from strands.agent.interrupt import InterruptState +from strands.agent.state import AgentState from strands.types.session import ( Session, SessionAgent, @@ -91,3 +94,38 @@ def test_session_message_with_bytes(): assert original_message["role"] == message["role"] assert original_message["content"][0]["text"] == message["content"][0]["text"] assert original_message["content"][1]["binary_data"] == message["content"][1]["binary_data"] + + +def test_session_agent_from_agent(): + agent = unittest.mock.Mock() + agent.agent_id = "a1" + agent.conversation_manager = unittest.mock.Mock(get_state=lambda: {"test": "conversation"}) + agent.state = AgentState({"test": "state"}) + agent._interrupt_state = InterruptState(interrupts={}, context={}, activated=False) + + tru_session_agent = SessionAgent.from_agent(agent) + exp_session_agent = SessionAgent( + agent_id="a1", + conversation_manager_state={"test": "conversation"}, + state={"test": "state"}, + _internal_state={"interrupt_state": {"interrupts": {}, "context": {}, "activated": False}}, + created_at=unittest.mock.ANY, + updated_at=unittest.mock.ANY, + ) + assert tru_session_agent == exp_session_agent + + +def test_session_agent_initialize_internal_state(): + agent = unittest.mock.Mock() + session_agent = SessionAgent( + agent_id="a1", + conversation_manager_state={}, + state={}, + _internal_state={"interrupt_state": {"interrupts": {}, "context": {"test": "init"}, "activated": False}}, + ) + + session_agent.initialize_internal_state(agent) + + tru_interrupt_state = agent._interrupt_state + exp_interrupt_state = InterruptState(interrupts={}, context={"test": "init"}, activated=False) + assert tru_interrupt_state == exp_interrupt_state diff --git a/tests_integ/test_interrupt.py b/tests_integ/test_interrupt.py new file mode 100644 index 000000000..164dfdede --- /dev/null +++ b/tests_integ/test_interrupt.py @@ -0,0 +1,192 @@ +import json +from unittest.mock import ANY + +import pytest + +from strands import Agent, tool +from strands.hooks import BeforeToolCallEvent, HookProvider +from strands.interrupt import Interrupt +from strands.session import FileSessionManager + + +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeToolCallEvent, self.interrupt) + + def interrupt(self, event): + if event.tool_use["name"] == "weather_tool": + return + + response = event.interrupt("test_interrupt", "need approval") + if response != "APPROVE": + event.cancel_tool = "tool rejected" + + return Hook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:00" + + return func + + +@pytest.fixture +def weather_tool(): + @tool(name="weather_tool") + def func(): + return "sunny" + + return func + + +@pytest.fixture +def agent(interrupt_hook, time_tool, weather_tool): + return Agent(hooks=[interrupt_hook], tools=[time_tool, weather_tool]) + + +@pytest.mark.asyncio +def test_interrupt(agent): + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + tru_interrupts = result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + result_message = json.dumps(result.message).lower() + assert all(string in result_message for string in ["12:00", "sunny"]) + + tru_tool_result_message = agent.messages[-2] + exp_tool_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "sunny"}, + ], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [ + {"text": "12:00"}, + ], + }, + }, + ], + } + assert tru_tool_result_message == exp_tool_result_message + + +@pytest.mark.asyncio +def test_interrupt_reject(agent): + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + interrupt = result.interrupts[0] + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "REJECT", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + tru_tool_result_message = agent.messages[-2] + exp_tool_result_message = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": ANY, + "status": "success", + "content": [{"text": "sunny"}], + }, + }, + { + "toolResult": { + "toolUseId": ANY, + "status": "error", + "content": [{"text": "tool rejected"}], + }, + }, + ], + } + assert tru_tool_result_message == exp_tool_result_message + + +@pytest.mark.asyncio +def test_interrupt_session(interrupt_hook, time_tool, weather_tool, tmpdir): + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) + result = agent("What is the time and weather?") + + tru_stop_reason = result.stop_reason + exp_stop_reason = "interrupt" + assert tru_stop_reason == exp_stop_reason + + interrupt = result.interrupts[0] + + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + agent = Agent(hooks=[interrupt_hook], session_manager=session_manager, tools=[time_tool, weather_tool]) + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + result = agent(responses) + + tru_stop_reason = result.stop_reason + exp_stop_reason = "end_turn" + assert tru_stop_reason == exp_stop_reason + + result_message = json.dumps(result.message).lower() + assert all(string in result_message for string in ["12:00", "sunny"])