Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def should_column_be_set(cls, column: Any) -> bool:
if not cls.should_dataclass_init_field(column.name):
return False

if column.computed and (cls.__session__ is not None or cls.__async_session__ is not None):
return False

return bool(cls.__set_foreign_keys__ or not column.foreign_keys)

@classmethod
Expand Down
10 changes: 9 additions & 1 deletion tests/sqlalchemy_factory/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Any, Optional

from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, func, orm, text
from sqlalchemy import Boolean, Column, Computed, DateTime, ForeignKey, Integer, String, func, orm, text
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry
Expand Down Expand Up @@ -146,3 +146,11 @@ class Employee(Base):
name = Column(String)
company_id = Column(Integer, ForeignKey("companies.id"))
company = relationship(Company, back_populates="employees")


class Shape(Base):
__tablename__ = "shape"

id = Column(Integer, primary_key=True)
side: Any = Column(Integer(), nullable=False, default=10)
area: Any = Column(Integer, Computed("side * side"), nullable=False)
29 changes: 29 additions & 0 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
CollectionChildMixin,
CollectionParentMixin,
NonSQLAchemyClass,
Shape,
_registry,
)
from tests.sqlalchemy_factory.types import ListLike, SetLike
Expand Down Expand Up @@ -133,6 +134,34 @@ class ModelFactory(SQLAlchemyFactory[Model]): ...
assert instance.age * 3 == instance.triple_age


def test_computed_column_sync_persistence(engine: Engine) -> None:
Base.metadata.create_all(engine)

class ShapeFactory(SQLAlchemyFactory[Shape]):
__model__ = Shape
__session__ = Session(engine)

instance = ShapeFactory.create_sync()
assert instance.area == pow(instance.side, 2)


async def test_computed_column_async_persistence(engine: Engine, async_engine: AsyncEngine) -> None:
class ShapeFactory(SQLAlchemyFactory[Shape]):
__model__ = Shape
__async_session__ = AsyncSession(async_engine)

instance = await ShapeFactory.create_async()
assert instance.area == pow(instance.side, 2)


def test_computed_column_no_persistence() -> None:
class ShapeFactory(SQLAlchemyFactory[Shape]):
__model__ = Shape

fields = ShapeFactory.get_model_fields()
assert "area" in [field.name for field in fields]


@pytest.mark.parametrize(
"type_",
tuple(SQLAlchemyFactory.get_sqlalchemy_types().keys()),
Expand Down
Loading