diff --git a/fastapi_asyncpg/__init__.py b/fastapi_asyncpg/__init__.py index f569094..c19060e 100644 --- a/fastapi_asyncpg/__init__.py +++ b/fastapi_asyncpg/__init__.py @@ -1,5 +1,4 @@ from __future__ import annotations - from fastapi import FastAPI import asyncpg @@ -46,19 +45,20 @@ async def on_connect(self): the db""" # if the pool is comming from outside (tests), don't connect it if self._pool: - self.app.state.pool = self._pool + self._get_pool_manager_from_app().put(self.dsn, self._pool) return pool = await asyncpg.create_pool(dsn=self.dsn, **self.con_opts) async with pool.acquire() as db: await self.init_db(db) - self.app.state.pool = pool + + self._get_pool_manager_from_app().put(self.dsn, pool) async def on_disconnect(self): # if the pool is comming from outside, don't desconnect it # someone else will do (usualy a pytest fixture) if self._pool: return - await self.app.state.pool.close() + await self.pool.close() def on_init(self, func): self.init_db = func @@ -66,7 +66,10 @@ def on_init(self, func): @property def pool(self): - return self.app.state.pool + """Fetch the connection pool associated with our DSN from the + pool manager stashed within app.state. + """ + return self._get_pool_manager_from_app().get(self.dsn) async def connection(self): """ @@ -103,9 +106,35 @@ async def get_content(db = Depens(db.transaction)): else: await txn.commit() + def _get_pool_manager_from_app(self): + """Find or create singleton AppPoolManager instance within self.app.state""" + if not hasattr(self.app.state, "fastapi_asyncpg_pool_manager"): + self.app.state.fastapi_asyncpg_pool_manager = AppPoolManager() + + return self.app.state.fastapi_asyncpg_pool_manager + atomic = transaction +class AppPoolManager: + """Object placed into fastapi app.state to manage one or more + asyncpg.pool.Pool instances within the fastapi app. + + If the app uses more than one asyncpg database, then there + will be more than one pool. We separate them by the + connection DSN. + """ + + def __init__(self): + self._pool_by_dsn = {} + + def get(self, dsn): + return self._pool_by_dsn[dsn] + + def put(self, dsn, pool): + self._pool_by_dsn[dsn] = pool + + class SingleConnectionTestingPool: """A fake pool that simulates pooling, but runs on a single transaction that it's rolled back after