Skip to content

Commit 0ea34f6

Browse files
committed
fix: better handling for when c extension isn't available for mysql connector/python instead of erroring out
1 parent 74ae5de commit 0ea34f6

File tree

1 file changed

+204
-185
lines changed

1 file changed

+204
-185
lines changed
Lines changed: 204 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -1,185 +1,204 @@
1-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License").
4-
# You may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
15-
from __future__ import annotations
16-
17-
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Set
18-
19-
if TYPE_CHECKING:
20-
from aws_advanced_python_wrapper.hostinfo import HostInfo
21-
from aws_advanced_python_wrapper.pep249 import Connection
22-
23-
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
24-
from inspect import signature
25-
26-
from mysql.connector import CMySQLConnection, MySQLConnection
27-
from mysql.connector.cursor import MySQLCursor
28-
from mysql.connector.cursor_cext import CMySQLCursor
29-
30-
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
31-
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
32-
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
33-
from aws_advanced_python_wrapper.utils.decorators import timeout
34-
from aws_advanced_python_wrapper.utils.messages import Messages
35-
from aws_advanced_python_wrapper.utils.properties import (Properties,
36-
PropertiesUtils,
37-
WrapperProperties)
38-
39-
40-
class MySQLDriverDialect(DriverDialect):
41-
_driver_name = "MySQL Connector Python"
42-
TARGET_DRIVER_CODE = "MySQL"
43-
AUTH_PLUGIN_PARAM = "auth_plugin"
44-
AUTH_METHOD = "mysql_clear_password"
45-
IS_CLOSED_TIMEOUT_SEC = 3
46-
47-
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MySQLDriverDialectExecutor")
48-
49-
_dialect_code: str = DriverDialectCodes.MYSQL_CONNECTOR_PYTHON
50-
_network_bound_methods: Set[str] = {
51-
"Connection.commit",
52-
"Connection.autocommit",
53-
"Connection.autocommit_setter",
54-
"Connection.is_read_only",
55-
"Connection.set_read_only",
56-
"Connection.rollback",
57-
"Connection.cursor",
58-
"Cursor.close",
59-
"Cursor.execute",
60-
"Cursor.fetchone",
61-
"Cursor.fetchmany",
62-
"Cursor.fetchall"
63-
}
64-
65-
def is_dialect(self, connect_func: Callable) -> bool:
66-
if MySQLDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)):
67-
return MySQLDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower()
68-
return True
69-
70-
def is_closed(self, conn: Connection) -> bool:
71-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
72-
73-
# is_connected validates the connection using a ping().
74-
# If there are any unread results from previous executions an error will be thrown.
75-
if self.can_execute_query(conn):
76-
socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props)
77-
timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC
78-
is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected)
79-
80-
try:
81-
return not is_connected_with_timeout()
82-
except TimeoutError:
83-
return False
84-
return False
85-
86-
raise UnsupportedOperationError(Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "is_connected"))
87-
88-
def get_autocommit(self, conn: Connection) -> bool:
89-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
90-
return conn.autocommit
91-
92-
raise UnsupportedOperationError(
93-
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit"))
94-
95-
def set_autocommit(self, conn: Connection, autocommit: bool):
96-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
97-
conn.autocommit = autocommit
98-
return
99-
100-
raise UnsupportedOperationError(
101-
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit"))
102-
103-
def set_password(self, props: Properties, pwd: str):
104-
WrapperProperties.PASSWORD.set(props, pwd)
105-
props[MySQLDriverDialect.AUTH_PLUGIN_PARAM] = MySQLDriverDialect.AUTH_METHOD
106-
107-
def abort_connection(self, conn: Connection):
108-
raise UnsupportedOperationError(
109-
Messages.get_formatted(
110-
"DriverDialect.UnsupportedOperationError",
111-
self._driver_name,
112-
"abort_connection"))
113-
114-
def can_execute_query(self, conn: Connection) -> bool:
115-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
116-
if conn.unread_result:
117-
return conn.can_consume_results
118-
return True
119-
120-
def is_in_transaction(self, conn: Connection) -> bool:
121-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
122-
return bool(conn.in_transaction)
123-
124-
raise UnsupportedOperationError(
125-
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name,
126-
"in_transaction"))
127-
128-
def get_connection_from_obj(self, obj: object) -> Any:
129-
if isinstance(obj, CMySQLConnection) or isinstance(obj, MySQLConnection):
130-
return obj
131-
132-
if isinstance(obj, CMySQLCursor):
133-
try:
134-
conn = None
135-
136-
if hasattr(obj, '_cnx'):
137-
conn = obj._cnx
138-
elif hasattr(obj, '_connection'):
139-
conn = obj._connection
140-
if conn is None:
141-
return None
142-
143-
if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection):
144-
return conn
145-
146-
except ReferenceError:
147-
return None
148-
149-
if isinstance(obj, MySQLCursor):
150-
try:
151-
if isinstance(obj._connection, CMySQLConnection) or isinstance(obj._connection, MySQLConnection):
152-
return obj._connection
153-
except ReferenceError:
154-
return None
155-
156-
return None
157-
158-
def transfer_session_state(self, from_conn: Connection, to_conn: Connection):
159-
if (isinstance(from_conn, CMySQLConnection) or isinstance(from_conn, MySQLConnection)) and (
160-
isinstance(to_conn, CMySQLConnection) or isinstance(to_conn, MySQLConnection)):
161-
to_conn.autocommit = from_conn.autocommit
162-
163-
def ping(self, conn: Connection) -> bool:
164-
return not self.is_closed(conn)
165-
166-
def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) -> Properties:
167-
driver_props: Properties = Properties(original_props.copy())
168-
PropertiesUtils.remove_wrapper_props(driver_props)
169-
170-
driver_props["host"] = host_info.host
171-
if host_info.is_port_specified():
172-
driver_props["port"] = str(host_info.port)
173-
174-
db = WrapperProperties.DATABASE.get(original_props)
175-
if db is not None:
176-
driver_props["database"] = db
177-
178-
connect_timeout = WrapperProperties.CONNECT_TIMEOUT_SEC.get(original_props)
179-
if connect_timeout is not None:
180-
driver_props["connect_timeout"] = connect_timeout
181-
182-
return driver_props
183-
184-
def supports_connect_timeout(self) -> bool:
185-
return True
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Set
18+
19+
if TYPE_CHECKING:
20+
from aws_advanced_python_wrapper.hostinfo import HostInfo
21+
from aws_advanced_python_wrapper.pep249 import Connection
22+
23+
from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError
24+
from inspect import signature
25+
26+
CMYSQL_ENABLED = False
27+
28+
from mysql.connector import MySQLConnection
29+
from mysql.connector.cursor import MySQLCursor
30+
31+
try:
32+
from mysql.connector import CMySQLConnection
33+
from mysql.connector.cursor_cext import CMySQLCursor
34+
35+
CMYSQL_ENABLED = True
36+
37+
except ImportError as exc:
38+
# Do nothing
39+
pass
40+
41+
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
42+
from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes
43+
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
44+
from aws_advanced_python_wrapper.utils.decorators import timeout
45+
from aws_advanced_python_wrapper.utils.messages import Messages
46+
from aws_advanced_python_wrapper.utils.properties import (Properties,
47+
PropertiesUtils,
48+
WrapperProperties)
49+
50+
51+
class MySQLDriverDialect(DriverDialect):
52+
_driver_name = "MySQL Connector Python"
53+
TARGET_DRIVER_CODE = "MySQL"
54+
AUTH_PLUGIN_PARAM = "auth_plugin"
55+
AUTH_METHOD = "mysql_clear_password"
56+
IS_CLOSED_TIMEOUT_SEC = 3
57+
58+
_executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MySQLDriverDialectExecutor")
59+
60+
_dialect_code: str = DriverDialectCodes.MYSQL_CONNECTOR_PYTHON
61+
_network_bound_methods: Set[str] = {
62+
"Connection.commit",
63+
"Connection.autocommit",
64+
"Connection.autocommit_setter",
65+
"Connection.is_read_only",
66+
"Connection.set_read_only",
67+
"Connection.rollback",
68+
"Connection.cursor",
69+
"Cursor.close",
70+
"Cursor.execute",
71+
"Cursor.fetchone",
72+
"Cursor.fetchmany",
73+
"Cursor.fetchall"
74+
}
75+
76+
@staticmethod
77+
def _is_mysql_connection(conn: Connection | object) -> bool:
78+
return isinstance(conn, MySQLConnection) or (CMYSQL_ENABLED and isinstance(conn, CMySQLConnection))
79+
80+
@staticmethod
81+
def _is_cmysql_cursor(obj: object) -> bool:
82+
return CMYSQL_ENABLED and isinstance(obj, CMySQLCursor)
83+
84+
def is_dialect(self, connect_func: Callable) -> bool:
85+
if MySQLDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)):
86+
return MySQLDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower()
87+
return True
88+
89+
def is_closed(self, conn: Connection) -> bool:
90+
if MySQLDriverDialect._is_mysql_connection(conn):
91+
92+
# is_connected validates the connection using a ping().
93+
# If there are any unread results from previous executions an error will be thrown.
94+
if self.can_execute_query(conn):
95+
socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props)
96+
timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC
97+
is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected)
98+
99+
try:
100+
return not is_connected_with_timeout()
101+
except TimeoutError:
102+
return False
103+
return False
104+
105+
raise UnsupportedOperationError(Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "is_connected"))
106+
107+
def get_autocommit(self, conn: Connection) -> bool:
108+
if MySQLDriverDialect._is_mysql_connection(conn):
109+
return conn.autocommit
110+
111+
raise UnsupportedOperationError(
112+
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit"))
113+
114+
def set_autocommit(self, conn: Connection, autocommit: bool):
115+
if MySQLDriverDialect._is_mysql_connection(conn):
116+
conn.autocommit = autocommit
117+
return
118+
119+
raise UnsupportedOperationError(
120+
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit"))
121+
122+
def set_password(self, props: Properties, pwd: str):
123+
WrapperProperties.PASSWORD.set(props, pwd)
124+
props[MySQLDriverDialect.AUTH_PLUGIN_PARAM] = MySQLDriverDialect.AUTH_METHOD
125+
126+
def abort_connection(self, conn: Connection):
127+
raise UnsupportedOperationError(
128+
Messages.get_formatted(
129+
"DriverDialect.UnsupportedOperationError",
130+
self._driver_name,
131+
"abort_connection"))
132+
133+
def can_execute_query(self, conn: Connection) -> bool:
134+
if MySQLDriverDialect._is_mysql_connection(conn):
135+
if conn.unread_result:
136+
return conn.can_consume_results
137+
return True
138+
139+
def is_in_transaction(self, conn: Connection) -> bool:
140+
if MySQLDriverDialect._is_mysql_connection(conn):
141+
return bool(conn.in_transaction)
142+
143+
raise UnsupportedOperationError(
144+
Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name,
145+
"in_transaction"))
146+
147+
def get_connection_from_obj(self, obj: object) -> Any:
148+
if MySQLDriverDialect._is_mysql_connection(obj):
149+
return obj
150+
151+
if MySQLDriverDialect._is_cmysql_cursor(obj):
152+
try:
153+
conn = None
154+
155+
if hasattr(obj, '_cnx'):
156+
conn = obj._cnx
157+
elif hasattr(obj, '_connection'):
158+
conn = obj._connection
159+
if conn is None:
160+
return None
161+
162+
if MySQLDriverDialect._is_mysql_connection(conn):
163+
return conn
164+
165+
except ReferenceError:
166+
return None
167+
168+
if isinstance(obj, MySQLCursor):
169+
try:
170+
if MySQLDriverDialect._is_mysql_connection(obj._connection):
171+
return obj._connection
172+
except ReferenceError:
173+
return None
174+
175+
return None
176+
177+
def transfer_session_state(self, from_conn: Connection, to_conn: Connection):
178+
if (isinstance(from_conn, CMySQLConnection) or isinstance(from_conn, MySQLConnection)) and (
179+
isinstance(to_conn, CMySQLConnection) or isinstance(to_conn, MySQLConnection)):
180+
to_conn.autocommit = from_conn.autocommit
181+
182+
def ping(self, conn: Connection) -> bool:
183+
return not self.is_closed(conn)
184+
185+
def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) -> Properties:
186+
driver_props: Properties = Properties(original_props.copy())
187+
PropertiesUtils.remove_wrapper_props(driver_props)
188+
189+
driver_props["host"] = host_info.host
190+
if host_info.is_port_specified():
191+
driver_props["port"] = str(host_info.port)
192+
193+
db = WrapperProperties.DATABASE.get(original_props)
194+
if db is not None:
195+
driver_props["database"] = db
196+
197+
connect_timeout = WrapperProperties.CONNECT_TIMEOUT_SEC.get(original_props)
198+
if connect_timeout is not None:
199+
driver_props["connect_timeout"] = connect_timeout
200+
201+
return driver_props
202+
203+
def supports_connect_timeout(self) -> bool:
204+
return True

0 commit comments

Comments
 (0)