Skip to content

Commit 11cc55e

Browse files
committed
Handle special cases with Literal (all int and all bool)
1 parent 303f10f commit 11cc55e

File tree

3 files changed

+49
-14
lines changed

3 files changed

+49
-14
lines changed

sqlmodel/_compat.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,13 @@ def get_sa_type_from_type_annotation(annotation: Any) -> Any:
208208
# Optional unions are allowed
209209
use_type = bases[0] if bases[0] is not NoneType else bases[1]
210210
return get_sa_type_from_type_annotation(use_type)
211+
if origin is Literal:
212+
literal_args = get_args(annotation)
213+
if all(isinstance(arg, bool) for arg in literal_args): # all bools
214+
return bool
215+
if all(isinstance(arg, int) for arg in literal_args): # all ints
216+
return int
217+
return str
211218
return origin
212219

213220
def get_sa_type_from_field(field: Any) -> Any:
@@ -460,7 +467,13 @@ def is_field_noneable(field: "FieldInfo") -> bool:
460467

461468
def get_sa_type_from_field(field: Any) -> Any:
462469
if get_origin(field.type_) is Literal:
463-
return Literal
470+
literal_args = get_args(field.type_)
471+
if all(isinstance(arg, bool) for arg in literal_args): # all bools
472+
return bool
473+
if all(isinstance(arg, int) for arg in literal_args): # all ints
474+
return int
475+
return str
476+
464477
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
465478
return field.type_
466479
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")

sqlmodel/main.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,6 @@ def get_sqlalchemy_type(field: Any) -> Any:
655655
type_ = get_sa_type_from_field(field)
656656
metadata = get_field_metadata(field)
657657

658-
# Checks for `Literal` type annotation
659-
if type_ is Literal:
660-
return AutoString
661658
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
662659
if issubclass(type_, Enum):
663660
return sa_Enum(type_)

tests/test_main.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,23 +128,48 @@ class Hero(SQLModel, table=True):
128128
assert hero_rusty_man.team.name == "Preventers"
129129

130130

131-
def test_literal_typehints_are_treated_as_strings(clear_sqlmodel):
131+
def test_literal_str(clear_sqlmodel, caplog):
132132
"""Test https://github.com/fastapi/sqlmodel/issues/57"""
133133

134-
class Hero(SQLModel, table=True):
134+
class Model(SQLModel, table=True):
135135
id: Optional[int] = Field(default=None, primary_key=True)
136-
name: str = Field(unique=True)
137-
weakness: Literal["Kryptonite", "Dehydration", "Munchies"]
138-
139-
superguy = Hero(name="Superguy", weakness="Kryptonite")
136+
all_str: Literal["a", "b", "c"]
137+
mixed: Literal["yes", "no", 1, 0]
138+
all_int: Literal[1, 2, 3]
139+
int_bool: Literal[0, 1, True, False]
140+
all_bool: Literal[True, False]
141+
142+
obj = Model(
143+
all_str="a",
144+
mixed="yes",
145+
all_int=1,
146+
int_bool=True,
147+
all_bool=False,
148+
)
140149

141150
engine = create_engine("sqlite://", echo=True)
142151

143152
SQLModel.metadata.create_all(engine)
144153

154+
# Check DDL
155+
assert "all_str VARCHAR NOT NULL" in caplog.text
156+
assert "mixed VARCHAR NOT NULL" in caplog.text
157+
assert "all_int INTEGER NOT NULL" in caplog.text
158+
assert "int_bool INTEGER NOT NULL" in caplog.text
159+
assert "all_bool BOOLEAN NOT NULL" in caplog.text
160+
161+
# Check query
145162
with Session(engine) as session:
146-
session.add(superguy)
163+
session.add(obj)
147164
session.commit()
148-
session.refresh(superguy)
149-
assert superguy.weakness == "Kryptonite"
150-
assert isinstance(superguy.weakness, str)
165+
session.refresh(obj)
166+
assert isinstance(obj.all_str, str)
167+
assert obj.all_str == "a"
168+
assert isinstance(obj.mixed, str)
169+
assert obj.mixed == "yes"
170+
assert isinstance(obj.all_int, int)
171+
assert obj.all_int == 1
172+
assert isinstance(obj.int_bool, int)
173+
assert obj.int_bool == 1
174+
assert isinstance(obj.all_bool, bool)
175+
assert obj.all_bool is False

0 commit comments

Comments
 (0)