diff --git a/bumble/keys.py b/bumble/keys.py index 3dc739d8..15d8aaed 100644 --- a/bumble/keys.py +++ b/bumble/keys.py @@ -27,6 +27,7 @@ import json import logging import os +import pathlib from typing import TYPE_CHECKING, Any from typing_extensions import Self @@ -248,29 +249,26 @@ class without a namespace. With the default namespace, reading from a file will DEFAULT_NAMESPACE = '__DEFAULT__' DEFAULT_BASE_NAME = "keys" - def __init__(self, namespace, filename=None): - self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE + def __init__( + self, namespace: str | None = None, filename: str | None = None + ) -> None: + self.namespace = namespace or self.DEFAULT_NAMESPACE - if filename is None: - # Use a default for the current user + if filename: + self.filename = pathlib.Path(filename).resolve() + self.directory_name = self.filename.parent + else: + import platformdirs # Deferred import - # Import here because this may not exist on all platforms - # pylint: disable=import-outside-toplevel - import appdirs + base_dir = platformdirs.user_data_path(self.APP_NAME, self.APP_AUTHOR) + self.directory_name = base_dir / self.KEYS_DIR - self.directory_name = os.path.join( - appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR - ) - base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace - json_filename = ( - f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p') - ) - self.filename = os.path.join(self.directory_name, json_filename) - else: - self.filename = filename - self.directory_name = os.path.dirname(os.path.abspath(self.filename)) + base_name = self.namespace if namespace else self.DEFAULT_BASE_NAME + safe_name = base_name.lower().replace(':', '-').replace('/', '-') + + self.filename = self.directory_name / f"{safe_name}.json" - logger.debug(f'JSON keystore: {self.filename}') + logger.debug('JSON keystore: %s', self.filename) @classmethod def from_device( @@ -293,7 +291,9 @@ def from_device( return cls(namespace, filename) - async def load(self): + async def load( + self, + ) -> tuple[dict[str, dict[str, dict[str, Any]]], dict[str, dict[str, Any]]]: # Try to open the file, without failing. If the file does not exist, it # will be created upon saving. try: @@ -312,17 +312,17 @@ async def load(self): return next(iter(db.items())) # Finally, just create an empty key map for the namespace - key_map = {} + key_map: dict[str, dict[str, Any]] = {} db[self.namespace] = key_map return (db, key_map) - async def save(self, db): + async def save(self, db: dict[str, dict[str, dict[str, Any]]]) -> None: # Create the directory if it doesn't exist - if not os.path.exists(self.directory_name): - os.makedirs(self.directory_name, exist_ok=True) + if not self.directory_name.exists(): + self.directory_name.mkdir(parents=True, exist_ok=True) # Save to a temporary file - temp_filename = self.filename + '.tmp' + temp_filename = self.filename.with_name(self.filename.name + ".tmp") with open(temp_filename, 'w', encoding='utf-8') as output: json.dump(db, output, sort_keys=True, indent=4) @@ -334,16 +334,16 @@ async def delete(self, name: str) -> None: del key_map[name] await self.save(db) - async def update(self, name, keys): + async def update(self, name: str, keys: PairingKeys) -> None: db, key_map = await self.load() key_map.setdefault(name, {}).update(keys.to_dict()) await self.save(db) - async def get_all(self): + async def get_all(self) -> list[tuple[str, PairingKeys]]: _, key_map = await self.load() return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()] - async def delete_all(self): + async def delete_all(self) -> None: db, key_map = await self.load() key_map.clear() await self.save(db) diff --git a/pyproject.toml b/pyproject.toml index 601c3367..b56e40e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ authors = [{ name = "Google", email = "bumble-dev@google.com" }] requires-python = ">=3.10" dependencies = [ "aiohttp ~= 3.8; platform_system!='Emscripten'", - "appdirs >= 1.4; platform_system!='Emscripten'", "click >= 8.1.3; platform_system!='Emscripten'", "cryptography >= 44.0.3; platform_system!='Emscripten' and platform_system!='Android'", # Pyodide bundles a version of cryptography that is built for wasm, which may not match the diff --git a/tests/keystore_test.py b/tests/keystore_test.py index 9cc9966a..9bc0cda2 100644 --- a/tests/keystore_test.py +++ b/tests/keystore_test.py @@ -21,6 +21,7 @@ import os import pathlib import tempfile +from unittest import mock import pytest @@ -179,11 +180,55 @@ async def test_default_namespace(temporary_file): assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_no_filename(tmp_path): + import platformdirs + + with mock.patch.object(platformdirs, 'user_data_path', return_value=tmp_path): + # Case 1: no namespace, no filename + keystore = JsonKeyStore(None, None) + expected_directory = tmp_path / 'Pairing' + expected_filename = expected_directory / 'keys.json' + assert keystore.directory_name == expected_directory + assert keystore.filename == expected_filename + + # Save some data + keys = PairingKeys() + ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) + keys.ltk = PairingKeys.Key(ltk) + await keystore.update('foo', keys) + assert expected_filename.exists() + + # Load back + keystore2 = JsonKeyStore(None, None) + foo = await keystore2.get('foo') + assert foo is not None + assert foo.ltk.value == ltk + + # Case 2: namespace, no filename + keystore3 = JsonKeyStore('my:namespace', None) + # safe_name = 'my-namespace' (lower is already 'my:namespace', then replace ':' with '-') + expected_filename3 = expected_directory / 'my-namespace.json' + assert keystore3.filename == expected_filename3 + + # Save some data + await keystore3.update('bar', keys) + assert expected_filename3.exists() + + # Load back + keystore4 = JsonKeyStore('my:namespace', None) + bar = await keystore4.get('bar') + assert bar is not None + assert bar.ltk.value == ltk + + # ----------------------------------------------------------------------------- async def run_tests(): await test_basic() await test_parsing() await test_default_namespace() + await test_no_filename() # -----------------------------------------------------------------------------