Skip to content

Commit 658ee07

Browse files
committed
feat: srw
1 parent 725a80a commit 658ee07

File tree

4 files changed

+129
-152
lines changed

4 files changed

+129
-152
lines changed

aws_advanced_python_wrapper/plugin_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@
8080
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
8181
from aws_advanced_python_wrapper.iam_plugin import IamAuthPluginFactory
8282
from aws_advanced_python_wrapper.plugin import CanReleaseResources
83-
from aws_advanced_python_wrapper.read_write_splitting_plugin import (
84-
ReadWriteSplittingPluginFactory, SimpleReadWriteSplittingPluginFactory)
83+
from aws_advanced_python_wrapper.read_write_splitting_plugin import ReadWriteSplittingPluginFactory
84+
from aws_advanced_python_wrapper.simple_read_write_splitting_plugin import SimpleReadWriteSplittingPluginFactory
8585
from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsPluginFactory
8686
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
8787
from aws_advanced_python_wrapper.utils.decorators import \

aws_advanced_python_wrapper/read_write_splitting_plugin.py

Lines changed: 62 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,24 @@
4040

4141
class ReadWriteSplittingConnectionManager(Plugin):
4242
"""Base class that manages connection switching logic."""
43+
_POOL_PROVIDER_CLASS_NAME = "aws_advanced_python_wrapper.sql_alchemy_connection_provider.SqlAlchemyPooledConnectionProvider"
44+
4345
_SUBSCRIBED_METHODS: Set[str] = {"init_host_provider",
4446
"connect",
4547
"notify_connection_changed",
4648
"Connection.set_read_only"}
4749
def __init__(self, plugin_service: PluginService, props, connection_handler: ConnectionHandler):
4850
self._plugin_service = plugin_service
51+
self._conn_provider_manager: ConnectionProviderManager = self._plugin_service.get_connection_provider_manager()
4952
self._properties = props
5053
self._connection_handler = connection_handler
5154
self._writer_connection: Optional[Connection] = None
5255
self._reader_connection: Optional[Connection] = None
5356
self._writer_host_info: Optional[HostInfo] = None
5457
self._reader_host_info: Optional[HostInfo] = None
5558
self._in_read_write_split: bool = False
59+
self._is_reader_conn_from_internal_pool: bool = False
60+
self._is_writer_conn_from_internal_pool: bool = False
5661

5762
@property
5863
def subscribed_methods(self) -> Set[str]:
@@ -63,7 +68,7 @@ def init_host_provider(
6368
props: Properties,
6469
host_list_provider_service: HostListProviderService,
6570
init_host_provider_func: Callable):
66-
self._connection_handler.host_list_provider_service = host_list_provider_service
71+
self._connection_handler.set_host_list_provider_service(host_list_provider_service)
6772
init_host_provider_func()
6873

