Skip to content

Commit ac20df5

Browse files
committed
fix: better handling for when c extension isn't available for mysql connector/python instead of erroring out
1 parent 74ae5de commit ac20df5

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

aws_advanced_python_wrapper/mysql_driver_dialect.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,20 @@
2323
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
2424
from inspect import signature
2525

26-
from mysql.connector import CMySQLConnection, MySQLConnection
26+
CMYSQL_ENABLED = False
27+
28+
from mysql.connector import MySQLConnection
2729
from mysql.connector.cursor import MySQLCursor
28-
from mysql.connector.cursor_cext import CMySQLCursor
30+
31+
try:
32+
from mysql.connector import CMySQLConnection
33+
from mysql.connector.cursor_cext import CMySQLCursor
34+
35+
CMYSQL_ENABLED = True
36+
37+
except ImportError as exc:
38+
# Do nothing
39+
pass
2940

3041
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
3142
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
@@ -62,13 +73,21 @@ class MySQLDriverDialect(DriverDialect):
6273
"Cursor.fetchall"
6374
}
6475

76+
@staticmethod
77+
def _is_mysql_connection(conn: Connection | object) -> bool:
78+
return isinstance(conn, MySQLConnection) or (CMYSQL_ENABLED and isinstance(conn, CMySQLConnection))
79+
80+
@staticmethod
81+
def _is_cmysql_cursor(obj: object) -> bool:
82+
return CMYSQL_ENABLED and isinstance(obj, CMySQLCursor)
83+
6584
def is_dialect(self, connect_func: Callable) -> bool:
6685
if MySQLDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)):
6786
return MySQLDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower()
6887
return True
6988

7089
def is_closed(self, conn: Connection) -> bool:
71-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
90+
if MySQLDriverDialect._is_mysql_connection(conn):
7291

7392
# is_connected validates the connection using a ping().
7493
# If there are any unread results from previous executions an error will be thrown.
@@ -86,14 +105,14 @@ def is_closed(self, conn: Connection) -> bool:
86105
raise UnsupportedOperationError(Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "is_connected"))
87106

88107
def get_autocommit(self, conn: Connection) -> bool:
89-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
108+
if MySQLDriverDialect._is_mysql_connection(conn):
90109
return conn.autocommit
91110

92111
raise UnsupportedOperationError(
93112
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit"))
94113

95114
def set_autocommit(self, conn: Connection, autocommit: bool):
96-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
115+
if MySQLDriverDialect._is_mysql_connection(conn):
97116
conn.autocommit = autocommit
98117
return
99118

@@ -112,24 +131,24 @@ def abort_connection(self, conn: Connection):
112131
"abort_connection"))
113132

114133
def can_execute_query(self, conn: Connection) -> bool:
115-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
134+
if MySQLDriverDialect._is_mysql_connection(conn):
116135
if conn.unread_result:
117136
return conn.can_consume_results
118137
return True
119138

120139
def is_in_transaction(self, conn: Connection) -> bool:
121-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
140+
if MySQLDriverDialect._is_mysql_connection(conn):
122141
return bool(conn.in_transaction)
123142

124143
raise UnsupportedOperationError(
125144
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name,
126145
"in_transaction"))
127146

128147
def get_connection_from_obj(self, obj: object) -> Any:
129-
if isinstance(obj, CMySQLConnection) or isinstance(obj, MySQLConnection):
148+
if MySQLDriverDialect._is_mysql_connection(obj):
130149
return obj
131150

132-
if isinstance(obj, CMySQLCursor):
151+
if MySQLDriverDialect._is_cmysql_cursor(obj):
133152
try:
134153
conn = None
135154

@@ -140,15 +159,15 @@ def get_connection_from_obj(self, obj: object) -> Any:
140159
if conn is None:
141160
return None
142161

143-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
162+
if MySQLDriverDialect._is_mysql_connection(conn):
144163
return conn
145164

146165
except ReferenceError:
147166
return None
148167

149168
if isinstance(obj, MySQLCursor):
150169
try:
151-
if isinstance(obj._connection, CMySQLConnection) or isinstance(obj._connection, MySQLConnection):
170+
if MySQLDriverDialect._is_mysql_connection(obj._connection):
152171
return obj._connection
153172
except ReferenceError:
154173
return None

0 commit comments

Comments
 (0)