diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index 76a6ff11..dd7055c5 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -23,10 +23,6 @@ from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError from inspect import signature -from mysql.connector import CMySQLConnection, MySQLConnection -from mysql.connector.cursor import MySQLCursor -from mysql.connector.cursor_cext import CMySQLCursor - from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes from aws_advanced_python_wrapper.errors import UnsupportedOperationError @@ -36,6 +32,21 @@ PropertiesUtils, WrapperProperties) +CMYSQL_ENABLED = False + +from mysql.connector import MySQLConnection # noqa: E402 +from mysql.connector.cursor import MySQLCursor # noqa: E402 + +try: + from mysql.connector import CMySQLConnection # noqa: E402 + from mysql.connector.cursor_cext import CMySQLCursor # noqa: E402 + + CMYSQL_ENABLED = True + +except ImportError: + # Do nothing + pass + class MySQLDriverDialect(DriverDialect): _driver_name = "MySQL Connector Python" @@ -62,20 +73,28 @@ class MySQLDriverDialect(DriverDialect): "Cursor.fetchall" } + @staticmethod + def _is_mysql_connection(conn: Connection | object) -> bool: + return isinstance(conn, MySQLConnection) or (CMYSQL_ENABLED and isinstance(conn, CMySQLConnection)) + + @staticmethod + def _is_cmysql_cursor(obj: object) -> bool: + return CMYSQL_ENABLED and isinstance(obj, CMySQLCursor) + def is_dialect(self, connect_func: Callable) -> bool: if MySQLDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)): return MySQLDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower() return True def is_closed(self, conn: Connection) -> bool: - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): + if MySQLDriverDialect._is_mysql_connection(conn): # is_connected validates the connection using a ping(). # If there are any unread results from previous executions an error will be thrown. if self.can_execute_query(conn): socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props) timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC - is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected) + is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected) # type: ignore try: return not is_connected_with_timeout() @@ -86,15 +105,15 @@ def is_closed(self, conn: Connection) -> bool: raise UnsupportedOperationError(Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "is_connected")) def get_autocommit(self, conn: Connection) -> bool: - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): - return conn.autocommit + if MySQLDriverDialect._is_mysql_connection(conn): + return conn.autocommit # type: ignore raise UnsupportedOperationError( Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit")) def set_autocommit(self, conn: Connection, autocommit: bool): - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): - conn.autocommit = autocommit + if MySQLDriverDialect._is_mysql_connection(conn): + conn.autocommit = autocommit # type: ignore return raise UnsupportedOperationError( @@ -112,24 +131,24 @@ def abort_connection(self, conn: Connection): "abort_connection")) def can_execute_query(self, conn: Connection) -> bool: - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): - if conn.unread_result: - return conn.can_consume_results + if MySQLDriverDialect._is_mysql_connection(conn): + if conn.unread_result: # type: ignore + return conn.can_consume_results # type: ignore return True def is_in_transaction(self, conn: Connection) -> bool: - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): - return bool(conn.in_transaction) + if MySQLDriverDialect._is_mysql_connection(conn): + return bool(conn.in_transaction) # type: ignore raise UnsupportedOperationError( Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "in_transaction")) def get_connection_from_obj(self, obj: object) -> Any: - if isinstance(obj, CMySQLConnection) or isinstance(obj, MySQLConnection): + if MySQLDriverDialect._is_mysql_connection(obj): return obj - if isinstance(obj, CMySQLCursor): + if MySQLDriverDialect._is_cmysql_cursor(obj): try: conn = None @@ -140,7 +159,7 @@ def get_connection_from_obj(self, obj: object) -> Any: if conn is None: return None - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): + if MySQLDriverDialect._is_mysql_connection(conn): return conn except ReferenceError: @@ -148,7 +167,7 @@ def get_connection_from_obj(self, obj: object) -> Any: if isinstance(obj, MySQLCursor): try: - if isinstance(obj._connection, CMySQLConnection) or isinstance(obj._connection, MySQLConnection): + if MySQLDriverDialect._is_mysql_connection(obj._connection): return obj._connection except ReferenceError: return None @@ -156,9 +175,8 @@ def get_connection_from_obj(self, obj: object) -> Any: return None def transfer_session_state(self, from_conn: Connection, to_conn: Connection): - if (isinstance(from_conn, CMySQLConnection) or isinstance(from_conn, MySQLConnection)) and ( - isinstance(to_conn, CMySQLConnection) or isinstance(to_conn, MySQLConnection)): - to_conn.autocommit = from_conn.autocommit + if MySQLDriverDialect._is_mysql_connection(from_conn) and MySQLDriverDialect._is_mysql_connection(to_conn): + to_conn.autocommit = from_conn.autocommit # type: ignore def ping(self, conn: Connection) -> bool: return not self.is_closed(conn) diff --git a/tests/integration/container/test_blue_green_deployment.py b/tests/integration/container/test_blue_green_deployment.py index 6be6c7f3..8a732642 100644 --- a/tests/integration/container/test_blue_green_deployment.py +++ b/tests/integration/container/test_blue_green_deployment.py @@ -30,7 +30,6 @@ import mysql.connector import psycopg -from mysql.connector import CMySQLConnection, MySQLConnection from aws_advanced_python_wrapper.mysql_driver_dialect import MySQLDriverDialect from aws_advanced_python_wrapper.pg_driver_dialect import PgDriverDialect @@ -458,7 +457,7 @@ def close_connection(self, conn: Optional[Connection]): def is_closed(self, conn: Connection) -> bool: if isinstance(conn, psycopg.Connection): return self.pg_dialect.is_closed(conn) - elif isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): + elif MySQLDriverDialect._is_mysql_connection(conn): return self.mysql_dialect.is_closed(conn) elif isinstance(conn, AwsWrapperConnection): return conn.is_closed