diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38c85915aa..1b3cc3030f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -503,6 +503,28 @@ def __new__( **kwargs: Any, ) -> Any: relationships: Dict[str, RelationshipInfo] = {} + backup_base_annotations: Dict[Type[Any], Dict[str, Any]] = {} + for base in bases: + base_relationships = getattr(base, "__sqlmodel_relationships__", None) + if base_relationships: + relationships.update(base_relationships) + # + # Temporarily pluck out `__annotations__` corresponding to relationships from base classes, otherwise these annotations + # make their way into `cls.model_fields` as `FieldInfo(..., required=True)`, even when the relationships are declared + # optional. When a model instance is then constructed using `model_validate` and an optional relationship field is not + # passed, this leads to an incorrect `pydantic.ValidationError`. + # + # We can't just clean up `new_cls.model_fields` after `new_cls` is constructed because by this time + # Pydantic has created model schema and validation rules, so this won't fix the problem. + # + base_annotations = getattr(base, "__annotations__", None) + if base_annotations: + backup_base_annotations[base] = base_annotations + base.__annotations__ = { + name: typ + for name, typ in base_annotations.items() + if name not in base_relationships + } dict_for_pydantic = {} original_annotations = get_annotations(class_dict) pydantic_annotations = {} @@ -537,6 +559,9 @@ def __new__( key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + # Restore base annotations + for base, annotations in backup_base_annotations.items(): + base.__annotations__ = annotations new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, @@ -559,8 +584,9 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): - col = get_column_from_field(v) - setattr(new_cls, k, col) + if k not in relationships: + col = get_column_from_field(v) + setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. # This could be done by reading new_cls.model_config['table'] in FastAPI, but diff --git a/tests/test_inherit_relationship.py b/tests/test_inherit_relationship.py new file mode 100644 index 0000000000..132c43432e --- /dev/null +++ b/tests/test_inherit_relationship.py @@ -0,0 +1,148 @@ +import datetime +from typing import Optional + +import pydantic +from sqlalchemy import DateTime, func +from sqlalchemy.orm import declared_attr, relationship +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from sqlmodel._compat import IS_PYDANTIC_V2 + + +def test_inherit_relationship(clear_sqlmodel) -> None: + def now(): + return datetime.datetime.now(tz=datetime.timezone.utc) + + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + class CreatedUpdatedMixin(SQLModel): + # Fields in reusable base models must be defined using `sa_type` and `sa_column_kwargs` instead of `sa_column` + # https://github.com/tiangolo/sqlmodel/discussions/743 + # + # created_at: datetime.datetime = Field(default_factory=now, sa_column=DateTime(default=now)) + created_at: datetime.datetime = Field( + default_factory=now, sa_type=DateTime, sa_column_kwargs={"default": now} + ) + + # With Pydantic V2, it is also possible to define `created_by` like this: + # + # ```python + # @declared_attr + # def _created_by(cls): + # return relationship(User, foreign_keys=cls.created_by_id) + # + # created_by: Optional[User] = Relationship(sa_relationship=_created_by)) + # ``` + # + # The difference from Pydantic V1 is that Pydantic V2 plucks attributes with names starting with '_' (but not '__') + # from class attributes and stores them separately as instances of `pydantic.ModelPrivateAttr` somewhere in depths of + # Pydantic internals. Under Pydantic V1 this doesn't happen, so SQLAlchemy ends up having two class attributes + # (`_created_by` and `created_by`) corresponding to one database attribute, causing a conflict and unreliable behavior. + # The approach with a lambda always works because it doesn't produce the second class attribute and thus eliminates + # the possibility of a conflict entirely. + # + created_by_id: Optional[int] = Field(default=None, foreign_key="user.id") + created_by: Optional[User] = Relationship( + sa_relationship=declared_attr( + lambda cls: relationship(User, foreign_keys=cls.created_by_id) + ) + ) + + updated_at: datetime.datetime = Field( + default_factory=now, sa_type=DateTime, sa_column_kwargs={"default": now} + ) + updated_by_id: Optional[int] = Field(default=None, foreign_key="user.id") + updated_by: Optional[User] = Relationship( + sa_relationship=declared_attr( + lambda cls: relationship(User, foreign_keys=cls.updated_by_id) + ) + ) + + class Asset(CreatedUpdatedMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + # Demonstrate that the mixin can be applied to more than 1 model + class Document(CreatedUpdatedMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + john = User(name="John") + jane = User(name="Jane") + asset = Asset(name="Test", created_by=john, updated_by=jane) + doc = Document(name="Resume", created_by=jane, updated_by=john) + + with Session(engine) as session: + session.add(asset) + session.add(doc) + session.commit() + + with Session(engine) as session: + assert session.scalar(select(func.count()).select_from(User)) == 2 + + asset = session.exec(select(Asset)).one() + assert asset.created_by.name == "John" + assert asset.updated_by.name == "Jane" + + doc = session.exec(select(Document)).one() + assert doc.created_by.name == "Jane" + assert doc.updated_by.name == "John" + + +def test_inherit_relationship_model_validate(clear_sqlmodel) -> None: + class User(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class Mixin(SQLModel): + owner_id: Optional[int] = Field(default=None, foreign_key="user.id") + owner: Optional[User] = Relationship( + sa_relationship=declared_attr( + lambda cls: relationship(User, foreign_keys=cls.owner_id) + ) + ) + + class Asset(Mixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + class AssetCreate(pydantic.BaseModel): + pass + + asset_create = AssetCreate() + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + user = User() + + # Owner must be optional + asset = Asset.model_validate(asset_create) + with Session(engine) as session: + session.add(asset) + session.commit() + session.refresh(asset) + assert asset.id is not None + assert asset.owner_id is None + assert asset.owner is None + + # When set, owner must be saved + # + # Under Pydantic V2, relationship fields set it `model_validate` are not saved, + # with or without inheritance. Consider it a known issue. + # + if IS_PYDANTIC_V2: + asset = Asset.model_validate(asset_create, update={"owner": user}) + with Session(engine) as session: + session.add(asset) + session.commit() + session.refresh(asset) + session.refresh(user) + assert asset.id is not None + assert user.id is not None + assert asset.owner_id == user.id + assert asset.owner.id == user.id