Skip to content

Commit 77031a5

Browse files
AlexPetuladhtruong
andauthored
feat: support for custom collection_class in SQLAlchemy relationships (#776)
Co-authored-by: Andrew Truong <[email protected]>
1 parent c54880f commit 77031a5

File tree

4 files changed

+307
-17
lines changed

4 files changed

+307
-17
lines changed

polyfactory/factories/sqlalchemy_factory.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
from __future__ import annotations
22

3+
from collections.abc import Collection, Mapping
34
from datetime import date, datetime
4-
from typing import TYPE_CHECKING, Annotated, Any, Callable, ClassVar, Generic, Protocol, TypeVar, Union
5-
6-
from polyfactory.exceptions import MissingDependencyException, ParameterException
5+
from typing import (
6+
TYPE_CHECKING,
7+
Annotated,
8+
Any,
9+
Callable,
10+
ClassVar,
11+
Generic,
12+
Protocol,
13+
TypeVar,
14+
Union,
15+
)
16+
17+
from sqlalchemy.util.langhelpers import duck_type_collection
18+
19+
from polyfactory.exceptions import ConfigurationException, MissingDependencyException, ParameterException
720
from polyfactory.factories.base import BaseFactory
821
from polyfactory.field_meta import Constraints, FieldMeta
922
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol
@@ -203,6 +216,30 @@ def get_type_from_column(cls, column: Column) -> type:
203216

204217
return annotation
205218

219+
@classmethod
220+
def get_type_from_collection_class(
221+
cls,
222+
collection_class: type[Collection[Any]] | Callable[[], Collection[Any]],
223+
entity_class: Any,
224+
) -> type[Any]:
225+
annotation: type[Any]
226+
227+
if isinstance(collection_class, type):
228+
if issubclass(collection_class, Mapping):
229+
annotation = dict[Any, entity_class]
230+
else:
231+
if not (duck_typed_as := duck_type_collection(collection_class)):
232+
msg = f"Cannot infer type from collection_class {collection_class}"
233+
raise ConfigurationException(
234+
msg,
235+
)
236+
237+
annotation = duck_typed_as[entity_class] # pyright: ignore[reportIndexIssue]
238+
else:
239+
annotation = dict[Any, entity_class]
240+
241+
return annotation
242+
206243
@classmethod
207244
def get_model_fields(cls) -> list[FieldMeta]:
208245
fields_meta: list[FieldMeta] = []
@@ -219,12 +256,22 @@ def get_model_fields(cls) -> list[FieldMeta]:
219256
if cls.__set_relationships__:
220257
for name, relationship in table.relationships.items():
221258
class_ = relationship.entity.class_
222-
annotation = class_ if not relationship.uselist else list[class_] # type: ignore[valid-type]
259+
annotation: Any
260+
261+
if relationship.uselist:
262+
collection_class = relationship.collection_class
263+
if collection_class is None:
264+
annotation = list[class_] # type: ignore[valid-type]
265+
else:
266+
annotation = cls.get_type_from_collection_class(collection_class, class_)
267+
else:
268+
annotation = class_
269+
223270
fields_meta.append(
224271
FieldMeta.from_type(
225272
name=name,
226273
annotation=annotation,
227-
),
274+
)
228275
)
229276
if cls.__set_association_proxy__:
230277
for name, attr in table.all_orm_descriptors.items():

tests/sqlalchemy_factory/models.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,7 @@
11
from dataclasses import dataclass
22
from typing import Any, Optional
33

4-
from sqlalchemy import (
5-
Boolean,
6-
Column,
7-
DateTime,
8-
ForeignKey,
9-
Integer,
10-
String,
11-
func,
12-
orm,
13-
text,
14-
)
4+
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, func, orm, text
155
from sqlalchemy.ext.associationproxy import association_proxy
166
from sqlalchemy.orm import relationship
177
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry
@@ -32,6 +22,20 @@ class Base(metaclass=DeclarativeMeta):
3222
metadata = _registry.metadata
3323

3424

