From 4d13aa3da572837b1203720b2d18879b74c6f604 Mon Sep 17 00:00:00 2001 From: seria <seria.ati@gmail.com> Date: Thu, 10 Apr 2025 09:49:29 +0900 Subject: [PATCH 1/3] Add _Executable overload for exec method in AsyncSession and Session --- sqlmodel/ext/asyncio/session.py | 15 ++++++++++++++- sqlmodel/orm/session.py | 15 ++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 467d0bd84e..c344898d3e 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -57,12 +57,25 @@ async def exec( _add_event: Optional[Any] = None, ) -> ScalarResult[_TSelectParam]: ... + @overload + async def exec( + self, + statement: _Executable, + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: ... + async def exec( self, statement: Union[ Select[_TSelectParam], SelectOfScalar[_TSelectParam], Executable[_TSelectParam], + _Executable, ], *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, @@ -70,7 +83,7 @@ async def exec( bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]: + ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam], Result[Any]]: if execution_options: execution_options = util.immutabledict(execution_options).union( _EXECUTE_OPTIONS diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index b60875095b..8b94a9ae85 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -49,12 +49,25 @@ def exec( _add_event: Optional[Any] = None, ) -> ScalarResult[_TSelectParam]: ... + @overload + def exec( + self, + statement: _Executable, + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: ... + def exec( self, statement: Union[ Select[_TSelectParam], SelectOfScalar[_TSelectParam], Executable[_TSelectParam], + _Executable, ], *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, @@ -62,7 +75,7 @@ def exec( bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]: + ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam], Result[Any]]: results = super().execute( statement, params=params, From 4ec8aac6c0438a32d9d44c1027bc1f835bd3eb26 Mon Sep 17 00:00:00 2001 From: seria <seria.ati@gmail.com> Date: Thu, 10 Apr 2025 10:29:33 +0900 Subject: [PATCH 2/3] Fix overload implementation --- sqlmodel/ext/asyncio/session.py | 10 ++++++---- sqlmodel/orm/session.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index c344898d3e..8b9f0970ce 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -12,6 +12,7 @@ ) from sqlalchemy import util +from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams from sqlalchemy.engine.result import Result, ScalarResult, TupleResult from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession @@ -19,6 +20,7 @@ from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS from sqlalchemy.orm._typing import OrmExecuteOptionsParameter from sqlalchemy.sql.base import Executable as _Executable +from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.util.concurrency import greenlet_spawn from typing_extensions import deprecated @@ -60,14 +62,14 @@ async def exec( @overload async def exec( self, - statement: _Executable, + statement: UpdateBase, *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: ... + ) -> CursorResult[Any]: ... async def exec( self, @@ -75,7 +77,7 @@ async def exec( Select[_TSelectParam], SelectOfScalar[_TSelectParam], Executable[_TSelectParam], - _Executable, + UpdateBase, ], *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, @@ -83,7 +85,7 @@ async def exec( bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam], Result[Any]]: + ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam], CursorResult[Any]]: if execution_options: execution_options = util.immutabledict(execution_options).union( _EXECUTE_OPTIONS diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 8b94a9ae85..708e7ef3a2 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -10,6 +10,7 @@ ) from sqlalchemy import util +from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams from sqlalchemy.engine.result import Result, ScalarResult, TupleResult from sqlalchemy.orm import Query as _Query @@ -17,6 +18,7 @@ from sqlalchemy.orm._typing import OrmExecuteOptionsParameter from sqlalchemy.sql._typing import _ColumnsClauseArgument from sqlalchemy.sql.base import Executable as _Executable +from sqlalchemy.sql.dml import UpdateBase from sqlmodel.sql.base import Executable from sqlmodel.sql.expression import Select, SelectOfScalar from typing_extensions import deprecated @@ -52,14 +54,14 @@ def exec( @overload def exec( self, - statement: _Executable, + statement: UpdateBase, *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: ... + ) -> CursorResult[Any]: ... def exec( self, @@ -67,7 +69,7 @@ def exec( Select[_TSelectParam], SelectOfScalar[_TSelectParam], Executable[_TSelectParam], - _Executable, + UpdateBase, ], *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, @@ -75,7 +77,7 @@ def exec( bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam], Result[Any]]: + ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam], CursorResult[Any]]: results = super().execute( statement, params=params, From 770a4fd80a57d463cbb040ff159fcc9d5620217a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Apr 2025 01:28:22 +0000 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/ext/asyncio/session.py | 4 +++- sqlmodel/orm/session.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 8b9f0970ce..54488357bb 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -85,7 +85,9 @@ async def exec( bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam], CursorResult[Any]]: + ) -> Union[ + TupleResult[_TSelectParam], ScalarResult[_TSelectParam], CursorResult[Any] + ]: if execution_options: execution_options = util.immutabledict(execution_options).union( _EXECUTE_OPTIONS diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 708e7ef3a2..dca4733d61 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -77,7 +77,9 @@ def exec( bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam], CursorResult[Any]]: + ) -> Union[ + TupleResult[_TSelectParam], ScalarResult[_TSelectParam], CursorResult[Any] + ]: results = super().execute( statement, params=params,