From a88ee75d811a40c533ea38eaa0573b96c180edc7 Mon Sep 17 00:00:00 2001 From: Evgeny Arshinov Date: Fri, 12 Apr 2024 14:20:06 +0200 Subject: [PATCH 1/3] =?UTF-8?q?=E2=9C=A8Properly=20support=20inheritance?= =?UTF-8?q?=20of=20Relationship=20attributes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 7 ++- tests/test_relationship_inheritance.py | 62 ++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/test_relationship_inheritance.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38c85915aa..881d0b05a0 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -503,6 +503,8 @@ def __new__( **kwargs: Any, ) -> Any: relationships: Dict[str, RelationshipInfo] = {} + for base in bases: + relationships.update(getattr(base, "__sqlmodel_relationships__", {})) dict_for_pydantic = {} original_annotations = get_annotations(class_dict) pydantic_annotations = {} @@ -559,8 +561,9 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): - col = get_column_from_field(v) - setattr(new_cls, k, col) + if k not in relationships: + col = get_column_from_field(v) + setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. # This could be done by reading new_cls.model_config['table'] in FastAPI, but diff --git a/tests/test_relationship_inheritance.py b/tests/test_relationship_inheritance.py new file mode 100644 index 0000000000..804c4bd741 --- /dev/null +++ b/tests/test_relationship_inheritance.py @@ -0,0 +1,62 @@ +from typing import Optional + +from sqlalchemy.orm import declared_attr, relationship +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select + + +def test_relationship_inheritance() -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class CreatedUpdatedMixin(SQLModel): + # With Pydantic V2, it is also possible to define `created_by` like this: + # + # ```python + # @declared_attr + # def _created_by(cls): + # return relationship(User, foreign_keys=cls.created_by_id) + # + # created_by: Optional[User] = Relationship(sa_relationship=_created_by)) + # ``` + # + # The difference from Pydantic V1 is that Pydantic V2 plucks attributes with names starting with '_' (but not '__') + # from class attributes and stores them separately as instances of `pydantic.ModelPrivateAttr` somewhere in depths of + # Pydantic internals. Under Pydantic V1 this doesn't happen, so SQLAlchemy ends up having two class attributes + # (`_created_by` and `created_by`) corresponding to one database attribute, causing a conflict and unreliable behavior. + # The approach with a lambda always works because it doesn't produce the second class attribute and thus eliminates + # the possibility of a conflict entirely. + # + created_by_id: Optional[int] = Field(default=None, foreign_key="user.id") + created_by: Optional[User] = Relationship( + sa_relationship=declared_attr( + lambda cls: relationship(User, foreign_keys=cls.created_by_id) + ) + ) + + updated_by_id: Optional[int] = Field(default=None, foreign_key="user.id") + updated_by: Optional[User] = Relationship( + sa_relationship=declared_attr( + lambda cls: relationship(User, foreign_keys=cls.updated_by_id) + ) + ) + + class Asset(CreatedUpdatedMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + john = User(name="John") + jane = User(name="Jane") + asset = Asset(created_by=john, updated_by=jane) + + with Session(engine) as session: + session.add(asset) + session.commit() + + with Session(engine) as session: + asset = session.exec(select(Asset)).one() + assert asset.created_by.name == "John" + assert asset.updated_by.name == "Jane" From 594aac393c94efada8dbd11549e7e80069ab4ab9 Mon Sep 17 00:00:00 2001 From: Evgeny Arshinov Date: Fri, 12 Apr 2024 14:20:06 +0200 Subject: [PATCH 2/3] For sake of demonstration, add created_at and updated_at fields to CreatedUpdatedMixin used in test --- tests/test_relationship_inheritance.py | 32 +++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/test_relationship_inheritance.py b/tests/test_relationship_inheritance.py index 804c4bd741..236c36cee7 100644 --- a/tests/test_relationship_inheritance.py +++ b/tests/test_relationship_inheritance.py @@ -1,15 +1,28 @@ +import datetime from typing import Optional +from sqlalchemy import DateTime, func from sqlalchemy.orm import declared_attr, relationship from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select def test_relationship_inheritance() -> None: + def now(): + return datetime.datetime.now(tz=datetime.timezone.utc) + class User(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str class CreatedUpdatedMixin(SQLModel): + # Fields in reusable base models must be defined using `sa_type` and `sa_column_kwargs` instead of `sa_column` + # https://github.com/tiangolo/sqlmodel/discussions/743 + # + # created_at: datetime.datetime = Field(default_factory=now, sa_column=DateTime(default=now)) + created_at: datetime.datetime = Field( + default_factory=now, sa_type=DateTime, sa_column_kwargs={"default": now} + ) + # With Pydantic V2, it is also possible to define `created_by` like this: # # ```python @@ -34,6 +47,9 @@ class CreatedUpdatedMixin(SQLModel): ) ) + updated_at: datetime.datetime = Field( + default_factory=now, sa_type=DateTime, sa_column_kwargs={"default": now} + ) updated_by_id: Optional[int] = Field(default=None, foreign_key="user.id") updated_by: Optional[User] = Relationship( sa_relationship=declared_attr( @@ -43,6 +59,12 @@ class CreatedUpdatedMixin(SQLModel): class Asset(CreatedUpdatedMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) + name: str + + # Demonstrate that the mixin can be applied to more than 1 model + class Document(CreatedUpdatedMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str engine = create_engine("sqlite://") @@ -50,13 +72,21 @@ class Asset(CreatedUpdatedMixin, table=True): john = User(name="John") jane = User(name="Jane") - asset = Asset(created_by=john, updated_by=jane) + asset = Asset(name="Test", created_by=john, updated_by=jane) + doc = Document(name="Resume", created_by=jane, updated_by=john) with Session(engine) as session: session.add(asset) + session.add(doc) session.commit() with Session(engine) as session: + assert session.scalar(select(func.count()).select_from(User)) == 2 + asset = session.exec(select(Asset)).one() assert asset.created_by.name == "John" assert asset.updated_by.name == "Jane" + + doc = session.exec(select(Document)).one() + assert doc.created_by.name == "Jane" + assert doc.updated_by.name == "John" From 46c8d6250426d890522654e74fc4d0384ba188cb Mon Sep 17 00:00:00 2001 From: Evgeny Arshinov Date: Fri, 12 Apr 2024 14:21:00 +0200 Subject: [PATCH 3/3] =?UTF-8?q?=E2=9C=8F=EF=B8=8FFix=20`model=5Fvalidate`?= =?UTF-8?q?=20in=20presence=20of=20inherited=20Relationship=20fields,=20ad?= =?UTF-8?q?d=20unit=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 25 +++++++- ...itance.py => test_inherit_relationship.py} | 58 ++++++++++++++++++- 2 files changed, 81 insertions(+), 2 deletions(-) rename tests/{test_relationship_inheritance.py => test_inherit_relationship.py} (67%) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 881d0b05a0..1b3cc3030f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -503,8 +503,28 @@ def __new__( **kwargs: Any, ) -> Any: relationships: Dict[str, RelationshipInfo] = {} + backup_base_annotations: Dict[Type[Any], Dict[str, Any]] = {} for base in bases: - relationships.update(getattr(base, "__sqlmodel_relationships__", {})) + base_relationships = getattr(base, "__sqlmodel_relationships__", None) + if base_relationships: + relationships.update(base_relationships) + # + # Temporarily pluck out `__annotations__` corresponding to relationships from base classes, otherwise these annotations + # make their way into `cls.model_fields` as `FieldInfo(..., required=True)`, even when the relationships are declared + # optional. When a model instance is then constructed using `model_validate` and an optional relationship field is not + # passed, this leads to an incorrect `pydantic.ValidationError`. + # + # We can't just clean up `new_cls.model_fields` after `new_cls` is constructed because by this time + # Pydantic has created model schema and validation rules, so this won't fix the problem. + # + base_annotations = getattr(base, "__annotations__", None) + if base_annotations: + backup_base_annotations[base] = base_annotations + base.__annotations__ = { + name: typ + for name, typ in base_annotations.items() + if name not in base_relationships + } dict_for_pydantic = {} original_annotations = get_annotations(class_dict) pydantic_annotations = {} @@ -539,6 +559,9 @@ def __new__( key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + # Restore base annotations + for base, annotations in backup_base_annotations.items(): + base.__annotations__ = annotations new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, diff --git a/tests/test_relationship_inheritance.py b/tests/test_inherit_relationship.py similarity index 67% rename from tests/test_relationship_inheritance.py rename to tests/test_inherit_relationship.py index 236c36cee7..132c43432e 100644 --- a/tests/test_relationship_inheritance.py +++ b/tests/test_inherit_relationship.py @@ -1,12 +1,14 @@ import datetime from typing import Optional +import pydantic from sqlalchemy import DateTime, func from sqlalchemy.orm import declared_attr, relationship from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from sqlmodel._compat import IS_PYDANTIC_V2 -def test_relationship_inheritance() -> None: +def test_inherit_relationship(clear_sqlmodel) -> None: def now(): return datetime.datetime.now(tz=datetime.timezone.utc) @@ -90,3 +92,57 @@ class Document(CreatedUpdatedMixin, table=True): doc = session.exec(select(Document)).one() assert doc.created_by.name == "Jane" assert doc.updated_by.name == "John" + + +def test_inherit_relationship_model_validate(clear_sqlmodel) -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Mixin(SQLModel): + owner_id: Optional[int] = Field(default=None, foreign_key="user.id") + owner: Optional[User] = Relationship( + sa_relationship=declared_attr( + lambda cls: relationship(User, foreign_keys=cls.owner_id) + ) + ) + + class Asset(Mixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class AssetCreate(pydantic.BaseModel): + pass + + asset_create = AssetCreate() + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + user = User() + + # Owner must be optional + asset = Asset.model_validate(asset_create) + with Session(engine) as session: + session.add(asset) + session.commit() + session.refresh(asset) + assert asset.id is not None + assert asset.owner_id is None + assert asset.owner is None + + # When set, owner must be saved + # + # Under Pydantic V2, relationship fields set it `model_validate` are not saved, + # with or without inheritance. Consider it a known issue. + # + if IS_PYDANTIC_V2: + asset = Asset.model_validate(asset_create, update={"owner": user}) + with Session(engine) as session: + session.add(asset) + session.commit() + session.refresh(asset) + session.refresh(user) + assert asset.id is not None + assert user.id is not None + assert asset.owner_id == user.id + assert asset.owner.id == user.id