Skip to content

Commit d41e010

Browse files
Evgeny Arshinovearshinov
authored andcommitted
✏️Fix model_validate in presence of inherited Relationship fields, add unit test
1 parent 0773e1f commit d41e010

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

sqlmodel/main.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,28 @@ def __new__(
415415
**kwargs: Any,
416416
) -> Any:
417417
relationships: Dict[str, RelationshipInfo] = {}
418+
backup_base_annotations: Dict[Type[Any], Dict[str, Any]] = {}
418419
for base in bases:
419-
relationships.update(getattr(base, "__sqlmodel_relationships__", {}))
420+
base_relationships = getattr(base, "__sqlmodel_relationships__", None)
421+
if base_relationships:
422+
relationships.update(base_relationships)
423+
#
424+
# Temporarily pluck out `__annotations__` corresponding to relationships from base classes, otherwise these annotations
425+
# make their way into `cls.model_fields` as `FieldInfo(..., required=True)`, even when the relationships are declared
426+
# optional. When a model instance is then constructed using `model_validate` and an optional relationship field is not
427+
# passed, this leads to an incorrect `pydantic.ValidationError`.
428+
#
429+
# We can't just clean up `new_cls.model_fields` after `new_cls` is constructed because by this time
430+
# Pydantic has created model schema and validation rules, so this won't fix the problem.
431+
#
432+
base_annotations = getattr(base, "__annotations__", None)
433+
if base_annotations:
434+
backup_base_annotations[base] = base_annotations
435+
base.__annotations__ = {
436+
name: typ
437+
for name, typ in base_annotations.items()
438+
if name not in base_relationships
439+
}
420440
dict_for_pydantic = {}
421441
original_annotations = get_annotations(class_dict)
422442
pydantic_annotations = {}
@@ -451,6 +471,9 @@ def __new__(
451471
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
452472
}
453473
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
474+
# Restore base annotations
475+
for base, annotations in backup_base_annotations.items():
476+
base.__annotations__ = annotations
454477
new_cls.__annotations__ = {
455478
**relationship_annotations,
456479
**pydantic_annotations,
Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import datetime
22
from typing import Optional
33

4+
import pydantic
45
from sqlalchemy import DateTime, func
56
from sqlalchemy.orm import declared_attr, relationship
67
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
78

89

9-
def test_relationship_inheritance() -> None:
10+
def test_inherit_relationship(clear_sqlmodel) -> None:
1011
def now():
1112
return datetime.datetime.now(tz=datetime.timezone.utc)
1213

@@ -90,3 +91,50 @@ class Document(CreatedUpdatedMixin, table=True):
9091
doc = session.exec(select(Document)).one()
9192
assert doc.created_by.name == "Jane"
9293
assert doc.updated_by.name == "John"
94+
95+
96+
def test_inherit_relationship_model_validate(clear_sqlmodel) -> None:
97+
class User(SQLModel, table=True):
98+
id: Optional[int] = Field(default=None, primary_key=True)
99+
100+
class Mixin(SQLModel):
101+
owner_id: Optional[int] = Field(default=None, foreign_key="user.id")
102+
owner: Optional[User] = Relationship(
103+
sa_relationship=declared_attr(
104+
lambda cls: relationship(User, foreign_keys=cls.owner_id)
105+
)
106+
)
107+
108+
class Asset(Mixin, table=True):
109+
id: Optional[int] = Field(default=None, primary_key=True)
110+
111+
class AssetCreate(pydantic.BaseModel):
112+
pass
113+
114+
asset_create = AssetCreate()
115+
116+
engine = create_engine("sqlite://")
117+
118+
SQLModel.metadata.create_all(engine)
119+
120+
user = User()
121+
122+
asset = Asset.model_validate(asset_create)
123+
with Session(engine) as session:
124+
session.add(asset)
125+
session.commit()
126+
session.refresh(asset)
127+
assert asset.id is not None
128+
assert asset.owner_id is None
129+
assert asset.owner is None
130+
131+
asset = Asset.model_validate(asset_create, update={"owner": user})
132+
with Session(engine) as session:
133+
session.add(asset)
134+
session.commit()
135+
session.refresh(asset)
136+
session.refresh(user)
137+
assert asset.id is not None
138+
assert user.id is not None
139+
assert asset.owner_id == user.id
140+
assert asset.owner.id == user.id

0 commit comments

Comments
 (0)