From 794a4a3ef078e3b12130d74dbd3c44c0cc473e09 Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Thu, 26 Sep 2024 17:26:50 -0700 Subject: [PATCH 1/3] add basic support for SCO packets over USB --- apps/controller_info.py | 12 + apps/controller_loopback.py | 177 +++++-- apps/usb_probe.py | 7 +- bumble/device.py | 27 +- bumble/hci.py | 121 ++++- bumble/hfp.py | 26 +- bumble/host.py | 28 +- bumble/rfcomm.py | 2 +- bumble/transport/usb.py | 921 ++++++++++++++++++++++------------ examples/run_hfp_handsfree.py | 354 +++++++++---- pyproject.toml | 2 +- 11 files changed, 1163 insertions(+), 514 deletions(-) diff --git a/apps/controller_info.py b/apps/controller_info.py index 0e49b6fb..12630525 100644 --- a/apps/controller_info.py +++ b/apps/controller_info.py @@ -45,8 +45,10 @@ HCI_Read_Local_Supported_Codecs_Command, HCI_Read_Local_Supported_Codecs_V2_Command, HCI_Read_Local_Version_Information_Command, + HCI_Read_Voice_Setting_Command, LeFeature, SpecificationVersion, + VoiceSetting, map_null_terminated_utf8_string, ) from bumble.host import Host @@ -214,6 +216,16 @@ async def get_codecs_info(host: Host) -> None: if not response2.vendor_specific_codec_ids: print(' No Vendor-specific codecs') + if host.supports_command(HCI_Read_Voice_Setting_Command.op_code): + response3 = await host.send_sync_command(HCI_Read_Voice_Setting_Command()) + voice_setting = VoiceSetting.from_int(response3.voice_setting) + print(color('Voice Setting:', 'yellow')) + print(f' Air Coding Format: {voice_setting.air_coding_format.name}') + print(f' Linear PCM Bit Position: {voice_setting.linear_pcm_bit_position}') + print(f' Input Sample Size: {voice_setting.input_sample_size.name}') + print(f' Input Data Format: {voice_setting.input_data_format.name}') + print(f' Input Coding Format: {voice_setting.input_coding_format.name}') + # ----------------------------------------------------------------------------- async def async_main( diff --git a/apps/controller_loopback.py b/apps/controller_loopback.py index 89a5c165..682dcb2e 100644 --- a/apps/controller_loopback.py +++ b/apps/controller_loopback.py @@ -16,6 +16,8 @@ # Imports # ----------------------------------------------------------------------------- import asyncio +import statistics +import struct import time import click @@ -25,7 +27,9 @@ from bumble.hci import ( HCI_READ_LOOPBACK_MODE_COMMAND, HCI_WRITE_LOOPBACK_MODE_COMMAND, + Address, HCI_Read_Loopback_Mode_Command, + HCI_SynchronousDataPacket, HCI_Write_Loopback_Mode_Command, LoopbackMode, ) @@ -36,34 +40,59 @@ class Loopback: """Send and receive ACL data packets in local loopback mode""" - def __init__(self, packet_size: int, packet_count: int, transport: str): + def __init__( + self, + packet_size: int, + packet_count: int, + connection_type: str, + mode: str, + interval: int, + transport: str, + ): self.transport = transport self.packet_size = packet_size self.packet_count = packet_count self.connection_handle: int | None = None + self.connection_type = connection_type self.connection_event = asyncio.Event() + self.mode = mode + self.interval = interval self.done = asyncio.Event() - self.expected_cid = 0 + self.expected_counter = 0 self.bytes_received = 0 self.start_timestamp = 0.0 self.last_timestamp = 0.0 + self.send_timestamps: list[float] = [] + self.rtts: list[float] = [] def on_connection(self, connection_handle: int, *args): """Retrieve connection handle from new connection event""" if not self.connection_event.is_set(): - # save first connection handle for ACL - # subsequent connections are SCO + # The first connection handle is of type ACL, + # subsequent connections are of type SCO + if self.connection_type == "sco" and self.connection_handle is None: + self.connection_handle = connection_handle + return + self.connection_handle = connection_handle self.connection_event.set() + def on_sco_connection( + self, address: Address, connection_handle: int, link_type: int + ): + self.on_connection(connection_handle) + def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes): """Calculate packet receive speed""" now = time.time() - print(f'<<< Received packet {cid}: {len(pdu)} bytes') + (counter,) = struct.unpack_from("H", pdu, 0) + rtt = now - self.send_timestamps[counter] + self.rtts.append(rtt) + print(f'<<< Received packet {counter}: {len(pdu)} bytes, RTT={rtt:.4f}') assert connection_handle == self.connection_handle - assert cid == self.expected_cid - self.expected_cid += 1 - if cid == 0: + assert counter == self.expected_counter + self.expected_counter += 1 + if counter == 0: self.start_timestamp = now else: elapsed_since_start = now - self.start_timestamp @@ -71,20 +100,52 @@ def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes): self.bytes_received += len(pdu) instant_rx_speed = len(pdu) / elapsed_since_last average_rx_speed = self.bytes_received / elapsed_since_start - print( - color( - f'@@@ RX speed: instant={instant_rx_speed:.4f},' - f' average={average_rx_speed:.4f}', - 'cyan', + if self.mode == 'throughput': + print( + color( + f'@@@ RX speed: instant={instant_rx_speed:.4f},' + f' average={average_rx_speed:.4f},', + 'cyan', + ) ) - ) self.last_timestamp = now - if self.expected_cid == self.packet_count: + if self.expected_counter == self.packet_count: print(color('@@@ Received last packet', 'green')) self.done.set() + def on_sco_packet(self, connection_handle: int, packet) -> None: + print("---", connection_handle, packet) + + async def send_acl_packet(self, host: Host, packet: bytes) -> None: + assert self.connection_handle + host.send_l2cap_pdu(self.connection_handle, 0, packet) + + async def send_sco_packet(self, host: Host, packet: bytes) -> None: + assert self.connection_handle + host.send_hci_packet( + HCI_SynchronousDataPacket( + connection_handle=self.connection_handle, + packet_status=HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA, + data_total_length=len(packet), + data=packet, + ) + ) + + async def send_loop(self, host: Host, sender) -> None: + for counter in range(0, self.packet_count): + print( + color( + f'>>> Sending {self.connection_type.upper()} ' + f'packet {counter}: {self.packet_size} bytes', + 'yellow', + ) + ) + self.send_timestamps.append(time.time()) + await sender(host, struct.pack("H", counter) + bytes(self.packet_size - 2)) + await asyncio.sleep(self.interval / 1000 if self.mode == "rtt" else 0) + async def run(self) -> None: """Run a loopback throughput test""" print(color('>>> Connecting to HCI...', 'green')) @@ -126,8 +187,11 @@ async def run(self) -> None: return # set event callbacks - host.on('connection', self.on_connection) + host.on('classic_connection', self.on_connection) + host.on('le_connection', self.on_connection) + host.on('sco_connection', self.on_sco_connection) host.on('l2cap_pdu', self.on_l2cap_pdu) + host.on('sco_packet', self.on_sco_packet) loopback_mode = LoopbackMode.LOCAL @@ -148,32 +212,37 @@ async def run(self) -> None: print(color('=== Start sending', 'magenta')) start_time = time.time() - bytes_sent = 0 - for cid in range(0, self.packet_count): - # using the cid as an incremental index - host.send_l2cap_pdu( - self.connection_handle, cid, bytes(self.packet_size) - ) - print( - color( - f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow' - ) - ) - bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes - await asyncio.sleep(0) # yield to allow packet receive + if self.connection_type == "acl": + sender = self.send_acl_packet + elif self.connection_type == "sco": + sender = self.send_sco_packet + else: + raise ValueError(f'Unknown connection type: {self.connection_type}') + await self.send_loop(host, sender) await self.done.wait() print(color('=== Done!', 'magenta')) + bytes_sent = self.packet_size * self.packet_count elapsed = time.time() - start_time average_tx_speed = bytes_sent / elapsed - print( - color( - f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes' - f' in {elapsed:.2f} seconds)', - 'green', + if self.mode == 'throughput': + print( + color( + f'@@@ TX speed: average={average_tx_speed:.4f} ' + f'({bytes_sent} bytes in {elapsed:.2f} seconds)', + 'green', + ) + ) + if self.mode == 'rtt': + print( + color( + f'RTTs: min={min(self.rtts):.4f}, ' + f'max={max(self.rtts):.4f}, ' + f'avg={statistics.mean(self.rtts):.4f}', + 'blue', + ) ) - ) # ----------------------------------------------------------------------------- @@ -194,11 +263,43 @@ async def run(self) -> None: default=10, help='Packet count', ) +@click.option( + '--connection-type', + '-t', + metavar='TYPE', + type=click.Choice(['acl', 'sco']), + default='acl', + help='Connection type', +) +@click.option( + '--mode', + '-m', + metavar='MODE', + type=click.Choice(['throughput', 'rtt']), + default='throughput', + help='Test mode', +) +@click.option( + '--interval', + type=int, + default=100, + help='Inter-packet interval (ms) [RTT mode only]', +) @click.argument('transport') -def main(packet_size, packet_count, transport): +def main(packet_size, packet_count, connection_type, mode, interval, transport): bumble.logging.setup_basic_logging() - loopback = Loopback(packet_size, packet_count, transport) - asyncio.run(loopback.run()) + + if connection_type == "sco" and packet_size > 255: + print("ERROR: the maximum packet size for SCO is 255") + return + + async def run(): + loopback = Loopback( + packet_size, packet_count, connection_type, mode, interval, transport + ) + await loopback.run() + + asyncio.run(run()) # ----------------------------------------------------------------------------- diff --git a/apps/usb_probe.py b/apps/usb_probe.py index 51ab91cc..81a266f2 100644 --- a/apps/usb_probe.py +++ b/apps/usb_probe.py @@ -111,9 +111,14 @@ def show_device_details(device): if (endpoint.getAddress() & USB_ENDPOINT_IN == 0) else 'IN' ) + endpoint_details = ( + f', Max Packet Size = {endpoint.getMaxPacketSize()}' + if endpoint_type == 'ISOCHRONOUS' + else '' + ) print( f' Endpoint 0x{endpoint.getAddress():02X}: ' - f'{endpoint_type} {endpoint_direction}' + f'{endpoint_type} {endpoint_direction}{endpoint_details}' ) diff --git a/bumble/device.py b/bumble/device.py index dbaeb52e..0da7b656 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -1423,6 +1423,9 @@ class ScoLink(utils.CompositeEventEmitter): acl_connection: Connection handle: int link_type: int + rx_packet_length: int + tx_packet_length: int + air_mode: hci.CodecID sink: Callable[[hci.HCI_SynchronousDataPacket], Any] | None = None EVENT_DISCONNECTION: ClassVar[str] = "disconnection" @@ -5968,7 +5971,7 @@ def on_connection_failure( def on_connection_request( self, bd_addr: hci.Address, class_of_device: int, link_type: int ): - logger.debug(f'*** Connection request: {bd_addr}') + logger.debug(f'*** Connection request: {bd_addr} link_type={link_type}') # Handle SCO request. if link_type in ( @@ -5978,6 +5981,7 @@ def on_connection_request( if connection := self.find_connection_by_bd_addr( bd_addr, transport=PhysicalTransport.BR_EDR ): + connection.emit(self.EVENT_SCO_REQUEST, link_type) self.emit(self.EVENT_SCO_REQUEST, connection, link_type) else: logger.error(f'SCO request from a non-connected device {bd_addr}') @@ -6337,8 +6341,7 @@ def on_remote_name( logger.warning('peer name is not valid UTF-8') if connection: connection.emit(connection.EVENT_REMOTE_NAME_FAILURE, error) - else: - self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error) + self.emit(self.EVENT_REMOTE_NAME_FAILURE, address, error) # [Classic only] @host_event_handler @@ -6355,7 +6358,13 @@ def on_remote_name_failure( @with_connection_from_address @utils.experimental('Only for testing.') def on_sco_connection( - self, acl_connection: Connection, sco_handle: int, link_type: int + self, + acl_connection: Connection, + sco_handle: int, + link_type: int, + rx_packet_length: int, + tx_packet_length: int, + air_mode: int, ) -> None: logger.debug( f'*** SCO connected: {acl_connection.peer_address}, ' @@ -6367,7 +6376,11 @@ def on_sco_connection( acl_connection=acl_connection, handle=sco_handle, link_type=link_type, + rx_packet_length=rx_packet_length, + tx_packet_length=tx_packet_length, + air_mode=hci.CodecID(air_mode), ) + acl_connection.emit(self.EVENT_SCO_CONNECTION, sco_link) self.emit(self.EVENT_SCO_CONNECTION, sco_link) # [Classic only] @@ -6378,7 +6391,8 @@ def on_sco_connection_failure( self, acl_connection: Connection, status: int ) -> None: logger.debug(f'*** SCO connection failure: {acl_connection.peer_address}***') - self.emit(self.EVENT_SCO_CONNECTION_FAILURE) + acl_connection.emit(self.EVENT_SCO_CONNECTION_FAILURE, status) + self.emit(self.EVENT_SCO_CONNECTION_FAILURE, status) # [Classic only] @host_event_handler @@ -6841,15 +6855,18 @@ def on_role_change_failure( @with_connection_from_address def on_classic_pairing(self, connection: Connection) -> None: connection.emit(connection.EVENT_CLASSIC_PAIRING) + self.emit(connection.EVENT_CLASSIC_PAIRING, connection) # [Classic only] @host_event_handler @with_connection_from_address def on_classic_pairing_failure(self, connection: Connection, status: int) -> None: connection.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, status) + self.emit(connection.EVENT_CLASSIC_PAIRING_FAILURE, connection, status) def on_pairing_start(self, connection: Connection) -> None: connection.emit(connection.EVENT_PAIRING_START) + self.emit(connection.EVENT_PAIRING_START, connection) def on_pairing( self, diff --git a/bumble/hci.py b/bumble/hci.py index d6696e0b..839cc9b7 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -1769,6 +1769,61 @@ def __bytes__(self) -> bytes: ) +@dataclasses.dataclass(frozen=True) +class VoiceSetting: + class AirCodingFormat(enum.IntEnum): + CVSD = 0 + U_LAW = 1 + A_LAW = 2 + TRANSPARENT_DATA = 3 + + class InputSampleSize(enum.IntEnum): + SIZE_8_BITS = 0 + SIZE_16_BITS = 1 + + class InputDataFormat(enum.IntEnum): + ONES_COMPLEMENT = 0 + TWOS_COMPLEMENT = 1 + SIGN_AND_MAGNITUDE = 2 + UNSIGNED = 3 + + class InputCodingFormat(enum.IntEnum): + LINEAR = 0 + U_LAW = 1 + A_LAW = 2 + RESERVED = 3 + + air_coding_format: AirCodingFormat = AirCodingFormat.CVSD + linear_pcm_bit_position: int = 0 + input_sample_size: InputSampleSize = InputSampleSize.SIZE_8_BITS + input_data_format: InputDataFormat = InputDataFormat.ONES_COMPLEMENT + input_coding_format: InputCodingFormat = InputCodingFormat.LINEAR + + @classmethod + def from_int(cls, value: int) -> VoiceSetting: + air_coding_format = cls.AirCodingFormat(value & 0b11) + linear_pcm_bit_position = (value >> 2) & 0b111 + input_sample_size = cls.InputSampleSize((value >> 5) & 0b1) + input_data_format = cls.InputDataFormat((value >> 6) & 0b11) + input_coding_format = cls.InputCodingFormat((value >> 8) & 0b11) + return cls( + air_coding_format=air_coding_format, + linear_pcm_bit_position=linear_pcm_bit_position, + input_sample_size=input_sample_size, + input_data_format=input_data_format, + input_coding_format=input_coding_format, + ) + + def __int__(self) -> int: + return ( + self.air_coding_format + | (self.linear_pcm_bit_position << 2) + | (self.input_sample_size << 5) + | (self.input_data_format << 6) + | (self.input_coding_format << 8) + ) + + # ----------------------------------------------------------------------------- class HCI_Constant: @staticmethod @@ -2907,6 +2962,23 @@ class HCI_Read_Clock_Offset_Command(HCI_AsyncCommand): connection_handle: int = field(metadata=metadata(2)) +# ----------------------------------------------------------------------------- +@HCI_Command.command +@dataclasses.dataclass +class HCI_Accept_Synchronous_Connection_Request_Command(HCI_AsyncCommand): + ''' + See Bluetooth spec @ 7.1.27 Accept Synchronous Connection Request Command + ''' + + bd_addr: Address = field(metadata=metadata(Address.parse_address)) + transmit_bandwidth: int = field(metadata=metadata(4)) + receive_bandwidth: int = field(metadata=metadata(4)) + max_latency: int = field(metadata=metadata(2)) + voice_setting: int = field(metadata=metadata(2)) + retransmission_effort: int = field(metadata=metadata(1)) + packet_type: int = field(metadata=metadata(2)) + + # ----------------------------------------------------------------------------- @HCI_Command.command @dataclasses.dataclass @@ -3965,6 +4037,23 @@ class HCI_Read_Local_OOB_Extended_Data_Command( ''' +# ----------------------------------------------------------------------------- +@HCI_SyncCommand.sync_command(HCI_StatusReturnParameters) +@dataclasses.dataclass +class HCI_Configure_Data_Path_Command(HCI_SyncCommand[HCI_StatusReturnParameters]): + ''' + See Bluetooth spec @ 7.3.101 Configure Data Path Command + ''' + + class DataPathDirection(SpecableEnum): + INPUT = 0x00 + OUTPUT = 0x01 + + data_path_direction: DataPathDirection = field(metadata=metadata(1)) + data_path_id: int = field(metadata=metadata(1)) + vendor_specific_config: bytes = field(metadata=metadata('*')) + + # ----------------------------------------------------------------------------- @dataclasses.dataclass class HCI_Read_Local_Version_Information_ReturnParameters(HCI_StatusReturnParameters): @@ -7355,7 +7444,7 @@ class LinkType(SpecableEnum): status: int = field(metadata=metadata(STATUS_SPEC)) connection_handle: int = field(metadata=metadata(2)) bd_addr: Address = field(metadata=metadata(Address.parse_address)) - link_type: int = field(metadata=LinkType.type_metadata(1)) + link_type: LinkType = field(metadata=LinkType.type_metadata(1)) encryption_enabled: int = field(metadata=metadata(1)) @@ -7751,12 +7840,6 @@ class LinkType(SpecableEnum): SCO = 0x00 ESCO = 0x02 - class AirMode(SpecableEnum): - U_LAW_LOG = 0x00 - A_LAW_LOG_AIR_MORE = 0x01 - CVSD = 0x02 - TRANSPARENT_DATA = 0x03 - status: int = field(metadata=metadata(STATUS_SPEC)) connection_handle: int = field(metadata=metadata(2)) bd_addr: Address = field(metadata=metadata(Address.parse_address)) @@ -7765,7 +7848,7 @@ class AirMode(SpecableEnum): retransmission_window: int = field(metadata=metadata(1)) rx_packet_length: int = field(metadata=metadata(2)) tx_packet_length: int = field(metadata=metadata(2)) - air_mode: int = field(metadata=AirMode.type_metadata(1)) + air_mode: int = field(metadata=CodecID.type_metadata(1)) # ----------------------------------------------------------------------------- @@ -7997,7 +8080,9 @@ def from_bytes(cls, packet: bytes) -> HCI_AclDataPacket: bc_flag = (h >> 14) & 3 data = packet[5:] if len(data) != data_total_length: - raise InvalidPacketError('invalid packet length') + raise InvalidPacketError( + f'invalid packet length {len(data)} != {data_total_length}' + ) return cls( connection_handle=connection_handle, pb_flag=pb_flag, @@ -8030,10 +8115,16 @@ class HCI_SynchronousDataPacket(HCI_Packet): See Bluetooth spec @ 5.4.3 HCI SCO Data Packets ''' + class Status(enum.IntEnum): + CORRECTLY_RECEIVED_DATA = 0b00 + POSSIBLY_INVALID_DATA = 0b01 + NO_DATA = 0b10 + DATA_PARTIALLY_LOST = 0b11 + hci_packet_type = HCI_SYNCHRONOUS_DATA_PACKET connection_handle: int - packet_status: int + packet_status: Status data_total_length: int data: bytes @@ -8042,7 +8133,7 @@ def from_bytes(cls, packet: bytes) -> HCI_SynchronousDataPacket: # Read the header h, data_total_length = struct.unpack_from('> 12) & 0b11 + packet_status = cls.Status((h >> 12) & 0b11) data = packet[4:] if len(data) != data_total_length: raise InvalidPacketError( @@ -8066,7 +8157,7 @@ def __str__(self) -> str: return ( f'{color("SCO", "blue")}: ' f'handle=0x{self.connection_handle:04x}, ' - f'ps={self.packet_status}, ' + f'ps={self.packet_status.name}, ' f'data_total_length={self.data_total_length}, ' f'data={self.data.hex()}' ) @@ -8094,8 +8185,8 @@ class HCI_IsoDataPacket(HCI_Packet): def __post_init__(self) -> None: self.ts_flag = self.time_stamp is not None - @staticmethod - def from_bytes(packet: bytes) -> HCI_IsoDataPacket: + @classmethod + def from_bytes(cls, packet: bytes) -> HCI_IsoDataPacket: time_stamp: int | None = None packet_sequence_number: int | None = None iso_sdu_length: int | None = None @@ -8124,7 +8215,7 @@ def from_bytes(packet: bytes) -> HCI_IsoDataPacket: pos += 4 iso_sdu_fragment = packet[pos:] - return HCI_IsoDataPacket( + return cls( connection_handle=connection_handle, pb_flag=pb_flag, ts_flag=ts_flag, diff --git a/bumble/hfp.py b/bumble/hfp.py index 3c623a93..d1443fcf 100644 --- a/bumble/hfp.py +++ b/bumble/hfp.py @@ -166,7 +166,7 @@ class AgFeature(enum.IntFlag): VOICE_RECOGNITION_TEXT = 0x2000 -class AudioCodec(enum.IntEnum): +class AudioCodec(utils.OpenIntEnum): """ Audio Codec IDs (normative). @@ -178,7 +178,7 @@ class AudioCodec(enum.IntEnum): LC3_SWB = 0x03 # Support for LC3-SWB audio codec -class HfIndicator(enum.IntEnum): +class HfIndicator(utils.OpenIntEnum): """ HF Indicators (normative). @@ -207,7 +207,7 @@ class CallHoldOperation(enum.Enum): ) -class ResponseHoldStatus(enum.IntEnum): +class ResponseHoldStatus(utils.OpenIntEnum): """ Response Hold status (normative). @@ -235,7 +235,7 @@ class AgIndicator(enum.Enum): BATTERY_CHARGE = 'battchg' -class CallSetupAgIndicator(enum.IntEnum): +class CallSetupAgIndicator(utils.OpenIntEnum): """ Values for the Call Setup AG indicator (normative). @@ -248,7 +248,7 @@ class CallSetupAgIndicator(enum.IntEnum): REMOTE_ALERTED = 3 # Remote party alerted in an outgoing call -class CallHeldAgIndicator(enum.IntEnum): +class CallHeldAgIndicator(utils.OpenIntEnum): """ Values for the Call Held AG indicator (normative). @@ -262,7 +262,7 @@ class CallHeldAgIndicator(enum.IntEnum): CALL_ON_HOLD_NO_ACTIVE_CALL = 2 # Call on hold, no active call -class CallInfoDirection(enum.IntEnum): +class CallInfoDirection(utils.OpenIntEnum): """ Call Info direction (normative). @@ -273,7 +273,7 @@ class CallInfoDirection(enum.IntEnum): MOBILE_TERMINATED_CALL = 1 -class CallInfoStatus(enum.IntEnum): +class CallInfoStatus(utils.OpenIntEnum): """ Call Info status (normative). @@ -288,7 +288,7 @@ class CallInfoStatus(enum.IntEnum): WAITING = 5 -class CallInfoMode(enum.IntEnum): +class CallInfoMode(utils.OpenIntEnum): """ Call Info mode (normative). @@ -301,7 +301,7 @@ class CallInfoMode(enum.IntEnum): UNKNOWN = 9 -class CallInfoMultiParty(enum.IntEnum): +class CallInfoMultiParty(utils.OpenIntEnum): """ Call Info Multi-Party state (normative). @@ -388,7 +388,7 @@ def to_clip_string(self) -> str: ) -class VoiceRecognitionState(enum.IntEnum): +class VoiceRecognitionState(utils.OpenIntEnum): """ vrec values provided in AT+BVRA command. @@ -401,7 +401,7 @@ class VoiceRecognitionState(enum.IntEnum): ENHANCED_READY = 2 -class CmeError(enum.IntEnum): +class CmeError(utils.OpenIntEnum): """ CME ERROR codes (partial listed). @@ -1624,7 +1624,7 @@ def _on_vgm(self, level: bytes) -> None: # ----------------------------------------------------------------------------- -class ProfileVersion(enum.IntEnum): +class ProfileVersion(utils.OpenIntEnum): """ Profile version (normative). @@ -2076,6 +2076,7 @@ def asdict(self) -> dict[str, Any]: max_latency=0x0008, packet_type=( HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 + | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5 @@ -2091,7 +2092,6 @@ def asdict(self) -> dict[str, Any]: max_latency=0x000D, packet_type=( HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 - | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5 diff --git a/bumble/host.py b/bumble/host.py index 98fd131a..32187af3 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -686,6 +686,8 @@ async def _send_command( self.pending_response, timeout=response_timeout ) return response + except asyncio.TimeoutError: + raise except Exception: logger.exception(color("!!! Exception while sending command:", "red")) raise @@ -866,7 +868,7 @@ def send_sco_sdu(self, connection_handle: int, sdu: bytes) -> None: self.send_hci_packet( hci.HCI_SynchronousDataPacket( connection_handle=connection_handle, - packet_status=0, + packet_status=hci.HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA, data_total_length=len(sdu), data=sdu, ) @@ -1177,11 +1179,28 @@ def on_hci_le_enhanced_connection_complete_v2_event( def on_hci_connection_complete_event( self, event: hci.HCI_Connection_Complete_Event ): + if event.link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO: + # Pass this on to the synchronous connection handler + forwarded_event = hci.HCI_Synchronous_Connection_Complete_Event( + status=event.status, + connection_handle=event.connection_handle, + bd_addr=event.bd_addr, + link_type=event.link_type, + transmission_interval=0, + retransmission_window=0, + rx_packet_length=0, + tx_packet_length=0, + air_mode=0, + ) + self.on_hci_synchronous_connection_complete_event(forwarded_event) + return + if event.status == hci.HCI_SUCCESS: # Create/update the connection logger.debug( - f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] ' - f'{event.bd_addr}' + f'### BR/EDR ACL CONNECTION: [0x{event.connection_handle:04X}] ' + f'{event.bd_addr} ' + f'{event.link_type.name}' ) connection = self.connections.get(event.connection_handle) @@ -1581,6 +1600,9 @@ def on_hci_synchronous_connection_complete_event( event.bd_addr, event.connection_handle, event.link_type, + event.rx_packet_length, + event.tx_packet_length, + event.air_mode, ) else: logger.debug(f'### SCO CONNECTION FAILED: {event.status}') diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 5512a157..98f40f5e 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -110,7 +110,7 @@ class MccType(enum.IntEnum): RFCOMM_DEFAULT_INITIAL_CREDITS = 7 RFCOMM_DEFAULT_MAX_CREDITS = 32 RFCOMM_DEFAULT_CREDIT_THRESHOLD = RFCOMM_DEFAULT_MAX_CREDITS // 2 -RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000 +RFCOMM_DEFAULT_MAX_FRAME_SIZE = 1000 RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1 RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30 diff --git a/bumble/transport/usb.py b/bumble/transport/usb.py index 80fdcd3f..53cc7326 100644 --- a/bumble/transport/usb.py +++ b/bumble/transport/usb.py @@ -22,6 +22,8 @@ import logging import platform import threading +from collections.abc import Callable +from typing import Any import usb1 @@ -35,6 +37,28 @@ logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +# pylint: disable=invalid-name +USB_RECIPIENT_DEVICE = 0x00 +USB_REQUEST_TYPE_CLASS = 0x01 << 5 +USB_DEVICE_CLASS_DEVICE = 0x00 +USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 +USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 +USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01 +USB_ENDPOINT_TRANSFER_TYPE_ISOCHRONOUS = 0x01 +USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02 +USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03 +USB_ENDPOINT_IN = 0x80 + +USB_BT_HCI_CLASS_TUPLE = ( + USB_DEVICE_CLASS_WIRELESS_CONTROLLER, + USB_DEVICE_SUBCLASS_RF_CONTROLLER, + USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER, +) + + # ----------------------------------------------------------------------------- def load_libusb(): ''' @@ -60,318 +84,589 @@ def load_libusb(): usb1.loadLibrary(libusb_dll) -async def open_usb_transport(spec: str) -> Transport: - ''' - Open a USB transport. - The moniker string has this syntax: - either or - : or - :/] or - :# - With as the 0-based index to select amongst all the devices that appear - to be supporting Bluetooth HCI (0 being the first one), or - Where and are the vendor ID and product ID in hexadecimal. The - / suffix or # suffix max be specified when more than one - device with the same vendor and product identifiers are present. +def find_endpoints(device, forced_mode, sco_alternate=None): + '''Look for the interfaces with the right class and endpoints''' + # pylint: disable-next=too-many-nested-blocks + for configuration_index, configuration in enumerate(device): + # Select the interface and endpoints for ACL + acl_interface = None + bulk_in = None + bulk_out = None + interrupt_in = None + for interface in configuration: + for setting in interface: + if acl_interface is not None: + continue - In addition, if the moniker ends with the symbol "!", the device will be used in - "forced" mode: - the first USB interface of the device will be used, regardless of the interface - class/subclass. - This may be useful for some devices that use a custom class/subclass but may - nonetheless work as-is. + if ( + not forced_mode + and ( + setting.getClass(), + setting.getSubClass(), + setting.getProtocol(), + ) + != USB_BT_HCI_CLASS_TUPLE + ): + continue + + for endpoint in setting: + attributes = endpoint.getAttributes() + address = endpoint.getAddress() + if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK: + if address & USB_ENDPOINT_IN: + if bulk_in is None: + bulk_in = endpoint + else: + if bulk_out is None: + bulk_out = endpoint + elif attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT: + if address & USB_ENDPOINT_IN and interrupt_in is None: + interrupt_in = endpoint + + # Only keep complete sets (endpoints that should be under the + # same interface) + if ( + bulk_in is not None + and bulk_out is not None + and interrupt_in is not None + ): + acl_interface = setting - Examples: - 0 --> the first BT USB dongle - 04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901 - 04b4:f901#2 --> the third USB device with vendor=04b4 and product=f901 - 04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and - serial number 00E04C239987 - usb:0B05:17CB! --> the BT USB dongle vendor=0B05 and product=17CB, in "forced" mode. - ''' + # Select the interface and endpoints for SCO + sco_interface = None + max_packet_size = (0, 0) + isochronous_in = None + isochronous_out = None - # pylint: disable=invalid-name - USB_RECIPIENT_DEVICE = 0x00 - USB_REQUEST_TYPE_CLASS = 0x01 << 5 - USB_DEVICE_CLASS_DEVICE = 0x00 - USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0 - USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01 - USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01 - USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02 - USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03 - USB_ENDPOINT_IN = 0x80 - - USB_BT_HCI_CLASS_TUPLE = ( - USB_DEVICE_CLASS_WIRELESS_CONTROLLER, - USB_DEVICE_SUBCLASS_RF_CONTROLLER, - USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER, - ) - - READ_SIZE = 4096 - - class UsbPacketSink: - def __init__(self, device, acl_out): - self.device = device - self.acl_out = acl_out - self.acl_out_transfer = device.getTransfer() - self.acl_out_transfer_ready = asyncio.Semaphore(1) - self.packets = asyncio.Queue[bytes]() # Queue of packets waiting to be sent - self.loop = asyncio.get_running_loop() - self.queue_task = None - self.cancel_done = self.loop.create_future() - self.closed = False - - def start(self): - self.queue_task = asyncio.create_task(self.process_queue()) - - def on_packet(self, packet): - # Ignore packets if we're closed - if self.closed: - return + for interface in configuration: + if sco_interface is not None: + continue - if len(packet) == 0: - logger.warning('packet too short') - return + if sco_alternate is None: + continue - # Queue the packet - self.packets.put_nowait(packet) + for setting in interface: + if ( + not forced_mode + and ( + setting.getClass(), + setting.getSubClass(), + setting.getProtocol(), + ) + != USB_BT_HCI_CLASS_TUPLE + ): + continue - def transfer_callback(self, transfer): - self.loop.call_soon_threadsafe(self.acl_out_transfer_ready.release) - status = transfer.getStatus() + if ( + sco_alternate != 0 + and setting.getAlternateSetting() != sco_alternate + ): + continue + + isochronous_in = None + isochronous_out = None + + for endpoint in setting: + if ( + endpoint.getAttributes() & 0x03 + == USB_ENDPOINT_TRANSFER_TYPE_ISOCHRONOUS + ): + if endpoint.getMaxPacketSize() > 0: + if endpoint.getAddress() & USB_ENDPOINT_IN: + if ( + isochronous_in is None + or endpoint.getMaxPacketSize() + > (isochronous_in.getMaxPacketSize()) + ): + isochronous_in = endpoint + else: + if ( + isochronous_out is None + or endpoint.getMaxPacketSize() + > (isochronous_out.getMaxPacketSize()) + ): + isochronous_out = endpoint + + if isochronous_in is not None and isochronous_out is not None: + if ( + sco_interface is None + or sco_alternate == 0 + and ( + isochronous_in.getMaxPacketSize(), + isochronous_out.getMaxPacketSize(), + ) + > max_packet_size + ): + sco_interface = setting + max_packet_size = ( + isochronous_in.getMaxPacketSize(), + isochronous_out.getMaxPacketSize(), + ) - # pylint: disable=no-member - if status == usb1.TRANSFER_CANCELLED: - self.loop.call_soon_threadsafe(self.cancel_done.set_result, None) - return + # Return if we found at least a compatible ACL interface + if acl_interface is not None: + return ( + configuration_index + 1, + acl_interface, + sco_interface, + interrupt_in, + bulk_in, + isochronous_in, + bulk_out, + isochronous_out, + ) - if status != usb1.TRANSFER_COMPLETED: - logger.warning( - color( - f'!!! OUT transfer not completed: status={status}', - 'red', - ) + logger.debug(f'skipping configuration {configuration_index + 1}') + + return None + + +class UsbPacketSink: + def __init__(self, device, bulk_out, isochronous_out): + self.device = device + self.packets = asyncio.Queue[bytes]() # Queue of packets waiting to be sent + self.bulk_out = bulk_out + self.isochronous_out = isochronous_out + self.bulk_or_control_out_transfer = device.getTransfer() + self.isochronous_out_transfer = device.getTransfer(iso_packets=1) + self.out_transfer_ready = asyncio.Semaphore(1) + self.packets: asyncio.Queue[bytes] = ( + asyncio.Queue() + ) # Queue of packets waiting to be sent + self.loop = asyncio.get_running_loop() + self.queue_task = None + self.closed = False + + def start(self): + self.queue_task = asyncio.create_task(self.process_queue()) + + def on_packet(self, packet): + # Ignore packets if we're closed + if self.closed: + return + + if len(packet) == 0: + logger.warning('packet too short') + return + + # Queue the packet + self.packets.put_nowait(packet) + + def transfer_callback(self, transfer): + self.loop.call_soon_threadsafe(self.out_transfer_ready.release) + status = transfer.getStatus() + + logger.debug(f"OUT CALLBACK: {status}") + + if status != usb1.TRANSFER_COMPLETED: + logger.warning( + color( + f'!!! OUT transfer not completed: status={status}', + 'red', ) + ) - async def process_queue(self): - while True: - # Wait for a packet to transfer. - packet = await self.packets.get() + async def process_queue(self): + while not self.closed: + # Wait for a packet to transfer. + packet = await self.packets.get() - # Wait until we can start a transfer. - await self.acl_out_transfer_ready.acquire() + # Wait until we can start a transfer. + await self.out_transfer_ready.acquire() - # Transfer the packet. - packet_type = packet[0] + # Transfer the packet. + packet_type = packet[0] + packet_payload = packet[1:] + submitted = False + try: if packet_type == hci.HCI_ACL_DATA_PACKET: - self.acl_out_transfer.setBulk( - self.acl_out, packet[1:], callback=self.transfer_callback + self.bulk_or_control_out_transfer.setBulk( + self.bulk_out.getAddress(), + packet_payload, + callback=self.transfer_callback, ) - self.acl_out_transfer.submit() + self.bulk_or_control_out_transfer.submit() + submitted = True elif packet_type == hci.HCI_COMMAND_PACKET: - self.acl_out_transfer.setControl( + self.bulk_or_control_out_transfer.setControl( USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, - packet[1:], + packet_payload, callback=self.transfer_callback, ) - self.acl_out_transfer.submit() + self.bulk_or_control_out_transfer.submit() + submitted = True + elif packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET: + if self.isochronous_out is None: + logger.warning( + color('isochronous packets not supported', 'red') + ) + self.out_transfer_ready.release() + continue + + self.isochronous_out_transfer.setIsochronous( + self.isochronous_out.getAddress(), + packet_payload, + callback=self.transfer_callback, + ) + self.isochronous_out_transfer.submit() + submitted = True else: logger.warning( color(f'unsupported packet type {packet_type}', 'red') ) + except Exception as error: + logger.warning(f'!!! exception while submitting transfer: {error}') - def close(self): - self.closed = True - if self.queue_task: - self.queue_task.cancel() + if not submitted: + self.out_transfer_ready.release() + + def close(self): + self.closed = True + + async def terminate(self): + self.close() - async def terminate(self): - if not self.closed: - self.close() + if self.queue_task: + self.queue_task.cancel() - # Empty the packet queue so that we don't send any more data - while not self.packets.empty(): - self.packets.get_nowait() + # Empty the packet queue so that we don't send any more data + while not self.packets.empty(): + self.packets.get_nowait() - # If we have a transfer in flight, cancel it - if self.acl_out_transfer.isSubmitted(): + # If we have transfers in flight, cancel them + for transfer in ( + self.bulk_or_control_out_transfer, + self.isochronous_out_transfer, + ): + if transfer.isSubmitted(): # Try to cancel the transfer, but that may fail because it may have # already completed try: - self.acl_out_transfer.cancel() + transfer.cancel() logger.debug('waiting for OUT transfer cancellation to be done...') - await self.cancel_done + await self.out_transfer_ready.acquire() logger.debug('OUT transfer cancellation done') - except usb1.USBError: - logger.debug('OUT transfer likely already completed') - - class UsbPacketSource(asyncio.Protocol, BaseSource): - def __init__(self, device, metadata, acl_in, events_in): - super().__init__() - self.device = device - self.metadata = metadata - self.acl_in = acl_in - self.acl_in_transfer = None - self.events_in = events_in - self.events_in_transfer = None - self.loop = asyncio.get_running_loop() - self.queue = asyncio.Queue() - self.dequeue_task = None - self.cancel_done = { - hci.HCI_EVENT_PACKET: self.loop.create_future(), - hci.HCI_ACL_DATA_PACKET: self.loop.create_future(), - } - self.closed = False - - def start(self): - # Set up transfer objects for input - self.events_in_transfer = device.getTransfer() - self.events_in_transfer.setInterrupt( - self.events_in, - READ_SIZE, - callback=self.transfer_callback, - user_data=hci.HCI_EVENT_PACKET, - ) - self.events_in_transfer.submit() + except usb1.USBError as error: + logger.debug(f'OUT transfer likely already completed ({error})') + + +READ_SIZE = 4096 + + +class ScoAccumulator: + def __init__(self, emit: Callable[[bytes], Any]) -> None: + self.emit = emit + self.packet = b'' + + def feed(self, data: bytes) -> None: + while data: + # Accumulate until we have a complete 3-byte header + if (bytes_needed := 3 - len(self.packet)) > 0: + self.packet += data[:bytes_needed] + data = data[bytes_needed:] + continue + + packet_length = 3 + self.packet[2] + bytes_needed = packet_length - len(self.packet) + self.packet += data[:bytes_needed] + data = data[bytes_needed:] + if len(self.packet) == packet_length: + # Packet complete + self.emit(self.packet) + self.packet = b'' + + +class UsbPacketSource(asyncio.Protocol, BaseSource): + def __init__(self, device, metadata, interrupt_in, bulk_in, isochronous_in): + super().__init__() + self.device = device + self.metadata = metadata + self.interrupt_in = interrupt_in + self.interrupt_in_transfer = None + self.bulk_in = bulk_in + self.bulk_in_transfer = None + self.isochronous_in = isochronous_in + self.isochronous_in_transfer = None + self.isochronous_accumulator = ScoAccumulator( + lambda packet: self.queue_packet(hci.HCI_SYNCHRONOUS_DATA_PACKET, packet) + ) + self.loop = asyncio.get_running_loop() + self.queue = asyncio.Queue() + self.dequeue_task = None + self.done = { + hci.HCI_EVENT_PACKET: asyncio.Event(), + hci.HCI_ACL_DATA_PACKET: asyncio.Event(), + hci.HCI_SYNCHRONOUS_DATA_PACKET: asyncio.Event(), + } + self.closed = False + self.lock = threading.Lock() + + def start(self): + # Set up transfer objects for input + self.interrupt_in_transfer = self.device.getTransfer() + self.interrupt_in_transfer.setInterrupt( + self.interrupt_in.getAddress(), + READ_SIZE, + callback=self.transfer_callback, + user_data=hci.HCI_EVENT_PACKET, + ) + self.interrupt_in_transfer.submit() + + self.bulk_in_transfer = self.device.getTransfer() + self.bulk_in_transfer.setBulk( + self.bulk_in.getAddress(), + READ_SIZE, + callback=self.transfer_callback, + user_data=hci.HCI_ACL_DATA_PACKET, + ) + self.bulk_in_transfer.submit() - self.acl_in_transfer = device.getTransfer() - self.acl_in_transfer.setBulk( - self.acl_in, - READ_SIZE, + if self.isochronous_in is not None: + self.isochronous_in_transfer = self.device.getTransfer(iso_packets=16) + self.isochronous_in_transfer.setIsochronous( + self.isochronous_in.getAddress(), + 16 * self.isochronous_in.getMaxPacketSize(), callback=self.transfer_callback, - user_data=hci.HCI_ACL_DATA_PACKET, + user_data=hci.HCI_SYNCHRONOUS_DATA_PACKET, ) - self.acl_in_transfer.submit() + self.isochronous_in_transfer.submit() - self.dequeue_task = self.loop.create_task(self.dequeue()) + self.dequeue_task = self.loop.create_task(self.dequeue()) - @property - def usb_transfer_submitted(self): - return ( - self.events_in_transfer.isSubmitted() - or self.acl_in_transfer.isSubmitted() - ) + def queue_packet(self, packet_type: int, packet_data: bytes) -> None: + self.loop.call_soon_threadsafe( + self.queue.put_nowait, bytes([packet_type]) + packet_data + ) - def transfer_callback(self, transfer): - packet_type = transfer.getUserData() - status = transfer.getStatus() + def transfer_callback(self, transfer): + packet_type = transfer.getUserData() + status = transfer.getStatus() - # pylint: disable=no-member - if status == usb1.TRANSFER_COMPLETED: - packet = ( - bytes([packet_type]) - + transfer.getBuffer()[: transfer.getActualLength()] - ) - self.loop.call_soon_threadsafe(self.queue.put_nowait, packet) + # pylint: disable=no-member + if ( + packet_type != hci.HCI_SYNCHRONOUS_DATA_PACKET + or transfer.getActualLength() + or status != usb1.TRANSFER_COMPLETED + ): + logger.debug( + f"IN[{packet_type}] CALLBACK: status={status}, length={transfer.getActualLength()}" + ) + if status == usb1.TRANSFER_COMPLETED: + with self.lock: + if self.closed: + logger.debug("packet source closed, discarding transfer") + else: + if packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET: + for iso_status, iso_buffer in transfer.iterISO(): + if not iso_buffer: + continue + if iso_status: + logger.warning(f"ISO packet status error: {iso_status}") + continue + logger.debug( + "### SCO packet: %d %s", + len(iso_buffer), + iso_buffer.hex(), + ) + self.isochronous_accumulator.feed(iso_buffer) + else: + self.queue_packet( + packet_type, + transfer.getBuffer()[: transfer.getActualLength()], + ) - # Re-submit the transfer so we can receive more data - transfer.submit() - elif status == usb1.TRANSFER_CANCELLED: - self.loop.call_soon_threadsafe( - self.cancel_done[packet_type].set_result, None - ) - else: - logger.warning( - color( - f'!!! IN[{packet_type}] transfer not completed: status={status}', - 'red', - ) - ) - self.loop.call_soon_threadsafe(self.on_transport_lost) + # Re-submit the transfer so we can receive more data + try: + transfer.submit() + except usb1.USBError as error: + logger.warning(f"Failed to re-submit transfer: {error}") + self.loop.call_soon_threadsafe(self.on_transport_lost) + elif status == usb1.TRANSFER_CANCELLED: + logger.debug(f"IN[{packet_type}] transfer canceled") + self.loop.call_soon_threadsafe(self.done[packet_type].set) + else: + logger.warning( + color(f'!!! IN[{packet_type}] transfer not completed', 'red') + ) + self.loop.call_soon_threadsafe(self.done[packet_type].set) + self.loop.call_soon_threadsafe(self.on_transport_lost) - async def dequeue(self): - while not self.closed: + async def dequeue(self): + while not self.closed: + try: + packet = await self.queue.get() + except asyncio.CancelledError: + return + if self.sink: try: - packet = await self.queue.get() - except asyncio.CancelledError: - return - if self.sink: - try: - self.sink.on_packet(packet) - except Exception: - logger.exception( - color('!!! Exception in sink.on_packet', 'red') - ) + self.sink.on_packet(packet) + except Exception: + logger.exception(color('!!! Exception in sink.on_packet', 'red')) - def close(self): + def close(self): + with self.lock: self.closed = True - async def terminate(self): - if not self.closed: - self.close() + async def terminate(self): + self.close() + if self.dequeue_task: self.dequeue_task.cancel() - # Cancel the transfers - for transfer in (self.events_in_transfer, self.acl_in_transfer): - if transfer.isSubmitted(): - # Try to cancel the transfer, but that may fail because it may have - # already completed - packet_type = transfer.getUserData() - try: - transfer.cancel() - logger.debug( - f'waiting for IN[{packet_type}] transfer cancellation ' - 'to be done...' - ) - await self.cancel_done[packet_type] - logger.debug(f'IN[{packet_type}] transfer cancellation done') - except usb1.USBError: - logger.debug( - f'IN[{packet_type}] transfer likely already completed' - ) - - class UsbTransport(Transport): - def __init__(self, context, device, interface, setting, source, sink): - super().__init__(source, sink) - self.context = context - self.device = device - self.interface = interface - self.loop = asyncio.get_running_loop() - self.event_loop_done = self.loop.create_future() - - # Get exclusive access - device.claimInterface(interface) - - # Set the alternate setting if not the default - if setting != 0: - device.setInterfaceAltSetting(interface, setting) - - # The source and sink can now start - source.start() - sink.start() - - # Create a thread to process events - self.event_thread = threading.Thread(target=self.run) - self.event_thread.start() - - def run(self): - logger.debug('starting USB event loop') - while self.source.usb_transfer_submitted: - # pylint: disable=no-member + # Cancel the transfers + for transfer in ( + self.interrupt_in_transfer, + self.bulk_in_transfer, + self.isochronous_in_transfer, + ): + if transfer is None: + continue + + if transfer.isSubmitted(): + # Try to cancel the transfer, but that may fail because it may + # have already completed + packet_type = transfer.getUserData() + assert isinstance(packet_type, int) try: - self.context.handleEvents() - except usb1.USBErrorInterrupted: - pass + transfer.cancel() + logger.debug( + f'waiting for IN[{packet_type}] transfer cancellation ' + 'to be done...' + ) + await self.done[packet_type].wait() + logger.debug(f'IN[{packet_type}] transfer cancellation done') + except usb1.USBError as error: + logger.debug( + f'IN[{packet_type}] transfer likely already completed ' + f'({error})' + ) + + +class UsbTransport(Transport): + def __init__(self, context, device, acl_interface, sco_interface, source, sink): + super().__init__(source, sink) + self.context = context + self.device = device + self.acl_interface = acl_interface + self.sco_interface = sco_interface + self.loop = asyncio.get_running_loop() + self.event_loop_done = self.loop.create_future() + self.event_loop_should_exit = False + self.lock = threading.Lock() + + # Get exclusive access + device.claimInterface(acl_interface.getNumber()) + if sco_interface is not None: + device.claimInterface(sco_interface.getNumber()) + + # Set the alternate setting if not the default + if acl_interface.getAlternateSetting() != 0: + logger.debug( + f'setting ACL interface {acl_interface.getNumber()} ' + f'altsetting {acl_interface.getAlternateSetting()}' + ) + device.setInterfaceAltSetting( + acl_interface.getNumber(), acl_interface.getAlternateSetting() + ) + if sco_interface is not None and sco_interface.getAlternateSetting() != 0: + logger.debug( + f'setting SCO interface {sco_interface.getNumber()} ' + f'altsetting {sco_interface.getAlternateSetting()}' + ) + device.setInterfaceAltSetting( + sco_interface.getNumber(), sco_interface.getAlternateSetting() + ) + + # The source and sink can now start + source.start() + sink.start() + + # Create a thread to process events + self.event_thread = threading.Thread(target=self.run) + self.event_thread.start() - logger.debug('USB event loop done') - self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None) + def run(self): + logger.debug('starting USB event loop') + while True: + with self.lock: + if self.event_loop_should_exit: + logger.debug("USB event loop exit requested") + break - async def close(self): - self.source.close() - self.sink.close() - await self.source.terminate() - await self.sink.terminate() - self.device.releaseInterface(self.interface) - self.device.close() - self.context.close() + # pylint: disable=no-member + try: + self.context.handleEvents() + except usb1.USBErrorInterrupted: + pass + except Exception as error: + logger.warning(f'!!! Exception while handling events: {error}') + + logger.debug('ending USB event loop') + self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None) + + async def close(self): + self.source.close() + self.sink.close() + await self.source.terminate() + await self.sink.terminate() + + # We no longer need the event loop to run + with self.lock: + self.event_loop_should_exit = True + self.context.interruptEventHandler() + + self.device.releaseInterface(self.acl_interface.getNumber()) + if self.sco_interface: + self.device.releaseInterface(self.sco_interface.getNumber()) + self.device.close() + self.context.close() + + # Wait for the thread to terminate + logger.debug("waiting for USB event loop to be done...") + await self.event_loop_done + logger.debug("USB event loop done") - # Wait for the thread to terminate - await self.event_loop_done + +async def open_usb_transport(spec: str) -> Transport: + ''' + Open a USB transport. + The moniker string has this syntax: + either or + : or + :/] or + :# + With as the 0-based index to select amongst all the devices that appear + to be supporting Bluetooth HCI (0 being the first one), or + Where and are the vendor ID and product ID in hexadecimal. The + / suffix or # suffix max be specified when more than one + device with the same vendor and product identifiers are present. + + Opotionally, the moniker may include a +sco= suffix to enable SCO/eSCO + and specify the alternate setting to use for SCO/eSCO transfers, with 0 meaning an + automatic selection. + + In addition, if the moniker ends with the symbol "!", the device will be used in + "forced" mode: + the first USB interface of the device will be used, regardless of the interface + class/subclass. + This may be useful for some devices that use a custom class/subclass but may + nonetheless work as-is. + + Examples: + 0 --> the first BT USB dongle + 04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901 + 04b4:f901#2 --> the third USB device with vendor=04b4 and product=f901 + 04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and + serial number 00E04C239987 + 0B05:17CB! --> the BT USB dongle with vendor=0B05 and product=17CB, in "forced" + mode. + 0+sco=0 --> the first BT USB dongle, with SCO enabled using auto-selection. + 0+sco=5 --> the first BT USB dongle, with SCO enabled using alternate setting 5. + ''' # Find the device according to the spec moniker load_libusb() @@ -379,6 +674,7 @@ async def close(self): context.open() try: found = None + device = None if spec.endswith('!'): spec = spec[:-1] @@ -386,6 +682,12 @@ async def close(self): else: forced_mode = False + if '+sco=' in spec: + spec, sco_alternate_str = spec.split('+sco=') + sco_alternate = int(sco_alternate_str) + else: + sco_alternate = None + if ':' in spec: vendor_id, product_id = spec.split(':') serial_number = None @@ -461,76 +763,41 @@ def device_is_bluetooth_hci(device): logger.debug(f'USB Device: {found}') - # Look for the first interface with the right class and endpoints - def find_endpoints(device): - # pylint: disable-next=too-many-nested-blocks - for configuration_index, configuration in enumerate(device): - interface = None - for interface in configuration: - setting = None - for setting in interface: - if ( - not forced_mode - and ( - setting.getClass(), - setting.getSubClass(), - setting.getProtocol(), - ) - != USB_BT_HCI_CLASS_TUPLE - ): - continue - - events_in = None - acl_in = None - acl_out = None - for endpoint in setting: - attributes = endpoint.getAttributes() - address = endpoint.getAddress() - if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK: - if address & USB_ENDPOINT_IN and acl_in is None: - acl_in = address - elif acl_out is None: - acl_out = address - elif ( - attributes & 0x03 - == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT - ): - if address & USB_ENDPOINT_IN and events_in is None: - events_in = address - - # Return if we found all 3 endpoints - if ( - acl_in is not None - and acl_out is not None - and events_in is not None - ): - return ( - configuration_index + 1, - setting.getNumber(), - setting.getAlternateSetting(), - acl_in, - acl_out, - events_in, - ) - - logger.debug( - f'skipping configuration {configuration_index + 1} / ' - f'interface {setting.getNumber()}' - ) - - return None - - endpoints = find_endpoints(found) + assert device is not None + endpoints = find_endpoints(device, forced_mode, sco_alternate) if endpoints is None: raise TransportInitError('no compatible interface found for device') - (configuration, interface, setting, acl_in, acl_out, events_in) = endpoints + ( + configuration, + acl_interface, + sco_interface, + interrupt_in, + bulk_in, + isochronous_in, + bulk_out, + isochronous_out, + ) = endpoints + acl_interface_info = ( + f'acl_interface={acl_interface.getNumber()}/' + f'{acl_interface.getAlternateSetting()}' + ) + sco_interface_info = ( + '' + if sco_interface is None + else ( + f'sco_interface={sco_interface.getNumber()}/' + f'{sco_interface.getAlternateSetting()}' + ) + ) logger.debug( f'selected endpoints: configuration={configuration}, ' - f'interface={interface}, ' - f'setting={setting}, ' - f'acl_in=0x{acl_in:02X}, ' - f'acl_out=0x{acl_out:02X}, ' - f'events_in=0x{events_in:02X}, ' + f'acl_interface={acl_interface_info}, ' + f'sco_interface={sco_interface_info}, ' + f'interrupt_in=0x{interrupt_in.getAddress():02X}, ' + f'bulk_in=0x{bulk_in.getAddress():02X}, ' + f'bulk_out=0x{bulk_out.getAddress():02X}, ' + f'isochronous_in=0x{isochronous_in.getAddress() if isochronous_in else 0:02X}, ' + f'isochronous_out=0x{isochronous_out.getAddress() if isochronous_out else 0:02X}' ) device_metadata = { @@ -562,9 +829,11 @@ def find_endpoints(device): except usb1.USBError: logger.warning('failed to set configuration') - source = UsbPacketSource(device, device_metadata, acl_in, events_in) - sink = UsbPacketSink(device, acl_out) - return UsbTransport(context, device, interface, setting, source, sink) + source = UsbPacketSource( + device, device_metadata, interrupt_in, bulk_in, isochronous_in + ) + sink = UsbPacketSink(device, bulk_out, isochronous_out) + return UsbTransport(context, device, acl_interface, sco_interface, source, sink) except usb1.USBError as error: logger.warning(color(f'!!! failed to open USB device: {error}', 'red')) context.close() diff --git a/examples/run_hfp_handsfree.py b/examples/run_hfp_handsfree.py index 2caf67b8..3815fa4b 100644 --- a/examples/run_hfp_handsfree.py +++ b/examples/run_hfp_handsfree.py @@ -20,17 +20,110 @@ import functools import json import sys +import wave import websockets.asyncio.server import bumble.logging from bumble import hci, hfp, rfcomm -from bumble.device import Connection, Device +from bumble.device import Connection, Device, ScoLink from bumble.hfp import HfProtocol from bumble.transport import open_transport +# ----------------------------------------------------------------------------- ws: websockets.asyncio.server.ServerConnection | None = None hf_protocol: HfProtocol | None = None +input_wav: wave.Wave_read | None = None +output_wav: wave.Wave_write | None = None + + +# ----------------------------------------------------------------------------- +def on_audio_packet(packet: hci.HCI_SynchronousDataPacket) -> None: + if ( + packet.packet_status + == hci.HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA + ): + if output_wav: + # Save the PCM audio to the output + output_wav.writeframes(packet.data) + else: + print('!!! discarding packet with status ', packet.packet_status.name) + + if input_wav and hf_protocol: + # Send PCM audio from the input + frame_count = len(packet.data) // 2 + while frame_count: + # NOTE: we use a fixed number of frames here, this should likely be adjusted + # based on the transport parameters (like the USB max packet size) + chunk_size = min(frame_count, 16) + if not (pcm_data := input_wav.readframes(chunk_size)): + return + frame_count -= chunk_size + hf_protocol.dlc.multiplexer.l2cap_channel.connection.device.host.send_sco_sdu( + connection_handle=packet.connection_handle, + sdu=pcm_data, + ) + + +# ----------------------------------------------------------------------------- +def on_sco_connection(link: ScoLink) -> None: + print('### SCO connection established:', link) + if link.air_mode == hci.CodecID.TRANSPARENT: + print("@@@ The controller does not encode/decode voice") + return + + link.sink = on_audio_packet + + +# ----------------------------------------------------------------------------- +def on_sco_request( + link_type: int, connection: Connection, protocol: HfProtocol +) -> None: + if link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO: + esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.SCO_CVSD_D1] + elif protocol.active_codec == hfp.AudioCodec.MSBC: + esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_MSBC_T2] + elif protocol.active_codec == hfp.AudioCodec.CVSD: + esco_parameters = hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S4] + else: + raise RuntimeError("unknown active codec") + + if connection.device.host.supports_command( + hci.HCI_ENHANCED_ACCEPT_SYNCHRONOUS_CONNECTION_REQUEST_COMMAND + ): + connection.cancel_on_disconnection( + connection.device.send_async_command( + hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command( + bd_addr=connection.peer_address, **esco_parameters.asdict() + ) + ) + ) + elif connection.device.host.supports_command( + hci.HCI_ACCEPT_SYNCHRONOUS_CONNECTION_REQUEST_COMMAND + ): + connection.cancel_on_disconnection( + connection.device.send_async_command( + hci.HCI_Accept_Synchronous_Connection_Request_Command( + bd_addr=connection.peer_address, + transmit_bandwidth=esco_parameters.transmit_bandwidth, + receive_bandwidth=esco_parameters.receive_bandwidth, + max_latency=esco_parameters.max_latency, + voice_setting=int( + hci.VoiceSetting( + input_sample_size=hci.VoiceSetting.InputSampleSize.SIZE_16_BITS, + input_data_format=hci.VoiceSetting.InputDataFormat.TWOS_COMPLEMENT, + ) + ), + retransmission_effort=esco_parameters.retransmission_effort, + packet_type=esco_parameters.packet_type, + ) + ) + ) + else: + print('!!! no supported command for SCO connection request') + return + + connection.on('sco_connection', on_sco_connection) # ----------------------------------------------------------------------------- @@ -40,134 +133,173 @@ def on_dlc(dlc: rfcomm.DLC, configuration: hfp.HfConfiguration): hf_protocol = HfProtocol(dlc, configuration) asyncio.create_task(hf_protocol.run()) - def on_sco_request(connection: Connection, link_type: int, protocol: HfProtocol): - if connection == protocol.dlc.multiplexer.l2cap_channel.connection: - if link_type == hci.HCI_Connection_Complete_Event.LinkType.SCO: - esco_parameters = hfp.ESCO_PARAMETERS[ - hfp.DefaultCodecParameters.SCO_CVSD_D1 - ] - elif protocol.active_codec == hfp.AudioCodec.MSBC: - esco_parameters = hfp.ESCO_PARAMETERS[ - hfp.DefaultCodecParameters.ESCO_MSBC_T2 - ] - elif protocol.active_codec == hfp.AudioCodec.CVSD: - esco_parameters = hfp.ESCO_PARAMETERS[ - hfp.DefaultCodecParameters.ESCO_CVSD_S4 - ] - else: - raise RuntimeError("unknown active codec") - - connection.cancel_on_disconnection( - connection.device.send_command( - hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command( - bd_addr=connection.peer_address, **esco_parameters.asdict() - ) - ) - ) - - handler = functools.partial(on_sco_request, protocol=hf_protocol) - dlc.multiplexer.l2cap_channel.connection.device.on('sco_request', handler) + connection = dlc.multiplexer.l2cap_channel.connection + handler = functools.partial( + on_sco_request, + connection=connection, + protocol=hf_protocol, + ) + connection.on('sco_request', handler) dlc.multiplexer.l2cap_channel.once( 'close', - lambda: dlc.multiplexer.l2cap_channel.connection.device.remove_listener( - 'sco_request', handler - ), + lambda: connection.remove_listener('sco_request', handler), + ) + + hf_protocol.on('ag_indicator', on_ag_indicator) + hf_protocol.on('codec_negotiation', on_codec_negotiation) + + +# ----------------------------------------------------------------------------- +def on_ag_indicator(indicator): + global ws + if ws: + asyncio.create_task(ws.send(str(indicator))) + + +# ----------------------------------------------------------------------------- +def on_codec_negotiation(codec: hfp.AudioCodec): + print(f'### Negotiated codec: {codec.name}') + global output_wav + if output_wav: + output_wav.setnchannels(1) + output_wav.setsampwidth(2) + match codec: + case hfp.AudioCodec.CVSD: + output_wav.setframerate(8000) + case hfp.AudioCodec.MSBC: + output_wav.setframerate(16000) + + +# ----------------------------------------------------------------------------- +async def run(device: Device, codec: str | None) -> None: + if codec is None: + supported_audio_codecs = [hfp.AudioCodec.CVSD, hfp.AudioCodec.MSBC] + else: + if codec == 'cvsd': + supported_audio_codecs = [hfp.AudioCodec.CVSD] + elif codec == 'msbc': + supported_audio_codecs = [hfp.AudioCodec.MSBC] + else: + print('Unknown codec: ', codec) + return + + # Hands-Free profile configuration. + # TODO: load configuration from file. + configuration = hfp.HfConfiguration( + supported_hf_features=[ + hfp.HfFeature.THREE_WAY_CALLING, + hfp.HfFeature.REMOTE_VOLUME_CONTROL, + hfp.HfFeature.ENHANCED_CALL_STATUS, + hfp.HfFeature.ENHANCED_CALL_CONTROL, + hfp.HfFeature.CODEC_NEGOTIATION, + hfp.HfFeature.HF_INDICATORS, + hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED, + ], + supported_hf_indicators=[ + hfp.HfIndicator.BATTERY_LEVEL, + ], + supported_audio_codecs=supported_audio_codecs, ) - def on_ag_indicator(indicator): + # Create and register a server + rfcomm_server = rfcomm.Server(device) + + # Listen for incoming DLC connections + channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration)) + print(f'### Listening for connection on channel {channel_number}') + + # Advertise the HFP RFComm channel in the SDP + device.sdp_service_records = { + 0x00010001: hfp.make_hf_sdp_records(0x00010001, channel_number, configuration) + } + + # Let's go! + await device.power_on() + + # Start being discoverable and connectable + await device.set_discoverable(True) + await device.set_connectable(True) + + # Start the UI websocket server to offer a few buttons and input boxes + async def serve(websocket: websockets.asyncio.server.ServerConnection): global ws - if ws: - asyncio.create_task(ws.send(str(indicator))) + ws = websocket + async for message in websocket: + with contextlib.suppress(websockets.exceptions.ConnectionClosedOK): + print('Received: ', str(message)) - hf_protocol.on('ag_indicator', on_ag_indicator) + parsed = json.loads(message) + message_type = parsed['type'] + if message_type == 'at_command': + if hf_protocol is not None: + response = str( + await hf_protocol.execute_command( + parsed['command'], + response_type=hfp.AtResponseType.MULTIPLE, + ) + ) + await websocket.send(response) + elif message_type == 'query_call': + if hf_protocol: + response = str(await hf_protocol.query_current_calls()) + await websocket.send(response) + + await websockets.asyncio.server.serve(serve, 'localhost', 8989) + + await asyncio.get_running_loop().create_future() # run forever # ----------------------------------------------------------------------------- async def main() -> None: if len(sys.argv) < 3: - print('Usage: run_classic_hfp.py ') - print('example: run_classic_hfp.py classic2.json usb:04b4:f901') + print( + 'Usage: run_hfp_handsfree.py ' + '[codec] [input] [output]' + ) + print('example: run_hfp_handsfree.py classic2.json usb:0') return - print('<<< connecting to HCI...') - async with await open_transport(sys.argv[2]) as hci_transport: - print('<<< connected') - - # Hands-Free profile configuration. - # TODO: load configuration from file. - configuration = hfp.HfConfiguration( - supported_hf_features=[ - hfp.HfFeature.THREE_WAY_CALLING, - hfp.HfFeature.REMOTE_VOLUME_CONTROL, - hfp.HfFeature.ENHANCED_CALL_STATUS, - hfp.HfFeature.ENHANCED_CALL_CONTROL, - hfp.HfFeature.CODEC_NEGOTIATION, - hfp.HfFeature.HF_INDICATORS, - hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED, - ], - supported_hf_indicators=[ - hfp.HfIndicator.BATTERY_LEVEL, - ], - supported_audio_codecs=[ - hfp.AudioCodec.CVSD, - hfp.AudioCodec.MSBC, - ], - ) + device_config = sys.argv[1] + transport_spec = sys.argv[2] - # Create a device - device = Device.from_config_file_with_hci( - sys.argv[1], hci_transport.source, hci_transport.sink - ) - device.classic_enabled = True + codec: str | None = None + if len(sys.argv) >= 4: + codec = sys.argv[3] - # Create and register a server - rfcomm_server = rfcomm.Server(device) + input_file_name: str | None = None + if len(sys.argv) >= 5: + input_file_name = sys.argv[4] - # Listen for incoming DLC connections - channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration)) - print(f'### Listening for connection on channel {channel_number}') + output_file_name: str | None = None + if len(sys.argv) >= 6: + output_file_name = sys.argv[5] - # Advertise the HFP RFComm channel in the SDP - device.sdp_service_records = { - 0x00010001: hfp.make_hf_sdp_records( - 0x00010001, channel_number, configuration - ) - } - - # Let's go! - await device.power_on() - - # Start being discoverable and connectable - await device.set_discoverable(True) - await device.set_connectable(True) - - # Start the UI websocket server to offer a few buttons and input boxes - async def serve(websocket: websockets.asyncio.server.ServerConnection): - global ws - ws = websocket - async for message in websocket: - with contextlib.suppress(websockets.exceptions.ConnectionClosedOK): - print('Received: ', str(message)) - - parsed = json.loads(message) - message_type = parsed['type'] - if message_type == 'at_command': - if hf_protocol is not None: - response = str( - await hf_protocol.execute_command( - parsed['command'], - response_type=hfp.AtResponseType.MULTIPLE, - ) - ) - await websocket.send(response) - elif message_type == 'query_call': - if hf_protocol: - response = str(await hf_protocol.query_current_calls()) - await websocket.send(response) - - await websockets.asyncio.server.serve(serve, 'localhost', 8989) + global input_wav, output_wav + with ( + ( + wave.open(input_file_name, "rb") + if input_file_name + else contextlib.nullcontext(None) + ) as input_wav, + ( + wave.open(output_file_name, "wb") + if output_file_name + else contextlib.nullcontext(None) + ) as output_wav, + ): + if input_wav and input_wav.getnchannels() != 1: + print("Mono input required") + return + if input_wav and input_wav.getsampwidth() != 2: + print("16-bit input required") + return - await hci_transport.source.terminated + async with await open_transport(transport_spec) as transport: + device = Device.from_config_file_with_hci( + device_config, transport.source, transport.sink + ) + device.classic_enabled = True + await run(device, codec) # ----------------------------------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 60c7f12a..e1d174ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "grpcio >= 1.62.1; platform_system!='Emscripten'", "humanize >= 4.6.0; platform_system!='Emscripten'", "libusb1 >= 2.0.1; platform_system!='Emscripten'", - "libusb-package == 1.0.26.1; platform_system!='Emscripten' and platform_system!='Android'", + "libusb-package == 1.0.26.3; platform_system!='Emscripten' and platform_system!='Android'", "platformdirs >= 3.10.0; platform_system!='Emscripten'", "prompt_toolkit >= 3.0.16; platform_system!='Emscripten'", "prettytable >= 3.6.0; platform_system!='Emscripten'", From 90560cdea1c21a76ebaeee7b48a0f9f3869e43ae Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 6 Mar 2026 17:54:03 -0800 Subject: [PATCH 2/3] revert libusb-package version change --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e1d174ea..60c7f12a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "grpcio >= 1.62.1; platform_system!='Emscripten'", "humanize >= 4.6.0; platform_system!='Emscripten'", "libusb1 >= 2.0.1; platform_system!='Emscripten'", - "libusb-package == 1.0.26.3; platform_system!='Emscripten' and platform_system!='Android'", + "libusb-package == 1.0.26.1; platform_system!='Emscripten' and platform_system!='Android'", "platformdirs >= 3.10.0; platform_system!='Emscripten'", "prompt_toolkit >= 3.0.16; platform_system!='Emscripten'", "prettytable >= 3.6.0; platform_system!='Emscripten'", From b2893f26b6d89acae57c73d29f81c3e22561c7d1 Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Fri, 6 Mar 2026 18:23:20 -0800 Subject: [PATCH 3/3] fix types --- bumble/transport/usb.py | 3 +-- examples/run_hfp_handsfree.py | 23 +++++++++++------------ tasks.py | 2 ++ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/bumble/transport/usb.py b/bumble/transport/usb.py index 53cc7326..81e37502 100644 --- a/bumble/transport/usb.py +++ b/bumble/transport/usb.py @@ -222,9 +222,8 @@ def find_endpoints(device, forced_mode, sco_alternate=None): class UsbPacketSink: - def __init__(self, device, bulk_out, isochronous_out): + def __init__(self, device, bulk_out, isochronous_out) -> None: self.device = device - self.packets = asyncio.Queue[bytes]() # Queue of packets waiting to be sent self.bulk_out = bulk_out self.isochronous_out = isochronous_out self.bulk_or_control_out_transfer = device.getTransfer() diff --git a/examples/run_hfp_handsfree.py b/examples/run_hfp_handsfree.py index 3815fa4b..efce3a23 100644 --- a/examples/run_hfp_handsfree.py +++ b/examples/run_hfp_handsfree.py @@ -275,18 +275,17 @@ async def main() -> None: output_file_name = sys.argv[5] global input_wav, output_wav - with ( - ( - wave.open(input_file_name, "rb") - if input_file_name - else contextlib.nullcontext(None) - ) as input_wav, - ( - wave.open(output_file_name, "wb") - if output_file_name - else contextlib.nullcontext(None) - ) as output_wav, - ): + input_cm: contextlib.AbstractContextManager[wave.Wave_read | None] = ( + wave.open(input_file_name, "rb") + if input_file_name + else contextlib.nullcontext(None) + ) + output_cm: contextlib.AbstractContextManager[wave.Wave_write | None] = ( + wave.open(output_file_name, "wb") + if output_file_name + else contextlib.nullcontext(None) + ) + with input_cm as input_wav, output_cm as output_wav: if input_wav and input_wav.getnchannels() != 1: print("Mono input required") return diff --git a/tasks.py b/tasks.py index 7d928cf9..f0490f09 100644 --- a/tasks.py +++ b/tasks.py @@ -170,7 +170,9 @@ def format_code(ctx, check=False, diff=False): @task def check_types(ctx): checklist = ["apps", "bumble", "examples", "tests", "tasks.py"] + print(">>> Running the type checker...") try: + print("+++ Checking with mypy...") ctx.run(f"mypy {' '.join(checklist)}") except UnexpectedExit as exc: print("Please check your code against the mypy messages.")