Skip to content

Commit 6064015

Browse files
thomasahleisaacbmiller
authored andcommitted
Added support for more types (#1900)
* Added support for more types
1 parent 0cd8494 commit 6064015

File tree

2 files changed

+208
-32
lines changed

2 files changed

+208
-32
lines changed

dspy/signatures/signature.py

Lines changed: 85 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from contextlib import ExitStack, contextmanager
88
from copy import deepcopy
99
from typing import Any, Dict, Tuple, Type, Union # noqa: UP035
10+
import importlib
1011

1112
from pydantic import BaseModel, Field, create_model
1213
from pydantic.fields import FieldInfo
1314

1415
import dsp
15-
from dspy.adapters.image_utils import Image
16+
from dspy.adapters.image_utils import Image # noqa: F401
1617
from dspy.signatures.field import InputField, OutputField, new_to_old_field
1718

1819

@@ -355,7 +356,7 @@ def make_signature(
355356
if type_ is None:
356357
type_ = str
357358
# if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type):
358-
if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias)):
359+
if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias, typing._SpecialForm)):
359360
raise ValueError(f"Field types must be types, not {type(type_)}")
360361
if not isinstance(field, FieldInfo):
361362
raise ValueError(f"Field values must be Field instances, not {type(field)}")
@@ -400,53 +401,106 @@ def _parse_arg_string(string: str, names=None) -> Dict[str, str]:
400401

401402

402403
def _parse_type_node(node, names=None) -> Any:
403-
"""Recursively parse an AST node representing a type annotation.
404-
405-
without using structural pattern matching introduced in Python 3.10.
406-
"""
404+
"""Recursively parse an AST node representing a type annotation."""
407405

408406
if names is None:
409-
names = typing.__dict__
407+
names = dict(typing.__dict__)
408+
names['NoneType'] = type(None)
409+
410+
def resolve_name(id_: str):
411+
# Check if it's a built-in known type or in the provided names
412+
if id_ in names:
413+
return names[id_]
414+
415+
# Common built-in types
416+
builtin_types = [int, str, float, bool, list, tuple, dict, set, frozenset, complex, bytes, bytearray]
417+
418+
# Try PIL Image if 'Image' encountered
419+
if 'Image' not in names:
420+
try:
421+
from PIL import Image
422+
names['Image'] = Image
423+
except ImportError:
424+
pass
425+
426+
# If we have PIL Image and id_ is 'Image', return it
427+
if 'Image' in names and id_ == 'Image':
428+
return names['Image']
429+
430+
# Check if it matches any known built-in type by name
431+
for t in builtin_types:
432+
if t.__name__ == id_:
433+
return t
434+
435+
# Attempt to import a module with this name dynamically
436+
# This allows handling of module-based annotations like `dspy.Image`.
437+
try:
438+
mod = importlib.import_module(id_)
439+
names[id_] = mod
440+
return mod
441+
except ImportError:
442+
pass
443+
444+
# If we don't know the type or module, raise an error
445+
raise ValueError(f"Unknown name: {id_}")
410446

411447
if isinstance(node, ast.Module):
412-
body = node.body
413-
if len(body) != 1:
414-
raise ValueError(f"Code is not syntactically valid: {node}")
415-
return _parse_type_node(body[0], names)
448+
if len(node.body) != 1:
449+
raise ValueError(f"Code is not syntactically valid: {ast.dump(node)}")
450+
return _parse_type_node(node.body[0], names)
416451

417452
if isinstance(node, ast.Expr):
418-
value = node.value
419-
return _parse_type_node(value, names)
453+
return _parse_type_node(node.value, names)
420454

421455
if isinstance(node, ast.Name):
422-
id_ = node.id
423-
if id_ in names:
424-
return names[id_]
456+
return resolve_name(node.id)
425457

426-
for type_ in [int, str, float, bool, list, tuple, dict, Image]:
427-
if type_.__name__ == id_:
428-
return type_
429-
raise ValueError(f"Unknown name: {id_}")
458+
if isinstance(node, ast.Attribute):
459+
base = _parse_type_node(node.value, names)
460+
attr_name = node.attr
461+
if hasattr(base, attr_name):
462+
return getattr(base, attr_name)
463+
else:
464+
raise ValueError(f"Unknown attribute: {attr_name} on {base}")
430465

