1+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License").
4+ # You may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ from __future__ import annotations
16+
17+ from enum import Enum , auto
18+ from time import perf_counter_ns , sleep
19+ from typing import TYPE_CHECKING , Callable , Optional
20+
21+ from aws_advanced_python_wrapper .host_availability import HostAvailability
22+ from aws_advanced_python_wrapper .read_write_splitting_plugin import ReadWriteSplittingConnectionManager , ConnectionHandler
23+ from aws_advanced_python_wrapper .utils .rds_url_type import RdsUrlType
24+ from aws_advanced_python_wrapper .utils .rdsutils import RdsUtils
25+
26+ if TYPE_CHECKING :
27+ from aws_advanced_python_wrapper .driver_dialect import DriverDialect
28+ from aws_advanced_python_wrapper .host_list_provider import HostListProviderService
29+ from aws_advanced_python_wrapper .pep249 import Connection
30+ from aws_advanced_python_wrapper .plugin_service import PluginService
31+ from aws_advanced_python_wrapper .utils .properties import Properties
32+
33+ from aws_advanced_python_wrapper .errors import AwsWrapperError
34+ from aws_advanced_python_wrapper .hostinfo import HostInfo , HostRole
35+ from aws_advanced_python_wrapper .plugin import PluginFactory
36+ from aws_advanced_python_wrapper .utils .log import Logger
37+ from aws_advanced_python_wrapper .utils .messages import Messages
38+ from aws_advanced_python_wrapper .utils .properties import WrapperProperties
39+
40+ logger = Logger (__name__ )
41+
42+ class VerifyOpenedConnectionType (Enum ):
43+ READER = auto ()
44+ WRITER = auto ()
45+
46+ @staticmethod
47+ def parse_connection_type (phase_str : Optional [str ]) -> VerifyOpenedConnectionType :
48+ if not phase_str :
49+ return None
50+
51+ phase_upper = phase_str .lower ()
52+ if phase_upper == "reader" :
53+ return VerifyOpenedConnectionType .READER
54+ elif phase_upper == "writer" :
55+ return VerifyOpenedConnectionType .WRITER
56+ else :
57+ raise ValueError (Messages .get_formatted ("SimpleReadWriteSplittingPlugin.IncorrectConfiguration" , WrapperProperties .SRW_VERIFY_OPENED_CONNECTION_TYPE .name ))
58+
59+ class EndpointBasedConnectionHandler (ConnectionHandler ):
60+ """Endpoint based implementation of connection handling logic."""
61+
62+ def __init__ (self , plugin_service : PluginService , props : Properties ):
63+ srw_read_endpoint = WrapperProperties .SRW_READ_ENDPOINT .get (props )
64+ if srw_read_endpoint is None :
65+ raise AwsWrapperError (Messages .get_formatted ("SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter" , WrapperProperties .SRW_READ_ENDPOINT .name ))
66+ self .read_endpoint = srw_read_endpoint
67+
68+ srw_write_endpoint = WrapperProperties .SRW_WRITE_ENDPOINT .get (props )
69+ if srw_write_endpoint is None :
70+ raise AwsWrapperError (Messages .get_formatted ("SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter" , WrapperProperties .SRW_WRITE_ENDPOINT .name ))
71+ self .write_endpoint = srw_write_endpoint
72+
73+ self .verify_new_connections = WrapperProperties .SRW_VERIFY_NEW_CONNECTIONS .get_bool (props )
74+ if self .verify_new_connections is True :
75+ srw_connect_retry_timeout_ms = WrapperProperties .SRW_CONNECT_RETRY_TIMEOUT_MS .get_int (props )
76+ if srw_connect_retry_timeout_ms <= 0 :
77+ raise ValueError (Messages .get_formatted ("SimpleReadWriteSplittingPlugin.IncorrectConfiguration" , WrapperProperties .SRW_CONNECT_RETRY_TIMEOUT_MS .name ))
78+ self .connect_retry_timeout_ms = srw_connect_retry_timeout_ms
79+
80+ srw_connect_retry_interval_ms = WrapperProperties .SRW_CONNECT_RETRY_INTERVAL_MS .get_int (props )
81+ if srw_connect_retry_interval_ms <= 0 :
82+ raise ValueError (Messages .get_formatted ("SimpleReadWriteSplittingPlugin.IncorrectConfiguration" , WrapperProperties .SRW_CONNECT_RETRY_INTERVAL_MS .name ))
83+ self .connect_retry_interval_ms = srw_connect_retry_interval_ms
84+
85+ self .verify_opened_connection_type = VerifyOpenedConnectionType .parse_connection_type (WrapperProperties .SRW_VERIFY_OPENED_CONNECTION_TYPE .get (props ))
86+
87+ self ._plugin_service = plugin_service
88+ self ._properties = props
89+ self ._rds_utils = RdsUtils ()
90+ self ._host_list_provider_service : Optional [HostListProviderService ] = None
91+ self ._write_endpoint_host_info = None
92+ self ._read_endpoint_host_info = None
93+
94+
95+ @property
96+ def host_list_provider_service (self ) -> HostListProviderService :
97+ return self ._host_list_provider_service
98+
99+ @host_list_provider_service .setter
100+ def host_list_provider_service (self , value : HostListProviderService ):
101+ self ._host_list_provider_service = value
102+
103+ def get_new_writer_connection (self ) -> Optional [tuple [Connection , HostInfo ]]:
104+ if self ._write_endpoint_host_info is None :
105+ self ._write_endpoint_host_info = self .create_host_info (self .write_endpoint , HostRole .WRITER )
106+
107+ conn : Optional [Connection ] = None
108+ if self .verify_new_connections :
109+ conn = self ._get_verified_connection (self ._properties , self ._write_endpoint_host_info , HostRole .WRITER )
110+ else :
111+ conn = self ._plugin_service .connect (self ._write_endpoint_host_info , self ._properties , self )
112+
113+ return conn , self ._write_endpoint_host_info
114+
115+ def get_new_reader_connection (self ) -> Optional [tuple [Connection , HostInfo ]]:
116+ if self ._read_endpoint_host_info is None :
117+ self ._read_endpoint_host_info = self .create_host_info (self .read_endpoint , HostRole .READER )
118+
119+ conn : Optional [Connection ] = None
120+ if self .verify_new_connections :
121+ conn = self ._get_verified_connection (self ._properties , self ._read_endpoint_host_info , HostRole .READER )
122+ else :
123+ conn = self ._plugin_service .connect (self ._read_endpoint_host_info , self ._properties , self )
124+
125+ return conn , self ._read_endpoint_host_info
126+
127+ def get_verified_initial_connection (self , host_info : HostInfo , props : Properties , is_initial_connection : bool , connect_func : Callable ) -> Optional [Connection ]:
128+ if not is_initial_connection or not self .verify_new_connections :
129+ # No verification required, continue with normal workflow.
130+ return connect_func ()
131+
132+ url_type : RdsUrlType = self ._rds_utils .identify_rds_type (host_info .host )
133+
134+ if url_type == RdsUrlType .RDS_WRITER_CLUSTER or (self .verify_opened_connection_type is not None and self .verify_opened_connection_type == VerifyOpenedConnectionType .WRITER ):
135+ writer_candidate_conn : Optional [Connection ] = self ._get_verified_connection (props , host_info , HostRole .WRITER , connect_func )
136+ if writer_candidate_conn is None :
137+ # Can't get verified writer connection, continue with normal workflow.
138+ return connect_func ()
139+ self .set_initial_connection_host_info (writer_candidate_conn , host_info )
140+ return writer_candidate_conn
141+
142+ if url_type == RdsUrlType .RDS_READER_CLUSTER or (self .verify_opened_connection_type is not None and self .verify_opened_connection_type == VerifyOpenedConnectionType .READER ):
143+ reader_candidate_conn : Optional [Connection ] = self ._get_verified_connection (props , host_info , HostRole .READER , connect_func )
144+ if reader_candidate_conn is None :
145+ # Can't get verified reader connection, continue with normal workflow.
146+ return connect_func ()
147+ self .set_initial_connection_host_info (reader_candidate_conn , host_info )
148+ return reader_candidate_conn
149+
150+ # Continue with normal workflow
151+ return connect_func ()
152+
153+ def set_initial_connection_host_info (self , conn : Connection , host_info : HostInfo ):
154+ if host_info is None :
155+ try :
156+ host_info = self ._plugin_service .identify_connection (conn )
157+ except Exception :
158+ return
159+
160+ if host_info is not None :
161+ self ._host_list_provider_service .initial_connection_host_info = host_info
162+
163+ def _get_verified_connection (self , props : Properties , host_info : HostInfo , role : HostRole , connect_func : Callable = None ) -> Connection :
164+ end_time_nano = perf_counter_ns () + (self .connect_retry_timeout_ms * 1000000 )
165+
166+ candidate_conn : Optional [Connection ]
167+
168+ while perf_counter_ns () < end_time_nano :
169+ candidate_conn = None
170+
171+ try :
172+ if host_info is None :
173+ if connect_func is None :
174+ # Unable to connect to verify role.
175+ break
176+ # No host_info provided, still verify role.
177+ candidate_conn = connect_func ()
178+ else :
179+ candidate_conn = self ._plugin_service .connect (host_info , props , self )
180+
181+ if candidate_conn is None or self ._plugin_service .get_host_role (candidate_conn ) != role :
182+ ReadWriteSplittingConnectionManager ._close_connection (candidate_conn )
183+ self ._delay ()
184+ continue
185+
186+ # Connection valid and verified.
187+ return candidate_conn
188+ except Exception as e :
189+ ReadWriteSplittingConnectionManager ._close_connection (candidate_conn )
190+ self ._delay ()
191+
192+ return None
193+
194+ def old_reader_can_be_used (self , reader_host_info : HostInfo ) -> bool :
195+ # Assume that the old reader can always be used, no topology-based information to check.
196+ return True
197+
198+ def should_close_writer_after_switch_to_reader (self ) -> bool :
199+ # Endpoint based connections do not use pooled connection providers.
200+ return False
201+
202+ def should_close_reader_after_switch_to_writer (self ) -> bool :
203+ # Endpoint based connections do not use pooled connection providers.
204+ return False
205+
206+ def need_connect_to_writer (self ) -> bool :
207+ # SetReadOnly(true) will always connect to the read_endpoint, and not the writer.
208+ return False
209+
210+ def refresh_and_store_host_list (self , current_conn : Connection , driver_dialect : DriverDialect ):
211+ # Endpoint based connections do not require a host list.
212+ return
213+
214+ def should_update_writer_with_current_conn (self , current_conn : Connection , current_host : HostInfo , writer_conn : Connection ) -> bool :
215+ return self .is_writer_host (current_host ) and current_conn != writer_conn and (not self .verify_new_connections or self ._plugin_service .get_host_role (current_conn ) == HostRole .WRITER )
216+
217+ def should_update_reader_with_current_conn (self , current_conn : Connection , current_host : HostInfo , reader_conn : Connection ) -> bool :
218+ return self .is_reader_host (current_host ) and current_conn != reader_conn and (not self .verify_new_connections or self ._plugin_service .get_host_role (current_conn ) == HostRole .READER )
219+
220+ def is_writer_host (self , current_host : HostInfo ) -> bool :
221+ return current_host .host .casefold () == self .write_endpoint .casefold ()
222+
223+ def is_reader_host (self , current_host : HostInfo ) -> bool :
224+ return current_host .host .casefold () == self .read_endpoint .casefold ()
225+
226+ def create_host_info (self , endpoint , role : HostRole ) -> HostInfo :
227+ port = self ._plugin_service .database_dialect .default_port
228+ if self .host_list_provider_service is not None and self .host_list_provider_service .initial_connection_host_info is not None and self .host_list_provider_service .initial_connection_host_info .port != HostInfo .NO_PORT :
229+ port = self .host_list_provider_service .initial_connection_host_info .port
230+ return HostInfo (
231+ host = endpoint ,
232+ port = port ,
233+ role = role ,
234+ availability = HostAvailability .AVAILABLE )
235+
236+ def _delay (self ):
237+ sleep (self .connect_retry_interval_ms / 1000 )
238+
239+ class SimpleReadWriteSplittingPlugin (ReadWriteSplittingConnectionManager ):
240+ def __init__ (self , plugin_service , props : Properties ):
241+ # The simple read/write splitting plugin handles connections based on configuration parameter endpoints.
242+ connection_handler = EndpointBasedConnectionHandler (
243+ plugin_service ,
244+ props ,
245+ )
246+
247+ super ().__init__ (plugin_service , props , connection_handler )
248+
249+ class SimpleReadWriteSplittingPluginFactory (PluginFactory ):
250+ def get_instance (self , plugin_service , props : Properties ):
251+ return SimpleReadWriteSplittingPlugin (plugin_service , props )
0 commit comments