Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
cast,
overload,
)
from typing import Annotated as TypingAnnotated

from pydantic import BaseModel, EmailStr
from pydantic.fields import FieldInfo as PydanticFieldInfo
Expand Down Expand Up @@ -54,7 +55,18 @@
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
from typing_extensions import Literal, TypeAlias, deprecated, get_origin
from typing_extensions import (
Annotated as TEAnnotated,
)
from typing_extensions import (
Literal,
TypeAlias,
deprecated,
get_origin,
)
from typing_extensions import (
get_args as te_get_args,
)

from ._compat import ( # type: ignore[attr-defined]
IS_PYDANTIC_V2,
Expand Down Expand Up @@ -546,6 +558,26 @@ def __new__(
**new_cls.__annotations__,
}

# For Pydantic v2: If a field used Annotated[..., Field(sa_column=Column(...))]
# Pydantic might not lift our custom attribute onto the final FieldInfo.
# Recover it from the original annotations before creating SQLAlchemy Columns.
if IS_PYDANTIC_V2:
for field_name, ann in original_annotations.items():
try:
origin = get_origin(ann)
if origin in (TEAnnotated, TypingAnnotated):
for extra in te_get_args(ann)[1:]:
sa_col = getattr(extra, "sa_column", Undefined)
if isinstance(sa_col, Column):
# Attach found Column to the Pydantic field info
model_fields = get_model_fields(new_cls)
if field_name in model_fields:
model_fields[field_name].sa_column = sa_col
break
except Exception:
# Best-effort; fall back to default behavior
pass

def get_config(name: str) -> Any:
config_class_value = get_config_value(
model=new_cls, parameter=name, default=Undefined
Expand All @@ -562,6 +594,26 @@ def get_config(name: str) -> Any:
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
# Prefer a Column passed via Annotated[..., Field(sa_column=...)]
if IS_PYDANTIC_V2:
ann = original_annotations.get(k, None)
if ann is not None:
try:
origin = get_origin(ann)
if origin in (TEAnnotated, TypingAnnotated):
for extra in te_get_args(ann)[1:]:
sa_col = getattr(extra, "sa_column", Undefined)
if isinstance(sa_col, Column):
setattr(new_cls, k, sa_col)
break
else:
# no Column override found, build normally
col = get_column_from_field(v)
setattr(new_cls, k, col)
continue
except Exception:
# Fall back to normal column building
pass
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
Expand Down Expand Up @@ -709,6 +761,27 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
else:
field_info = field.field_info
sa_column = getattr(field_info, "sa_column", Undefined)
# In Pydantic v2, when using Annotated[T, Field(...)], the Field(...) object
# is stored in the field's metadata and some custom attributes (like
# sa_column) might not be lifted onto the main FieldInfo. Inspect metadata
# to honor a Column passed via Annotated Field(...).
if IS_PYDANTIC_V2 and not isinstance(sa_column, Column):
# Try to recover a Column passed via Annotated[..., Field(sa_column=...)]
raw_ann = getattr(field, "annotation", None)
origin = get_origin(raw_ann)
if origin in (TEAnnotated, TypingAnnotated):
for extra in te_get_args(raw_ann)[1:]:
meta_sa_column = getattr(extra, "sa_column", Undefined)
if isinstance(meta_sa_column, Column):
sa_column = meta_sa_column
break
# Also check field metadata in case custom FieldInfo leaked through
if not isinstance(sa_column, Column):
for meta in getattr(field, "metadata", ()):
meta_sa_column = getattr(meta, "sa_column", Undefined)
if isinstance(meta_sa_column, Column):
sa_column = meta_sa_column
break
if isinstance(sa_column, Column):
return sa_column
sa_type = get_sqlalchemy_type(field)
Expand Down
19 changes: 18 additions & 1 deletion tests/test_field_sa_column.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from datetime import datetime
from typing import Optional

import pytest
from sqlalchemy import Column, Integer, String
from sqlalchemy import Column, DateTime, Integer, String
from sqlmodel import Field, SQLModel
from typing_extensions import Annotated

from tests.conftest import needs_pydanticv2


def test_sa_column_takes_precedence() -> None:
Expand Down Expand Up @@ -119,3 +123,16 @@ class Item(SQLModel, table=True):
sa_column=Column(Integer, primary_key=True),
ondelete="CASCADE",
)


@needs_pydanticv2
def test_sa_column_in_annotated_is_respected() -> None:
class Item(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
available_at: Annotated[
datetime, Field(sa_column=Column(DateTime(timezone=True)))
]

# Should reflect timezone=True from the provided Column
assert isinstance(Item.available_at.type, DateTime) # type: ignore[attr-defined]
assert Item.available_at.type.timezone is True # type: ignore[attr-defined]
Loading