Skip to content

Commit 863d96c

Browse files
committed
Fix: Use correct field_meta for constrained union types when building field values for coverage
1 parent 7d67749 commit 863d96c

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

polyfactory/factories/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,9 @@ def get_field_value_coverage( # noqa: C901,PLR0912
902902
for unwrapped_annotation in flatten_annotation(field_meta.annotation):
903903
unwrapped_annotation = cls._resolve_forward_references(unwrapped_annotation) # noqa: PLW2901
904904

905+
unwrapped_annotation_meta = next(
906+
(meta for meta in (field_meta.children or []) if meta.annotation == unwrapped_annotation), field_meta
907+
)
905908
if unwrapped_annotation in (None, NoneType):
906909
yield None
907910

@@ -911,11 +914,11 @@ def get_field_value_coverage( # noqa: C901,PLR0912
911914
elif isinstance(unwrapped_annotation, EnumMeta):
912915
yield CoverageContainer(list(unwrapped_annotation))
913916

914-
elif field_meta.constraints:
917+
elif unwrapped_annotation_meta.constraints:
915918
yield CoverageContainerCallable(
916919
cls.get_constrained_field_value,
917920
annotation=unwrapped_annotation,
918-
field_meta=field_meta,
921+
field_meta=unwrapped_annotation_meta,
919922
)
920923

921924
elif BaseFactory.is_factory_type(annotation=unwrapped_annotation):

tests/test_pydantic_factory.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from uuid import UUID
1313

1414
import pytest
15-
from annotated_types import Ge, Gt, Le, LowerCase, MinLen, UpperCase
15+
from annotated_types import Ge, Gt, Le, LowerCase, MaxLen, MinLen, UpperCase
1616
from typing_extensions import Annotated, TypeAlias
1717

1818
import pydantic
@@ -717,18 +717,49 @@ class C(BaseModel):
717717
assert CFactory.build()
718718

719719

720-
def test_constrained_union_types() -> None:
720+
@pytest.mark.skipif(IS_PYDANTIC_V2, reason="pydantic 1 only test")
721+
def test_constrained_union_types_pydantic_v1() -> None:
721722
class A(BaseModel):
722723
a: Union[Annotated[List[str], MinLen(100)], Annotated[int, Ge(1000)]]
723724
b: Union[List[Annotated[str, MinLen(100)]], int]
724725
c: Union[Annotated[List[int], MinLen(100)], None]
725-
d: Union[Annotated[List[int], MinLen(100)], Annotated[List[str], MinLen(100)]]
726-
e: Optional[Union[Annotated[List[int], MinLen(10)], Annotated[List[str], MinLen(10)]]]
726+
d: Union[Annotated[List[int], MinLen(100)], Annotated[List[str], MaxLen(99)]]
727+
e: Optional[Union[Annotated[List[int], MinLen(10)], Annotated[List[str], MaxLen(9)]]]
727728
f: Optional[Union[Annotated[List[int], MinLen(10)], List[str]]]
729+
g: Optional[
730+
Union[
731+
Annotated[List[int], MinLen(10)],
732+
Union[Annotated[List[str], MaxLen(9)], Annotated[Decimal, Field(max_digits=4, decimal_places=2)]],
733+
]
734+
]
735+
736+
AFactory = ModelFactory.create_factory(A, __allow_none_optionals__=False)
737+
738+
assert AFactory.build()
739+
740+
741+
@pytest.mark.skipif(IS_PYDANTIC_V1, reason="pydantic 2 only test")
742+
def test_constrained_union_types_pydantic_v2() -> None:
743+
class A(BaseModel):
744+
a: Union[Annotated[List[str], MinLen(100)], Annotated[int, Ge(1000)]]
745+
b: Union[List[Annotated[str, MinLen(100)]], int]
746+
c: Union[Annotated[List[int], MinLen(100)], None]
747+
d: Union[Annotated[List[int], MinLen(100)], Annotated[List[str], MaxLen(99)]]
748+
e: Optional[Union[Annotated[List[int], MinLen(10)], Annotated[List[str], MaxLen(9)]]]
749+
f: Optional[Union[Annotated[List[int], MinLen(10)], List[str]]]
750+
g: Optional[
751+
Union[
752+
Annotated[List[int], MinLen(10)],
753+
Union[Annotated[List[str], MaxLen(9)], Annotated[Decimal, Field(max_digits=4, decimal_places=2)]],
754+
]
755+
]
756+
# This annotation is not allowed in pydantic 1
757+
h: Annotated[Union[List[int], List[str]], MinLen(10)]
728758

729759
AFactory = ModelFactory.create_factory(A, __allow_none_optionals__=False)
730760

731761
assert AFactory.build()
762+
assert list(AFactory.coverage())
732763

733764

734765
@pytest.mark.parametrize("allow_none", (True, False))

0 commit comments

Comments
 (0)