diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 1605efe92d..3d862d403e 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -138,8 +138,10 @@ async def command( spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options) # Support CSOT + applied_csot = False if client: - conn.apply_timeout(client, spec) + res = conn.apply_timeout(client, spec) + applied_csot = bool(res) _csot.apply_write_concern(spec, write_concern) if use_op_msg: @@ -195,7 +197,7 @@ async def command( reply = None response_doc: _DocumentOut = {"ok": 1} else: - reply = await async_receive_message(conn, request_id) + reply = await async_receive_message(conn, request_id, enable_pending=applied_csot) conn.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response( codec_options=codec_options, user_fields=user_fields diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a67cc5f3c8..d4d042dd86 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -34,7 +34,7 @@ ) from bson import DEFAULT_CODEC_OPTIONS -from pymongo import _csot, helpers_shared +from pymongo import _csot, helpers_shared, network_layer from pymongo.asynchronous.client_session import _validate_session_write_concern from pymongo.asynchronous.helpers import _handle_reauth from pymongo.asynchronous.network import command @@ -188,6 +188,42 @@ def __init__( self.creation_time = time.monotonic() # For gossiping $clusterTime from the connection handshake to the client. self._cluster_time = None + self.pending_response = False + self.pending_bytes = 0 + self.pending_deadline = 0.0 + + def mark_pending(self, nbytes: int) -> None: + """Mark this connection as having a pending response.""" + self.pending_response = True + self.pending_bytes = nbytes + self.pending_deadline = time.monotonic() + 3 # 3 seconds timeout for pending response + + async def complete_pending(self) -> None: + """Complete a pending response.""" + if not self.pending_response: + return + + if _csot.get_timeout(): + deadline = min(_csot.get_deadline(), self.pending_deadline) + else: + timeout = self.conn.gettimeout + if timeout is not None: + deadline = min(time.monotonic() + timeout, self.pending_deadline) + else: + deadline = self.pending_deadline + + if not _IS_SYNC: + # In async the reader task reads the whole message at once. + # TODO: respect deadline + await self.receive_message(None, True) + else: + try: + network_layer.receive_data(self, self.pending_bytes, deadline, True) # type:ignore[arg-type] + except BaseException as error: + await self._raise_connection_failure(error) + self.pending_response = False + self.pending_bytes = 0 + self.pending_deadline = 0.0 def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" @@ -454,13 +490,17 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None: except BaseException as error: await self._raise_connection_failure(error) - async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: + async def receive_message( + self, request_id: Optional[int], enable_pending: bool = False + ) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise ConnectionFailure. If any exception is raised, the socket is closed. """ try: - return await async_receive_message(self, request_id, self.max_message_size) + return await async_receive_message( + self, request_id, self.max_message_size, enable_pending + ) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: await self._raise_connection_failure(error) @@ -495,7 +535,9 @@ async def write_command( :param msg: bytes, the command message. """ await self.send_message(msg, 0) - reply = await self.receive_message(request_id) + reply = await self.receive_message( + request_id, enable_pending=(_csot.get_timeout() is not None) + ) result = reply.command_response(codec_options) # Raises NotPrimaryError or OperationFailure. @@ -635,7 +677,10 @@ async def _raise_connection_failure(self, error: BaseException) -> NoReturn: reason = None else: reason = ConnectionClosedReason.ERROR - await self.close_conn(reason) + + # Pending connections should be placed back in the pool. + if not self.pending_response: + await self.close_conn(reason) # SSLError from PyOpenSSL inherits directly from Exception. if isinstance(error, (IOError, OSError, SSLError)): details = _get_timeout_details(self.opts) @@ -1076,7 +1121,7 @@ async def checkout( This method should always be used in a with-statement:: - with pool.get_conn() as connection: + with pool.checkout() as connection: connection.send_message(msg) data = connection.receive_message(op_code, request_id) @@ -1388,6 +1433,7 @@ async def _perished(self, conn: AsyncConnection) -> bool: pool, to keep performance reasonable - we can't avoid AutoReconnects completely anyway. """ + await conn.complete_pending() idle_time_seconds = conn.idle_time_seconds() # If socket is idle, open a new one. if ( diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 0e0d53b96f..2df594774a 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -205,7 +205,7 @@ async def run_operation( reply = await conn.receive_message(None) else: await conn.send_message(data, max_doc_size) - reply = await conn.receive_message(request_id) + reply = await conn.receive_message(request_id, operation.pending_enabled()) # Unpack and check for command errors. if use_cmd: diff --git a/pymongo/message.py b/pymongo/message.py index d51c77a174..d9d1450838 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -1569,6 +1569,7 @@ class _Query: "allow_disk_use", "_as_command", "exhaust", + "_pending_enabled", ) # For compatibility with the _GetMore class. @@ -1612,6 +1613,10 @@ def __init__( self.name = "find" self._as_command: Optional[tuple[dict[str, Any], str]] = None self.exhaust = exhaust + self._pending_enabled = False + + def pending_enabled(self) -> bool: + return self._pending_enabled def reset(self) -> None: self._as_command = None @@ -1673,7 +1678,9 @@ def as_command( conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] # Support CSOT if apply_timeout: - conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type] + res = conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type] + if res is not None: + self._pending_enabled = True self._as_command = cmd, self.db return self._as_command @@ -1747,6 +1754,7 @@ class _GetMore: "_as_command", "exhaust", "comment", + "_pending_enabled", ) name = "getMore" @@ -1779,6 +1787,10 @@ def __init__( self._as_command: Optional[tuple[dict[str, Any], str]] = None self.exhaust = exhaust self.comment = comment + self._pending_enabled = False + + def pending_enabled(self) -> bool: + return self._pending_enabled def reset(self) -> None: self._as_command = None @@ -1822,7 +1834,9 @@ def as_command( conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] # Support CSOT if apply_timeout: - conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type] + res = conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type] + if res is not None: + self._pending_enabled = True self._as_command = cmd, self.db return self._as_command diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index e287655c61..07e87551d7 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -325,7 +325,9 @@ def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: raise socket.timeout("timed out") -def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: +def receive_data( + conn: Connection, length: int, deadline: Optional[float], enable_pending: bool = False +) -> memoryview: buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 @@ -336,7 +338,7 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me # When the timeout has expired we perform one final non-blocking recv. # This helps avoid spurious timeouts when the response is actually already # buffered on the client. - orig_timeout = conn.conn.gettimeout() + orig_timeout = conn.conn.gettimeout try: while bytes_read < length: try: @@ -357,12 +359,16 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me if conn.cancel_context.cancelled: raise _OperationCancelled("operation cancelled") from None # We reached the true deadline. + if enable_pending: + conn.mark_pending(length - bytes_read) raise socket.timeout("timed out") from None except socket.timeout: if conn.cancel_context.cancelled: raise _OperationCancelled("operation cancelled") from None if _PYPY: # We reached the true deadline. + if enable_pending: + conn.mark_pending(length - bytes_read) raise continue except OSError as exc: @@ -438,6 +444,7 @@ class NetworkingInterface(NetworkingInterfaceBase): def __init__(self, conn: Union[socket.socket, _sslConn]): super().__init__(conn) + @property def gettimeout(self) -> float | None: return self.conn.gettimeout() @@ -692,6 +699,7 @@ async def async_receive_message( conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE, + enable_pending: bool = False, ) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise socket.error.""" timeout: Optional[Union[float, int]] @@ -721,6 +729,8 @@ async def async_receive_message( if pending: await asyncio.wait(pending) if len(done) == 0: + if enable_pending: + conn.mark_pending(1) raise socket.timeout("timed out") if read_task in done: data, op_code = read_task.result() @@ -740,19 +750,24 @@ async def async_receive_message( def receive_message( - conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE + conn: Connection, + request_id: Optional[int], + max_message_size: int = MAX_MESSAGE_SIZE, + enable_pending: bool = False, ) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise socket.error.""" if _csot.get_timeout(): deadline = _csot.get_deadline() else: - timeout = conn.conn.gettimeout() + timeout = conn.conn.gettimeout if timeout: deadline = time.monotonic() + timeout else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline)) + length, _, response_to, op_code = _UNPACK_HEADER( + receive_data(conn, 16, deadline, enable_pending) + ) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -767,10 +782,12 @@ def receive_message( f"message size ({max_message_size!r})" ) if op_code == 2012: - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) - data = decompress(receive_data(conn, length - 25, deadline), compressor_id) + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( + receive_data(conn, 9, deadline, enable_pending) + ) + data = decompress(receive_data(conn, length - 25, deadline, enable_pending), compressor_id) else: - data = receive_data(conn, length - 16, deadline) + data = receive_data(conn, length - 16, deadline, enable_pending) try: unpack_reply = _UNPACK_REPLY[op_code] diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 9559a5a542..b27541e5d8 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -138,8 +138,10 @@ def command( spec = orig = client._encrypter.encrypt(dbname, spec, codec_options) # Support CSOT + applied_csot = False if client: - conn.apply_timeout(client, spec) + res = conn.apply_timeout(client, spec) + applied_csot = bool(res) _csot.apply_write_concern(spec, write_concern) if use_op_msg: @@ -195,7 +197,7 @@ def command( reply = None response_doc: _DocumentOut = {"ok": 1} else: - reply = receive_message(conn, request_id) + reply = receive_message(conn, request_id, enable_pending=applied_csot) conn.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response( codec_options=codec_options, user_fields=user_fields diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 224834af31..19addc6336 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -34,7 +34,7 @@ ) from bson import DEFAULT_CODEC_OPTIONS -from pymongo import _csot, helpers_shared +from pymongo import _csot, helpers_shared, network_layer from pymongo.common import ( MAX_BSON_SIZE, MAX_MESSAGE_SIZE, @@ -188,6 +188,42 @@ def __init__( self.creation_time = time.monotonic() # For gossiping $clusterTime from the connection handshake to the client. self._cluster_time = None + self.pending_response = False + self.pending_bytes = 0 + self.pending_deadline = 0.0 + + def mark_pending(self, nbytes: int) -> None: + """Mark this connection as having a pending response.""" + self.pending_response = True + self.pending_bytes = nbytes + self.pending_deadline = time.monotonic() + 3 # 3 seconds timeout for pending response + + def complete_pending(self) -> None: + """Complete a pending response.""" + if not self.pending_response: + return + + if _csot.get_timeout(): + deadline = min(_csot.get_deadline(), self.pending_deadline) + else: + timeout = self.conn.gettimeout + if timeout is not None: + deadline = min(time.monotonic() + timeout, self.pending_deadline) + else: + deadline = self.pending_deadline + + if not _IS_SYNC: + # In async the reader task reads the whole message at once. + # TODO: respect deadline + self.receive_message(None, True) + else: + try: + network_layer.receive_data(self, self.pending_bytes, deadline, True) # type:ignore[arg-type] + except BaseException as error: + self._raise_connection_failure(error) + self.pending_response = False + self.pending_bytes = 0 + self.pending_deadline = 0.0 def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" @@ -454,13 +490,15 @@ def send_message(self, message: bytes, max_doc_size: int) -> None: except BaseException as error: self._raise_connection_failure(error) - def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: + def receive_message( + self, request_id: Optional[int], enable_pending: bool = False + ) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise ConnectionFailure. If any exception is raised, the socket is closed. """ try: - return receive_message(self, request_id, self.max_message_size) + return receive_message(self, request_id, self.max_message_size, enable_pending) # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -495,7 +533,7 @@ def write_command( :param msg: bytes, the command message. """ self.send_message(msg, 0) - reply = self.receive_message(request_id) + reply = self.receive_message(request_id, enable_pending=(_csot.get_timeout() is not None)) result = reply.command_response(codec_options) # Raises NotPrimaryError or OperationFailure. @@ -633,7 +671,10 @@ def _raise_connection_failure(self, error: BaseException) -> NoReturn: reason = None else: reason = ConnectionClosedReason.ERROR - self.close_conn(reason) + + # Pending connections should be placed back in the pool. + if not self.pending_response: + self.close_conn(reason) # SSLError from PyOpenSSL inherits directly from Exception. if isinstance(error, (IOError, OSError, SSLError)): details = _get_timeout_details(self.opts) @@ -1072,7 +1113,7 @@ def checkout( This method should always be used in a with-statement:: - with pool.get_conn() as connection: + with pool.checkout() as connection: connection.send_message(msg) data = connection.receive_message(op_code, request_id) @@ -1384,6 +1425,7 @@ def _perished(self, conn: Connection) -> bool: pool, to keep performance reasonable - we can't avoid AutoReconnects completely anyway. """ + conn.complete_pending() idle_time_seconds = conn.idle_time_seconds() # If socket is idle, open a new one. if ( diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index c3643ba815..af3bff9c9d 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -205,7 +205,7 @@ def run_operation( reply = conn.receive_message(None) else: conn.send_message(data, max_doc_size) - reply = conn.receive_message(request_id) + reply = conn.receive_message(request_id, operation.pending_enabled()) # Unpack and check for command errors. if use_cmd: diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index c9cfca81fc..738a434b9c 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -2222,7 +2222,7 @@ async def test_exhaust_getmore_server_error(self): await cursor.next() # Cause a server error on getmore. - async def receive_message(request_id): + async def receive_message(request_id, enable_pending=False): # Discard the actual server response. await AsyncConnection.receive_message(conn, request_id) diff --git a/test/test_client.py b/test/test_client.py index 038ba2241b..94e6ec3f3f 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -2179,7 +2179,7 @@ def test_exhaust_getmore_server_error(self): cursor.next() # Cause a server error on getmore. - def receive_message(request_id): + def receive_message(request_id, enable_pending=False): # Discard the actual server response. Connection.receive_message(conn, request_id)