|
1 | 1 | import warnings |
| 2 | +from collections.abc import Collection |
2 | 3 | from datetime import datetime |
3 | 4 | from decimal import Decimal |
4 | 5 | from enum import Enum |
5 | 6 | from typing import Any, Callable, get_args |
6 | | -from uuid import UUID |
| 7 | +from uuid import UUID, uuid4 |
7 | 8 |
|
8 | 9 | import pytest |
9 | 10 | from sqlalchemy import ( |
10 | 11 | Column, |
| 12 | + ForeignKey, |
11 | 13 | Integer, |
12 | 14 | Numeric, |
13 | 15 | String, |
|
20 | 22 | from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession |
21 | 23 | from sqlalchemy.ext.hybrid import hybrid_property |
22 | 24 | from sqlalchemy.orm import Session |
| 25 | + |
| 26 | +try: |
| 27 | + from sqlalchemy.orm.collections import ( |
| 28 | + attribute_keyed_dict, |
| 29 | + column_keyed_dict, |
| 30 | + keyfunc_mapping, |
| 31 | + ) |
| 32 | +except ImportError: # SQLAlchemy < 2.0 |
| 33 | + from sqlalchemy.orm import collections as _collections |
| 34 | + |
| 35 | + attribute_keyed_dict = _collections.attribute_mapped_collection # type: ignore[attr-defined] |
| 36 | + column_keyed_dict = _collections.column_mapped_collection # type: ignore[attr-defined] |
| 37 | + keyfunc_mapping = _collections.mapped_collection # type: ignore[attr-defined] |
| 38 | + |
23 | 39 | from sqlalchemy.orm.decl_api import DeclarativeMeta, registry |
24 | 40 |
|
25 | 41 | from polyfactory.exceptions import ConfigurationException, ParameterException |
|
32 | 48 | Author, |
33 | 49 | Base, |
34 | 50 | Book, |
| 51 | + CollectionChildMixin, |
| 52 | + CollectionParentMixin, |
35 | 53 | NonSQLAchemyClass, |
36 | 54 | _registry, |
37 | 55 | ) |
| 56 | +from tests.sqlalchemy_factory.types import ListLike, SetLike |
38 | 57 |
|
39 | 58 |
|
40 | 59 | @pytest.mark.parametrize( |
@@ -195,6 +214,184 @@ class AuthorFactory(SQLAlchemyFactory[Author]): |
195 | 214 | assert isinstance(result.books[0], Book) |
196 | 215 |
|
197 | 216 |
|
| 217 | +@pytest.mark.parametrize( |
| 218 | + "collection_class_type", |
| 219 | + (set, list, ListLike, SetLike), |
| 220 | +) |
| 221 | +def test_relationship_collection_class_sequence(collection_class_type: type[Collection]) -> None: |
| 222 | + table_suffix = uuid4().hex |
| 223 | + _registry = registry() |
| 224 | + |
| 225 | + class Base(metaclass=DeclarativeMeta): |
| 226 | + __abstract__ = True |
| 227 | + __allow_unmapped__ = True |
| 228 | + |
| 229 | + registry = _registry |
| 230 | + metadata = _registry.metadata |
| 231 | + |
| 232 | + class Parent(CollectionParentMixin, Base): |
| 233 | + __tablename__ = f"parent_{table_suffix}" |
| 234 | + |
| 235 | + children: Any = orm.relationship("Child", collection_class=collection_class_type) |
| 236 | + |
| 237 | + class Child(CollectionChildMixin, Base): |
| 238 | + __tablename__ = f"child_{table_suffix}" |
| 239 | + |
| 240 | + parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False) |
| 241 | + |
| 242 | + class ParentFactory(SQLAlchemyFactory[Parent]): |
| 243 | + __model__ = Parent |
| 244 | + __set_relationships__ = True |
| 245 | + |
| 246 | + result = ParentFactory.build() |
| 247 | + assert result.children is not None |
| 248 | + |
| 249 | + assert isinstance(result.children, collection_class_type) |
| 250 | + first_item = next(iter(result.children)) |
| 251 | + assert isinstance(first_item, Child) |
| 252 | + |
| 253 | + |
| 254 | +def test_relationship_collection_class_attribute_keyed_dict() -> None: |
| 255 | + table_suffix = uuid4().hex |
| 256 | + _registry = registry() |
| 257 | + |
| 258 | + class Base(metaclass=DeclarativeMeta): |
| 259 | + __abstract__ = True |
| 260 | + __allow_unmapped__ = True |
| 261 | + |
| 262 | + registry = _registry |
| 263 | + metadata = _registry.metadata |
| 264 | + |
| 265 | + class Parent(CollectionParentMixin, Base): |
| 266 | + __tablename__ = f"parent_{table_suffix}" |
| 267 | + |
| 268 | + children = orm.relationship("Child", collection_class=attribute_keyed_dict("id")) |
| 269 | + |
| 270 | + class Child(CollectionChildMixin, Base): |
| 271 | + __tablename__ = f"child_{table_suffix}" |
| 272 | + |
| 273 | + parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False) |
| 274 | + |
| 275 | + class ParentFactory(SQLAlchemyFactory[Parent]): |
| 276 | + __model__ = Parent |
| 277 | + __set_relationships__ = True |
| 278 | + |
| 279 | + result = ParentFactory.build() |
| 280 | + assert result.children is not None |
| 281 | + |
| 282 | + assert isinstance(result.children, dict) |
| 283 | + child = next(iter(result.children.values())) |
| 284 | + assert isinstance(child, Child) |
| 285 | + assert child.id in result.children |
| 286 | + |
| 287 | + |
| 288 | +def test_relationship_collection_class_column_keyed_dict() -> None: |
| 289 | + table_suffix = uuid4().hex |
| 290 | + _registry = registry() |
| 291 | + |
| 292 | + class Base(metaclass=DeclarativeMeta): |
| 293 | + __abstract__ = True |
| 294 | + __allow_unmapped__ = True |
| 295 | + |
| 296 | + registry = _registry |
| 297 | + metadata = _registry.metadata |
| 298 | + |
| 299 | + class Parent(CollectionParentMixin, Base): |
| 300 | + __tablename__ = f"parent_{table_suffix}" |
| 301 | + |
| 302 | + children: Any |
| 303 | + |
| 304 | + class Child(CollectionChildMixin, Base): |
| 305 | + __tablename__ = f"child_{table_suffix}" |
| 306 | + |
| 307 | + parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False) |
| 308 | + |
| 309 | + Parent.children = orm.relationship("Child", collection_class=column_keyed_dict(Child.__table__.c.id)) # type: ignore[attr-defined] |
| 310 | + |
| 311 | + class ParentFactory(SQLAlchemyFactory[Parent]): |
| 312 | + __model__ = Parent |
| 313 | + __set_relationships__ = True |
| 314 | + |
| 315 | + result = ParentFactory.build() |
| 316 | + assert result.children is not None |
| 317 | + |
| 318 | + assert isinstance(result.children, dict) |
| 319 | + child = next(iter(result.children.values())) |
| 320 | + assert isinstance(child, Child) |
| 321 | + assert child.id in result.children |
| 322 | + |
| 323 | + |
| 324 | +def test_relationship_collection_class_arbitrary_keying() -> None: |
| 325 | + table_suffix = uuid4().hex |
| 326 | + _registry = registry() |
| 327 | + |
| 328 | + class Base(metaclass=DeclarativeMeta): |
| 329 | + __abstract__ = True |
| 330 | + __allow_unmapped__ = True |
| 331 | + |
| 332 | + registry = _registry |
| 333 | + metadata = _registry.metadata |
| 334 | + |
| 335 | + class Parent(CollectionParentMixin, Base): |
| 336 | + __tablename__ = f"parent_{table_suffix}" |
| 337 | + |
| 338 | + children = orm.relationship("Child", collection_class=keyfunc_mapping(lambda c: c.id)) |
| 339 | + |
| 340 | + class Child(CollectionChildMixin, Base): |
| 341 | + __tablename__ = f"child_{table_suffix}" |
| 342 | + |
| 343 | + parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False) |
| 344 | + |
| 345 | + class ParentFactory(SQLAlchemyFactory[Parent]): |
| 346 | + __model__ = Parent |
| 347 | + __set_relationships__ = True |
| 348 | + |
| 349 | + result = ParentFactory.build() |
| 350 | + assert result.children is not None |
| 351 | + |
| 352 | + assert isinstance(result.children, dict) |
| 353 | + child = next(iter(result.children.values())) |
| 354 | + assert isinstance(child, Child) |
| 355 | + assert child.id in result.children |
| 356 | + |
| 357 | + |
| 358 | +def test_relationship_collection_class_arbitrary_keyfunc() -> None: |
| 359 | + def make_mapping() -> Any: |
| 360 | + return keyfunc_mapping(lambda c: c.id)() # type: ignore[call-arg] |
| 361 | + |
| 362 | + table_suffix = uuid4().hex |
| 363 | + _registry = registry() |
| 364 | + |
| 365 | + class Base(metaclass=DeclarativeMeta): |
| 366 | + __abstract__ = True |
| 367 | + __allow_unmapped__ = True |
| 368 | + |
| 369 | + registry = _registry |
| 370 | + metadata = _registry.metadata |
| 371 | + |
| 372 | + class Parent(CollectionParentMixin, Base): |
| 373 | + __tablename__ = f"parent_{table_suffix}" |
| 374 | + |
| 375 | + children = orm.relationship("Child", collection_class=make_mapping) |
| 376 | + |
| 377 | + class Child(CollectionChildMixin, Base): |
| 378 | + __tablename__ = f"child_{table_suffix}" |
| 379 | + |
| 380 | + parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False) |
| 381 | + |
| 382 | + class ParentFactory(SQLAlchemyFactory[Parent]): |
| 383 | + __model__ = Parent |
| 384 | + __set_relationships__ = True |
| 385 | + |
| 386 | + result = ParentFactory.build() |
| 387 | + assert result.children is not None |
| 388 | + |
| 389 | + assert isinstance(result.children, dict) |
| 390 | + child = next(iter(result.children.values())) |
| 391 | + assert isinstance(child, Child) |
| 392 | + assert child.id in result.children |
| 393 | + |
| 394 | + |
198 | 395 | def test_sqla_factory_create(engine: Engine) -> None: |
199 | 396 | Base.metadata.create_all(engine) |
200 | 397 |
|
|
0 commit comments