Skip to content

Commit 4699adf

Browse files
tkelloggisaacbmiller
authored andcommitted
Allow react.Tool to wrap methods (#1856)
The big reason for this is to pass parameters out-of-band, e.g. a user_id to ensure the LLM doesn't get the wrong data. The unit test includes a usage, you can't use it as a decorator this way, but it works. The alternative, of course, is to have a very long function and have all the tools be nested functions. It works, but can lead to some very long functions. I prefer long classes over long functions.
1 parent 5deb0b8 commit 4699adf

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

dspy/predict/react.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
class Tool:
1111
def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None):
12-
annotations_func = func if inspect.isfunction(func) else func.__call__
12+
annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__
1313
self.func = func
1414
self.name = name or getattr(func, '__name__', type(func).__name__)
1515
self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "")

tests/predict/test_react.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dspy
44
from dspy.utils.dummies import DummyLM, dummy_rm
5+
from dspy.predict import react
56

67

78
# def test_example_no_tools():
@@ -121,4 +122,28 @@
121122
# react = dspy.ReAct(ExampleSignature)
122123

123124
# assert react.react[0].signature.instructions is not None
124-
# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.")
125+
# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.")
126+
127+
def test_tool_from_function():
128+
def foo(a: int, b: int) -> int:
129+
"""Add two numbers."""
130+
return a + b
131+
132+
tool = react.Tool(foo)
133+
assert tool.name == "foo"
134+
assert tool.desc == "Add two numbers."
135+
assert tool.args == {"a": "int", "b": "int"}
136+
137+
def test_tool_from_class():
138+
class Foo:
139+
def __init__(self, user_id: str):
140+
self.user_id = user_id
141+
142+
def foo(self, a: int, b: int) -> int:
143+
"""Add two numbers."""
144+
return a + b
145+
146+
tool = react.Tool(Foo("123").foo)
147+
assert tool.name == "foo"
148+
assert tool.desc == "Add two numbers."
149+
assert tool.args == {"a": "int", "b": "int"}

0 commit comments

Comments
 (0)