diff --git a/aiosqlite/core.py b/aiosqlite/core.py index 61bab67..e2485ab 100644 --- a/aiosqlite/core.py +++ b/aiosqlite/core.py @@ -62,10 +62,15 @@ def __init__( DeprecationWarning, ) - def _stop_running(self): + async def _stop_running(self): self._running = False - # PEP 661 is not accepted yet, so we cannot type a sentinel - self._tx.put_nowait(_STOP_RUNNING_SENTINEL) # type: ignore[arg-type] + + function = partial(lambda: _STOP_RUNNING_SENTINEL) + future = asyncio.get_event_loop().create_future() + + self._tx.put_nowait((future, function)) + + return await future @property def _conn(self) -> sqlite3.Connection: @@ -95,9 +100,6 @@ def run(self) -> None: # futures) tx_item = self._tx.get() - if tx_item is _STOP_RUNNING_SENTINEL: - break - future, function = tx_item try: @@ -105,6 +107,9 @@ def run(self) -> None: result = function() LOG.debug("operation %s completed", function) future.get_loop().call_soon_threadsafe(set_result, future, result) + + if result is _STOP_RUNNING_SENTINEL: + break except BaseException as e: # noqa B036 LOG.debug("returning exception %s", e) future.get_loop().call_soon_threadsafe(set_exception, future, e) @@ -129,7 +134,7 @@ async def _connect(self) -> "Connection": self._tx.put_nowait((future, self._connector)) self._connection = await future except BaseException: - self._stop_running() + await self._stop_running() self._connection = None raise @@ -170,7 +175,7 @@ async def close(self) -> None: LOG.info("exception occurred while closing connection") raise finally: - self._stop_running() + await self._stop_running() self._connection = None @contextmanager diff --git a/aiosqlite/tests/smoke.py b/aiosqlite/tests/smoke.py index 57c1beb..cbbe75a 100644 --- a/aiosqlite/tests/smoke.py +++ b/aiosqlite/tests/smoke.py @@ -413,6 +413,17 @@ async def test_cursor_on_closed_connection_loop(self): except sqlite3.ProgrammingError: pass + async def test_close_blocking_until_transaction_queue_empty(self): + db = await aiosqlite.connect(self.db) + # Insert transactions into the + # transaction queue '_tx' + for i in range(1000): + await db.execute(f"select 1, {i}") + # Wait for all transactions to complete + await db.close() + # Check no more transaction pending + self.assertEqual(db._tx.empty(), True) + async def test_close_twice(self): db = await aiosqlite.connect(self.db)