25+
class CollectionParentMixin:
26+
__abstract__ = True
27+
__allow_unmapped__ = True
28+
29+
id = Column(Integer(), primary_key=True)
30+
31+
32+
class CollectionChildMixin:
33+
__abstract__ = True
34+
__allow_unmapped__ = True
35+
36+
id = Column(Integer(), primary_key=True)
37+
38+
3539
class Author(Base):
3640
__tablename__ = "authors"
3741

tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py

Lines changed: 198 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import warnings
2+
from collections.abc import Collection
23
from datetime import datetime
34
from decimal import Decimal
45
from enum import Enum
56
from typing import Any, Callable, get_args
6-
from uuid import UUID
7+
from uuid import UUID, uuid4
78

89
import pytest
910
from sqlalchemy import (
1011
Column,
12+
ForeignKey,
1113
Integer,
1214
Numeric,
1315
String,
@@ -20,6 +22,20 @@
2022
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
2123
from sqlalchemy.ext.hybrid import hybrid_property
2224
from sqlalchemy.orm import Session
25+
26+
try:
27+
from sqlalchemy.orm.collections import (
28+
attribute_keyed_dict,
29+
column_keyed_dict,
30+
keyfunc_mapping,
31+
)
32+
except ImportError: # SQLAlchemy < 2.0
33+
from sqlalchemy.orm import collections as _collections
34+
35+
attribute_keyed_dict = _collections.attribute_mapped_collection # type: ignore[attr-defined]
36+
column_keyed_dict = _collections.column_mapped_collection # type: ignore[attr-defined]
37+
keyfunc_mapping = _collections.mapped_collection # type: ignore[attr-defined]
38+
2339
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry
2440

2541
from polyfactory.exceptions import ConfigurationException, ParameterException
@@ -32,9 +48,12 @@
3248
Author,
3349
Base,
3450
Book,
51+
CollectionChildMixin,
52+
CollectionParentMixin,
3553
NonSQLAchemyClass,
3654
_registry,
3755
)
56+
from tests.sqlalchemy_factory.types import ListLike, SetLike
3857

3958