6974
def connect(
@@ -113,10 +118,10 @@ def _update_internal_connection_info(self):
113118
return
114119

115120
if self._connection_handler.should_update_writer_with_current_conn(current_conn, current_host, self._writer_connection):
116-
self._close_connection(self._writer_connection)
121+
self.close_connection(self._writer_connection)
117122
self._set_writer_connection(current_conn, current_host)
118123
elif self._connection_handler.should_update_reader_with_current_conn(current_conn, current_host, self._reader_connection):
119-
self._close_connection(self._reader_connection)
124+
self.close_connection(self._reader_connection)
120125
self._set_reader_connection(current_conn, current_host)
121126

122127
def _set_writer_connection(self, writer_conn: Connection, writer_host_info: HostInfo):
@@ -129,12 +134,15 @@ def _set_reader_connection(self, reader_conn: Connection, reader_host_info: Host
129134
self._reader_host_info = reader_host_info
130135
logger.debug("ReadWriteSplittingPlugin.SetReaderConnection", reader_host_info.url)
131136

132-
def _get_new_writer_connection(self):
133-
conn, writer_host = self._connection_handler.get_new_writer_connection()
137+
def _initialize_writer_connection(self):
138+
conn, writer_host = self._connection_handler.open_new_writer_connection()
134139

135140
if conn is None:
136-
self._log_and_raise_exception("ReadWriteSplittingPlugin.WriterUnavailable")
141+
self.log_and_raise_exception("ReadWriteSplittingPlugin.WriterUnavailable")
137142
return
143+
144+
provider = self._conn_provider_manager.get_connection_provider(writer_host, self._properties)
145+
self._is_writer_conn_from_internal_pool = (ReadWriteSplittingConnectionManager._POOL_PROVIDER_CLASS_NAME in str(type(provider)))
138146

139147
self._set_writer_connection(conn, writer_host)
140148
self._switch_current_connection_to(conn, writer_host)
@@ -145,12 +153,12 @@ def _switch_connection_if_required(self, read_only: bool):
145153

146154
if (current_conn is not None and
147155
driver_dialect is not None and driver_dialect.is_closed(current_conn)):
148-
self._log_and_raise_exception("ReadWriteSplittingPlugin.SetReadOnlyOnClosedConnection")
156+
self.log_and_raise_exception("ReadWriteSplittingPlugin.SetReadOnlyOnClosedConnection")
149157

150158
self._connection_handler.refresh_and_store_host_list(current_conn, driver_dialect)
151159
current_host = self._plugin_service.current_host_info
152160
if current_host is None:
153-
self._log_and_raise_exception("ReadWriteSplittingPlugin.UnavailableHostInfo")
161+
self.log_and_raise_exception("ReadWriteSplittingPlugin.UnavailableHostInfo")
154162
return
155163

156164
if read_only:
@@ -162,18 +170,18 @@ def _switch_connection_if_required(self, read_only: bool):
162170
# do this
163171
ex = None
164172
if not self._is_connection_usable(current_conn, driver_dialect):
165-
self._log_and_raise_exception("ReadWriteSplittingPlugin.ErrorSwitchingToReader")
173+
self.log_and_raise_exception("ReadWriteSplittingPlugin.ErrorSwitchingToReader")
166174
return
167175

168176
logger.warning("ReadWriteSplittingPlugin.FallbackToWriter", current_host.url)
169177
elif not self._connection_handler.is_writer_host(current_host):
170178
if self._plugin_service.is_in_transaction:
171-
self._log_and_raise_exception("ReadWriteSplittingPlugin.SetReadOnlyFalseInTransaction")
179+
self.log_and_raise_exception("ReadWriteSplittingPlugin.SetReadOnlyFalseInTransaction")
172180

173181
try:
174182
self._switch_to_writer_connection()
175183
except Exception:
176-
self._log_and_raise_exception("ReadWriteSplittingPlugin.ErrorSwitchingToWriter")
184+
self.log_and_raise_exception("ReadWriteSplittingPlugin.ErrorSwitchingToWriter")
177185

178186
def _switch_current_connection_to(self, new_conn: Connection, new_conn_host: HostInfo):
179187
current_conn = self._plugin_service.current_connection
@@ -195,11 +203,11 @@ def _switch_to_writer_connection(self):
195203

196204
self._in_read_write_split = True
197205
if not self._is_connection_usable(self._writer_connection, driver_dialect):
198-
self._get_new_writer_connection()
206+
self._initialize_writer_connection()
199207
elif self._writer_connection is not None and self._writer_host_info is not None:
200208
self._switch_current_connection_to(self._writer_connection, self._writer_host_info)
201209

202-
if self._connection_handler.should_close_reader_after_switch_to_writer():
210+
if self._is_reader_conn_from_internal_pool:
203211
self._close_connection_if_idle(self._reader_connection)
204212

205213
logger.debug("ReadWriteSplittingPlugin.SwitchedFromReaderToWriter", self._writer_host_info.url)
@@ -230,53 +238,68 @@ def _switch_to_reader_connection(self):
230238
self._close_connection_if_idle(self._reader_connection)
231239
self._initialize_reader_connection()
232240

233-
if self._connection_handler.should_close_writer_after_switch_to_reader():
241+
if self._is_writer_conn_from_internal_pool:
234242
self._close_connection_if_idle(self._writer_connection)
235243

236244
def _initialize_reader_connection(self):
237245
if self._connection_handler.need_connect_to_writer():
238246
if not self._is_connection_usable(self._writer_connection, self._plugin_service.driver_dialect):
239-
self._get_new_writer_connection()
247+
self._initialize_writer_connection()
240248
logger.warning("ReadWriteSplittingPlugin.NoReadersFound", self._writer_host_info.url)
241249
return
242250

243-
conn, reader_host = self._connection_handler.get_new_reader_connection()
251+
conn, reader_host = self._connection_handler.open_new_reader_connection()
244252

245253
if conn is None or reader_host is None:
246-
self._log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable")
254+
self.log_and_raise_exception("ReadWriteSplittingPlugin.NoReadersAvailable")
247255
return
248256

249257
logger.debug("ReadWriteSplittingPlugin.SuccessfullyConnectedToReader", reader_host.url)
250258

259+
provider = self._conn_provider_manager.get_connection_provider(reader_host, self._properties)
260+
self._is_reader_conn_from_internal_pool = (ReadWriteSplittingConnectionManager._POOL_PROVIDER_CLASS_NAME in str(type(provider)))
261+
251262
self._set_reader_connection(conn, reader_host)
252263
self._switch_current_connection_to(conn, reader_host)
253264

254265
logger.debug("ReadWriteSplittingPlugin.SwitchedFromWriterToReader", reader_host.url)
255266

256267
def _close_connection_if_idle(self, internal_conn: Optional[Connection]):
268+
if internal_conn is None:
269+
return
270+
257271
current_conn = self._plugin_service.current_connection
258272
driver_dialect = self._plugin_service.driver_dialect
259-
260-
if (internal_conn is not None and internal_conn != current_conn and self._is_connection_usable(internal_conn, driver_dialect)):
261-
self._close_connection(internal_conn)
273+
274+
try:
275+
if (internal_conn != current_conn and
276+
self._is_connection_usable(internal_conn, driver_dialect)):
277+
internal_conn.close()
278+
except Exception:
279+
# Ignore exceptions during cleanup - connection might already be dead
280+
pass
281+
finally:
282+
# Always clear cached references to prevent reuse of dead connections
262283
if internal_conn == self._writer_connection:
263284
self._writer_connection = None
285+
self._writer_host_info = None
264286
if internal_conn == self._reader_connection:
265287
self._reader_connection = None
288+
self._reader_host_info = None
266289

267290
def _close_idle_connections(self):
268291
logger.debug("ReadWriteSplittingPlugin.ClosingInternalConnections")
269292
self._close_connection_if_idle(self._reader_connection)
270293
self._close_connection_if_idle(self._writer_connection)
271294

272295
@staticmethod
273-
def _log_and_raise_exception(log_msg: str):
296+
def log_and_raise_exception(log_msg: str):
274297
logger.error(log_msg)
275298
raise ReadWriteSplittingError(Messages.get(log_msg))
276299

277300
@staticmethod
278301
def _is_connection_usable(conn: Optional[Connection], driver_dialect: Optional[DriverDialect]):
279-
if conn is not None or driver_dialect is None:
302+
if conn is None or driver_dialect is None:
280303
return False
281304
try:
282305
return not driver_dialect.is_closed(conn)
@@ -285,7 +308,7 @@ def _is_connection_usable(conn: Optional[Connection], driver_dialect: Optional[D
285308
return False
286309

287310
@staticmethod
288-
def _close_connection(connection: Connection):
311+
def close_connection(connection: Connection):
289312
if connection is not None:
290313
try:
291314
connection.close()
@@ -295,17 +318,15 @@ def _close_connection(connection: Connection):
295318

296319
class ConnectionHandler(Protocol):
297320
"""Protocol for handling writer/reader connection logic."""
298-
@property
299-
@abstractmethod
300-
def host_list_provider_service(self) -> HostListProviderService:
301-
...
321+
def set_host_list_provider_service(self, value: HostListProviderService):
322+
self._host_list_provider_service = value
302323

303-
def get_new_writer_connection(self) -> Optional[tuple[Connection, HostInfo]]:
304-
"""Get or create a writer connection."""
324+
def open_new_writer_connection(self) -> Optional[tuple[Connection, HostInfo]]:
325+
"""Open a writer connection."""
305326
...
306327

307-
def get_new_reader_connection(self) -> Optional[tuple[Connection, HostInfo]]:
308-
"""Get or create a reader connection."""
328+
def open_new_reader_connection(self) -> Optional[tuple[Connection, HostInfo]]:
329+
"""Open a reader connection."""
309330
...
310331

311332
def get_verified_initial_connection(self, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]:
@@ -332,14 +353,6 @@ def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool:
332353
"""Return true if the current host can be used to switch connection to."""
333354
...
334355

335-
def should_close_writer_after_switch_to_reader(self) -> bool:
336-
"""Return true if the cached writer should be closed upon switch to reader."""
337-
...
338-
339-
def should_close_reader_after_switch_to_writer(self) -> bool:
340-
"""Return true if the cached reader should be closed upon switch to writer."""
341-
...
342-
343356
def need_connect_to_writer(self) -> bool:
344357
"""Return true if switching to reader should instead connect to writer."""
345358
...
@@ -349,16 +362,11 @@ def refresh_and_store_host_list(self, current_conn: Connection, driver_dialect:
349362
...
350363

351364
class TopologyBasedConnectionHandler(ConnectionHandler):
352-
"""Topology based implementation of connection handling logic."""
353-
_POOL_PROVIDER_CLASS_NAME = "aws_advanced_python_wrapper.sql_alchemy_connection_provider.SqlAlchemyPooledConnectionProvider"
354-
365+
"""Topology based implementation of connection handling logic."""
355366
def __init__(self, plugin_service: PluginService, props: Properties):
356367
self._plugin_service = plugin_service
357368
self._properties = props
358369
self._host_list_provider_service: Optional[HostListProviderService] = None
359-
self._conn_provider_manager: ConnectionProviderManager = self._plugin_service.get_connection_provider_manager()
360-
self._is_reader_conn_from_internal_pool: bool = False
361-
self._is_writer_conn_from_internal_pool: bool = False
362370
strategy = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(self._properties)
363371
if strategy is not None:
364372
self._reader_selector_strategy = strategy
@@ -368,26 +376,16 @@ def __init__(self, plugin_service: PluginService, props: Properties):
368376
self._reader_selector_strategy = default_strategy
369377
self._hosts = None
370378

371-
@property
372-
def host_list_provider_service(self) -> HostListProviderService:
373-
return self._host_list_provider_service
374-
375-
@host_list_provider_service.setter
376-
def host_list_provider_service(self, value: HostListProviderService):
377-
self._host_list_provider_service = value
378-
379-
def get_new_writer_connection(self) -> Optional[tuple[Connection, HostInfo]]:
380-
writer_host = self._get_writer(self._hosts)
379+
def open_new_writer_connection(self) -> Optional[tuple[Connection, HostInfo]]:
380+
writer_host = self._get_writer()
381381
if writer_host is None:
382382
return
383383

384384
conn = self._plugin_service.connect(writer_host, self._properties, self)
385-
provider = self._conn_provider_manager.get_connection_provider(writer_host, self._properties)
386-
self._is_writer_conn_from_internal_pool = (TopologyBasedConnectionHandler._POOL_PROVIDER_CLASS_NAME in str(type(provider)))
387385

388386
return conn, writer_host
389387

390-
def get_new_reader_connection(self) -> Optional[tuple[Connection, HostInfo]]:
388+
def open_new_reader_connection(self) -> Optional[tuple[Connection, HostInfo]]:
391389
conn: Optional[Connection] = None
392390
reader_host: Optional[HostInfo] = None
393391

@@ -398,8 +396,6 @@ def get_new_reader_connection(self) -> Optional[tuple[Connection, HostInfo]]:
398396
try:
399397
conn = self._plugin_service.connect(host, self._properties, self)
400398
reader_host = host
401-
provider = self._conn_provider_manager.get_connection_provider(host, self._properties)
402-
self._is_reader_conn_from_internal_pool = (TopologyBasedConnectionHandler._POOL_PROVIDER_CLASS_NAME in str(type(provider)))
403399
break
404400
except Exception:
405401
logger.warning("ReadWriteSplittingPlugin.FailedToConnectToReader", host.url)
@@ -419,7 +415,7 @@ def get_verified_initial_connection(self, host_info: HostInfo, props: Properties
419415

420416
current_role = self._plugin_service.get_host_role(current_conn)
421417
if current_role is None or current_role == HostRole.UNKNOWN:
422-
ReadWriteSplittingConnectionManager._log_and_raise_exception("ReadWriteSplittingPlugin.ErrorVerifyingInitialHostSpecRole")
418+
ReadWriteSplittingConnectionManager.log_and_raise_exception("ReadWriteSplittingPlugin.ErrorVerifyingInitialHostSpecRole")
423419

424420
current_host = self._plugin_service.initial_connection_host_info
425421
if current_host is not None:
@@ -437,15 +433,9 @@ def old_reader_can_be_used(self, reader_host_info: HostInfo) -> bool:
437433
hostnames = [host_info.host for host_info in self._hosts]
438434
return reader_host_info is not None and reader_host_info.host in hostnames
439435

440-
def should_close_writer_after_switch_to_reader(self) -> bool:
441-
return self._is_writer_conn_from_internal_pool
442-
443-
def should_close_reader_after_switch_to_writer(self) -> bool:
444-
return self._is_reader_conn_from_internal_pool
445-
446436
def need_connect_to_writer(self) -> bool:
447437
if self._hosts is not None and len(self._hosts) == 1:
448-
return self._get_writer(self._hosts) is not None
438+
return self._get_writer() is not None
449439
return False
450440

451441
def refresh_and_store_host_list(self, current_conn: Connection, driver_dialect: DriverDialect):
@@ -457,7 +447,7 @@ def refresh_and_store_host_list(self, current_conn: Connection, driver_dialect:
457447

458448
hosts = self._plugin_service.hosts
459449
if hosts is None or len(hosts) == 0:
460-
ReadWriteSplittingConnectionManager._log_and_raise_exception("ReadWriteSplittingPlugin.EmptyHostList")
450+
ReadWriteSplittingConnectionManager.log_and_raise_exception("ReadWriteSplittingPlugin.EmptyHostList")
461451

462452
self._hosts = hosts
463453

@@ -473,12 +463,12 @@ def is_writer_host(self, current_host: HostInfo) -> bool:
473463
def is_reader_host(self, current_host) -> bool:
474464
return current_host.role == HostRole.READER
475465

476-
def _get_writer(hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
477-
for host in hosts:
466+
def _get_writer(self) -> Optional[HostInfo]:
467+
for host in self._hosts:
478468
if host.role == HostRole.WRITER:
479469
return host
480470

481-
ReadWriteSplittingConnectionManager._log_and_raise_exception("ReadWriteSplittingPlugin.NoWriterFound")
471+
ReadWriteSplittingConnectionManager.log_and_raise_exception("ReadWriteSplittingPlugin.NoWriterFound")
482472

483473
return None
484474

0 commit comments

Comments
 (0)