@@ -31,6 +31,15 @@ def tracer():
3131 yield mock_get_tracer .return_value
3232
3333
34+ @pytest .fixture
35+ def cancel_hook (agent ):
36+ def callback (event ):
37+ event .cancel = "Tool execution cancelled by user"
38+ return event
39+
40+ return callback
41+
42+
3443@pytest .mark .asyncio
3544async def test_executor_stream_yields_result (
3645 executor , agent , tool_results , invocation_state , hook_events , weather_tool , alist
@@ -215,3 +224,27 @@ async def test_executor_stream_with_trace(
215224
216225 cycle_trace .add_child .assert_called_once ()
217226 assert isinstance (cycle_trace .add_child .call_args [0 ][0 ], Trace )
227+
228+
229+ @pytest .mark .asyncio
230+ async def test_executor_stream_cancel (executor , agent , cancel_hook , tool_results , invocation_state , alist ):
231+ agent .hooks .add_callback (BeforeToolCallEvent , cancel_hook )
232+ tool_use : ToolUse = {"name" : "weather_tool" , "toolUseId" : "1" , "input" : {}}
233+
234+ stream = executor ._stream (agent , tool_use , tool_results , invocation_state )
235+
236+ tru_events = await alist (stream )
237+ exp_events = [
238+ ToolResultEvent (
239+ {
240+ "toolUseId" : "1" ,
241+ "status" : "error" ,
242+ "content" : [{"text" : "Tool execution cancelled by user" }],
243+ },
244+ ),
245+ ]
246+ assert tru_events == exp_events
247+
248+ tru_results = tool_results
249+ exp_results = [exp_events [- 1 ].tool_result ]
250+ assert tru_results == exp_results
0 commit comments