431466
if isinstance(node, ast.Subscript):
432467
base_type = _parse_type_node(node.value, names)
433-
arg_type = _parse_type_node(node.slice, names)
434-
return base_type[arg_type]
468+
slice_node = node.slice
469+
if isinstance(slice_node, ast.Index): # For older Python versions
470+
slice_node = slice_node.value
471+
472+
if isinstance(slice_node, ast.Tuple):
473+
arg_types = tuple(_parse_type_node(elt, names) for elt in slice_node.elts)
474+
else:
475+
arg_types = (_parse_type_node(slice_node, names),)
476+
477+
# Special handling for Union, Optional
478+
if base_type is typing.Union:
479+
return typing.Union[arg_types]
480+
if base_type is typing.Optional:
481+
if len(arg_types) != 1:
482+
raise ValueError("Optional must have exactly one type argument")
483+
return typing.Optional[arg_types[0]]
484+
485+
return base_type[arg_types]
435486

436487
if isinstance(node, ast.Tuple):
437-
elts = node.elts
438-
return tuple(_parse_type_node(elt, names) for elt in elts)
488+
return tuple(_parse_type_node(elt, names) for elt in node.elts)
489+
490+
if isinstance(node, ast.Constant):
491+
return node.value
439492

440-
if isinstance(node, ast.Call) and node.func.id == "Field":
493+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "Field":
441494
keys = [kw.arg for kw in node.keywords]
442-
values = [kw.value.value for kw in node.keywords]
495+
values = []
496+
for kw in node.keywords:
497+
if isinstance(kw.value, ast.Constant):
498+
values.append(kw.value.value)
499+
else:
500+
values.append(_parse_type_node(kw.value, names))
443501
return Field(**dict(zip(keys, values)))
444502

445-
if isinstance(node, ast.Attribute) and node.attr == "Image":
446-
return Image
447-
448-
raise ValueError(f"Code is not syntactically valid: {node}")
449-
503+
raise ValueError(f"Unhandled AST node type in annotation: {ast.dump(node)}")
450504

451505
def infer_prefix(attribute_name: str) -> str:
452506
"""Infer a prefix from an attribute name."""

tests/signatures/test_signature.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import textwrap
2-
from typing import List
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

