|
51 | 51 | from sqlalchemy.orm.instrumentation import is_instrumented |
52 | 52 | from sqlalchemy.sql.schema import MetaData |
53 | 53 | from sqlalchemy.sql.sqltypes import LargeBinary, Time |
54 | | -from typing_extensions import Literal, deprecated, get_origin |
| 54 | +from typing_extensions import Literal, _AnnotatedAlias, deprecated, get_origin |
55 | 55 |
|
56 | 56 | from ._compat import ( # type: ignore[attr-defined] |
57 | 57 | IS_PYDANTIC_V2, |
@@ -561,48 +561,59 @@ def get_sqlalchemy_type(field: Any) -> Any: |
561 | 561 | return sa_type |
562 | 562 |
|
563 | 563 | type_ = get_type_from_field(field) |
564 | | - metadata = get_field_metadata(field) |
| 564 | + if isinstance(type_, _AnnotatedAlias): |
| 565 | + class_to_compare = type_.__origin__ |
| 566 | + if len(type_.__metadata__) == 1: |
| 567 | + metadata = get_field_metadata(type_.__metadata__[0]) |
| 568 | + else: |
| 569 | + # not sure if this is the right behavior |
| 570 | + raise ValueError( |
| 571 | + f"AnnotatedAlias with multiple metadata is not supported: {type_}" |
| 572 | + ) |
| 573 | + else: |
| 574 | + class_to_compare = type_ |
| 575 | + metadata = get_field_metadata(field) |
565 | 576 |
|
566 | 577 | # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI |
567 | | - if issubclass(type_, Enum): |
| 578 | + if issubclass(class_to_compare, Enum): |
568 | 579 | return sa_Enum(type_) |
569 | | - if issubclass(type_, str): |
| 580 | + if issubclass(class_to_compare, str): |
570 | 581 | max_length = getattr(metadata, "max_length", None) |
571 | 582 | if max_length: |
572 | 583 | return AutoString(length=max_length) |
573 | 584 | return AutoString |
574 | | - if issubclass(type_, float): |
| 585 | + if issubclass(class_to_compare, float): |
575 | 586 | return Float |
576 | | - if issubclass(type_, bool): |
| 587 | + if issubclass(class_to_compare, bool): |
577 | 588 | return Boolean |
578 | | - if issubclass(type_, int): |
| 589 | + if issubclass(class_to_compare, int): |
579 | 590 | return Integer |
580 | | - if issubclass(type_, datetime): |
| 591 | + if issubclass(class_to_compare, datetime): |
581 | 592 | return DateTime |
582 | | - if issubclass(type_, date): |
| 593 | + if issubclass(class_to_compare, date): |
583 | 594 | return Date |
584 | | - if issubclass(type_, timedelta): |
| 595 | + if issubclass(class_to_compare, timedelta): |
585 | 596 | return Interval |
586 | | - if issubclass(type_, time): |
| 597 | + if issubclass(class_to_compare, time): |
587 | 598 | return Time |
588 | | - if issubclass(type_, bytes): |
| 599 | + if issubclass(class_to_compare, bytes): |
589 | 600 | return LargeBinary |
590 | | - if issubclass(type_, Decimal): |
| 601 | + if issubclass(class_to_compare, Decimal): |
591 | 602 | return Numeric( |
592 | 603 | precision=getattr(metadata, "max_digits", None), |
593 | 604 | scale=getattr(metadata, "decimal_places", None), |
594 | 605 | ) |
595 | | - if issubclass(type_, ipaddress.IPv4Address): |
| 606 | + if issubclass(class_to_compare, ipaddress.IPv4Address): |
596 | 607 | return AutoString |
597 | | - if issubclass(type_, ipaddress.IPv4Network): |
| 608 | + if issubclass(class_to_compare, ipaddress.IPv4Network): |
598 | 609 | return AutoString |
599 | | - if issubclass(type_, ipaddress.IPv6Address): |
| 610 | + if issubclass(class_to_compare, ipaddress.IPv6Address): |
600 | 611 | return AutoString |
601 | | - if issubclass(type_, ipaddress.IPv6Network): |
| 612 | + if issubclass(class_to_compare, ipaddress.IPv6Network): |
602 | 613 | return AutoString |
603 | | - if issubclass(type_, Path): |
| 614 | + if issubclass(class_to_compare, Path): |
604 | 615 | return AutoString |
605 | | - if issubclass(type_, uuid.UUID): |
| 616 | + if issubclass(class_to_compare, uuid.UUID): |
606 | 617 | return GUID |
607 | 618 | raise ValueError(f"{type_} has no matching SQLAlchemy type") |
608 | 619 |
|
|
0 commit comments