Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure Postgres queries are committed or autocommit is used #5039

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,28 @@ class PostgreSQLOnlineStore(OnlineStore):
_conn_pool_async: Optional[AsyncConnectionPool] = None

@contextlib.contextmanager
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
def _get_conn(
self, config: RepoConfig, autocommit: bool = False
) -> Generator[Connection, Any, Any]:
assert config.online_store.type == "postgres"

if config.online_store.conn_type == ConnectionType.pool:
if not self._conn_pool:
self._conn_pool = _get_connection_pool(config.online_store)
self._conn_pool.open()
connection = self._conn_pool.getconn()
connection.set_autocommit(autocommit)
yield connection
self._conn_pool.putconn(connection)
else:
if not self._conn:
self._conn = _get_conn(config.online_store)
self._conn.set_autocommit(autocommit)
yield self._conn

@contextlib.asynccontextmanager
async def _get_conn_async(
self, config: RepoConfig
self, config: RepoConfig, autocommit: bool = False
) -> AsyncGenerator[AsyncConnection, Any]:
if config.online_store.conn_type == ConnectionType.pool:
if not self._conn_pool_async:
Expand All @@ -84,11 +88,13 @@ async def _get_conn_async(
)
await self._conn_pool_async.open()
connection = await self._conn_pool_async.getconn()
await connection.set_autocommit(autocommit)
yield connection
await self._conn_pool_async.putconn(connection)
else:
if not self._conn_async:
self._conn_async = await _get_conn_async(config.online_store)
await self._conn_async.set_autocommit(autocommit)
yield self._conn_async

def online_write_batch(
Expand Down Expand Up @@ -161,7 +167,7 @@ def online_read(
config, table, keys, requested_features
)

with self._get_conn(config) as conn, conn.cursor() as cur:
with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
cur.execute(query, params)
rows = cur.fetchall()

Expand All @@ -179,7 +185,7 @@ async def online_read_async(
config, table, keys, requested_features
)

async with self._get_conn_async(config) as conn:
async with self._get_conn_async(config, autocommit=True) as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
rows = await cur.fetchall()
Expand Down Expand Up @@ -339,6 +345,7 @@ def teardown(
for table in tables:
table_name = _table_id(project, table)
cur.execute(_drop_table_and_index(table_name))
conn.commit()
except Exception:
logging.exception("Teardown failed")
raise
Expand Down Expand Up @@ -398,7 +405,7 @@ def retrieve_online_documents(
Optional[ValueProto],
]
] = []
with self._get_conn(config) as conn, conn.cursor() as cur:
with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
table_name = _table_id(project, table)

# Search query template to find the top k items that are closest to the given embedding
Expand Down
Loading