44
import pydantic
55
import pytest
@@ -279,3 +279,125 @@ class CustomSignature2(dspy.Signature):
279279
assert CustomSignature2.instructions == "I am a malicious instruction."
280280
assert CustomSignature2.fields["sentence"].json_schema_extra["desc"] == "I am an malicious input!"
281281
assert CustomSignature2.fields["sentiment"].json_schema_extra["prefix"] == "Sentiment:"
282+
283+
284+
def test_typed_signatures_basic_types():
285+
# Simple built-in types
286+
sig = Signature("input1: int, input2: str -> output: float")
287+
assert "input1" in sig.input_fields
288+
assert sig.input_fields["input1"].annotation == int
289+
assert "input2" in sig.input_fields
290+
assert sig.input_fields["input2"].annotation == str
291+
assert "output" in sig.output_fields
292+
assert sig.output_fields["output"].annotation == float
293+
294+
295+
def test_typed_signatures_generics():
296+
# More complex generic types
297+
sig = Signature("input_list: List[int], input_dict: Dict[str, float] -> output_tuple: Tuple[str, int]")
298+
assert "input_list" in sig.input_fields
299+
assert sig.input_fields["input_list"].annotation == List[int]
300+
assert "input_dict" in sig.input_fields
301+
assert sig.input_fields["input_dict"].annotation == Dict[str, float]
302+
assert "output_tuple" in sig.output_fields
303+
assert sig.output_fields["output_tuple"].annotation == Tuple[str, int]
304+
305+
306+
def test_typed_signatures_unions_and_optionals():
307+
sig = Signature("input_opt: Optional[str], input_union: Union[int, None] -> output_union: Union[int, str]")
308+
assert "input_opt" in sig.input_fields
309+
# Optional[str] is actually Union[str, None]
310+
# Depending on the environment, it might resolve to Union[str, None] or Optional[str], either is correct.
311+
# We'll just check for a Union containing str and NoneType:
312+
input_opt_annotation = sig.input_fields["input_opt"].annotation
313+
assert (input_opt_annotation == Optional[str] or
314+
(getattr(input_opt_annotation, '__origin__', None) is Union and str in input_opt_annotation.__args__ and type(None) in input_opt_annotation.__args__))
315+
316+
assert "input_union" in sig.input_fields
317+
input_union_annotation = sig.input_fields["input_union"].annotation
318+
assert (getattr(input_union_annotation, '__origin__', None) is Union and
319+
int in input_union_annotation.__args__ and type(None) in input_union_annotation.__args__)
320+
321+
assert "output_union" in sig.output_fields
322+
output_union_annotation = sig.output_fields["output_union"].annotation
323+
assert (getattr(output_union_annotation, '__origin__', None) is Union and
324+
int in output_union_annotation.__args__ and str in output_union_annotation.__args__)
325+
326+
327+
def test_typed_signatures_any():
328+
sig = Signature("input_any: Any -> output_any: Any")
329+
assert "input_any" in sig.input_fields
330+
assert sig.input_fields["input_any"].annotation == Any
331+
assert "output_any" in sig.output_fields
332+
assert sig.output_fields["output_any"].annotation == Any
333+
334+
335+
def test_typed_signatures_nested():
336+
# Nested generics and unions
337+
sig = Signature("input_nested: List[Union[str, int]] -> output_nested: Tuple[int, Optional[float], List[str]]")
338+
input_nested_ann = sig.input_fields["input_nested"].annotation
339+
assert getattr(input_nested_ann, '__origin__', None) is list
340+
assert len(input_nested_ann.__args__) == 1
341+
union_arg = input_nested_ann.__args__[0]
342+
assert getattr(union_arg, '__origin__', None) is Union
343+
assert str in union_arg.__args__ and int in union_arg.__args__
344+
345+
output_nested_ann = sig.output_fields["output_nested"].annotation
346+
assert getattr(output_nested_ann, '__origin__', None) is tuple
347+
assert output_nested_ann.__args__[0] == int
348+
# The second arg is Optional[float], which is Union[float, None]
349+
second_arg = output_nested_ann.__args__[1]
350+
assert getattr(second_arg, '__origin__', None) is Union
351+
assert float in second_arg.__args__ and type(None) in second_arg.__args__
352+
# The third arg is List[str]
353+
third_arg = output_nested_ann.__args__[2]
354+
assert getattr(third_arg, '__origin__', None) is list
355+
assert third_arg.__args__[0] == str
356+
357+
358+
def test_typed_signatures_from_dict():
359+
# Creating a Signature directly from a dictionary with types
360+
fields = {
361+
"input_str_list": (List[str], InputField()),
362+
"input_dict_int": (Dict[str, int], InputField()),
363+
"output_tup": (Tuple[int, float], OutputField()),
364+
}
365+
sig = Signature(fields)
366+
assert "input_str_list" in sig.input_fields
367+
assert sig.input_fields["input_str_list"].annotation == List[str]
368+
assert "input_dict_int" in sig.input_fields
369+
assert sig.input_fields["input_dict_int"].annotation == Dict[str, int]
370+
assert "output_tup" in sig.output_fields
371+
assert sig.output_fields["output_tup"].annotation == Tuple[int, float]
372+
373+
374+
def test_typed_signatures_complex_combinations():
375+
# Test a very complex signature with multiple nested constructs
376+
# input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]
377+
sig = Signature("input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]")
378+
input_complex_ann = sig.input_fields["input_complex"].annotation
379+
assert getattr(input_complex_ann, '__origin__', None) is dict
380+
key_arg, value_arg = input_complex_ann.__args__
381+
assert key_arg == str
382+
# value_arg: List[Optional[Tuple[int, str]]]
383+
assert getattr(value_arg, '__origin__', None) is list
384+
inner_union = value_arg.__args__[0]
385+
# inner_union should be Optional[Tuple[int, str]]
386+
# which is Union[Tuple[int, str], None]
387+
assert getattr(inner_union, '__origin__', None) is Union
388+
tuple_type = [t for t in inner_union.__args__ if t != type(None)][0]
389+
assert getattr(tuple_type, '__origin__', None) is tuple
390+
assert tuple_type.__args__ == (int, str)
391+
392+
output_complex_ann = sig.output_fields["output_complex"].annotation
393+
assert getattr(output_complex_ann, '__origin__', None) is Union
394+
assert len(output_complex_ann.__args__) == 2
395+
possible_args = set(output_complex_ann.__args__)
396+
# Expecting List[str] and Dict[str, Any]
397+
# Because sets don't preserve order, just check membership.
398+
# Find the List[str] arg
399+
list_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is list)
400+
dict_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is dict)
401+
assert list_arg.__args__ == (str,)
402+
k, v = dict_arg.__args__
403+
assert k == str and v == Any

0 commit comments

Comments
 (0)