4059
@pytest.mark.parametrize(
@@ -195,6 +214,184 @@ class AuthorFactory(SQLAlchemyFactory[Author]):
195214
assert isinstance(result.books[0], Book)
196215

197216

217+
@pytest.mark.parametrize(
218+
"collection_class_type",
219+
(set, list, ListLike, SetLike),
220+
)
221+
def test_relationship_collection_class_sequence(collection_class_type: type[Collection]) -> None:
222+
table_suffix = uuid4().hex
223+
_registry = registry()
224+
225+
class Base(metaclass=DeclarativeMeta):
226+
__abstract__ = True
227+
__allow_unmapped__ = True
228+
229+
registry = _registry
230+
metadata = _registry.metadata
231+
232+
class Parent(CollectionParentMixin, Base):
233+
__tablename__ = f"parent_{table_suffix}"
234+
235+
children: Any = orm.relationship("Child", collection_class=collection_class_type)
236+
237+
class Child(CollectionChildMixin, Base):
238+
__tablename__ = f"child_{table_suffix}"
239+
240+
parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False)
241+
242+
class ParentFactory(SQLAlchemyFactory[Parent]):
243+
__model__ = Parent
244+
__set_relationships__ = True
245+
246+
result = ParentFactory.build()
247+
assert result.children is not None
248+
249+
assert isinstance(result.children, collection_class_type)
250+
first_item = next(iter(result.children))
251+
assert isinstance(first_item, Child)
252+
253+
254+
def test_relationship_collection_class_attribute_keyed_dict() -> None:
255+
table_suffix = uuid4().hex
256+
_registry = registry()
257+
258+
class Base(metaclass=DeclarativeMeta):
259+
__abstract__ = True
260+
__allow_unmapped__ = True
261+
262+
registry = _registry
263+
metadata = _registry.metadata
264+
265+
class Parent(CollectionParentMixin, Base):
266+
__tablename__ = f"parent_{table_suffix}"
267+
268+
children = orm.relationship("Child", collection_class=attribute_keyed_dict("id"))
269+
270+
class Child(CollectionChildMixin, Base):
271+
__tablename__ = f"child_{table_suffix}"
272+
273+
parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False)
274+
275+
class ParentFactory(SQLAlchemyFactory[Parent]):
276+
__model__ = Parent
277+
__set_relationships__ = True
278+
279+
result = ParentFactory.build()
280+
assert result.children is not None
281+
282+
assert isinstance(result.children, dict)
283+
child = next(iter(result.children.values()))
284+
assert isinstance(child, Child)
285+
assert child.id in result.children
286+
287+
288+
def test_relationship_collection_class_column_keyed_dict() -> None:
289+
table_suffix = uuid4().hex
290+
_registry = registry()
291+
292+
class Base(metaclass=DeclarativeMeta):
293+
__abstract__ = True
294+
__allow_unmapped__ = True
295+
296+
registry = _registry
297+
metadata = _registry.metadata
298+
299+
class Parent(CollectionParentMixin, Base):
300+
__tablename__ = f"parent_{table_suffix}"
301+
302+
children: Any
303+
304+
class Child(CollectionChildMixin, Base):
305+
__tablename__ = f"child_{table_suffix}"
306+
307+
parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False)
308+
309+
Parent.children = orm.relationship("Child", collection_class=column_keyed_dict(Child.__table__.c.id)) # type: ignore[attr-defined]
310+
311+
class ParentFactory(SQLAlchemyFactory[Parent]):
312+
__model__ = Parent
313+
__set_relationships__ = True
314+
315+
result = ParentFactory.build()
316+
assert result.children is not None
317+
318+
assert isinstance(result.children, dict)
319+
child = next(iter(result.children.values()))
320+
assert isinstance(child, Child)
321+
assert child.id in result.children
322+
323+
324+
def test_relationship_collection_class_arbitrary_keying() -> None:
325+
table_suffix = uuid4().hex
326+
_registry = registry()
327+
328+
class Base(metaclass=DeclarativeMeta):
329+
__abstract__ = True
330+
__allow_unmapped__ = True
331+
332+
registry = _registry
333+
metadata = _registry.metadata
334+
335+
class Parent(CollectionParentMixin, Base):
336+
__tablename__ = f"parent_{table_suffix}"
337+
338+
children = orm.relationship("Child", collection_class=keyfunc_mapping(lambda c: c.id))
339+
340+
class Child(CollectionChildMixin, Base):
341+
__tablename__ = f"child_{table_suffix}"
342+
343+
parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False)
344+
345+
class ParentFactory(SQLAlchemyFactory[Parent]):
346+
__model__ = Parent
347+
__set_relationships__ = True
348+
349+
result = ParentFactory.build()
350+
assert result.children is not None
351+
352+
assert isinstance(result.children, dict)
353+
child = next(iter(result.children.values()))
354+
assert isinstance(child, Child)
355+
assert child.id in result.children
356+
357+
358+
def test_relationship_collection_class_arbitrary_keyfunc() -> None:
359+
def make_mapping() -> Any:
360+
return keyfunc_mapping(lambda c: c.id)() # type: ignore[call-arg]
361+
362+
table_suffix = uuid4().hex
363+
_registry = registry()
364+
365+
class Base(metaclass=DeclarativeMeta):
366+
__abstract__ = True
367+
__allow_unmapped__ = True
368+
369+
registry = _registry
370+
metadata = _registry.metadata
371+
372+
class Parent(CollectionParentMixin, Base):
373+
__tablename__ = f"parent_{table_suffix}"
374+
375+
children = orm.relationship("Child", collection_class=make_mapping)
376+
377+
class Child(CollectionChildMixin, Base):
378+
__tablename__ = f"child_{table_suffix}"
379+
380+
parent_id = Column(Integer(), ForeignKey(f"{Parent.__tablename__}.id"), nullable=False)
381+
382+
class ParentFactory(SQLAlchemyFactory[Parent]):
383+
__model__ = Parent
384+
__set_relationships__ = True
385+
386+
result = ParentFactory.build()
387+
assert result.children is not None
388+
389+
assert isinstance(result.children, dict)
390+
child = next(iter(result.children.values()))
391+
assert isinstance(child, Child)
392+
assert child.id in result.children
393+
394+
198395
def test_sqla_factory_create(engine: Engine) -> None:
199396
Base.metadata.create_all(engine)
200397

0 commit comments

Comments
 (0)