diff --git a/aiven_db_migrate/migrate/pgmigrate.py b/aiven_db_migrate/migrate/pgmigrate.py index 49222d6..9a7037c 100644 --- a/aiven_db_migrate/migrate/pgmigrate.py +++ b/aiven_db_migrate/migrate/pgmigrate.py @@ -4,7 +4,7 @@ PGDataDumpFailedError, PGDataNotFoundError, PGMigrateValidationFailedError, PGSchemaDumpFailedError, PGTooMuchDataError ) from aiven_db_migrate.migrate.pgutils import ( - create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length + create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length, wait_select ) from aiven_db_migrate.migrate.version import __version__ from concurrent import futures @@ -31,8 +31,6 @@ import threading import time -# https://www.psycopg.org/docs/faq.html#faq-interrupt-query -psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) MAX_CLI_LEN = 2097152 # getconf ARG_MAX @@ -136,6 +134,12 @@ def conn_str(self, *, dbname: str = None) -> str: conn_info["application_name"] = conn_info["application_name"] + "/" + self.mangle_db_name(conn_info["dbname"]) return create_connection_string(conn_info) + def connect_timeout(self): + try: + return int(self.conn_info.get("connect_timeout", os.environ.get("PGCONNECT_TIMEOUT", "")), 10) + except ValueError: + return None + @contextmanager def _cursor(self, *, dbname: str = None) -> RealDictCursor: conn: psycopg2.extensions.connection = None @@ -146,8 +150,8 @@ def _cursor(self, *, dbname: str = None) -> RealDictCursor: # from multiple threads; allow only one connection at time self.conn_lock.acquire() try: - conn = psycopg2.connect(**conn_info) - conn.autocommit = True + conn = psycopg2.connect(**conn_info, async_=True) + wait_select(conn, self.connect_timeout()) yield conn.cursor(cursor_factory=RealDictCursor) finally: if conn is not None: @@ -165,7 +169,15 @@ def c( ) -> List[Dict[str, Any]]: results: List[Dict[str, Any]] = [] with self._cursor(dbname=dbname) as cur: - cur.execute(query, args) + try: + cur.execute(query, args) + wait_select(cur.connection) + except KeyboardInterrupt: + # We wrap the whole execute+wait block to make sure we cancel + # the query in all cases, which we couldn't if KeyboardInterupt + # was only handled inside wait_select. + cur.connection.cancel() + raise if return_rows: results = cur.fetchall() if return_rows > 0 and len(results) != return_rows: diff --git a/aiven_db_migrate/migrate/pgutils.py b/aiven_db_migrate/migrate/pgutils.py index a49db50..71a16f0 100644 --- a/aiven_db_migrate/migrate/pgutils.py +++ b/aiven_db_migrate/migrate/pgutils.py @@ -4,6 +4,10 @@ from typing import Any, Dict from urllib.parse import parse_qs, urlparse +import psycopg2 +import select +import time + def find_pgbin_dir(pgversion: str) -> Path: def _pgbin_paths(): @@ -105,3 +109,37 @@ def parse_connection_string_url(url: str) -> Dict[str, str]: for k, v in parse_qs(p.query).items(): fields[k] = v[-1] return fields + + +# This enables interruptible queries with an approach similar to +# https://www.psycopg.org/docs/faq.html#faq-interrupt-query +# However, to handle timeouts we can't use psycopg2.extensions.set_wait_callback : +# https://github.com/psycopg/psycopg2/issues/944 +# Instead we rely on manually calling wait_select after connection and queries. +# Since it's not a wait callback, we do not capture and transform KeyboardInterupt here. +def wait_select(conn, timeout=None): + start_time = time.monotonic() + poll = select.poll() + while True: + if timeout is not None and timeout > 0: + time_left = start_time + timeout - time.monotonic() + if time_left <= 0: + raise TimeoutError("wait_select: timeout after {} seconds".format(timeout)) + else: + time_left = 1 + state = conn.poll() + if state == psycopg2.extensions.POLL_OK: + return + elif state == psycopg2.extensions.POLL_READ: + poll.register(conn.fileno(), select.POLLIN) + elif state == psycopg2.extensions.POLL_WRITE: + poll.register(conn.fileno(), select.POLLOUT) + else: + raise conn.OperationalError("wait_select: invalid poll state") + try: + # When the remote address does not exist at all, poll.poll() waits its full timeout without any event. + # However, in the same conditions, conn.poll() raises a psycopg2 exception almost immediately. + # It is better to fail quickly instead of waiting the full timeout, so we keep our poll.poll() below 1sec. + poll.poll(min(1.0, time_left) * 1000) + finally: + poll.unregister(conn.fileno()) diff --git a/test/conftest.py b/test/conftest.py index 4c8cb5f..df5a168 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -402,6 +402,7 @@ def inject_pg_fixture(*, name: str, pgversion: str, scope="module"): SUPPORTED_PG_VERSIONS = ["9.5", "9.6", "10", "11", "12"] +pg_cluster_for_tests: List[str] = list() pg_source_and_target_for_tests: List[Tuple[str, str]] = list() pg_source_and_target_for_replication_tests: List[Tuple[str, str]] = list() @@ -437,6 +438,10 @@ def generate_fixtures(): pg_source_and_target_for_tests.append((source_name, target_name)) if LooseVersion(source) >= "10": pg_source_and_target_for_replication_tests.append((source_name, target_name)) + for version in set(pg_source_versions).union(pg_target_versions): + fixture_name = "pg{}".format(version.replace(".", "")) + inject_pg_fixture(name=fixture_name, pgversion=version) + pg_cluster_for_tests.append(fixture_name) generate_fixtures() @@ -450,6 +455,17 @@ def test_pg_source_and_target_for_replication_tests(): print(pg_source_and_target_for_replication_tests) +@pytest.fixture(name="pg_cluster", params=pg_cluster_for_tests, scope="function") +def fixture_pg_cluster(request): + """Returns a fixture parametrized on the union of all source and target pg versions.""" + cluster_runner = request.getfixturevalue(request.param) + yield cluster_runner + for cleanup in cluster_runner.cleanups: + cleanup() + cluster_runner.cleanups.clear() + cluster_runner.drop_dbs() + + @pytest.fixture(name="pg_source_and_target", params=pg_source_and_target_for_tests, scope="function") def fixture_pg_source_and_target(request): source, target = request.param diff --git a/test/test_pg_cluster.py b/test/test_pg_cluster.py new file mode 100644 index 0000000..1f6f080 --- /dev/null +++ b/test/test_pg_cluster.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 Aiven, Helsinki, Finland. https://aiven.io/ +from aiven_db_migrate.migrate.pgmigrate import PGCluster +from multiprocessing import Process +from test.conftest import PGRunner +from typing import Tuple + +import os +import pytest +import signal +import time + + +def test_interruptible_queries(pg_cluster: PGRunner): + def wait_and_interrupt(): + time.sleep(1) + os.kill(os.getppid(), signal.SIGINT) + + cluster = PGCluster(conn_info=pg_cluster.conn_info()) + interuptor = Process(target=wait_and_interrupt) + interuptor.start() + start_time = time.monotonic() + with pytest.raises(KeyboardInterrupt): + cluster.c("select pg_sleep(100)") + assert time.monotonic() - start_time < 2 + interuptor.join() diff --git a/test/test_pg_migrate.py b/test/test_pg_migrate.py index b1d3420..a243d13 100644 --- a/test/test_pg_migrate.py +++ b/test/test_pg_migrate.py @@ -6,8 +6,10 @@ from test.utils import random_string, Timer from typing import Any, Dict, Optional +import os import psycopg2 import pytest +import time class PGMigrateTest: @@ -154,6 +156,27 @@ def test_migrate_invalid_conn_str(self): PGMigrate(source_conn_info=source_conn_info, target_conn_info=target_conn_info).migrate() assert str(err.value) == "Invalid source or target connection string" + def test_migrate_connect_timeout_parameter(self): + for source_conn_info in ("host=example.org connect_timeout=1", "postgresql://example.org?connect_timeout=1"): + start_time = time.monotonic() + with pytest.raises(TimeoutError): + PGMigrate(source_conn_info=source_conn_info, target_conn_info=self.target.conn_info()).migrate() + end_time = time.monotonic() + assert end_time - start_time < 2 + + def test_migrate_connect_timeout_environment(self): + start_time = time.monotonic() + original_timeout = os.environ.get("PGCONNECT_TIMEOUT") + try: + os.environ["PGCONNECT_TIMEOUT"] = "1" + with pytest.raises(TimeoutError): + PGMigrate(source_conn_info="host=example.org", target_conn_info=self.target.conn_info()).migrate() + end_time = time.monotonic() + assert end_time - start_time < 2 + finally: + if original_timeout is not None: + os.environ["PGCONNECT_TIMEOUT"] = original_timeout + def test_migrate_same_server(self): source_conn_info = target_conn_info = self.target.conn_info() with pytest.raises(PGMigrateValidationFailedError) as err: