Skip to content

Commit bef1602

Browse files
committed
support passing fields as annotated types
1 parent 9d0b8b6 commit bef1602

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed

sqlmodel/main.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from sqlalchemy.orm.instrumentation import is_instrumented
5252
from sqlalchemy.sql.schema import MetaData
5353
from sqlalchemy.sql.sqltypes import LargeBinary, Time
54-
from typing_extensions import Literal, deprecated, get_origin
54+
from typing_extensions import Literal, _AnnotatedAlias, deprecated, get_origin
5555

5656
from ._compat import ( # type: ignore[attr-defined]
5757
IS_PYDANTIC_V2,
@@ -561,48 +561,59 @@ def get_sqlalchemy_type(field: Any) -> Any:
561561
return sa_type
562562

563563
type_ = get_type_from_field(field)
564-
metadata = get_field_metadata(field)
564+
if isinstance(type_, _AnnotatedAlias):
565+
class_to_compare = type_.__origin__
566+
if len(type_.__metadata__) == 1:
567+
metadata = get_field_metadata(type_.__metadata__[0])
568+
else:
569+
# not sure if this is the right behavior
570+
raise ValueError(
571+
f"AnnotatedAlias with multiple metadata is not supported: {type_}"
572+
)
573+
else:
574+
class_to_compare = type_
575+
metadata = get_field_metadata(field)
565576

566577
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
567-
if issubclass(type_, Enum):
578+
if issubclass(class_to_compare, Enum):
568579
return sa_Enum(type_)
569-
if issubclass(type_, str):
580+
if issubclass(class_to_compare, str):
570581
max_length = getattr(metadata, "max_length", None)
571582
if max_length:
572583
return AutoString(length=max_length)
573584
return AutoString
574-
if issubclass(type_, float):
585+
if issubclass(class_to_compare, float):
575586
return Float
576-
if issubclass(type_, bool):
587+
if issubclass(class_to_compare, bool):
577588
return Boolean
578-
if issubclass(type_, int):
589+
if issubclass(class_to_compare, int):
579590
return Integer
580-
if issubclass(type_, datetime):
591+
if issubclass(class_to_compare, datetime):
581592
return DateTime
582-
if issubclass(type_, date):
593+
if issubclass(class_to_compare, date):
583594
return Date
584-
if issubclass(type_, timedelta):
595+
if issubclass(class_to_compare, timedelta):
585596
return Interval
586-
if issubclass(type_, time):
597+
if issubclass(class_to_compare, time):
587598
return Time
588-
if issubclass(type_, bytes):
599+
if issubclass(class_to_compare, bytes):
589600
return LargeBinary
590-
if issubclass(type_, Decimal):
601+
if issubclass(class_to_compare, Decimal):
591602
return Numeric(
592603
precision=getattr(metadata, "max_digits", None),
593604
scale=getattr(metadata, "decimal_places", None),
594605
)
595-
if issubclass(type_, ipaddress.IPv4Address):
606+
if issubclass(class_to_compare, ipaddress.IPv4Address):
596607
return AutoString
597-
if issubclass(type_, ipaddress.IPv4Network):
608+
if issubclass(class_to_compare, ipaddress.IPv4Network):
598609
return AutoString
599-
if issubclass(type_, ipaddress.IPv6Address):
610+
if issubclass(class_to_compare, ipaddress.IPv6Address):
600611
return AutoString
601-
if issubclass(type_, ipaddress.IPv6Network):
612+
if issubclass(class_to_compare, ipaddress.IPv6Network):
602613
return AutoString
603-
if issubclass(type_, Path):
614+
if issubclass(class_to_compare, Path):
604615
return AutoString
605-
if issubclass(type_, uuid.UUID):
616+
if issubclass(class_to_compare, uuid.UUID):
606617
return GUID
607618
raise ValueError(f"{type_} has no matching SQLAlchemy type")
608619

0 commit comments

Comments
 (0)