Skip to content

Commit a6d0622

Browse files
committed
hooks - before tool call event - cancel tool
1 parent 2493545 commit a6d0622

File tree

6 files changed

+95
-2
lines changed

6 files changed

+95
-2
lines changed

src/strands/hooks/events.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,17 @@ class BeforeToolCallEvent(HookEvent):
9797
to change which tool gets executed. This may be None if tool lookup failed.
9898
tool_use: The tool parameters that will be passed to selected_tool.
9999
invocation_state: Keyword arguments that will be passed to the tool.
100+
cancel: A user defined message that when set, will cancel the tool call.
101+
The message will be placed into a tool result with an error status.
100102
"""
101103

102104
selected_tool: Optional[AgentTool]
103105
tool_use: ToolUse
104106
invocation_state: dict[str, Any]
107+
cancel: Optional[str] = None
105108

106109
def _can_write(self, name: str) -> bool:
107-
return name in ["selected_tool", "tool_use"]
110+
return name in ["cancel", "selected_tool", "tool_use"]
108111

109112

110113
@dataclass

src/strands/tools/executors/_executor.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,24 @@ async def _stream(
8181
)
8282
)
8383

84+
if before_event.cancel:
85+
after_event = agent.hooks.invoke_callbacks(
86+
AfterToolCallEvent(
87+
agent=agent,
88+
tool_use=tool_use,
89+
invocation_state=invocation_state,
90+
result={
91+
"toolUseId": str(tool_use.get("toolUseId")),
92+
"status": "error",
93+
"content": [{"text": before_event.cancel}],
94+
},
95+
selected_tool=None,
96+
)
97+
)
98+
yield ToolResultEvent(after_event.result)
99+
tool_results.append(after_event.result)
100+
return
101+
84102
try:
85103
selected_tool = before_event.selected_tool
86104
tool_use = before_event.tool_use
@@ -123,7 +141,7 @@ async def _stream(
123141
# so that we don't needlessly yield ToolStreamEvents for non-generator callbacks.
124142
# In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent
125143
# we yield it directly; all other cases (non-sdk AgentTools), we wrap events in
126-
# ToolStreamEvent and the last even is just the result
144+
# ToolStreamEvent and the last event is just the result.
127145

128146
if isinstance(event, ToolResultEvent):
129147
# below the last "event" must point to the tool_result

tests/strands/tools/executors/test_executor.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ def tracer():
3131
yield mock_get_tracer.return_value
3232

3333

34+
@pytest.fixture
35+
def cancel_hook(agent):
36+
def callback(event):
37+
event.cancel = "Tool execution cancelled by user"
38+
return event
39+
40+
return callback
41+
42+
3443
@pytest.mark.asyncio
3544
async def test_executor_stream_yields_result(
3645
executor, agent, tool_results, invocation_state, hook_events, weather_tool, alist
@@ -215,3 +224,27 @@ async def test_executor_stream_with_trace(
215224

216225
cycle_trace.add_child.assert_called_once()
217226
assert isinstance(cycle_trace.add_child.call_args[0][0], Trace)
227+
228+
229+
@pytest.mark.asyncio
230+
async def test_executor_stream_cancel(executor, agent, cancel_hook, tool_results, invocation_state, alist):
231+
agent.hooks.add_callback(BeforeToolCallEvent, cancel_hook)
232+
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
233+
234+
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
235+
236+
tru_events = await alist(stream)
237+
exp_events = [
238+
ToolResultEvent(
239+
{
240+
"toolUseId": "1",
241+
"status": "error",
242+
"content": [{"text": "Tool execution cancelled by user"}],
243+
},
244+
),
245+
]
246+
assert tru_events == exp_events
247+
248+
tru_results = tool_results
249+
exp_results = [exp_events[-1].tool_result]
250+
assert tru_results == exp_results
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from strands.hooks import BeforeToolCallEvent, HookProvider
4+
5+
6+
@pytest.fixture
7+
def cancel_hook():
8+
class Hook(HookProvider):
9+
def register_hooks(self, registry):
10+
registry.add_callback(BeforeToolCallEvent, self.cancel)
11+
12+
def cancel(self, event):
13+
event.cancel = "cancelled tool call"
14+
15+
return Hook()

tests_integ/tools/executors/test_concurrent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23

34
import pytest
45

@@ -59,3 +60,14 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events):
5960
{"name": "time_tool", "event": "end"},
6061
]
6162
assert tru_events == exp_events
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_agent_invoke_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events):
67+
agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook])
68+
69+
await agent.invoke_async("What is the time in New York?")
70+
messages = json.dumps(agent.messages)
71+
72+
assert len(tool_events) == 0
73+
assert "cancelled tool call" in messages

tests_integ/tools/executors/test_sequential.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23

34
import pytest
45

@@ -59,3 +60,14 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events):
5960
{"name": "weather_tool", "event": "end"},
6061
]
6162
assert tru_events == exp_events
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_agent_invoke_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events):
67+
agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook])
68+
69+
await agent.invoke_async("What is the time in New York?")
70+
messages = json.dumps(agent.messages)
71+
72+
assert len(tool_events) == 0
73+
assert "cancelled tool call" in messages

0 commit comments

Comments
 (0)