Skip to content

Commit 8eab490

Browse files
feat: Add mutual exclusion for synchronized stream access in logging handlers and CLPLoglevelTimeout (fixes #55). (#59)
Co-authored-by: Lin Zhihao <[email protected]>
1 parent fedc9b0 commit 8eab490

File tree

1 file changed

+59
-28
lines changed

1 file changed

+59
-28
lines changed

src/clp_logging/handlers.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pathlib import Path
1111
from queue import Empty, Queue
1212
from signal import SIGINT, signal, SIGTERM
13-
from threading import Thread, Timer
13+
from threading import RLock, Thread, Timer
1414
from types import FrameType
1515
from typing import Any, Callable, ClassVar, Dict, IO, Optional, Tuple, Union
1616

@@ -175,6 +175,11 @@ class CLPLogLevelTimeout:
175175
the last timeout. Therefore, if we've seen a log level with a low
176176
delta, that delta will continue to be used to calculate the soft
177177
timer until a timeout occurs.
178+
179+
Thread safety:
180+
- This class locks any operations on the stream set by `set_ostream`.
181+
- Any logging handler with a timeout object should lock the stream
182+
operations using the lock return by `get_lock()`.
178183
"""
179184

180185
# delta times in milliseconds
@@ -226,29 +231,34 @@ def __init__(
226231
self.ostream: Optional[Union[ZstdCompressionWriter, IO[bytes]]] = None
227232
self.hard_timeout_thread: Optional[Timer] = None
228233
self.soft_timeout_thread: Optional[Timer] = None
234+
self.lock: RLock = RLock()
229235

230236
def set_ostream(self, ostream: Union[ZstdCompressionWriter, IO[bytes]]) -> None:
231237
self.ostream = ostream
232238

239+
def get_lock(self) -> RLock:
240+
return self.lock
241+
233242
def timeout(self) -> None:
234243
"""
235244
Wraps the call to the user supplied `timeout_fn` ensuring that any
236245
existing timeout threads are cancelled, `next_hard_timeout_ts` and
237246
`min_soft_timeout_delta` are reset, and the zstandard frame is flushed.
238247
"""
239-
if self.hard_timeout_thread:
240-
self.hard_timeout_thread.cancel()
241-
if self.soft_timeout_thread:
242-
self.soft_timeout_thread.cancel()
243-
self.next_hard_timeout_ts = ULONG_MAX
244-
self.min_soft_timeout_delta = ULONG_MAX
245-
246-
if self.ostream:
247-
if isinstance(self.ostream, ZstdCompressionWriter):
248-
self.ostream.flush(FLUSH_FRAME)
249-
else:
250-
self.ostream.flush()
251-
self.timeout_fn()
248+
with self.get_lock():
249+
if self.hard_timeout_thread:
250+
self.hard_timeout_thread.cancel()
251+
if self.soft_timeout_thread:
252+
self.soft_timeout_thread.cancel()
253+
self.next_hard_timeout_ts = ULONG_MAX
254+
self.min_soft_timeout_delta = ULONG_MAX
255+
256+
if self.ostream:
257+
if isinstance(self.ostream, ZstdCompressionWriter):
258+
self.ostream.flush(FLUSH_FRAME)
259+
else:
260+
self.ostream.flush()
261+
self.timeout_fn()
252262

253263
def update(self, loglevel: int, log_timestamp_ms: int, log_fn: Callable[[str], None]) -> None:
254264
"""
@@ -302,6 +312,21 @@ def update(self, loglevel: int, log_timestamp_ms: int, log_fn: Callable[[str], N
302312
self.soft_timeout_thread.start()
303313

304314

315+
def _get_mutex_context_from_loglevel_timeout(loglevel_timeout: Optional[CLPLogLevelTimeout]) -> Any:
316+
"""
317+
Gets a mutual exclusive context manager for IR stream access.
318+
319+
NOTE: The return type should be `AbstractContextManager[Optional[bool]]`,
320+
but it is annotated as `Any` to satisfy the linter in Python 3.7 and 3.8,
321+
as `AbstractContextManager` was introduced in Python 3.9 (#18239).
322+
323+
:param loglevel_timeout: An optional `CLPLogLevelTimeout` object.
324+
:return: A context manager that either provides the lock from
325+
`loglevel_timeout` or a `nullcontext` if `loglevel_timeout` is `None`.
326+
"""
327+
return loglevel_timeout.get_lock() if loglevel_timeout else nullcontext()
328+
329+
305330
class CLPSockListener:
306331
"""
307332
Server that listens to a named Unix domain socket for `CLPSockHandler`
@@ -451,18 +476,21 @@ def log_fn(msg: str) -> None:
451476
timestamp_ms - last_timestamp_ms
452477
)
453478
last_timestamp_ms = timestamp_ms
454-
if loglevel_timeout:
455-
loglevel_timeout.update(loglevel, last_timestamp_ms, log_fn)
456479
buf += timestamp_buf
457-
ostream.write(buf)
480+
with _get_mutex_context_from_loglevel_timeout(loglevel_timeout):
481+
if loglevel_timeout:
482+
loglevel_timeout.update(loglevel, last_timestamp_ms, log_fn)
483+
ostream.write(buf)
458484
if loglevel_timeout:
459485
loglevel_timeout.timeout()
460-
ostream.write(EOF_CHAR)
461486

462-
if enable_compression:
463-
# Since we are not using context manager, the ostream should be
464-
# explicitly closed.
465-
ostream.close()
487+
with _get_mutex_context_from_loglevel_timeout(loglevel_timeout):
488+
ostream.write(EOF_CHAR)
489+
490+
if enable_compression:
491+
# Since we are not using context manager, the ostream should be
492+
# explicitly closed.
493+
ostream.close()
466494
# tell _server to exit
467495
CLPSockListener._signaled = True
468496
return 0
@@ -740,17 +768,19 @@ def _direct_write(self, msg: str) -> None:
740768
raise RuntimeError("Stream already closed")
741769
clp_msg: bytearray
742770
clp_msg, self.last_timestamp_ms = _encode_log_event(msg, self.last_timestamp_ms)
743-
self.ostream.write(clp_msg)
771+
with _get_mutex_context_from_loglevel_timeout(self.loglevel_timeout):
772+
self.ostream.write(clp_msg)
744773

745774
# override
746775
def _write(self, loglevel: int, msg: str) -> None:
747776
if self.closed:
748777
raise RuntimeError("Stream already closed")
749778
clp_msg: bytearray
750779
clp_msg, self.last_timestamp_ms = _encode_log_event(msg, self.last_timestamp_ms)
751-
if self.loglevel_timeout:
752-
self.loglevel_timeout.update(loglevel, self.last_timestamp_ms, self._direct_write)
753-
self.ostream.write(clp_msg)
780+
with _get_mutex_context_from_loglevel_timeout(self.loglevel_timeout):
781+
if self.loglevel_timeout:
782+
self.loglevel_timeout.update(loglevel, self.last_timestamp_ms, self._direct_write)
783+
self.ostream.write(clp_msg)
754784

755785
# Added to logging.StreamHandler in python 3.7
756786
# override
@@ -775,8 +805,9 @@ def setStream(self, stream: IO[bytes]) -> Optional[IO[bytes]]:
775805
def close(self) -> None:
776806
if self.loglevel_timeout:
777807
self.loglevel_timeout.timeout()
778-
self.ostream.write(EOF_CHAR)
779-
self.ostream.close()
808+
with _get_mutex_context_from_loglevel_timeout(self.loglevel_timeout):
809+
self.ostream.write(EOF_CHAR)
810+
self.ostream.close()
780811
self.closed = True
781812
super().close()
782813

0 commit comments

Comments
 (0)