Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .lib import E_ECDH, decrypt, encrypt
from hmac import compare_digest
from werkzeug.exceptions import SecurityError
from wallycore import ec_sig_verify, sha256, hmac_sha256, EC_FLAG_ECDSA, \
ec_public_key_bip341_tweak

Expand Down Expand Up @@ -46,7 +47,8 @@ def decrypt_response_payload(self, encrypted, hmac):

# Verify hmac received
hmac_calculated = hmac_sha256(self.response_hmac_key, encrypted)
assert compare_digest(hmac, hmac_calculated)
if not compare_digest(hmac, hmac_calculated):
raise SecurityError()

# Return decrypted data
return decrypt(self.response_encryption_key, encrypted)
Expand Down
29 changes: 23 additions & 6 deletions flaskserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import json
import base64
import time
import collections
from flask import Flask, request, jsonify
from .server import PINServerECDH, PINServerECDHv1, PINServerECDHv2
from .pindb import PINDb
from werkzeug.exceptions import BadRequest
from wallycore import AES_KEY_LEN_256, AES_BLOCK_LEN, HMAC_SHA256_LEN
from dotenv import load_dotenv

Expand Down Expand Up @@ -54,12 +56,17 @@ def start_handshake_route():

# NOTE: explicit fields in protocol v1
def _complete_server_call_v1(pin_func, udata):
if udata.keys() != {'cke', 'ske', 'encrypted_data', 'hmac_encrypted_data'}:
raise BadRequest()

ske = udata['ske']
assert 'replay_counter' not in udata

# Get associated session (ensuring not stale)
_cleanup_expired_sessions()
e_ecdh_server = sessions[ske]

e_ecdh_server = sessions.get(ske)
if not e_ecdh_server:
raise BadRequest()

