diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7c916f79af..540c2ed05c 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -26,6 +26,7 @@ cast, overload, ) +from typing import Annotated as TypingAnnotated from pydantic import BaseModel, EmailStr from pydantic.fields import FieldInfo as PydanticFieldInfo @@ -54,7 +55,18 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, deprecated, get_origin +from typing_extensions import ( + Annotated as TEAnnotated, +) +from typing_extensions import ( + Literal, + TypeAlias, + deprecated, + get_origin, +) +from typing_extensions import ( + get_args as te_get_args, +) from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -546,6 +558,26 @@ def __new__( **new_cls.__annotations__, } + # For Pydantic v2: If a field used Annotated[..., Field(sa_column=Column(...))] + # Pydantic might not lift our custom attribute onto the final FieldInfo. + # Recover it from the original annotations before creating SQLAlchemy Columns. + if IS_PYDANTIC_V2: + for field_name, ann in original_annotations.items(): + try: + origin = get_origin(ann) + if origin in (TEAnnotated, TypingAnnotated): + for extra in te_get_args(ann)[1:]: + sa_col = getattr(extra, "sa_column", Undefined) + if isinstance(sa_col, Column): + # Attach found Column to the Pydantic field info + model_fields = get_model_fields(new_cls) + if field_name in model_fields: + model_fields[field_name].sa_column = sa_col + break + except Exception: + # Best-effort; fall back to default behavior + pass + def get_config(name: str) -> Any: config_class_value = get_config_value( model=new_cls, parameter=name, default=Undefined @@ -562,6 +594,26 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): + # Prefer a Column passed via Annotated[..., Field(sa_column=...)] + if IS_PYDANTIC_V2: + ann = original_annotations.get(k, None) + if ann is not None: + try: + origin = get_origin(ann) + if origin in (TEAnnotated, TypingAnnotated): + for extra in te_get_args(ann)[1:]: + sa_col = getattr(extra, "sa_column", Undefined) + if isinstance(sa_col, Column): + setattr(new_cls, k, sa_col) + break + else: + # no Column override found, build normally + col = get_column_from_field(v) + setattr(new_cls, k, col) + continue + except Exception: + # Fall back to normal column building + pass col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field @@ -709,6 +761,27 @@ def get_column_from_field(field: Any) -> Column: # type: ignore else: field_info = field.field_info sa_column = getattr(field_info, "sa_column", Undefined) + # In Pydantic v2, when using Annotated[T, Field(...)], the Field(...) object + # is stored in the field's metadata and some custom attributes (like + # sa_column) might not be lifted onto the main FieldInfo. Inspect metadata + # to honor a Column passed via Annotated Field(...). + if IS_PYDANTIC_V2 and not isinstance(sa_column, Column): + # Try to recover a Column passed via Annotated[..., Field(sa_column=...)] + raw_ann = getattr(field, "annotation", None) + origin = get_origin(raw_ann) + if origin in (TEAnnotated, TypingAnnotated): + for extra in te_get_args(raw_ann)[1:]: + meta_sa_column = getattr(extra, "sa_column", Undefined) + if isinstance(meta_sa_column, Column): + sa_column = meta_sa_column + break + # Also check field metadata in case custom FieldInfo leaked through + if not isinstance(sa_column, Column): + for meta in getattr(field, "metadata", ()): + meta_sa_column = getattr(meta, "sa_column", Undefined) + if isinstance(meta_sa_column, Column): + sa_column = meta_sa_column + break if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py index e2ccc6d7ef..7e739155dd 100644 --- a/tests/test_field_sa_column.py +++ b/tests/test_field_sa_column.py @@ -1,8 +1,12 @@ +from datetime import datetime from typing import Optional import pytest -from sqlalchemy import Column, Integer, String +from sqlalchemy import Column, DateTime, Integer, String from sqlmodel import Field, SQLModel +from typing_extensions import Annotated + +from tests.conftest import needs_pydanticv2 def test_sa_column_takes_precedence() -> None: @@ -119,3 +123,16 @@ class Item(SQLModel, table=True): sa_column=Column(Integer, primary_key=True), ondelete="CASCADE", ) + + +@needs_pydanticv2 +def test_sa_column_in_annotated_is_respected() -> None: + class Item(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + available_at: Annotated[ + datetime, Field(sa_column=Column(DateTime(timezone=True))) + ] + + # Should reflect timezone=True from the provided Column + assert isinstance(Item.available_at.type, DateTime) # type: ignore[attr-defined] + assert Item.available_at.type.timezone is True # type: ignore[attr-defined]