Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions bumble/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import json
import logging
import os
import pathlib
from typing import TYPE_CHECKING, Any

from typing_extensions import Self
Expand Down Expand Up @@ -248,29 +249,26 @@ class without a namespace. With the default namespace, reading from a file will
DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys"

def __init__(self, namespace, filename=None):
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
def __init__(
self, namespace: str | None = None, filename: str | None = None
) -> None:
self.namespace = namespace or self.DEFAULT_NAMESPACE

if filename is None:
# Use a default for the current user
if filename:
self.filename = pathlib.Path(filename).resolve()
self.directory_name = self.filename.parent
else:
import platformdirs # Deferred import

# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs
base_dir = platformdirs.user_data_path(self.APP_NAME, self.APP_AUTHOR)
self.directory_name = base_dir / self.KEYS_DIR

self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
json_filename = (
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
)
self.filename = os.path.join(self.directory_name, json_filename)
else:
self.filename = filename
self.directory_name = os.path.dirname(os.path.abspath(self.filename))
base_name = self.namespace if namespace else self.DEFAULT_BASE_NAME
safe_name = base_name.lower().replace(':', '-').replace('/', '-')

self.filename = self.directory_name / f"{safe_name}.json"

logger.debug(f'JSON keystore: {self.filename}')
logger.debug('JSON keystore: %s', self.filename)

@classmethod
def from_device(
Expand All @@ -293,7 +291,9 @@ def from_device(

return cls(namespace, filename)

async def load(self):
async def load(
self,
) -> tuple[dict[str, dict[str, dict[str, Any]]], dict[str, dict[str, Any]]]:
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try:
Expand All @@ -312,17 +312,17 @@ async def load(self):
return next(iter(db.items()))

# Finally, just create an empty key map for the namespace
key_map = {}
key_map: dict[str, dict[str, Any]] = {}
db[self.namespace] = key_map
return (db, key_map)

async def save(self, db):
async def save(self, db: dict[str, dict[str, dict[str, Any]]]) -> None:
# Create the directory if it doesn't exist
if not os.path.exists(self.directory_name):
os.makedirs(self.directory_name, exist_ok=True)
if not self.directory_name.exists():
self.directory_name.mkdir(parents=True, exist_ok=True)

# Save to a temporary file
temp_filename = self.filename + '.tmp'
temp_filename = self.filename.with_name(self.filename.name + ".tmp")
with open(temp_filename, 'w', encoding='utf-8') as output:
json.dump(db, output, sort_keys=True, indent=4)

Expand All @@ -334,16 +334,16 @@ async def delete(self, name: str) -> None:
del key_map[name]
await self.save(db)

async def update(self, name, keys):
async def update(self, name: str, keys: PairingKeys) -> None:
db, key_map = await self.load()
key_map.setdefault(name, {}).update(keys.to_dict())
await self.save(db)

async def get_all(self):
async def get_all(self) -> list[tuple[str, PairingKeys]]:
_, key_map = await self.load()
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]

async def delete_all(self):
async def delete_all(self) -> None:
db, key_map = await self.load()
key_map.clear()
await self.save(db)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ authors = [{ name = "Google", email = "bumble-dev@google.com" }]
requires-python = ">=3.10"
dependencies = [
"aiohttp ~= 3.8; platform_system!='Emscripten'",
"appdirs >= 1.4; platform_system!='Emscripten'",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be replaced with platformdirs, because if we don't mark it as a dependency, most users won't have it by default.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have it below at line#31

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes, sorry I missed that.

"click >= 8.1.3; platform_system!='Emscripten'",
"cryptography >= 44.0.3; platform_system!='Emscripten' and platform_system!='Android'",
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
Expand Down
45 changes: 45 additions & 0 deletions tests/keystore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import pathlib
import tempfile
from unittest import mock

import pytest

Expand Down Expand Up @@ -179,11 +180,55 @@ async def test_default_namespace(temporary_file):
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_no_filename(tmp_path):
import platformdirs

with mock.patch.object(platformdirs, 'user_data_path', return_value=tmp_path):
# Case 1: no namespace, no filename
keystore = JsonKeyStore(None, None)
expected_directory = tmp_path / 'Pairing'
expected_filename = expected_directory / 'keys.json'
assert keystore.directory_name == expected_directory
assert keystore.filename == expected_filename

# Save some data
keys = PairingKeys()
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
assert expected_filename.exists()

# Load back
keystore2 = JsonKeyStore(None, None)
foo = await keystore2.get('foo')
assert foo is not None
assert foo.ltk.value == ltk

# Case 2: namespace, no filename
keystore3 = JsonKeyStore('my:namespace', None)
# safe_name = 'my-namespace' (lower is already 'my:namespace', then replace ':' with '-')
expected_filename3 = expected_directory / 'my-namespace.json'
assert keystore3.filename == expected_filename3

# Save some data
await keystore3.update('bar', keys)
assert expected_filename3.exists()

# Load back
keystore4 = JsonKeyStore('my:namespace', None)
bar = await keystore4.get('bar')
assert bar is not None
assert bar.ltk.value == ltk


# -----------------------------------------------------------------------------
async def run_tests():
await test_basic()
await test_parsing()
await test_default_namespace()
await test_no_filename()


# -----------------------------------------------------------------------------
Expand Down
Loading