Skip to content

Commit 99b8455

Browse files
authored
ENH: Order _generate_notebooks output by datetime (#33)
* ENH: Order `generate_*` output by timestamp * TST: Make `TestGenerateNotebooks` strict on yielded order * ENH: Incorporate memoization into `single_password_crypto_factory` * TST: Add `memoized_single_arg` test
1 parent 495ab66 commit 99b8455

File tree

4 files changed

+174
-143
lines changed

4 files changed

+174
-143
lines changed

pgcontents/crypto.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
import sys
99
import base64
10+
from functools import wraps
1011

1112
from cryptography.fernet import Fernet
1213
from cryptography.hazmat.backends import default_backend
@@ -221,8 +222,25 @@ def single_password_crypto_factory(password):
221222
The factory here returns a ``FernetEncryption`` that uses a key derived
222223
from ``password`` and salted with the supplied user_id.
223224
"""
225+
@memoize_single_arg
224226
def factory(user_id):
225227
return FernetEncryption(
226228
Fernet(derive_single_fernet_key(password, user_id))
227229
)
228230
return factory
231+
232+
233+
def memoize_single_arg(f):
234+
"""
235+
Decorator memoizing a single-argument function
236+
"""
237+
memo = {}
238+
239+
@wraps(f)
240+
def memoized_f(arg):
241+
try:
242+
return memo[arg]
243+
except KeyError:
244+
result = memo[arg] = f(arg)
245+
return result
246+
return memoized_f

pgcontents/query.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ def generate_files(engine, crypto_factory, min_dt=None, max_dt=None):
553553
"""
554554
Create a generator of decrypted files.
555555
556+
Files are yielded in ascending order of their timestamp.
557+
556558
This function selects all current notebooks (optionally, falling within a
557559
datetime range), decrypts them, and returns a generator yielding dicts,
558560
each containing a decoded notebook and metadata including the user,
@@ -571,12 +573,8 @@ def generate_files(engine, crypto_factory, min_dt=None, max_dt=None):
571573
max_dt : datetime.datetime, optional
572574
Last modified datetime at and after which a file will be excluded.
573575
"""
574-
where_conds = []
575-
if min_dt is not None:
576-
where_conds.append(files.c.created_at >= min_dt)
577-
if max_dt is not None:
578-
where_conds.append(files.c.created_at < max_dt)
579-
return _generate_notebooks(files, engine, where_conds, crypto_factory)
576+
return _generate_notebooks(files, files.c.created_at,
577+
engine, crypto_factory, min_dt, max_dt)
580578

581579

582580
# =======================================
@@ -736,6 +734,8 @@ def generate_checkpoints(engine, crypto_factory, min_dt=None, max_dt=None):
736734
"""
737735
Create a generator of decrypted remote checkpoints.
738736
737+
Checkpoints are yielded in ascending order of their timestamp.
738+
739739
This function selects all notebook checkpoints (optionally, falling within
740740
a datetime range), decrypts them, and returns a generator yielding dicts,
741741
each containing a decoded notebook and metadata including the user,
@@ -754,38 +754,53 @@ def generate_checkpoints(engine, crypto_factory, min_dt=None, max_dt=None):
754754
max_dt : datetime.datetime, optional
755755
Last modified datetime at and after which a file will be excluded.
756756
"""
757-
where_conds = []
758-
if min_dt is not None:
759-
where_conds.append(remote_checkpoints.c.last_modified >= min_dt)
760-
if max_dt is not None:
761-
where_conds.append(remote_checkpoints.c.last_modified < max_dt)
762757
return _generate_notebooks(remote_checkpoints,
763-
engine, where_conds, crypto_factory)
758+
remote_checkpoints.c.last_modified,
759+
engine, crypto_factory, min_dt, max_dt)
764760

765761

766762
# ====================
767763
# Files or Checkpoints
768764
# ====================
769-
def _generate_notebooks(table, engine, where_conds, crypto_factory):
765+
def _generate_notebooks(table, timestamp_column,
766+
engine, crypto_factory, min_dt, max_dt):
770767
"""
771768
See docstrings for `generate_files` and `generate_checkpoints`.
772-
`where_conds` should be a list of SQLAlchemy expressions, which are used as
773-
the conditions for WHERE clauses on the SELECT queries to the database.
769+
770+
Parameters
771+
----------
772+
table : SQLAlchemy.Table
773+
Table to fetch notebooks from, `files` or `remote_checkpoints.
774+
timestamp_column : SQLAlchemy.Column
775+
`table`'s column storing timestamps, `created_at` or `last_modified`.
776+
engine : SQLAlchemy.engine
777+
Engine encapsulating database connections.
778+
crypto_factory : function[str -> Any]
779+
A function from user_id to an object providing the interface required
780+
by PostgresContentsManager.crypto. Results of this will be used for
781+
decryption of the selected notebooks.
782+
min_dt : datetime.datetime, optional
783+
Minimum last modified datetime at which a file will be included.
784+
max_dt : datetime.datetime, optional
785+
Last modified datetime at and after which a file will be excluded.
774786
"""
787+
where_conds = []
788+
if min_dt is not None:
789+
where_conds.append(timestamp_column >= min_dt)
790+
if max_dt is not None:
791+
where_conds.append(timestamp_column < max_dt)
792+
775793
# Query for notebooks satisfying the conditions.
776-
query = select([table]).order_by(table.c.user_id)
794+
query = select([table]).order_by(timestamp_column)
777795
for cond in where_conds:
778796
query = query.where(cond)
779797
result = engine.execute(query)
780798

781799
# Decrypt each notebook and yield the result.
782-
last_user_id = None
783800
for nb_row in result:
784-
# The decrypt function depends on the user, so if the user is the same
785-
# then the decrypt function carries over.
786-
if nb_row['user_id'] != last_user_id:
787-
decrypt_func = crypto_factory(nb_row['user_id']).decrypt
788-
last_user_id = nb_row['user_id']
801+
# The decrypt function depends on the user
802+
user_id = nb_row['user_id']
803+
decrypt_func = crypto_factory(user_id).decrypt
789804

790805
nb_dict = to_dict_with_content(table.c, nb_row, decrypt_func)
791806
if table is files:
@@ -798,7 +813,7 @@ def _generate_notebooks(table, engine, where_conds, crypto_factory):
798813
# here as well.
799814
yield {
800815
'id': nb_dict['id'],
801-
'user_id': nb_dict['user_id'],
816+
'user_id': user_id,
802817
'path': to_api_path(nb_dict['path']),
803818
'last_modified': nb_dict['last_modified'],
804819
'content': reads_base64(nb_dict['content']),

pgcontents/tests/test_encryption.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,83 @@
11
"""
22
Tests for notebook encryption utilities.
33
"""
4+
from unittest import TestCase
5+
46
from cryptography.fernet import Fernet
57

68
from ..crypto import (
79
derive_fallback_fernet_keys,
810
FallbackCrypto,
911
FernetEncryption,
12+
memoize_single_arg,
1013
NoEncryption,
1114
single_password_crypto_factory,
1215
)
1316

1417

15-
def test_fernet_derivation():
16-
pws = [u'currentpassword', u'oldpassword', None]
18+
class TestEncryption(TestCase):
19+
20+
def test_fernet_derivation(self):
21+
pws = [u'currentpassword', u'oldpassword', None]
22+
23+
# This must be Unicode, so we use the `u` prefix to support py2.
24+
user_id = u'4e322fa200fffd0001000001'
25+
26+
current_crypto = single_password_crypto_factory(pws[0])(user_id)
27+
old_crypto = single_password_crypto_factory(pws[1])(user_id)
28+
29+
def make_single_key_crypto(key):
30+
if key is None:
31+
return NoEncryption()
32+
return FernetEncryption(Fernet(key.encode('ascii')))
33+
34+
multi_fernet_crypto = FallbackCrypto(
35+
[make_single_key_crypto(k)
36+
for k in derive_fallback_fernet_keys(pws, user_id)]
37+
)
1738

18-
# This must be Unicode, so we use the `u` prefix to support py2.
19-
user_id = u'4e322fa200fffd0001000001'
39+
data = b'ayy lmao'
2040

21-
current_crypto = single_password_crypto_factory(pws[0])(user_id)
22-
old_crypto = single_password_crypto_factory(pws[1])(user_id)
41+
# Data encrypted with the current key.
42+
encrypted_data_current = current_crypto.encrypt(data)
43+
self.assertNotEqual(encrypted_data_current, data)
44+
self.assertEqual(current_crypto.decrypt(encrypted_data_current), data)
2345

24-
def make_single_key_crypto(key):
25-
if key is None:
26-
return NoEncryption()
27-
return FernetEncryption(Fernet(key.encode('ascii')))
46+
# Data encrypted with the old key.
47+
encrypted_data_old = old_crypto.encrypt(data)
48+
self.assertNotEqual(encrypted_data_current, data)
49+
self.assertEqual(old_crypto.decrypt(encrypted_data_old), data)
2850

29-
multi_fernet_crypto = FallbackCrypto(
30-
[make_single_key_crypto(k)
31-
for k in derive_fallback_fernet_keys(pws, user_id)]
32-
)
51+
# The single fernet with the first key should be able to decrypt the
52+
# multi-fernet's encrypted data.
53+
self.assertEqual(
54+
current_crypto.decrypt(multi_fernet_crypto.encrypt(data)),
55+
data
56+
)
3357

34-
data = b'ayy lmao'
58+
# Multi should be able decrypt anything encrypted with either key.
59+
self.assertEqual(multi_fernet_crypto.decrypt(encrypted_data_current),
60+
data)
61+
self.assertEqual(multi_fernet_crypto.decrypt(encrypted_data_old), data)
3562

36-
# Data encrypted with the current key.
37-
encrypted_data_current = current_crypto.encrypt(data)
38-
assert encrypted_data_current != data
39-
assert current_crypto.decrypt(encrypted_data_current) == data
63+
# Unencrypted data should be returned unchanged.
64+
self.assertEqual(multi_fernet_crypto.decrypt(data), data)
4065

41-
# Data encrypted with the old key.
42-
encrypted_data_old = old_crypto.encrypt(data)
43-
assert encrypted_data_current != data
44-
assert old_crypto.decrypt(encrypted_data_old) == data
66+
def test_memoize_single_arg(self):
67+
full_calls = []
4568

46-
# The single fernet with the first key should be able to decrypt the
47-
# multi-fernet's encrypted data.
69+
@memoize_single_arg
70+
def mock_factory(user_id):
71+
full_calls.append(user_id)
72+
return u'crypto' + user_id
4873

49-
assert current_crypto.decrypt(multi_fernet_crypto.encrypt(data)) == data
74+
calls_to_make = [u'1', u'2', u'3', u'2', u'1']
75+
expected_results = [u'crypto' + user_id for user_id in calls_to_make]
76+
expected_full_calls = [u'1', u'2', u'3']
5077

51-
# Multi should be able decrypt anything encrypted with either key.
52-
assert multi_fernet_crypto.decrypt(encrypted_data_current) == data
53-
assert multi_fernet_crypto.decrypt(encrypted_data_old) == data
78+
results = []
79+
for user_id in calls_to_make:
80+
results.append(mock_factory(user_id))
5481

55-
# Unencrypted data should be returned unchanged.
56-
assert multi_fernet_crypto.decrypt(data) == data
82+
self.assertEqual(results, expected_results)
83+
self.assertEqual(full_calls, expected_full_calls)

0 commit comments

Comments
 (0)