Skip to content

Commit

Permalink
Prevent AssertionError in the recv_events thread.
Browse files Browse the repository at this point in the history
close_socket() was interacting with the protocol, namely calling
protocol.receive_of(), without locking the mutex. This created the
possibility of a race condition.

If two threads called receive_eof() concurrently, the second one
could return before the first one finished running it. This led to
receive_eof() returning (in the second thread) before the connection
state was CLOSED, breaking an invariant.

This race condition could be triggered reliably by shutting down the
network (e.g., turning wifi off), closing the connection, and waiting
for the timeout. Then, close() calls close_socket() — this happens in
the `raise_close_exc` branch of send_context(). This unblocks the read
in recv_events() which calls close_socket() in the `finally:` branch.

Fix #1558.
  • Loading branch information
aaugustin committed Jan 12, 2025
1 parent 031ec31 commit e7a098e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,8 +923,9 @@ def close_socket(self) -> None:

# Calling protocol.receive_eof() is safe because it's idempotent.
# This guarantees that the protocol state becomes CLOSED.
self.protocol.receive_eof()
assert self.protocol.state is CLOSED
with self.protocol_mutex:
self.protocol.receive_eof()
assert self.protocol.state is CLOSED

# Abort recv() with a ConnectionClosed exception.
self.recv_messages.close()
3 changes: 2 additions & 1 deletion tests/sync/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def test_connection_closed_during_handshake(self):
"""Client reads EOF before receiving handshake response from server."""

def close_connection(self, request):
self.close_socket()
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()

with run_server(process_request=close_connection) as server:
with self.assertRaises(InvalidMessage) as raised:
Expand Down

0 comments on commit e7a098e

Please sign in to comment.