|
1 | | -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
2 | | -# |
3 | | -# Licensed under the Apache License, Version 2.0 (the "License"). |
4 | | -# You may not use this file except in compliance with the License. |
5 | | -# You may obtain a copy of the License at |
6 | | -# |
7 | | -# http://www.apache.org/licenses/LICENSE-2.0 |
8 | | -# |
9 | | -# Unless required by applicable law or agreed to in writing, software |
10 | | -# distributed under the License is distributed on an "AS IS" BASIS, |
11 | | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | -# See the License for the specific language governing permissions and |
13 | | -# limitations under the License. |
14 | | - |
15 | | -from __future__ import annotations |
16 | | - |
17 | | -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Set |
18 | | - |
19 | | -if TYPE_CHECKING: |
20 | | - from aws_advanced_python_wrapper.hostinfo import HostInfo |
21 | | - from aws_advanced_python_wrapper.pep249 import Connection |
22 | | - |
23 | | -from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError |
24 | | -from inspect import signature |
25 | | - |
26 | | -from mysql.connector import CMySQLConnection, MySQLConnection |
27 | | -from mysql.connector.cursor import MySQLCursor |
28 | | -from mysql.connector.cursor_cext import CMySQLCursor |
29 | | - |
30 | | -from aws_advanced_python_wrapper.driver_dialect import DriverDialect |
31 | | -from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes |
32 | | -from aws_advanced_python_wrapper.errors import UnsupportedOperationError |
33 | | -from aws_advanced_python_wrapper.utils.decorators import timeout |
34 | | -from aws_advanced_python_wrapper.utils.messages import Messages |
35 | | -from aws_advanced_python_wrapper.utils.properties import (Properties, |
36 | | - PropertiesUtils, |
37 | | - WrapperProperties) |
38 | | - |
39 | | - |
40 | | -class MySQLDriverDialect(DriverDialect): |
41 | | - _driver_name = "MySQL Connector Python" |
42 | | - TARGET_DRIVER_CODE = "MySQL" |
43 | | - AUTH_PLUGIN_PARAM = "auth_plugin" |
44 | | - AUTH_METHOD = "mysql_clear_password" |
45 | | - IS_CLOSED_TIMEOUT_SEC = 3 |
46 | | - |
47 | | - _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MySQLDriverDialectExecutor") |
48 | | - |
49 | | - _dialect_code: str = DriverDialectCodes.MYSQL_CONNECTOR_PYTHON |
50 | | - _network_bound_methods: Set[str] = { |
51 | | - "Connection.commit", |
52 | | - "Connection.autocommit", |
53 | | - "Connection.autocommit_setter", |
54 | | - "Connection.is_read_only", |
55 | | - "Connection.set_read_only", |
56 | | - "Connection.rollback", |
57 | | - "Connection.cursor", |
58 | | - "Cursor.close", |
59 | | - "Cursor.execute", |
60 | | - "Cursor.fetchone", |
61 | | - "Cursor.fetchmany", |
62 | | - "Cursor.fetchall" |
63 | | - } |
64 | | - |
65 | | - def is_dialect(self, connect_func: Callable) -> bool: |
66 | | - if MySQLDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)): |
67 | | - return MySQLDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower() |
68 | | - return True |
69 | | - |
70 | | - def is_closed(self, conn: Connection) -> bool: |
71 | | - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): |
72 | | - |
73 | | - # is_connected validates the connection using a ping(). |
74 | | - # If there are any unread results from previous executions an error will be thrown. |
75 | | - if self.can_execute_query(conn): |
76 | | - socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props) |
77 | | - timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC |
78 | | - is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected) |
79 | | - |
80 | | - try: |
81 | | - return not is_connected_with_timeout() |
82 | | - except TimeoutError: |
83 | | - return False |
84 | | - return False |
85 | | - |
86 | | - raise UnsupportedOperationError(Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "is_connected")) |
87 | | - |
88 | | - def get_autocommit(self, conn: Connection) -> bool: |
89 | | - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): |
90 | | - return conn.autocommit |
91 | | - |
92 | | - raise UnsupportedOperationError( |
93 | | - Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit")) |
94 | | - |
95 | | - def set_autocommit(self, conn: Connection, autocommit: bool): |
96 | | - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): |
97 | | - conn.autocommit = autocommit |
98 | | - return |
99 | | - |
100 | | - raise UnsupportedOperationError( |
101 | | - Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit")) |
102 | | - |
103 | | - def set_password(self, props: Properties, pwd: str): |
104 | | - WrapperProperties.PASSWORD.set(props, pwd) |
105 | | - props[MySQLDriverDialect.AUTH_PLUGIN_PARAM] = MySQLDriverDialect.AUTH_METHOD |
106 | | - |
107 | | - def abort_connection(self, conn: Connection): |
108 | | - raise UnsupportedOperationError( |
109 | | - Messages.get_formatted( |
110 | | - "DriverDialect.UnsupportedOperationError", |
111 | | - self._driver_name, |
112 | | - "abort_connection")) |
113 | | - |
114 | | - def can_execute_query(self, conn: Connection) -> bool: |
115 | | - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): |
116 | | - if conn.unread_result: |
117 | | - return conn.can_consume_results |
118 | | - return True |
119 | | - |
120 | | - def is_in_transaction(self, conn: Connection) -> bool: |
121 | | - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): |
122 | | - return bool(conn.in_transaction) |
123 | | - |
124 | | - raise UnsupportedOperationError( |
125 | | - Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, |
126 | | - "in_transaction")) |
127 | | - |
128 | | - def get_connection_from_obj(self, obj: object) -> Any: |
129 | | - if isinstance(obj, CMySQLConnection) or isinstance(obj, MySQLConnection): |
130 | | - return obj |
131 | | - |
132 | | - if isinstance(obj, CMySQLCursor): |
133 | | - try: |
134 | | - conn = None |
135 | | - |
136 | | - if hasattr(obj, '_cnx'): |
137 | | - conn = obj._cnx |
138 | | - elif hasattr(obj, '_connection'): |
139 | | - conn = obj._connection |
140 | | - if conn is None: |
141 | | - return None |
142 | | - |
143 | | - if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): |
144 | | - return conn |
145 | | - |
146 | | - except ReferenceError: |
147 | | - return None |
148 | | - |
149 | | - if isinstance(obj, MySQLCursor): |
150 | | - try: |
151 | | - if isinstance(obj._connection, CMySQLConnection) or isinstance(obj._connection, MySQLConnection): |
152 | | - return obj._connection |
153 | | - except ReferenceError: |
154 | | - return None |
155 | | - |
156 | | - return None |
157 | | - |
158 | | - def transfer_session_state(self, from_conn: Connection, to_conn: Connection): |
159 | | - if (isinstance(from_conn, CMySQLConnection) or isinstance(from_conn, MySQLConnection)) and ( |
160 | | - isinstance(to_conn, CMySQLConnection) or isinstance(to_conn, MySQLConnection)): |
161 | | - to_conn.autocommit = from_conn.autocommit |
162 | | - |
163 | | - def ping(self, conn: Connection) -> bool: |
164 | | - return not self.is_closed(conn) |
165 | | - |
166 | | - def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) -> Properties: |
167 | | - driver_props: Properties = Properties(original_props.copy()) |
168 | | - PropertiesUtils.remove_wrapper_props(driver_props) |
169 | | - |
170 | | - driver_props["host"] = host_info.host |
171 | | - if host_info.is_port_specified(): |
172 | | - driver_props["port"] = str(host_info.port) |
173 | | - |
174 | | - db = WrapperProperties.DATABASE.get(original_props) |
175 | | - if db is not None: |
176 | | - driver_props["database"] = db |
177 | | - |
178 | | - connect_timeout = WrapperProperties.CONNECT_TIMEOUT_SEC.get(original_props) |
179 | | - if connect_timeout is not None: |
180 | | - driver_props["connect_timeout"] = connect_timeout |
181 | | - |
182 | | - return driver_props |
183 | | - |
184 | | - def supports_connect_timeout(self) -> bool: |
185 | | - return True |
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). |
| 4 | +# You may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Set |
| 18 | + |
| 19 | +if TYPE_CHECKING: |
| 20 | + from aws_advanced_python_wrapper.hostinfo import HostInfo |
| 21 | + from aws_advanced_python_wrapper.pep249 import Connection |
| 22 | + |
| 23 | +from concurrent.futures import Executor, ThreadPoolExecutor, TimeoutError |
| 24 | +from inspect import signature |
| 25 | + |
| 26 | +CMYSQL_ENABLED = False |
| 27 | + |
| 28 | +from mysql.connector import MySQLConnection |
| 29 | +from mysql.connector.cursor import MySQLCursor |
| 30 | + |
| 31 | +try: |
| 32 | + from mysql.connector import CMySQLConnection |
| 33 | + from mysql.connector.cursor_cext import CMySQLCursor |
| 34 | + |
| 35 | + CMYSQL_ENABLED = True |
| 36 | + |
| 37 | +except ImportError as exc: |
| 38 | + # Do nothing |
| 39 | + pass |
| 40 | + |
| 41 | +from aws_advanced_python_wrapper.driver_dialect import DriverDialect |
| 42 | +from aws_advanced_python_wrapper.driver_dialect_codes import DriverDialectCodes |
| 43 | +from aws_advanced_python_wrapper.errors import UnsupportedOperationError |
| 44 | +from aws_advanced_python_wrapper.utils.decorators import timeout |
| 45 | +from aws_advanced_python_wrapper.utils.messages import Messages |
| 46 | +from aws_advanced_python_wrapper.utils.properties import (Properties, |
| 47 | + PropertiesUtils, |
| 48 | + WrapperProperties) |
| 49 | + |
| 50 | + |
| 51 | +class MySQLDriverDialect(DriverDialect): |
| 52 | + _driver_name = "MySQL Connector Python" |
| 53 | + TARGET_DRIVER_CODE = "MySQL" |
| 54 | + AUTH_PLUGIN_PARAM = "auth_plugin" |
| 55 | + AUTH_METHOD = "mysql_clear_password" |
| 56 | + IS_CLOSED_TIMEOUT_SEC = 3 |
| 57 | + |
| 58 | + _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="MySQLDriverDialectExecutor") |
| 59 | + |
| 60 | + _dialect_code: str = DriverDialectCodes.MYSQL_CONNECTOR_PYTHON |
| 61 | + _network_bound_methods: Set[str] = { |
| 62 | + "Connection.commit", |
| 63 | + "Connection.autocommit", |
| 64 | + "Connection.autocommit_setter", |
| 65 | + "Connection.is_read_only", |
| 66 | + "Connection.set_read_only", |
| 67 | + "Connection.rollback", |
| 68 | + "Connection.cursor", |
| 69 | + "Cursor.close", |
| 70 | + "Cursor.execute", |
| 71 | + "Cursor.fetchone", |
| 72 | + "Cursor.fetchmany", |
| 73 | + "Cursor.fetchall" |
| 74 | + } |
| 75 | + |
| 76 | + @staticmethod |
| 77 | + def _is_mysql_connection(conn: Connection | object) -> bool: |
| 78 | + return isinstance(conn, MySQLConnection) or (CMYSQL_ENABLED and isinstance(conn, CMySQLConnection)) |
| 79 | + |
| 80 | + @staticmethod |
| 81 | + def _is_cmysql_cursor(obj: object) -> bool: |
| 82 | + return CMYSQL_ENABLED and isinstance(obj, CMySQLCursor) |
| 83 | + |
| 84 | + def is_dialect(self, connect_func: Callable) -> bool: |
| 85 | + if MySQLDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)): |
| 86 | + return MySQLDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower() |
| 87 | + return True |
| 88 | + |
| 89 | + def is_closed(self, conn: Connection) -> bool: |
| 90 | + if MySQLDriverDialect._is_mysql_connection(conn): |
| 91 | + |
| 92 | + # is_connected validates the connection using a ping(). |
| 93 | + # If there are any unread results from previous executions an error will be thrown. |
| 94 | + if self.can_execute_query(conn): |
| 95 | + socket_timeout = WrapperProperties.SOCKET_TIMEOUT_SEC.get_float(self._props) |
| 96 | + timeout_sec = socket_timeout if socket_timeout > 0 else MySQLDriverDialect.IS_CLOSED_TIMEOUT_SEC |
| 97 | + is_connected_with_timeout = timeout(MySQLDriverDialect._executor, timeout_sec)(conn.is_connected) |
| 98 | + |
| 99 | + try: |
| 100 | + return not is_connected_with_timeout() |
| 101 | + except TimeoutError: |
| 102 | + return False |
| 103 | + return False |
| 104 | + |
| 105 | + raise UnsupportedOperationError(Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "is_connected")) |
| 106 | + |
| 107 | + def get_autocommit(self, conn: Connection) -> bool: |
| 108 | + if MySQLDriverDialect._is_mysql_connection(conn): |
| 109 | + return conn.autocommit |
| 110 | + |
| 111 | + raise UnsupportedOperationError( |
| 112 | + Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit")) |
| 113 | + |
| 114 | + def set_autocommit(self, conn: Connection, autocommit: bool): |
| 115 | + if MySQLDriverDialect._is_mysql_connection(conn): |
| 116 | + conn.autocommit = autocommit |
| 117 | + return |
| 118 | + |
| 119 | + raise UnsupportedOperationError( |
| 120 | + Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, "autocommit")) |
| 121 | + |
| 122 | + def set_password(self, props: Properties, pwd: str): |
| 123 | + WrapperProperties.PASSWORD.set(props, pwd) |
| 124 | + props[MySQLDriverDialect.AUTH_PLUGIN_PARAM] = MySQLDriverDialect.AUTH_METHOD |
| 125 | + |
| 126 | + def abort_connection(self, conn: Connection): |
| 127 | + raise UnsupportedOperationError( |
| 128 | + Messages.get_formatted( |
| 129 | + "DriverDialect.UnsupportedOperationError", |
| 130 | + self._driver_name, |
| 131 | + "abort_connection")) |
| 132 | + |
| 133 | + def can_execute_query(self, conn: Connection) -> bool: |
| 134 | + if MySQLDriverDialect._is_mysql_connection(conn): |
| 135 | + if conn.unread_result: |
| 136 | + return conn.can_consume_results |
| 137 | + return True |
| 138 | + |
| 139 | + def is_in_transaction(self, conn: Connection) -> bool: |
| 140 | + if MySQLDriverDialect._is_mysql_connection(conn): |
| 141 | + return bool(conn.in_transaction) |
| 142 | + |
| 143 | + raise UnsupportedOperationError( |
| 144 | + Messages.get_formatted("DriverDialect.UnsupportedOperationError", self._driver_name, |
| 145 | + "in_transaction")) |
| 146 | + |
| 147 | + def get_connection_from_obj(self, obj: object) -> Any: |
| 148 | + if MySQLDriverDialect._is_mysql_connection(obj): |
| 149 | + return obj |
| 150 | + |
| 151 | + if MySQLDriverDialect._is_cmysql_cursor(obj): |
| 152 | + try: |
| 153 | + conn = None |
| 154 | + |
| 155 | + if hasattr(obj, '_cnx'): |
| 156 | + conn = obj._cnx |
| 157 | + elif hasattr(obj, '_connection'): |
| 158 | + conn = obj._connection |
| 159 | + if conn is None: |
| 160 | + return None |
| 161 | + |
| 162 | + if MySQLDriverDialect._is_mysql_connection(conn): |
| 163 | + return conn |
| 164 | + |
| 165 | + except ReferenceError: |
| 166 | + return None |
| 167 | + |
| 168 | + if isinstance(obj, MySQLCursor): |
| 169 | + try: |
| 170 | + if MySQLDriverDialect._is_mysql_connection(obj._connection): |
| 171 | + return obj._connection |
| 172 | + except ReferenceError: |
| 173 | + return None |
| 174 | + |
| 175 | + return None |
| 176 | + |
| 177 | + def transfer_session_state(self, from_conn: Connection, to_conn: Connection): |
| 178 | + if (isinstance(from_conn, CMySQLConnection) or isinstance(from_conn, MySQLConnection)) and ( |
| 179 | + isinstance(to_conn, CMySQLConnection) or isinstance(to_conn, MySQLConnection)): |
| 180 | + to_conn.autocommit = from_conn.autocommit |
| 181 | + |
| 182 | + def ping(self, conn: Connection) -> bool: |
| 183 | + return not self.is_closed(conn) |
| 184 | + |
| 185 | + def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) -> Properties: |
| 186 | + driver_props: Properties = Properties(original_props.copy()) |
| 187 | + PropertiesUtils.remove_wrapper_props(driver_props) |
| 188 | + |
| 189 | + driver_props["host"] = host_info.host |
| 190 | + if host_info.is_port_specified(): |
| 191 | + driver_props["port"] = str(host_info.port) |
| 192 | + |
| 193 | + db = WrapperProperties.DATABASE.get(original_props) |
| 194 | + if db is not None: |
| 195 | + driver_props["database"] = db |
| 196 | + |
| 197 | + connect_timeout = WrapperProperties.CONNECT_TIMEOUT_SEC.get(original_props) |
| 198 | + if connect_timeout is not None: |
| 199 | + driver_props["connect_timeout"] = connect_timeout |
| 200 | + |
| 201 | + return driver_props |
| 202 | + |
| 203 | + def supports_connect_timeout(self) -> bool: |
| 204 | + return True |
0 commit comments