Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions aws_advanced_python_wrapper/mysql_driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
from inspect import signature

from mysql.connector import CMySQLConnection, MySQLConnection
from mysql.connector.cursor import MySQLCursor
from mysql.connector.cursor_cext import CMySQLCursor

from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
Expand All @@ -36,6 +32,21 @@
PropertiesUtils,
WrapperProperties)

CMYSQL_ENABLED = False

from mysql.connector import MySQLConnection # noqa: E402
from mysql.connector.cursor import MySQLCursor # noqa: E402

try:
from mysql.connector import CMySQLConnection # noqa: E402
from mysql.connector.cursor_cext import CMySQLCursor # noqa: E402

CMYSQL_ENABLED = True

except ImportError:
# Do nothing
pass


class MySQLDriverDialect(DriverDialect):
_driver_name = "MySQL Connector Python"
Expand All @@ -62,20 +73,28 @@ class MySQLDriverDialect(DriverDialect):
"Cursor.fetchall"
}

@staticmethod
def _is_mysql_connection(conn: Connection | object) -> bool:
return isinstance(conn, MySQLConnection) or (CMYSQL_ENABLED and isinstance(conn, CMySQLConnection))

@staticmethod
def _is_cmysql_cursor(obj: object) -> bool:
return CMYSQL_ENABLED and isinstance(obj, CMySQLCursor)

def is_dialect(self, connect_func: Callable) -> bool:
if MySQLDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)):
return MySQLDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower()
return True

def is_closed(self, conn: Connection) -> bool:
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
if MySQLDriverDialect._is_mysql_connection(conn):

# is_connected validates the connection using a ping().
# If there are any unread results from previous executions an error will be thrown.
if self.can_execute_query(conn):
socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props)
timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC
is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected)
is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected) # type: ignore

try:
return not is_connected_with_timeout()
Expand All @@ -86,15 +105,15 @@ def is_closed(self, conn: Connection) -> bool:
raise UnsupportedOperationError(Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "is_connected"))

def get_autocommit(self, conn: Connection) -> bool:
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
return conn.autocommit
if MySQLDriverDialect._is_mysql_connection(conn):
return conn.autocommit # type: ignore

raise UnsupportedOperationError(
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit"))

def set_autocommit(self, conn: Connection, autocommit: bool):
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
conn.autocommit = autocommit
if MySQLDriverDialect._is_mysql_connection(conn):
conn.autocommit = autocommit # type: ignore
return

raise UnsupportedOperationError(
Expand All @@ -112,24 +131,24 @@ def abort_connection(self, conn: Connection):
"abort_connection"))

def can_execute_query(self, conn: Connection) -> bool:
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
if conn.unread_result:
return conn.can_consume_results
if MySQLDriverDialect._is_mysql_connection(conn):
if conn.unread_result: # type: ignore
return conn.can_consume_results # type: ignore
return True

def is_in_transaction(self, conn: Connection) -> bool:
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
return bool(conn.in_transaction)
if MySQLDriverDialect._is_mysql_connection(conn):
return bool(conn.in_transaction) # type: ignore

raise UnsupportedOperationError(
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name,
"in_transaction"))

def get_connection_from_obj(self, obj: object) -> Any:
if isinstance(obj, CMySQLConnection) or isinstance(obj, MySQLConnection):
if MySQLDriverDialect._is_mysql_connection(obj):
return obj

if isinstance(obj, CMySQLCursor):
if MySQLDriverDialect._is_cmysql_cursor(obj):
try:
conn = None

Expand All @@ -140,25 +159,24 @@ def get_connection_from_obj(self, obj: object) -> Any:
if conn is None:
return None

if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
if MySQLDriverDialect._is_mysql_connection(conn):
return conn

except ReferenceError:
return None

if isinstance(obj, MySQLCursor):
try:
if isinstance(obj._connection, CMySQLConnection) or isinstance(obj._connection, MySQLConnection):
if MySQLDriverDialect._is_mysql_connection(obj._connection):
return obj._connection
except ReferenceError:
return None

return None

def transfer_session_state(self, from_conn: Connection, to_conn: Connection):
if (isinstance(from_conn, CMySQLConnection) or isinstance(from_conn, MySQLConnection)) and (
isinstance(to_conn, CMySQLConnection) or isinstance(to_conn, MySQLConnection)):
to_conn.autocommit = from_conn.autocommit
if MySQLDriverDialect._is_mysql_connection(from_conn) and MySQLDriverDialect._is_mysql_connection(to_conn):
to_conn.autocommit = from_conn.autocommit # type: ignore

def ping(self, conn: Connection) -> bool:
return not self.is_closed(conn)
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/container/test_blue_green_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

import mysql.connector
import psycopg
from mysql.connector import CMySQLConnection, MySQLConnection

from aws_advanced_python_wrapper.mysql_driver_dialect import MySQLDriverDialect
from aws_advanced_python_wrapper.pg_driver_dialect import PgDriverDialect
Expand Down Expand Up @@ -458,7 +457,7 @@ def close_connection(self, conn: Optional[Connection]):
def is_closed(self, conn: Connection) -> bool:
if isinstance(conn, psycopg.Connection):
return self.pg_dialect.is_closed(conn)
elif isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
elif MySQLDriverDialect._is_mysql_connection(conn):
return self.mysql_dialect.is_closed(conn)
elif isinstance(conn, AwsWrapperConnection):
return conn.is_closed
Expand Down