# get/set pin and get response data
encrypted_key, hmac = e_ecdh_server.call_with_payload(
Expand All @@ -82,9 +89,14 @@ def _complete_server_call_v1(pin_func, udata):

# NOTE: v2 is one concatentated field, base64-encoded
def _complete_server_call_v2(pin_func, udata):
assert 'data' in udata
data = base64.b64decode(udata['data'].encode())
assert len(data) > 37 # cke and counter and some encrypted payload
if udata.keys() != {'data'}:
raise BadRequest()

try:
data = base64.b64decode(udata['data'].encode())
assert len(data) > 37 # cke and counter and some encrypted payload
except Exception as e:
raise BadRequest(e)

cke = data[:33]
replay_counter = data[33:37]
Expand All @@ -104,7 +116,12 @@ def _complete_server_call_v2(pin_func, udata):
def _complete_server_call(pin_func):
try:
# Get request data
udata = json.loads(request.data)
try:
udata = json.loads(request.data)
assert isinstance(udata, collections.abc.Mapping)
except Exception as e:
raise BadRequest(e)

if 'data' in udata:
return _complete_server_call_v2(pin_func, udata)
return _complete_server_call_v1(pin_func, udata)
Expand Down
15 changes: 10 additions & 5 deletions pindb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .lib import decrypt, encrypt
from pathlib import Path
from hmac import compare_digest
from werkzeug.exceptions import BadRequest
from wallycore import ec_sig_to_public_key, sha256, hmac_sha256, \
AES_KEY_LEN_256, EC_SIGNATURE_RECOVERABLE_LEN, SHA256_LEN
from dotenv import load_dotenv
Expand Down Expand Up @@ -102,17 +103,19 @@ class PINDb(object):

@classmethod
def _extract_fields(cls, cke, data, replay_counter=None):
assert len(data) > SHA256_LEN
if len(data) <= SHA256_LEN:
raise BadRequest()

# secret + (optional)entropy + sig
pin_secret = data[:SHA256_LEN]
if len(data) == SHA256_LEN + SHA256_LEN + EC_SIGNATURE_RECOVERABLE_LEN:
entropy = data[SHA256_LEN: SHA256_LEN + SHA256_LEN]
sig = data[SHA256_LEN + SHA256_LEN:]
else:
assert len(data) == SHA256_LEN + EC_SIGNATURE_RECOVERABLE_LEN
elif len(data) == SHA256_LEN + EC_SIGNATURE_RECOVERABLE_LEN:
entropy = b''
sig = data[SHA256_LEN:]
else:
raise BadRequest()

# The client_public_key also signs over any replay counter
if replay_counter is not None:
Expand All @@ -133,7 +136,8 @@ def _check_v2_anti_replay(cls, server_counter, client_counter):
if server_counter is not None and client_counter is not None:
server_counter = int.from_bytes(server_counter, byteorder='little', signed=False)
client_counter = int.from_bytes(client_counter, byteorder='little', signed=False)
assert client_counter > server_counter
if client_counter <= server_counter:
raise BadRequest()

@classmethod
def _save_pin_fields(cls, pin_pubkey_hash, hash_pin_secret, aes_key,
Expand Down Expand Up @@ -269,7 +273,8 @@ def set_pin(cls, cke, payload, aes_pin_data_key, replay_counter=None):
# NOTE: we require client-passed entropy at this point
pin_secret, entropy, pin_pubkey = cls._extract_fields(cke, payload, replay_counter)
pin_pubkey_hash = bytes(sha256(pin_pubkey))
assert entropy
if not entropy:
raise BadRequest()

# Load any existing replay counter for the pubkey
# and if found check the anti-replay counter
Expand Down
17 changes: 13 additions & 4 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from hmac import compare_digest
import os
from .lib import decrypt, encrypt, E_ECDH
from werkzeug.exceptions import BadRequest, SecurityError
from wallycore import ec_private_key_verify, ec_sig_from_bytes, sha256, \
hmac_sha256, EC_FLAG_ECDSA, ec_private_key_bip341_tweak, ec_public_key_from_private_key

Expand Down Expand Up @@ -72,7 +73,8 @@ def get_signed_public_key(self):
def decrypt_request_payload(self, cke, encrypted, hmac):
# Verify hmac received
hmac_calculated = hmac_sha256(self.request_hmac_key, cke + encrypted)
assert compare_digest(hmac, hmac_calculated)
if not compare_digest(hmac, hmac_calculated):
raise SecurityError()

# Return decrypted data
return decrypt(self.request_encryption_key, encrypted)
Expand All @@ -86,8 +88,11 @@ def encrypt_response_payload(self, payload):
# Calls passed function with unwrapped payload, and wraps response before
# returning. Separates payload handler func from wrapper encryption.
def call_with_payload(self, cke, encrypted, hmac, func):
self.generate_shared_secrets(cke)
payload = self.decrypt_request_payload(cke, encrypted, hmac)
try:
self.generate_shared_secrets(cke)
payload = self.decrypt_request_payload(cke, encrypted, hmac)
except Exception as e:
raise BadRequest(e)

# Call the passed function with the decrypted payload
response = func(cke, payload, self._get_aes_pin_data_key())
Expand Down Expand Up @@ -126,6 +131,10 @@ def encrypt_response_payload(self, cke, payload):
# Calls passed function with unwrapped payload, and wraps response before
# returning. Separates payload handler func from wrapper encryption.
def call_with_payload(self, cke, encrypted, func):
payload = self.decrypt_request_payload(cke, encrypted)
try:
payload = self.decrypt_request_payload(cke, encrypted)
except Exception as e:
raise BadRequest(e)

response = func(cke, payload, self._get_aes_pin_data_key(), self.replay_counter)
return self.encrypt_response_payload(cke, response)
11 changes: 6 additions & 5 deletions test/test_ecdh_v1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import os
from werkzeug.exceptions import BadRequest

from ..client import PINClientECDHv1
from ..server import PINServerECDHv1
Expand Down Expand Up @@ -162,14 +163,14 @@ def test_bad_request_cke_throws(self):
server.decrypt_request_payload(cke, encrypted, hmac) # no error

server.generate_shared_secrets(bad_cke)
with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
server.decrypt_request_payload(bad_cke, encrypted, hmac) # error

# Ensure call_with_payload() throws before it calls the handler fn
def _func(client_key, payload, aes_pin_data_key):
self.fail('should-never-get-here')

with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
server.call_with_payload(bad_cke, encrypted, hmac, _func)

def test_bad_request_hmac_throws(self):
Expand All @@ -189,14 +190,14 @@ def test_bad_request_hmac_throws(self):
# Ensure decrypt_request() throws
server.generate_shared_secrets(cke)
server.decrypt_request_payload(cke, encrypted, hmac) # no error
with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
server.decrypt_request_payload(cke, encrypted, bad_hmac) # error

# Ensure call_with_payload() throws before it calls the handler fn
def _func(client_key, payload, aes_pin_data_key):
self.fail('should-never-get-here')

with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
server.call_with_payload(cke, encrypted, bad_hmac, _func)

def test_bad_response_hmac_throws(self):
Expand All @@ -221,7 +222,7 @@ def _func(client_key, payload, pin_data_aes_key):
self.assertNotEqual(hmac, bad_hmac)

client.decrypt_response_payload(encrypted, hmac) # No error
with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
client.decrypt_response_payload(encrypted, bad_hmac) # error


Expand Down
7 changes: 4 additions & 3 deletions test/test_ecdh_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import os
from werkzeug.exceptions import BadRequest

from ..client import PINClientECDHv2
from ..server import PINServerECDHv2
Expand Down Expand Up @@ -134,7 +135,7 @@ def test_bad_request_cke_throws(self):
def _func(client_key, payload, aes_pin_data_key):
self.fail('should-never-get-here')

with self.assertRaises(ValueError) as cm:
with self.assertRaises(BadRequest) as cm:
server.call_with_payload(bad_cke, encrypted, _func)

def test_bad_request_counter_throws(self):
Expand All @@ -159,7 +160,7 @@ def test_bad_request_counter_throws(self):
def _func(client_key, payload, aes_pin_data_key):
self.fail('should-never-get-here')

with self.assertRaises(ValueError) as cm:
with self.assertRaises(BadRequest) as cm:
server.call_with_payload(cke, encrypted, _func)

def test_bad_request_hmac_throws(self):
Expand All @@ -186,7 +187,7 @@ def test_bad_request_hmac_throws(self):
def _func(client_key, payload, aes_pin_data_key, replay_counter):
self.fail('should-never-get-here')

with self.assertRaises(ValueError) as cm:
with self.assertRaises(BadRequest) as cm:
server.call_with_payload(cke, bad_encrypted, _func)

def test_bad_response_hmac_throws(self):
Expand Down
11 changes: 6 additions & 5 deletions test/test_pindb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from hmac import compare_digest
from werkzeug.exceptions import BadRequest

from ..pindb import PINDb
from ..lib import E_ECDH
Expand Down Expand Up @@ -428,12 +429,12 @@ def test_bad_v2_counter_breaks_set_pin(self):
# Set-pin must also respect the counter
v2_replay_counter = b'\x05\x00\x00\x00'
payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter)
with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter)

v2_replay_counter = b'\x00\x00\x00\x00'
payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter)
with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter)

# Key still present and readable as set failed
Expand Down Expand Up @@ -513,7 +514,7 @@ def _test_client_entropy_impl(self, use_v2_protocol):
payload = self.make_payload(sig_priv, cke, secret, b'', v2_replay_counter)

# Verify trying to set-pin without entropy fails
with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter)

# Get-pin should be fine without entropy
Expand All @@ -525,10 +526,10 @@ def _test_client_entropy_impl(self, use_v2_protocol):
for entropy in [self.new_entropy()[:-1], self.new_entropy() + b'\xab']:
payload = self.make_payload(sig_priv, cke, secret, entropy, v2_replay_counter)

with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter)

with self.assertRaises(AssertionError) as cm:
with self.assertRaises(BadRequest) as cm:
PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter)

def test_client_entropy(self):
Expand Down
Loading
Loading