Skip to content

PYTHON-4324 CSOT avoid connection churn when operations timeout #2269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pymongo/asynchronous/network.py
Original file line number Diff line number Diff line change
@@ -138,8 +138,10 @@ async def command(
spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options)

# Support CSOT
applied_csot = False
if client:
conn.apply_timeout(client, spec)
res = conn.apply_timeout(client, spec)
applied_csot = bool(res)
_csot.apply_write_concern(spec, write_concern)

if use_op_msg:
@@ -195,7 +197,7 @@ async def command(
reply = None
response_doc: _DocumentOut = {"ok": 1}
else:
reply = await async_receive_message(conn, request_id)
reply = await async_receive_message(conn, request_id, enable_pending=applied_csot)
conn.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields
58 changes: 52 additions & 6 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
)

from bson import DEFAULT_CODEC_OPTIONS
from pymongo import _csot, helpers_shared
from pymongo import _csot, helpers_shared, network_layer
from pymongo.asynchronous.client_session import _validate_session_write_concern
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.asynchronous.network import command
@@ -188,6 +188,42 @@ def __init__(
self.creation_time = time.monotonic()
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None
self.pending_response = False
self.pending_bytes = 0
self.pending_deadline = 0.0

def mark_pending(self, nbytes: int) -> None:
"""Mark this connection as having a pending response."""
self.pending_response = True
self.pending_bytes = nbytes
self.pending_deadline = time.monotonic() + 3 # 3 seconds timeout for pending response

async def complete_pending(self) -> None:
"""Complete a pending response."""
if not self.pending_response:
return

if _csot.get_timeout():
deadline = min(_csot.get_deadline(), self.pending_deadline)
else:
timeout = self.conn.gettimeout
if timeout is not None:
deadline = min(time.monotonic() + timeout, self.pending_deadline)
else:
deadline = self.pending_deadline

if not _IS_SYNC:
# In async the reader task reads the whole message at once.
# TODO: respect deadline
await self.receive_message(None, True)
else:
try:
network_layer.receive_data(self, self.pending_bytes, deadline, True) # type:ignore[arg-type]
except BaseException as error:
await self._raise_connection_failure(error)
self.pending_response = False
self.pending_bytes = 0
self.pending_deadline = 0.0

def set_conn_timeout(self, timeout: Optional[float]) -> None:
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
@@ -454,13 +490,17 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
except BaseException as error:
await self._raise_connection_failure(error)

async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
async def receive_message(
self, request_id: Optional[int], enable_pending: bool = False
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise ConnectionFailure.
If any exception is raised, the socket is closed.
"""
try:
return await async_receive_message(self, request_id, self.max_message_size)
return await async_receive_message(
self, request_id, self.max_message_size, enable_pending
)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
await self._raise_connection_failure(error)
@@ -495,7 +535,9 @@ async def write_command(
:param msg: bytes, the command message.
"""
await self.send_message(msg, 0)
reply = await self.receive_message(request_id)
reply = await self.receive_message(
request_id, enable_pending=(_csot.get_timeout() is not None)
)
result = reply.command_response(codec_options)

# Raises NotPrimaryError or OperationFailure.
@@ -635,7 +677,10 @@ async def _raise_connection_failure(self, error: BaseException) -> NoReturn:
reason = None
else:
reason = ConnectionClosedReason.ERROR
await self.close_conn(reason)

# Pending connections should be placed back in the pool.
if not self.pending_response:
await self.close_conn(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, SSLError)):
details = _get_timeout_details(self.opts)
@@ -1076,7 +1121,7 @@ async def checkout(
This method should always be used in a with-statement::
with pool.get_conn() as connection:
with pool.checkout() as connection:
connection.send_message(msg)
data = connection.receive_message(op_code, request_id)
@@ -1388,6 +1433,7 @@ async def _perished(self, conn: AsyncConnection) -> bool:
pool, to keep performance reasonable - we can't avoid AutoReconnects
completely anyway.
"""
await conn.complete_pending()
idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one.
if (
2 changes: 1 addition & 1 deletion pymongo/asynchronous/server.py
Original file line number Diff line number Diff line change
@@ -205,7 +205,7 @@ async def run_operation(
reply = await conn.receive_message(None)
else:
await conn.send_message(data, max_doc_size)
reply = await conn.receive_message(request_id)
reply = await conn.receive_message(request_id, operation.pending_enabled())

# Unpack and check for command errors.
if use_cmd:
18 changes: 16 additions & 2 deletions pymongo/message.py
Original file line number Diff line number Diff line change
@@ -1569,6 +1569,7 @@ class _Query:
"allow_disk_use",
"_as_command",
"exhaust",
"_pending_enabled",
)

# For compatibility with the _GetMore class.
@@ -1612,6 +1613,10 @@ def __init__(
self.name = "find"
self._as_command: Optional[tuple[dict[str, Any], str]] = None
self.exhaust = exhaust
self._pending_enabled = False

def pending_enabled(self) -> bool:
return self._pending_enabled

def reset(self) -> None:
self._as_command = None
@@ -1673,7 +1678,9 @@ def as_command(
conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type]
# Support CSOT
if apply_timeout:
conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type]
res = conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type]
if res is not None:
self._pending_enabled = True
self._as_command = cmd, self.db
return self._as_command

@@ -1747,6 +1754,7 @@ class _GetMore:
"_as_command",
"exhaust",
"comment",
"_pending_enabled",
)

name = "getMore"
@@ -1779,6 +1787,10 @@ def __init__(
self._as_command: Optional[tuple[dict[str, Any], str]] = None
self.exhaust = exhaust
self.comment = comment
self._pending_enabled = False

def pending_enabled(self) -> bool:
return self._pending_enabled

def reset(self) -> None:
self._as_command = None
@@ -1822,7 +1834,9 @@ def as_command(
conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type]
# Support CSOT
if apply_timeout:
conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type]
res = conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type]
if res is not None:
self._pending_enabled = True
self._as_command = cmd, self.db
return self._as_command

33 changes: 25 additions & 8 deletions pymongo/network_layer.py
Original file line number Diff line number Diff line change
@@ -325,7 +325,9 @@ def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
raise socket.timeout("timed out")


def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
def receive_data(
conn: Connection, length: int, deadline: Optional[float], enable_pending: bool = False
) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
@@ -336,7 +338,7 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
# When the timeout has expired we perform one final non-blocking recv.
# This helps avoid spurious timeouts when the response is actually already
# buffered on the client.
orig_timeout = conn.conn.gettimeout()
orig_timeout = conn.conn.gettimeout
try:
while bytes_read < length:
try:
@@ -357,12 +359,16 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
# We reached the true deadline.
if enable_pending:
conn.mark_pending(length - bytes_read)
raise socket.timeout("timed out") from None
except socket.timeout:
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
if _PYPY:
# We reached the true deadline.
if enable_pending:
conn.mark_pending(length - bytes_read)
raise
continue
except OSError as exc:
@@ -438,6 +444,7 @@ class NetworkingInterface(NetworkingInterfaceBase):
def __init__(self, conn: Union[socket.socket, _sslConn]):
super().__init__(conn)

@property
def gettimeout(self) -> float | None:
return self.conn.gettimeout()

@@ -692,6 +699,7 @@ async def async_receive_message(
conn: AsyncConnection,
request_id: Optional[int],
max_message_size: int = MAX_MESSAGE_SIZE,
enable_pending: bool = False,
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
timeout: Optional[Union[float, int]]
@@ -721,6 +729,8 @@ async def async_receive_message(
if pending:
await asyncio.wait(pending)
if len(done) == 0:
if enable_pending:
conn.mark_pending(1)
raise socket.timeout("timed out")
if read_task in done:
data, op_code = read_task.result()
@@ -740,19 +750,24 @@ async def async_receive_message(


def receive_message(
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
conn: Connection,
request_id: Optional[int],
max_message_size: int = MAX_MESSAGE_SIZE,
enable_pending: bool = False,
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
timeout = conn.conn.gettimeout
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
length, _, response_to, op_code = _UNPACK_HEADER(
receive_data(conn, 16, deadline, enable_pending)
)
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
@@ -767,10 +782,12 @@ def receive_message(
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
receive_data(conn, 9, deadline, enable_pending)
)
data = decompress(receive_data(conn, length - 25, deadline, enable_pending), compressor_id)
else:
data = receive_data(conn, length - 16, deadline)
data = receive_data(conn, length - 16, deadline, enable_pending)

try:
unpack_reply = _UNPACK_REPLY[op_code]
6 changes: 4 additions & 2 deletions pymongo/synchronous/network.py
Original file line number Diff line number Diff line change
@@ -138,8 +138,10 @@ def command(
spec = orig = client._encrypter.encrypt(dbname, spec, codec_options)

# Support CSOT
applied_csot = False
if client:
conn.apply_timeout(client, spec)
res = conn.apply_timeout(client, spec)
applied_csot = bool(res)
_csot.apply_write_concern(spec, write_concern)

if use_op_msg:
@@ -195,7 +197,7 @@ def command(
reply = None
response_doc: _DocumentOut = {"ok": 1}
else:
reply = receive_message(conn, request_id)
reply = receive_message(conn, request_id, enable_pending=applied_csot)
conn.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields
54 changes: 48 additions & 6 deletions pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
)

from bson import DEFAULT_CODEC_OPTIONS
from pymongo import _csot, helpers_shared
from pymongo import _csot, helpers_shared, network_layer
from pymongo.common import (
MAX_BSON_SIZE,
MAX_MESSAGE_SIZE,
@@ -188,6 +188,42 @@ def __init__(
self.creation_time = time.monotonic()
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None
self.pending_response = False
self.pending_bytes = 0
self.pending_deadline = 0.0

def mark_pending(self, nbytes: int) -> None:
"""Mark this connection as having a pending response."""
self.pending_response = True
self.pending_bytes = nbytes
self.pending_deadline = time.monotonic() + 3 # 3 seconds timeout for pending response

def complete_pending(self) -> None:
"""Complete a pending response."""
if not self.pending_response:
return

if _csot.get_timeout():
deadline = min(_csot.get_deadline(), self.pending_deadline)
else:
timeout = self.conn.gettimeout
if timeout is not None:
deadline = min(time.monotonic() + timeout, self.pending_deadline)
else:
deadline = self.pending_deadline

if not _IS_SYNC:
# In async the reader task reads the whole message at once.
# TODO: respect deadline
self.receive_message(None, True)
else:
try:
network_layer.receive_data(self, self.pending_bytes, deadline, True) # type:ignore[arg-type]
except BaseException as error:
self._raise_connection_failure(error)
self.pending_response = False
self.pending_bytes = 0
self.pending_deadline = 0.0

def set_conn_timeout(self, timeout: Optional[float]) -> None:
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
@@ -454,13 +490,15 @@ def send_message(self, message: bytes, max_doc_size: int) -> None:
except BaseException as error:
self._raise_connection_failure(error)

def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
def receive_message(
self, request_id: Optional[int], enable_pending: bool = False
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise ConnectionFailure.
If any exception is raised, the socket is closed.
"""
try:
return receive_message(self, request_id, self.max_message_size)
return receive_message(self, request_id, self.max_message_size, enable_pending)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)
@@ -495,7 +533,7 @@ def write_command(
:param msg: bytes, the command message.
"""
self.send_message(msg, 0)
reply = self.receive_message(request_id)
reply = self.receive_message(request_id, enable_pending=(_csot.get_timeout() is not None))
result = reply.command_response(codec_options)

# Raises NotPrimaryError or OperationFailure.
@@ -633,7 +671,10 @@ def _raise_connection_failure(self, error: BaseException) -> NoReturn:
reason = None
else:
reason = ConnectionClosedReason.ERROR
self.close_conn(reason)

# Pending connections should be placed back in the pool.
if not self.pending_response:
self.close_conn(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, SSLError)):
details = _get_timeout_details(self.opts)
@@ -1072,7 +1113,7 @@ def checkout(
This method should always be used in a with-statement::
with pool.get_conn() as connection:
with pool.checkout() as connection:
connection.send_message(msg)
data = connection.receive_message(op_code, request_id)
@@ -1384,6 +1425,7 @@ def _perished(self, conn: Connection) -> bool:
pool, to keep performance reasonable - we can't avoid AutoReconnects
completely anyway.
"""
conn.complete_pending()
idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one.
if (
2 changes: 1 addition & 1 deletion pymongo/synchronous/server.py
Original file line number Diff line number Diff line change
@@ -205,7 +205,7 @@ def run_operation(
reply = conn.receive_message(None)
else:
conn.send_message(data, max_doc_size)
reply = conn.receive_message(request_id)
reply = conn.receive_message(request_id, operation.pending_enabled())

# Unpack and check for command errors.
if use_cmd:
2 changes: 1 addition & 1 deletion test/asynchronous/test_client.py
Original file line number Diff line number Diff line change
@@ -2222,7 +2222,7 @@ async def test_exhaust_getmore_server_error(self):
await cursor.next()

# Cause a server error on getmore.
async def receive_message(request_id):
async def receive_message(request_id, enable_pending=False):
# Discard the actual server response.
await AsyncConnection.receive_message(conn, request_id)

2 changes: 1 addition & 1 deletion test/test_client.py
Original file line number Diff line number Diff line change
@@ -2179,7 +2179,7 @@ def test_exhaust_getmore_server_error(self):
cursor.next()

# Cause a server error on getmore.
def receive_message(request_id):
def receive_message(request_id, enable_pending=False):
# Discard the actual server response.
Connection.receive_message(conn, request_id)