Skip to content

Commit 3edcd9e

Browse files
authored
Add on_tool_start/end callbacks (#1879)
Signed-off-by: B-Step62 <[email protected]>
1 parent b54557c commit 3edcd9e

File tree

4 files changed

+83
-1
lines changed

4 files changed

+83
-1
lines changed

dspy/predict/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
from .multi_chain_comparison import MultiChainComparison
66
from .predict import Predict
77
from .program_of_thought import ProgramOfThought
8-
from .react import ReAct
8+
from .react import ReAct, Tool
99
from .retry import Retry
1010
from .parallel import Parallel

dspy/predict/react.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dspy.primitives.program import Module
66
from dspy.signatures.signature import ensure_signature
77
from dspy.adapters.json_adapter import get_annotation_name
8+
from dspy.utils.callback import with_callbacks
89
from typing import Callable, Any, get_type_hints, get_origin, Literal
910

1011
class Tool:
@@ -19,6 +20,7 @@ def __init__(self, func: Callable, name: str = None, desc: str = None, args: dic
1920
for k, v in (args or get_type_hints(annotations_func)).items() if k != 'return'
2021
}
2122

23+
@with_callbacks
2224
def __call__(self, *args, **kwargs):
2325
return self.func(*args, **kwargs)
2426

dspy/utils/callback.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,38 @@ def on_adapter_parse_end(
190190
"""
191191
pass
192192

193+
def on_tool_start(
194+
self,
195+
call_id: str,
196+
instance: Any,
197+
inputs: Dict[str, Any],
198+
):
199+
"""A handler triggered when a tool is called.
200+
201+
Args:
202+
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
203+
instance: The Tool instance.
204+
inputs: The inputs to the Tool's __call__ method. Each arguments is stored as
205+
a key-value pair in a dictionary.
206+
"""
207+
pass
208+
209+
def on_tool_end(
210+
self,
211+
call_id: str,
212+
outputs: Optional[Dict[str, Any]],
213+
exception: Optional[Exception] = None,
214+
):
215+
"""A handler triggered after a tool is executed.
216+
217+
Args:
218+
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
219+
outputs: The outputs of the Tool's __call__ method. If the method is interrupted by
220+
an exception, this will be None.
221+
exception: If an exception is raised during the execution, it will be stored here.
222+
"""
223+
pass
224+
193225

194226
def with_callbacks(fn):
195227
@functools.wraps(fn)
@@ -256,6 +288,9 @@ def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) -
256288
else:
257289
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")
258290

291+
if isinstance(instance, dspy.Tool):
292+
return callback.on_tool_start
293+
259294
# We treat everything else as a module.
260295
return callback.on_module_start
261296

@@ -272,5 +307,9 @@ def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) ->
272307
return callback.on_adapter_parse_end
273308
else:
274309
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")
310+
311+
if isinstance(instance, dspy.Tool):
312+
return callback.on_tool_end
313+
275314
# We treat everything else as a module.
276315
return callback.on_module_end

tests/callback/test_callback.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def on_adapter_parse_start(self, call_id, instance, inputs):
4747
def on_adapter_parse_end(self, call_id, outputs, exception):
4848
self.calls.append({"handler": "on_adapter_parse_end", "outputs": outputs, "exception": exception})
4949

50+
def on_tool_start(self, call_id, instance, inputs):
51+
self.calls.append({"handler": "on_tool_start", "instance": instance, "inputs": inputs})
52+
53+
def on_tool_end(self, call_id, outputs, exception):
54+
self.calls.append({"handler": "on_tool_end", "outputs": outputs, "exception": exception})
55+
5056

5157
@pytest.mark.parametrize(
5258
("args", "kwargs"),
@@ -181,6 +187,41 @@ def test_callback_complex_module():
181187
]
182188

183189

190+
def test_tool_calls():
191+
callback = MyCallback()
192+
dspy.settings.configure(callbacks=[callback])
193+
194+
def tool_1(query: str) -> str:
195+
"""A dummy tool function."""
196+
return "result 1"
197+
198+
def tool_2(query: str) -> str:
199+
"""Another dummy tool function."""
200+
return "result 2"
201+
202+
class MyModule(dspy.Module):
203+
def __init__(self):
204+
self.tools = [dspy.Tool(tool_1), dspy.Tool(tool_2)]
205+
206+
def forward(self, query: str) -> str:
207+
query = self.tools[0](query)
208+
return self.tools[1](query)
209+
210+
module = MyModule()
211+
result = module("query")
212+
213+
assert result == "result 2"
214+
assert len(callback.calls) == 6
215+
assert [call["handler"] for call in callback.calls] == [
216+
"on_module_start",
217+
"on_tool_start",
218+
"on_tool_end",
219+
"on_tool_start",
220+
"on_tool_end",
221+
"on_module_end",
222+
]
223+
224+
184225
def test_active_id():
185226
# Test the call ID is generated and handled properly
186227
class CustomCallback(BaseCallback):

0 commit comments

Comments
 (0)