11from __future__ import annotations
22
33from datetime import date , datetime
4- from typing import TYPE_CHECKING , Any , Callable , ClassVar , Generic , List , TypeVar , Union
4+ from typing import TYPE_CHECKING , Any , Callable , ClassVar , Generic , List , Protocol , TypeVar , Union
55
66from typing_extensions import Annotated
77
@@ -68,6 +68,14 @@ async def save_many(self, data: list[T]) -> list[T]:
6868 return data
6969
7070
71+ _T_co = TypeVar ("_T_co" , covariant = True )
72+
73+
74+ class _SessionMaker (Protocol [_T_co ]):
75+ @staticmethod
76+ def __call__ () -> _T_co : ...
77+
78+
7179class SQLAlchemyFactory (Generic [T ], BaseFactory [T ]):
7280 """Base factory for SQLAlchemy models."""
7381
@@ -82,8 +90,8 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
8290 __set_association_proxy__ : ClassVar [bool ] = False
8391 """Configuration to consider AssociationProxy property as a model field or not."""
8492
85- __session__ : ClassVar [Session | Callable [[], Session ] | None ] = None
86- __async_session__ : ClassVar [AsyncSession | Callable [[], AsyncSession ] | None ] = None
93+ __session__ : ClassVar [Session | _SessionMaker [ Session ] | None ] = None
94+ __async_session__ : ClassVar [AsyncSession | _SessionMaker [ AsyncSession ] | None ] = None
8795
8896 __config_keys__ = (
8997 * BaseFactory .__config_keys__ ,
0 commit comments