|
23 | 23 | MessageAddedEvent, |
24 | 24 | ) |
25 | 25 | from ..telemetry.metrics import Trace |
26 | | -from ..telemetry.tracer import get_tracer |
| 26 | +from ..telemetry.tracer import Tracer, get_tracer |
27 | 27 | from ..tools._validator import validate_and_prepare_tools |
28 | 28 | from ..types._events import ( |
29 | 29 | EventLoopStopEvent, |
|
33 | 33 | ModelStopReason, |
34 | 34 | StartEvent, |
35 | 35 | StartEventLoopEvent, |
| 36 | + ToolInterruptEvent, |
36 | 37 | ToolResultMessageEvent, |
37 | 38 | TypedEvent, |
38 | 39 | ) |
@@ -112,104 +113,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> |
112 | 113 | ) |
113 | 114 | invocation_state["event_loop_cycle_span"] = cycle_span |
114 | 115 |
|
115 | | - # Create a trace for the stream_messages call |
116 | | - stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) |
117 | | - cycle_trace.add_child(stream_trace) |
118 | | - |
119 | | - # Process messages with exponential backoff for throttling |
120 | | - message: Message |
121 | 116 | stop_reason: StopReason |
122 | | - usage: Any |
123 | | - metrics: Metrics |
124 | | - |
125 | | - # Retry loop for handling throttling exceptions |
126 | | - current_delay = INITIAL_DELAY |
127 | | - for attempt in range(MAX_ATTEMPTS): |
128 | | - model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None |
129 | | - model_invoke_span = tracer.start_model_invoke_span( |
130 | | - messages=agent.messages, |
131 | | - parent_span=cycle_span, |
132 | | - model_id=model_id, |
133 | | - ) |
134 | | - with trace_api.use_span(model_invoke_span): |
135 | | - agent.hooks.invoke_callbacks( |
136 | | - BeforeModelInvocationEvent( |
137 | | - agent=agent, |
138 | | - ) |
139 | | - ) |
140 | | - |
141 | | - tool_specs = agent.tool_registry.get_all_tool_specs() |
142 | 117 |
|
143 | | - try: |
144 | | - async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): |
145 | | - if not isinstance(event, ModelStopReason): |
146 | | - yield event |
| 118 | + if agent.interrupted: |
| 119 | + stop_reason = "tool_use" |
| 120 | + else: |
| 121 | + events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) |
| 122 | + async for event in events: |
| 123 | + if isinstance(event, ModelStopReason): |
| 124 | + stop_reason = event["stop"][0] |
| 125 | + continue |
147 | 126 |
|
148 | | - stop_reason, message, usage, metrics = event["stop"] |
149 | | - invocation_state.setdefault("request_state", {}) |
| 127 | + yield event |
150 | 128 |
|
151 | | - agent.hooks.invoke_callbacks( |
152 | | - AfterModelInvocationEvent( |
153 | | - agent=agent, |
154 | | - stop_response=AfterModelInvocationEvent.ModelStopResponse( |
155 | | - stop_reason=stop_reason, |
156 | | - message=message, |
157 | | - ), |
158 | | - ) |
159 | | - ) |
160 | | - |
161 | | - if stop_reason == "max_tokens": |
162 | | - message = recover_message_on_max_tokens_reached(message) |
163 | | - |
164 | | - if model_invoke_span: |
165 | | - tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) |
166 | | - break # Success! Break out of retry loop |
167 | | - |
168 | | - except Exception as e: |
169 | | - if model_invoke_span: |
170 | | - tracer.end_span_with_error(model_invoke_span, str(e), e) |
171 | | - |
172 | | - agent.hooks.invoke_callbacks( |
173 | | - AfterModelInvocationEvent( |
174 | | - agent=agent, |
175 | | - exception=e, |
176 | | - ) |
177 | | - ) |
178 | | - |
179 | | - if isinstance(e, ModelThrottledException): |
180 | | - if attempt + 1 == MAX_ATTEMPTS: |
181 | | - yield ForceStopEvent(reason=e) |
182 | | - raise e |
183 | | - |
184 | | - logger.debug( |
185 | | - "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " |
186 | | - "| throttling exception encountered " |
187 | | - "| delaying before next retry", |
188 | | - current_delay, |
189 | | - MAX_ATTEMPTS, |
190 | | - attempt + 1, |
191 | | - ) |
192 | | - await asyncio.sleep(current_delay) |
193 | | - current_delay = min(current_delay * 2, MAX_DELAY) |
194 | | - |
195 | | - yield EventLoopThrottleEvent(delay=current_delay) |
196 | | - else: |
197 | | - raise e |
| 129 | + message = agent.messages[-1] |
198 | 130 |
|
199 | 131 | try: |
200 | | - # Add message in trace and mark the end of the stream messages trace |
201 | | - stream_trace.add_message(message) |
202 | | - stream_trace.end() |
203 | | - |
204 | | - # Add the response message to the conversation |
205 | | - agent.messages.append(message) |
206 | | - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) |
207 | | - yield ModelMessageEvent(message=message) |
208 | | - |
209 | | - # Update metrics |
210 | | - agent.event_loop_metrics.update_usage(usage) |
211 | | - agent.event_loop_metrics.update_metrics(metrics) |
212 | | - |
213 | 132 | if stop_reason == "max_tokens": |
214 | 133 | """ |
215 | 134 | Handle max_tokens limit reached by the model. |
@@ -307,6 +226,105 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - |
307 | 226 | recursive_trace.end() |
308 | 227 |
|
309 | 228 |
|
| 229 | +async def _handle_model_execution( |
| 230 | + agent: "Agent", |
| 231 | + cycle_span: Any, |
| 232 | + cycle_trace: Trace, |
| 233 | + invocation_state: dict[str, Any], |
| 234 | + tracer: Tracer, |
| 235 | +) -> AsyncGenerator[TypedEvent, None]: |
| 236 | + """<TODO>.""" |
| 237 | + # Create a trace for the stream_messages call |
| 238 | + stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) |
| 239 | + cycle_trace.add_child(stream_trace) |
| 240 | + |
| 241 | + # Retry loop for handling throttling exceptions |
| 242 | + current_delay = INITIAL_DELAY |
| 243 | + for attempt in range(MAX_ATTEMPTS): |
| 244 | + model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None |
| 245 | + model_invoke_span = tracer.start_model_invoke_span( |
| 246 | + messages=agent.messages, |
| 247 | + parent_span=cycle_span, |
| 248 | + model_id=model_id, |
| 249 | + ) |
| 250 | + with trace_api.use_span(model_invoke_span): |
| 251 | + agent.hooks.invoke_callbacks( |
| 252 | + BeforeModelInvocationEvent( |
| 253 | + agent=agent, |
| 254 | + ) |
| 255 | + ) |
| 256 | + |
| 257 | + tool_specs = agent.tool_registry.get_all_tool_specs() |
| 258 | + |
| 259 | + try: |
| 260 | + async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): |
| 261 | + yield event |
| 262 | + |
| 263 | + stop_reason, message, usage, metrics = event["stop"] |
| 264 | + invocation_state.setdefault("request_state", {}) |
| 265 | + |
| 266 | + agent.hooks.invoke_callbacks( |
| 267 | + AfterModelInvocationEvent( |
| 268 | + agent=agent, |
| 269 | + stop_response=AfterModelInvocationEvent.ModelStopResponse( |
| 270 | + stop_reason=stop_reason, |
| 271 | + message=message, |
| 272 | + ), |
| 273 | + ) |
| 274 | + ) |
| 275 | + |
| 276 | + if stop_reason == "max_tokens": |
| 277 | + message = recover_message_on_max_tokens_reached(message) |
| 278 | + |
| 279 | + if model_invoke_span: |
| 280 | + tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) |
| 281 | + break # Success! Break out of retry loop |
| 282 | + |
| 283 | + except Exception as e: |
| 284 | + if model_invoke_span: |
| 285 | + tracer.end_span_with_error(model_invoke_span, str(e), e) |
| 286 | + |
| 287 | + agent.hooks.invoke_callbacks( |
| 288 | + AfterModelInvocationEvent( |
| 289 | + agent=agent, |
| 290 | + exception=e, |
| 291 | + ) |
| 292 | + ) |
| 293 | + |
| 294 | + if isinstance(e, ModelThrottledException): |
| 295 | + if attempt + 1 == MAX_ATTEMPTS: |
| 296 | + yield ForceStopEvent(reason=e) |
| 297 | + raise e |
| 298 | + |
| 299 | + logger.debug( |
| 300 | + "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " |
| 301 | + "| throttling exception encountered " |
| 302 | + "| delaying before next retry", |
| 303 | + current_delay, |
| 304 | + MAX_ATTEMPTS, |
| 305 | + attempt + 1, |
| 306 | + ) |
| 307 | + await asyncio.sleep(current_delay) |
| 308 | + current_delay = min(current_delay * 2, MAX_DELAY) |
| 309 | + |
| 310 | + yield EventLoopThrottleEvent(delay=current_delay) |
| 311 | + else: |
| 312 | + raise e |
| 313 | + |
| 314 | + # Add message in trace and mark the end of the stream messages trace |
| 315 | + stream_trace.add_message(message) |
| 316 | + stream_trace.end() |
| 317 | + |
| 318 | + # Update metrics |
| 319 | + agent.event_loop_metrics.update_usage(usage) |
| 320 | + agent.event_loop_metrics.update_metrics(metrics) |
| 321 | + |
| 322 | + # Add the response message to the conversation |
| 323 | + agent.messages.append(message) |
| 324 | + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) |
| 325 | + yield ModelMessageEvent(message=message) |
| 326 | + |
| 327 | + |
310 | 328 | async def _handle_tool_execution( |
311 | 329 | stop_reason: StopReason, |
312 | 330 | message: Message, |
@@ -345,15 +363,29 @@ async def _handle_tool_execution( |
345 | 363 | yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) |
346 | 364 | return |
347 | 365 |
|
| 366 | + tool_interrupts = [] |
348 | 367 | tool_events = agent.tool_executor._execute( |
349 | 368 | agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state |
350 | 369 | ) |
351 | 370 | async for tool_event in tool_events: |
352 | 371 | yield tool_event |
353 | 372 |
|
| 373 | + if isinstance(tool_event, ToolInterruptEvent): |
| 374 | + tool_interrupts.append(tool_event["tool_interrupt_event"]["interrupt"]) |
| 375 | + |
354 | 376 | # Store parent cycle ID for the next cycle |
355 | 377 | invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] |
356 | 378 |
|
| 379 | + if tool_interrupts: |
| 380 | + # TODO: deal with metrics and traces |
| 381 | + yield EventLoopStopEvent( |
| 382 | + "interrupt", message, agent.event_loop_metrics, invocation_state["request_state"], tool_interrupts |
| 383 | + ) |
| 384 | + return |
| 385 | + |
| 386 | + agent.interrupted = False |
| 387 | + agent.interrupts = {} |
| 388 | + |
357 | 389 | tool_result_message: Message = { |
358 | 390 | "role": "user", |
359 | 391 | "content": [{"toolResult": result} for result in tool_results], |
|
0 commit comments