Skip to content

Commit 252c047

Browse files
committed
✨Add foreign_key_args and foreign_key_kwargs arguments to Field(...) to let the user define additional sqlalchemy.orm.ForeignKey attributes, such as ondelete and onupdate, for foreign keys defined in a base model.
1 parent c75743d commit 252c047

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

sqlmodel/main.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
106106
sa_column = kwargs.pop("sa_column", Undefined)
107107
sa_column_args = kwargs.pop("sa_column_args", Undefined)
108108
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
109+
sa_foreign_key_args = kwargs.pop("sa_foreign_key_args", Undefined)
110+
sa_foreign_key_kwargs = kwargs.pop("sa_foreign_key_kwargs", Undefined)
109111
if sa_column is not Undefined:
110112
if sa_column_args is not Undefined:
111113
raise RuntimeError(
@@ -153,6 +155,8 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
153155
self.sa_column = sa_column
154156
self.sa_column_args = sa_column_args
155157
self.sa_column_kwargs = sa_column_kwargs
158+
self.sa_foreign_key_args = sa_foreign_key_args
159+
self.sa_foreign_key_kwargs = sa_foreign_key_kwargs
156160

157161

158162
class RelationshipInfo(Representation):
@@ -222,6 +226,8 @@ def Field(
222226
sa_type: Union[Type[Any], UndefinedType] = Undefined,
223227
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
224228
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
229+
sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined,
230+
sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
225231
schema_extra: Optional[Dict[str, Any]] = None,
226232
) -> Any:
227233
...
@@ -303,6 +309,8 @@ def Field(
303309
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
304310
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
305311
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
312+
sa_foreign_key_args: Union[Sequence[Any], UndefinedType] = Undefined,
313+
sa_foreign_key_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
306314
schema_extra: Optional[Dict[str, Any]] = None,
307315
) -> Any:
308316
current_schema_extra = schema_extra or {}
@@ -340,6 +348,8 @@ def Field(
340348
sa_column=sa_column,
341349
sa_column_args=sa_column_args,
342350
sa_column_kwargs=sa_column_kwargs,
351+
sa_foreign_key_args=sa_foreign_key_args,
352+
sa_foreign_key_kwargs=sa_foreign_key_kwargs,
343353
**current_schema_extra,
344354
)
345355
post_init_field_info(field_info)
@@ -638,7 +648,19 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
638648
unique = False
639649
if foreign_key:
640650
assert isinstance(foreign_key, str)
641-
args.append(ForeignKey(foreign_key))
651+
sa_foreign_key_args = getattr(field_info, "sa_foreign_key_args", Undefined)
652+
fk_args = (
653+
[]
654+
if sa_foreign_key_args is Undefined
655+
else list(cast(Sequence[Any], sa_foreign_key_args))
656+
)
657+
sa_foreign_key_kwargs = getattr(field_info, "sa_foreign_key_kwargs", Undefined)
658+
fk_kwargs = (
659+
{}
660+
if sa_foreign_key_kwargs is Undefined
661+
else cast(Dict[Any, Any], sa_foreign_key_kwargs)
662+
)
663+
args.append(ForeignKey(foreign_key, *fk_args, **fk_kwargs))
642664
kwargs = {
643665
"primary_key": primary_key,
644666
"nullable": nullable,

tests/test_foreign_key_args.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Optional
2+
3+
import pytest
4+
import sqlalchemy.event
5+
import sqlalchemy.exc
6+
from sqlalchemy import ForeignKey, create_engine, func
7+
from sqlmodel import Field, SQLModel, select
8+
from sqlmodel.orm.session import Session
9+
10+
11+
def test_fk_constructed_in_base_model_fails(clear_sqlmodel) -> None:
12+
class User(SQLModel, table=True):
13+
id: Optional[int] = Field(default=None, primary_key=True)
14+
15+
class Base(SQLModel):
16+
owner_id: Optional[int] = Field(
17+
default=None, sa_column_args=(ForeignKey("user.id", ondelete="SET NULL"),)
18+
)
19+
20+
class Asset(Base, table=True):
21+
id: Optional[int] = Field(default=None, primary_key=True)
22+
23+
with pytest.raises(sqlalchemy.exc.InvalidRequestError) as e:
24+
25+
class Document(Base, table=True):
26+
id: Optional[int] = Field(default=None, primary_key=True)
27+
28+
assert "This ForeignKey already has a parent" in str(e.errisinstance)
29+
30+
31+
def test_fk_args_in_base_model_work(clear_sqlmodel) -> None:
32+
class User(SQLModel, table=True):
33+
id: Optional[int] = Field(default=None, primary_key=True)
34+
35+
class Base(SQLModel):
36+
owner_id: Optional[int] = Field(
37+
default=None,
38+
foreign_key="user.id",
39+
sa_foreign_key_kwargs={"ondelete": "SET NULL"},
40+
)
41+
42+
class Asset(Base, table=True):
43+
id: Optional[int] = Field(default=None, primary_key=True)
44+
45+
class Document(Base, table=True):
46+
id: Optional[int] = Field(default=None, primary_key=True)
47+
48+
engine = create_engine("sqlite://")
49+
sqlalchemy.event.listen(
50+
engine, "connect", lambda conn, *args: conn.execute("pragma foreign_keys=ON")
51+
)
52+
53+
SQLModel.metadata.create_all(engine)
54+
55+
# Test that the ON DELETE SET NULL we assigned actually works
56+
with Session(engine) as session:
57+
user = User()
58+
session.add(user)
59+
session.commit()
60+
session.refresh(user)
61+
62+
asset = Asset(owner_id=user.id)
63+
session.add(asset)
64+
session.commit()
65+
session.refresh(asset)
66+
assert asset.owner_id == user.id
67+
68+
session.delete(user)
69+
session.commit()
70+
assert session.scalar(select(func.count()).select_from(User)) == 0
71+
72+
# Normally, one would also define a relationship (in the Asset class, `owner: Optional[User] = Relationship("User")`)
73+
# so that SQLAlchemy knows that Asset and User are related, marks the Asset as dirty and refreshes it when requested.
74+
# But Relationships are a separate complicated topic, which we don't want to touch here.
75+
asset = session.exec(select(Asset)).one()
76+
assert asset.owner_id is None

0 commit comments

Comments
 (0)