diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index bbe86ee2ae..835a952b19 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -129,6 +129,7 @@ # Extensions and modifications of SQLAlchemy in SQLModel from .engine.create import create_engine as create_engine from .orm.session import Session as Session +from .orm.session import sessionmaker as sessionmaker from .sql.expression import select as select from .sql.expression import col as col from .sql.sqltypes import AutoString as AutoString diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index a5a63e2c69..b9e9933d5c 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -3,6 +3,7 @@ from sqlalchemy import util from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session +from sqlalchemy.orm import sessionmaker as _sessionmaker from sqlalchemy.sql.base import Executable as _Executable from sqlmodel.sql.expression import Select, SelectOfScalar from typing_extensions import Literal @@ -137,3 +138,10 @@ def get( with_for_update=with_for_update, identity_token=identity_token, ) + + +class sessionmaker(_sessionmaker): + def __init__(self, *args, **kwargs): + if 'class_' not in kwargs: + kwargs['class_'] = Session + super().__init__(*args, **kwargs)