Skip to content

Commit ec157f0

Browse files
committed
interrupt
1 parent 1f25512 commit ec157f0

File tree

13 files changed

+279
-105
lines changed

13 files changed

+279
-105
lines changed

src/strands/agent/agent.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ def __init__(
349349
self.hooks.add_hook(hook)
350350
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
351351

352+
self.interrupted = False
353+
self.interrupts = {}
354+
352355
@property
353356
def tool(self) -> ToolCaller:
354357
"""Call tool as a function.
@@ -540,6 +543,7 @@ async def stream_async(
540543
Args:
541544
prompt: User input in various formats:
542545
- str: Simple text input
546+
- ContentBlock: Multi-modal content block
543547
- list[ContentBlock]: Multi-modal content blocks
544548
- list[Message]: Complete messages with roles
545549
- None: Use existing conversation history
@@ -564,6 +568,8 @@ async def stream_async(
564568
yield event["data"]
565569
```
566570
"""
571+
self._resume(prompt)
572+
567573
callback_handler = kwargs.get("callback_handler", self.callback_handler)
568574

569575
# Process input and get message to add (if any)
@@ -585,6 +591,11 @@ async def stream_async(
585591

586592
result = AgentResult(*event["stop"])
587593
callback_handler(result=result)
594+
595+
if result.stop_reason == "interrupt":
596+
self.interrupted = True
597+
self.interrupts = {interrupt.name: interrupt for interrupt in result.interrupts}
598+
588599
yield AgentResultEvent(result=result).as_dict()
589600

590601
self._end_agent_trace_span(response=result)
@@ -593,6 +604,16 @@ async def stream_async(
593604
self._end_agent_trace_span(error=e)
594605
raise
595606

607+
def _resume(self, prompt: AgentInput) -> None:
608+
if not self.interrupted:
609+
return
610+
611+
if not isinstance(prompt, dict) or "resume" not in prompt:
612+
raise ValueError("<TODO>.")
613+
614+
for interrupt in self.interrupts.values():
615+
interrupt.resume = prompt["resume"][interrupt.name]
616+
596617
async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
597618
"""Execute the agent's event loop with the given message and parameters.
598619
@@ -673,6 +694,8 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
673694
if isinstance(prompt, str):
674695
# String input - convert to user message
675696
messages = [{"role": "user", "content": [{"text": prompt}]}]
697+
elif isinstance(prompt, dict):
698+
messages = [{"role": "user", "content": prompt}] if "resume" not in prompt else []
676699
elif isinstance(prompt, list):
677700
if len(prompt) == 0:
678701
# Empty list
@@ -692,7 +715,9 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
692715
else:
693716
messages = []
694717
if messages is None:
695-
raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.")
718+
raise ValueError(
719+
"Input prompt must be of type: `str | ContentBlock | list[Contentblock] | Messages | None`."
720+
)
696721
return messages
697722

698723
def _record_tool_execution(

src/strands/agent/agent_result.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
"""
55

66
from dataclasses import dataclass
7-
from typing import Any
7+
from typing import Any, Optional
88

9+
from ..hooks.interrupt import Interrupt
910
from ..telemetry.metrics import EventLoopMetrics
1011
from ..types.content import Message
1112
from ..types.streaming import StopReason
@@ -26,6 +27,7 @@ class AgentResult:
2627
message: Message
2728
metrics: EventLoopMetrics
2829
state: Any
30+
interrupts: Optional[list[Interrupt]] = None
2931

3032
def __str__(self) -> str:
3133
"""Get the agent's last message as a string.

src/strands/event_loop/event_loop.py

Lines changed: 125 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
MessageAddedEvent,
2424
)
2525
from ..telemetry.metrics import Trace
26-
from ..telemetry.tracer import get_tracer
26+
from ..telemetry.tracer import Tracer, get_tracer
2727
from ..tools._validator import validate_and_prepare_tools
2828
from ..types._events import (
2929
EventLoopStopEvent,
@@ -33,6 +33,7 @@
3333
ModelStopReason,
3434
StartEvent,
3535
StartEventLoopEvent,
36+
ToolInterruptEvent,
3637
ToolResultMessageEvent,
3738
TypedEvent,
3839
)
@@ -112,104 +113,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
112113
)
113114
invocation_state["event_loop_cycle_span"] = cycle_span
114115

115-
# Create a trace for the stream_messages call
116-
stream_trace = Trace("stream_messages", parent_id=cycle_trace.id)
117-
cycle_trace.add_child(stream_trace)
118-
119-
# Process messages with exponential backoff for throttling
120-
message: Message
121116
stop_reason: StopReason
122-
usage: Any
123-
metrics: Metrics
124-
125-
# Retry loop for handling throttling exceptions
126-
current_delay = INITIAL_DELAY
127-
for attempt in range(MAX_ATTEMPTS):
128-
model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None
129-
model_invoke_span = tracer.start_model_invoke_span(
130-
messages=agent.messages,
131-
parent_span=cycle_span,
132-
model_id=model_id,
133-
)
134-
with trace_api.use_span(model_invoke_span):
135-
agent.hooks.invoke_callbacks(
136-
BeforeModelInvocationEvent(
137-
agent=agent,
138-
)
139-
)
140-
141-
tool_specs = agent.tool_registry.get_all_tool_specs()
142117

143-
try:
144-
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
145-
if not isinstance(event, ModelStopReason):
146-
yield event
118+
if agent.interrupted:
119+
stop_reason = "tool_use"
120+
else:
121+
events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer)
122+
async for event in events:
123+
if isinstance(event, ModelStopReason):
124+
stop_reason = event["stop"][0]
125+
continue
147126

148-
stop_reason, message, usage, metrics = event["stop"]
149-
invocation_state.setdefault("request_state", {})
127+
yield event
150128

151-
agent.hooks.invoke_callbacks(
152-
AfterModelInvocationEvent(
153-
agent=agent,
154-
stop_response=AfterModelInvocationEvent.ModelStopResponse(
155-
stop_reason=stop_reason,
156-
message=message,
157-
),
158-
)
159-
)
160-
161-
if stop_reason == "max_tokens":
162-
message = recover_message_on_max_tokens_reached(message)
163-
164-
if model_invoke_span:
165-
tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason)
166-
break # Success! Break out of retry loop
167-
168-
except Exception as e:
169-
if model_invoke_span:
170-
tracer.end_span_with_error(model_invoke_span, str(e), e)
171-
172-
agent.hooks.invoke_callbacks(
173-
AfterModelInvocationEvent(
174-
agent=agent,
175-
exception=e,
176-
)
177-
)
178-
179-
if isinstance(e, ModelThrottledException):
180-
if attempt + 1 == MAX_ATTEMPTS:
181-
yield ForceStopEvent(reason=e)
182-
raise e
183-
184-
logger.debug(
185-
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
186-
"| throttling exception encountered "
187-
"| delaying before next retry",
188-
current_delay,
189-
MAX_ATTEMPTS,
190-
attempt + 1,
191-
)
192-
await asyncio.sleep(current_delay)
193-
current_delay = min(current_delay * 2, MAX_DELAY)
194-
195-
yield EventLoopThrottleEvent(delay=current_delay)
196-
else:
197-
raise e
129+
message = agent.messages[-1]
198130

199131
try:
200-
# Add message in trace and mark the end of the stream messages trace
201-
stream_trace.add_message(message)
202-
stream_trace.end()
203-
204-
# Add the response message to the conversation
205-
agent.messages.append(message)
206-
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
207-
yield ModelMessageEvent(message=message)
208-
209-
# Update metrics
210-
agent.event_loop_metrics.update_usage(usage)
211-
agent.event_loop_metrics.update_metrics(metrics)
212-
213132
if stop_reason == "max_tokens":
214133
"""
215134
Handle max_tokens limit reached by the model.
@@ -307,6 +226,105 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
307226
recursive_trace.end()
308227

309228

229+
async def _handle_model_execution(
230+
agent: "Agent",
231+
cycle_span: Any,
232+
cycle_trace: Trace,
233+
invocation_state: dict[str, Any],
234+
tracer: Tracer,
235+
) -> AsyncGenerator[TypedEvent, None]:
236+
"""<TODO>."""
237+
# Create a trace for the stream_messages call
238+
stream_trace = Trace("stream_messages", parent_id=cycle_trace.id)
239+
cycle_trace.add_child(stream_trace)
240+
241+
# Retry loop for handling throttling exceptions
242+
current_delay = INITIAL_DELAY
243+
for attempt in range(MAX_ATTEMPTS):
244+
model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None
245+
model_invoke_span = tracer.start_model_invoke_span(
246+
messages=agent.messages,
247+
parent_span=cycle_span,
248+
model_id=model_id,
249+
)
250+
with trace_api.use_span(model_invoke_span):
251+
agent.hooks.invoke_callbacks(
252+
BeforeModelInvocationEvent(
253+
agent=agent,
254+
)
255+
)
256+
257+
tool_specs = agent.tool_registry.get_all_tool_specs()
258+
259+
try:
260+
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
261+
yield event
262+
263+
stop_reason, message, usage, metrics = event["stop"]
264+
invocation_state.setdefault("request_state", {})
265+
266+
agent.hooks.invoke_callbacks(
267+
AfterModelInvocationEvent(
268+
agent=agent,
269+
stop_response=AfterModelInvocationEvent.ModelStopResponse(
270+
stop_reason=stop_reason,
271+
message=message,
272+
),
273+
)
274+
)
275+
276+
if stop_reason == "max_tokens":
277+
message = recover_message_on_max_tokens_reached(message)
278+
279+
if model_invoke_span:
280+
tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason)
281+
break # Success! Break out of retry loop
282+
283+
except Exception as e:
284+
if model_invoke_span:
285+
tracer.end_span_with_error(model_invoke_span, str(e), e)
286+
287+
agent.hooks.invoke_callbacks(
288+
AfterModelInvocationEvent(
289+
agent=agent,
290+
exception=e,
291+
)
292+
)
293+
294+
if isinstance(e, ModelThrottledException):
295+
if attempt + 1 == MAX_ATTEMPTS:
296+
yield ForceStopEvent(reason=e)
297+
raise e
298+
299+
logger.debug(
300+
"retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> "
301+
"| throttling exception encountered "
302+
"| delaying before next retry",
303+
current_delay,
304+
MAX_ATTEMPTS,
305+
attempt + 1,
306+
)
307+
await asyncio.sleep(current_delay)
308+
current_delay = min(current_delay * 2, MAX_DELAY)
309+
310+
yield EventLoopThrottleEvent(delay=current_delay)
311+
else:
312+
raise e
313+
314+
# Add message in trace and mark the end of the stream messages trace
315+
stream_trace.add_message(message)
316+
stream_trace.end()
317+
318+
# Update metrics
319+
agent.event_loop_metrics.update_usage(usage)
320+
agent.event_loop_metrics.update_metrics(metrics)
321+
322+
# Add the response message to the conversation
323+
agent.messages.append(message)
324+
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
325+
yield ModelMessageEvent(message=message)
326+
327+
310328
async def _handle_tool_execution(
311329
stop_reason: StopReason,
312330
message: Message,
@@ -345,15 +363,29 @@ async def _handle_tool_execution(
345363
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
346364
return
347365

366+
tool_interrupts = []
348367
tool_events = agent.tool_executor._execute(
349368
agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state
350369
)
351370
async for tool_event in tool_events:
352371
yield tool_event
353372

373+
if isinstance(tool_event, ToolInterruptEvent):
374+
tool_interrupts.append(tool_event["tool_interrupt_event"]["interrupt"])
375+
354376
# Store parent cycle ID for the next cycle
355377
invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"]
356378

379+
if tool_interrupts:
380+
# TODO: deal with metrics and traces
381+
yield EventLoopStopEvent(
382+
"interrupt", message, agent.event_loop_metrics, invocation_state["request_state"], tool_interrupts
383+
)
384+
return
385+
386+
agent.interrupted = False
387+
agent.interrupts = {}
388+
357389
tool_result_message: Message = {
358390
"role": "user",
359391
"content": [{"toolResult": result} for result in tool_results],

0 commit comments

Comments
 (0)