From fb4b2530665775ccd40d76709b3279a180a3a0d2 Mon Sep 17 00:00:00 2001 From: Jason Wallace Date: Fri, 8 Sep 2023 13:19:12 +0200 Subject: [PATCH] Refactor HID write timeout logic (#1621) Related https://github.com/tiny-pilot/tinypilot/issues/1026 This PR moves the multiprocessing timeout logic to a separate `execute` module, along with other related classes. This change simplifies the `hid.write` module by stripping out multiprocessing details not directly related to writing to the HID interface. ### Notes 1. We've moved the following classes to the new `execute` module: * `ProcessWithResult` * `ProcessResult` 1. We've moved the HID write timeout logic to a generic `execute.with_timeout` function 2. We've moved `hid.write_test.py` to `execute_test.py` 3. I'm not really sure why we need this function or if it is still required, but I kept it around anyway: https://github.com/tiny-pilot/tinypilot/blob/106e6448bd931da40f9e49c5fc97d5970fffa6d6/app/process.py#L84-L87 ### Peer testing You can test this build by running the following command on a device: ```bash curl \ --silent \ --show-error \ --location \ https://raw.githubusercontent.com/tiny-pilot/tinypilot/master/scripts/install-bundle | \ sudo bash -s -- \ https://output.circle-artifacts.com/output/job/5789b04e-8216-42fe-ad1e-227948464d03/artifacts/0/bundler/dist/tinypilot-community-20230907T1543Z-1.9.1-12+e1ba857.tgz ``` Review
on CodeApprove --- app/execute.py | 90 ++++++++++++++++++++++ app/{hid/write_test.py => execute_test.py} | 36 ++++++--- app/hid/write.py | 75 ++---------------- 3 files changed, 122 insertions(+), 79 deletions(-) create mode 100644 app/execute.py rename app/{hid/write_test.py => execute_test.py} (61%) diff --git a/app/execute.py b/app/execute.py new file mode 100644 index 000000000..0577a5bb7 --- /dev/null +++ b/app/execute.py @@ -0,0 +1,90 @@ +import dataclasses +import multiprocessing +import typing + + +@dataclasses.dataclass +class ProcessResult: + return_value: typing.Any = None + exception: Exception = None + + def was_successful(self) -> bool: + return self.exception is None + + +class ProcessWithResult(multiprocessing.Process): + """A multiprocessing.Process object that keeps track of the child process' + result (i.e., the return value and exception raised). + + Inspired by: + https://stackoverflow.com/a/33599967/3769045 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Create the Connection objects used for communication between the + # parent and child processes. + self.parent_conn, self.child_conn = multiprocessing.Pipe() + + def run(self): + """Method to be run in sub-process.""" + result = ProcessResult() + try: + if self._target: + result.return_value = self._target(*self._args, **self._kwargs) + except Exception as e: + result.exception = e + raise + finally: + self.child_conn.send(result) + + def result(self): + """Get the result from the child process. + + Returns: + If the child process has completed, a ProcessResult object. + Otherwise, a None object. + """ + return self.parent_conn.recv() if self.parent_conn.poll() else None + + +def with_timeout(function, *, args=None, timeout_in_seconds): + """Executes a function in a child process with a specified timeout. + + Usage example: + + with_timeout(save_contact, + args=(first_name, last_name), + timeout_in_seconds=0.5) + + Args: + function: The function to be executed in a child process. + args: Optional `function` arguments as a tuple. + timeout_in_seconds: The execution time limit in seconds. + + Returns: + The return value of the `function`. + + Raises: + TimeoutError: If the execution time of the `function` exceeds the + timeout `seconds`. + """ + process = ProcessWithResult(target=function, args=args or (), daemon=True) + process.start() + process.join(timeout=timeout_in_seconds) + if process.is_alive(): + process.kill() + _wait_for_process_exit(process) + result = process.result() + if result is None: + raise TimeoutError( + f'Process failed to complete in {timeout_in_seconds} seconds') + if not result.was_successful(): + raise result.exception + return result.return_value + + +def _wait_for_process_exit(target_process): + max_attempts = 3 + for _ in range(max_attempts): + target_process.join(timeout=0.1) diff --git a/app/hid/write_test.py b/app/execute_test.py similarity index 61% rename from app/hid/write_test.py rename to app/execute_test.py index f24b745bc..1940e6b23 100644 --- a/app/hid/write_test.py +++ b/app/execute_test.py @@ -3,7 +3,7 @@ import unittest from unittest import mock -import hid.write +import execute # Dummy functions to represent what can happen when a Human Interface Device # writes. @@ -35,20 +35,19 @@ def return_string(): return 'Done!' -class WriteTest(unittest.TestCase): +class ExecuteTest(unittest.TestCase): def test_process_with_result_child_completed(self): - process = hid.write.ProcessWithResult(target=do_nothing, daemon=True) + process = execute.ProcessWithResult(target=do_nothing, daemon=True) process.start() process.join() result = process.result() self.assertTrue(result.was_successful()) self.assertEqual( - hid.write.ProcessResult(return_value=None, exception=None), result) + execute.ProcessResult(return_value=None, exception=None), result) def test_process_with_result_child_not_completed(self): - process = hid.write.ProcessWithResult(target=sleep_1_second, - daemon=True) + process = execute.ProcessWithResult(target=sleep_1_second, daemon=True) process.start() # Get the result before the child process has completed. self.assertIsNone(process.result()) @@ -60,23 +59,36 @@ def test_process_with_result_child_exception(self): # Silence stderr while the child exception is being raised to avoid # polluting the terminal output. with mock.patch('sys.stderr', io.StringIO()): - process = hid.write.ProcessWithResult(target=raise_exception, - daemon=True) + process = execute.ProcessWithResult(target=raise_exception, + daemon=True) process.start() process.join() result = process.result() self.assertFalse(result.was_successful()) self.assertEqual( - hid.write.ProcessResult(return_value=None, exception=mock.ANY), + execute.ProcessResult(return_value=None, exception=mock.ANY), result) self.assertEqual('Child exception', str(result.exception)) def test_process_with_result_return_value(self): - process = hid.write.ProcessWithResult(target=return_string, daemon=True) + process = execute.ProcessWithResult(target=return_string, daemon=True) process.start() process.join() result = process.result() self.assertTrue(result.was_successful()) self.assertEqual( - hid.write.ProcessResult(return_value='Done!', exception=None), - result) + execute.ProcessResult(return_value='Done!', exception=None), result) + + def test_execute_with_timeout_and_timeout_reached(self): + with self.assertRaises(TimeoutError): + execute.with_timeout(sleep_1_second, timeout_in_seconds=0.5) + + def test_execute_with_timeout_return_value(self): + return_value = execute.with_timeout(return_string, + timeout_in_seconds=0.5) + self.assertEqual('Done!', return_value) + + def test_execute_with_timeout_child_exception(self): + with self.assertRaises(Exception) as ctx: + execute.with_timeout(raise_exception, timeout_in_seconds=0.5) + self.assertEqual('Child exception', str(ctx.exception)) diff --git a/app/hid/write.py b/app/hid/write.py index dacae470c..12a5d4928 100644 --- a/app/hid/write.py +++ b/app/hid/write.py @@ -1,7 +1,6 @@ -import dataclasses import logging -import multiprocessing -import typing + +import execute logger = logging.getLogger(__name__) @@ -14,51 +13,6 @@ class WriteError(Error): pass -@dataclasses.dataclass -class ProcessResult: - return_value: typing.Any = None - exception: Exception = None - - def was_successful(self) -> bool: - return self.exception is None - - -class ProcessWithResult(multiprocessing.Process): - """A multiprocessing.Process object that keeps track of the child process' - result (i.e., the return value and exception raised). - - Inspired by: - https://stackoverflow.com/a/33599967/3769045 - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Create the Connection objects used for communication between the - # parent and child processes. - self.parent_conn, self.child_conn = multiprocessing.Pipe() - - def run(self): - """Method to be run in sub-process.""" - result = ProcessResult() - try: - if self._target: - result.return_value = self._target(*self._args, **self._kwargs) - except Exception as e: - result.exception = e - raise - finally: - self.child_conn.send(result) - - def result(self): - """Get the result from the child process. - - Returns: - If the child process has completed, a ProcessResult object. - Otherwise, a None object. - """ - return self.parent_conn.recv() if self.parent_conn.poll() else None - - def _write_to_hid_interface_immediately(hid_path, buffer): try: with open(hid_path, 'ab+') as hid_handle: @@ -78,23 +32,10 @@ def write_to_hid_interface(hid_path, buffer): # Writes can hang, for example, when TinyPilot is attempting to write to the # mouse interface, but the target system has no GUI. To avoid locking up the # main server process, perform the HID interface I/O in a separate process. - write_process = ProcessWithResult( - target=_write_to_hid_interface_immediately, - args=(hid_path, buffer), - daemon=True) - write_process.start() - write_process.join(timeout=0.5) - if write_process.is_alive(): - write_process.kill() - _wait_for_process_exit(write_process) - result = write_process.result() - # If the result is None, it means the write failed to complete in time. - if result is None or not result.was_successful(): + try: + execute.with_timeout(_write_to_hid_interface_immediately, + args=(hid_path, buffer), + timeout_in_seconds=0.5) + except TimeoutError as e: raise WriteError(f'Failed to write to HID interface: {hid_path}. ' - 'Is USB cable connected?') - - -def _wait_for_process_exit(target_process): - max_attempts = 3 - for _ in range(max_attempts): - target_process.join(timeout=0.1) + 'Is USB cable connected?') from e