From 3186d880fc51a7e662c484b9a73794ca82f46bd4 Mon Sep 17 00:00:00 2001 From: MattHag <16444067+MattHag@users.noreply.github.com> Date: Sun, 29 Dec 2024 23:05:43 +0100 Subject: [PATCH] base: Refactor device filtering Related #2273 --- lib/logitech_receiver/base.py | 81 +++++++++++++++++----------- tests/logitech_receiver/test_base.py | 8 +-- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/lib/logitech_receiver/base.py b/lib/logitech_receiver/base.py index 1af147c3c..51355dc04 100644 --- a/lib/logitech_receiver/base.py +++ b/lib/logitech_receiver/base.py @@ -56,7 +56,7 @@ logger = logging.getLogger(__name__) -class HIDAPI(typing.Protocol): +class HIDProtocol(typing.Protocol): def find_paired_node_wpid(self, receiver_path: str, index: int): ... @@ -106,7 +106,7 @@ def close(self, device_handle) -> None: # when pinging, be extra patient (no longer) _PING_TIMEOUT = DEFAULT_TIMEOUT -hidapi = typing.cast(HIDAPI, hidapi) +hidapi = typing.cast(HIDProtocol, hidapi) request_lock = threading.Lock() # serialize all requests handles_lock = {} @@ -156,7 +156,7 @@ def product_information(usb_id: int) -> dict[str, Any]: def receivers(): """Enumerate all the receivers attached to the machine.""" - yield from hidapi.enumerate(_filter_receivers) + yield from hidapi.enumerate(get_known_receiver_info) def filter_products_of_interest( @@ -164,34 +164,53 @@ def filter_products_of_interest( ) -> dict[str, Any] | None: """Check that this product is of interest and if so return the device record for further checking""" - def _other_device_check(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None: - """Check whether product is a Logitech USB-connected or Bluetooth device based on bus, vendor, and product IDs - This allows Solaar to support receiverless HID++ 2.0 devices that it knows nothing about""" - if vendor_id != LOGITECH_VENDOR_ID: - return + recv = get_known_receiver_info(bus_id, vendor_id, product_id, hidpp_short, hidpp_long) + if recv: # known or unknown receiver + return recv - device_info = None - if bus_id == BusID.USB and (0xC07D <= product_id <= 0xC094 or 0xC32B <= product_id <= 0xC344): - device_info = _usb_device(product_id, 2) - elif bus_id == BusID.BLUETOOTH and (0xB012 <= product_id <= 0xB0FF or 0xB317 <= product_id <= 0xB3FF): - device_info = _bluetooth_device(product_id) - return device_info + device = get_known_device_info(bus_id, vendor_id, product_id) + if device: + return device - record = _filter_receivers(bus_id, vendor_id, product_id, hidpp_short, hidpp_long) - if record: # known or unknown receiver - return record + if hidpp_short or hidpp_long: + return get_unknown_hid_device_info(bus_id, vendor_id, product_id) + + if hidpp_short is None and hidpp_long is None: + return get_unknown_logitech_device_info(bus_id, vendor_id, product_id) + return None + + +def get_known_device_info(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any]: + for recv in KNOWN_DEVICE_IDS: + if _match_device(recv, bus_id, vendor_id, product_id): + return recv + + +def get_unknown_hid_device_info(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any]: + return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": True} + + +def get_unknown_logitech_device_info(bus_id: int, vendor_id: int, product_id: int) -> dict[str, Any] | None: + """Get info from unknown device in Logitech product range. + + Check whether product is a Logitech USB-connected or Bluetooth + device based on bus, vendor, and product ID. This allows Solaar to + support receiverless HID++ 2.0 devices that it knows nothing about. + """ + if vendor_id != LOGITECH_VENDOR_ID: + return None + + if bus_id == BusID.USB.value and (0xC07D <= product_id <= 0xC094 or 0xC32B <= product_id <= 0xC344): + device_info = _usb_device(product_id, 2) + return device_info + elif bus_id == BusID.BLUETOOTH.value and (0xB012 <= product_id <= 0xB0FF or 0xB317 <= product_id <= 0xB3FF): + device_info = _bluetooth_device(product_id) + return device_info - for record in KNOWN_DEVICE_IDS: - if _match(record, bus_id, vendor_id, product_id): - return record - if hidpp_short or hidpp_long: # unknown devices that use HID++ - return {"vendor_id": vendor_id, "product_id": product_id, "bus_id": bus_id, "isDevice": True} - elif hidpp_short is None and hidpp_long is None: # unknown devices in correct range of IDs - return _other_device_check(bus_id, vendor_id, product_id) return None -def _match(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int): +def _match_device(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int): return ( (record.get("bus_id") is None or record.get("bus_id") == bus_id) and (record.get("vendor_id") is None or record.get("vendor_id") == vendor_id) @@ -199,7 +218,7 @@ def _match(record: dict[str, Any], bus_id: int, vendor_id: int, product_id: int) ) -def _filter_receivers( +def get_known_receiver_info( bus_id: int, vendor_id: int, product_id: int, _hidpp_short: bool = False, _hidpp_long: bool = False ) -> dict[str, Any]: """Check that this product is a Logitech receiver and return it. @@ -210,7 +229,7 @@ def _filter_receivers( """ try: record = base_usb.get_receiver_info(product_id) - if _match(record, bus_id, vendor_id, product_id): + if _match_device(record, bus_id, vendor_id, product_id): return record except ValueError: pass @@ -507,7 +526,7 @@ def request( ihandle = int(handle) notifications_hook = getattr(handle, "notifications_hook", None) try: - _skip_incoming(handle, ihandle, notifications_hook) + _read_input_buffer(handle, ihandle, notifications_hook) except exceptions.NoReceiver: logger.warning("device or receiver disconnected") return None @@ -604,7 +623,7 @@ def ping(handle, devnumber, long_message: bool = False): with acquire_timeout(handle_lock(handle), handle, 10.0): notifications_hook = getattr(handle, "notifications_hook", None) try: - _skip_incoming(handle, int(handle), notifications_hook) + _read_input_buffer(handle, int(handle), notifications_hook) except exceptions.NoReceiver: logger.warning("device or receiver disconnected") return @@ -651,8 +670,8 @@ def ping(handle, devnumber, long_message: bool = False): logger.warning("(%s) timeout (%0.2f/%0.2f) on device %d ping", handle, delta, _PING_TIMEOUT, devnumber) -def _skip_incoming(handle, ihandle, notifications_hook): - """Read anything already in the input buffer. +def _read_input_buffer(handle, ihandle, notifications_hook): + """Consume anything already in the input buffer. Used by request() and ping() before their write. """ diff --git a/tests/logitech_receiver/test_base.py b/tests/logitech_receiver/test_base.py index 9b01a127e..b23f9e7f8 100644 --- a/tests/logitech_receiver/test_base.py +++ b/tests/logitech_receiver/test_base.py @@ -40,7 +40,7 @@ def test_filter_receivers_known(): bus_id = 2 product_id = 0xC548 - receiver_info = base._filter_receivers(bus_id, LOGITECH_VENDOR_ID, product_id) + receiver_info = base.get_known_receiver_info(bus_id, LOGITECH_VENDOR_ID, product_id) assert receiver_info["name"] == "Bolt Receiver" assert receiver_info["receiver_kind"] == "bolt" @@ -50,7 +50,7 @@ def test_filter_receivers_unknown(): bus_id = 1 product_id = 0xC500 - receiver_info = base._filter_receivers(bus_id, LOGITECH_VENDOR_ID, product_id) + receiver_info = base.get_known_receiver_info(bus_id, LOGITECH_VENDOR_ID, product_id) assert receiver_info["bus_id"] == bus_id assert receiver_info["product_id"] == product_id @@ -90,7 +90,7 @@ def test_filter_products_of_interest(product_id, bus, hidpp_short, hidpp_long, e def test_match(): record = {"vendor_id": LOGITECH_VENDOR_ID} - res = base._match(record, 0, LOGITECH_VENDOR_ID, 0) + res = base._match_device(record, 0, LOGITECH_VENDOR_ID, 0) assert res is True @@ -152,7 +152,7 @@ def test_request_errors( with mock.patch( "logitech_receiver.base._read", return_value=(HIDPP_SHORT_MESSAGE_ID, device_number, prefix + reply_data_sw_id + struct.pack("B", error_code)), - ), mock.patch("logitech_receiver.base._skip_incoming", return_value=None), mock.patch( + ), mock.patch("logitech_receiver.base._read_input_buffer"), mock.patch( "logitech_receiver.base.write", return_value=None ), mock.patch("logitech_receiver.base._get_next_sw_id", return_value=next_sw_id): if raise_exception: