|
1 | 1 | import textwrap |
2 | | -from typing import List |
| 2 | +from typing import Any, Dict, List, Optional, Tuple, Union |
3 | 3 |
|
4 | 4 | import pydantic |
5 | 5 | import pytest |
@@ -279,3 +279,125 @@ class CustomSignature2(dspy.Signature): |
279 | 279 | assert CustomSignature2.instructions == "I am a malicious instruction." |
280 | 280 | assert CustomSignature2.fields["sentence"].json_schema_extra["desc"] == "I am an malicious input!" |
281 | 281 | assert CustomSignature2.fields["sentiment"].json_schema_extra["prefix"] == "Sentiment:" |
| 282 | + |
| 283 | + |
| 284 | +def test_typed_signatures_basic_types(): |
| 285 | + # Simple built-in types |
| 286 | + sig = Signature("input1: int, input2: str -> output: float") |
| 287 | + assert "input1" in sig.input_fields |
| 288 | + assert sig.input_fields["input1"].annotation == int |
| 289 | + assert "input2" in sig.input_fields |
| 290 | + assert sig.input_fields["input2"].annotation == str |
| 291 | + assert "output" in sig.output_fields |
| 292 | + assert sig.output_fields["output"].annotation == float |
| 293 | + |
| 294 | + |
| 295 | +def test_typed_signatures_generics(): |
| 296 | + # More complex generic types |
| 297 | + sig = Signature("input_list: List[int], input_dict: Dict[str, float] -> output_tuple: Tuple[str, int]") |
| 298 | + assert "input_list" in sig.input_fields |
| 299 | + assert sig.input_fields["input_list"].annotation == List[int] |
| 300 | + assert "input_dict" in sig.input_fields |
| 301 | + assert sig.input_fields["input_dict"].annotation == Dict[str, float] |
| 302 | + assert "output_tuple" in sig.output_fields |
| 303 | + assert sig.output_fields["output_tuple"].annotation == Tuple[str, int] |
| 304 | + |
| 305 | + |
| 306 | +def test_typed_signatures_unions_and_optionals(): |
| 307 | + sig = Signature("input_opt: Optional[str], input_union: Union[int, None] -> output_union: Union[int, str]") |
| 308 | + assert "input_opt" in sig.input_fields |
| 309 | + # Optional[str] is actually Union[str, None] |
| 310 | + # Depending on the environment, it might resolve to Union[str, None] or Optional[str], either is correct. |
| 311 | + # We'll just check for a Union containing str and NoneType: |
| 312 | + input_opt_annotation = sig.input_fields["input_opt"].annotation |
| 313 | + assert (input_opt_annotation == Optional[str] or |
| 314 | + (getattr(input_opt_annotation, '__origin__', None) is Union and str in input_opt_annotation.__args__ and type(None) in input_opt_annotation.__args__)) |
| 315 | + |
| 316 | + assert "input_union" in sig.input_fields |
| 317 | + input_union_annotation = sig.input_fields["input_union"].annotation |
| 318 | + assert (getattr(input_union_annotation, '__origin__', None) is Union and |
| 319 | + int in input_union_annotation.__args__ and type(None) in input_union_annotation.__args__) |
| 320 | + |
| 321 | + assert "output_union" in sig.output_fields |
| 322 | + output_union_annotation = sig.output_fields["output_union"].annotation |
| 323 | + assert (getattr(output_union_annotation, '__origin__', None) is Union and |
| 324 | + int in output_union_annotation.__args__ and str in output_union_annotation.__args__) |
| 325 | + |
| 326 | + |
| 327 | +def test_typed_signatures_any(): |
| 328 | + sig = Signature("input_any: Any -> output_any: Any") |
| 329 | + assert "input_any" in sig.input_fields |
| 330 | + assert sig.input_fields["input_any"].annotation == Any |
| 331 | + assert "output_any" in sig.output_fields |
| 332 | + assert sig.output_fields["output_any"].annotation == Any |
| 333 | + |
| 334 | + |
| 335 | +def test_typed_signatures_nested(): |
| 336 | + # Nested generics and unions |
| 337 | + sig = Signature("input_nested: List[Union[str, int]] -> output_nested: Tuple[int, Optional[float], List[str]]") |
| 338 | + input_nested_ann = sig.input_fields["input_nested"].annotation |
| 339 | + assert getattr(input_nested_ann, '__origin__', None) is list |
| 340 | + assert len(input_nested_ann.__args__) == 1 |
| 341 | + union_arg = input_nested_ann.__args__[0] |
| 342 | + assert getattr(union_arg, '__origin__', None) is Union |
| 343 | + assert str in union_arg.__args__ and int in union_arg.__args__ |
| 344 | + |
| 345 | + output_nested_ann = sig.output_fields["output_nested"].annotation |
| 346 | + assert getattr(output_nested_ann, '__origin__', None) is tuple |
| 347 | + assert output_nested_ann.__args__[0] == int |
| 348 | + # The second arg is Optional[float], which is Union[float, None] |
| 349 | + second_arg = output_nested_ann.__args__[1] |
| 350 | + assert getattr(second_arg, '__origin__', None) is Union |
| 351 | + assert float in second_arg.__args__ and type(None) in second_arg.__args__ |
| 352 | + # The third arg is List[str] |
| 353 | + third_arg = output_nested_ann.__args__[2] |
| 354 | + assert getattr(third_arg, '__origin__', None) is list |
| 355 | + assert third_arg.__args__[0] == str |
| 356 | + |
| 357 | + |
| 358 | +def test_typed_signatures_from_dict(): |
| 359 | + # Creating a Signature directly from a dictionary with types |
| 360 | + fields = { |
| 361 | + "input_str_list": (List[str], InputField()), |
| 362 | + "input_dict_int": (Dict[str, int], InputField()), |
| 363 | + "output_tup": (Tuple[int, float], OutputField()), |
| 364 | + } |
| 365 | + sig = Signature(fields) |
| 366 | + assert "input_str_list" in sig.input_fields |
| 367 | + assert sig.input_fields["input_str_list"].annotation == List[str] |
| 368 | + assert "input_dict_int" in sig.input_fields |
| 369 | + assert sig.input_fields["input_dict_int"].annotation == Dict[str, int] |
| 370 | + assert "output_tup" in sig.output_fields |
| 371 | + assert sig.output_fields["output_tup"].annotation == Tuple[int, float] |
| 372 | + |
| 373 | + |
| 374 | +def test_typed_signatures_complex_combinations(): |
| 375 | + # Test a very complex signature with multiple nested constructs |
| 376 | + # input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]] |
| 377 | + sig = Signature("input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]") |
| 378 | + input_complex_ann = sig.input_fields["input_complex"].annotation |
| 379 | + assert getattr(input_complex_ann, '__origin__', None) is dict |
| 380 | + key_arg, value_arg = input_complex_ann.__args__ |
| 381 | + assert key_arg == str |
| 382 | + # value_arg: List[Optional[Tuple[int, str]]] |
| 383 | + assert getattr(value_arg, '__origin__', None) is list |
| 384 | + inner_union = value_arg.__args__[0] |
| 385 | + # inner_union should be Optional[Tuple[int, str]] |
| 386 | + # which is Union[Tuple[int, str], None] |
| 387 | + assert getattr(inner_union, '__origin__', None) is Union |
| 388 | + tuple_type = [t for t in inner_union.__args__ if t != type(None)][0] |
| 389 | + assert getattr(tuple_type, '__origin__', None) is tuple |
| 390 | + assert tuple_type.__args__ == (int, str) |
| 391 | + |
| 392 | + output_complex_ann = sig.output_fields["output_complex"].annotation |
| 393 | + assert getattr(output_complex_ann, '__origin__', None) is Union |
| 394 | + assert len(output_complex_ann.__args__) == 2 |
| 395 | + possible_args = set(output_complex_ann.__args__) |
| 396 | + # Expecting List[str] and Dict[str, Any] |
| 397 | + # Because sets don't preserve order, just check membership. |
| 398 | + # Find the List[str] arg |
| 399 | + list_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is list) |
| 400 | + dict_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is dict) |
| 401 | + assert list_arg.__args__ == (str,) |
| 402 | + k, v = dict_arg.__args__ |
| 403 | + assert k == str and v == Any |
0 commit comments