From cabdc8a1815a18d042b2780decaf4dfe089caaac Mon Sep 17 00:00:00 2001 From: "Jamie C. Driver" Date: Wed, 20 Nov 2024 10:30:13 +0000 Subject: [PATCH] consistency: return http error 400 for bad data in request payload --- client.py | 4 +++- flaskserver.py | 29 +++++++++++++++++++++++------ pindb.py | 15 ++++++++++----- server.py | 17 +++++++++++++---- test/test_ecdh_v1.py | 11 ++++++----- test/test_ecdh_v2.py | 7 ++++--- test/test_pindb.py | 11 ++++++----- test/test_pinserver.py | 24 ++++++++++++------------ 8 files changed, 77 insertions(+), 41 deletions(-) diff --git a/client.py b/client.py index 421313a..516a78d 100644 --- a/client.py +++ b/client.py @@ -1,5 +1,6 @@ from .lib import E_ECDH, decrypt, encrypt from hmac import compare_digest +from werkzeug.exceptions import SecurityError from wallycore import ec_sig_verify, sha256, hmac_sha256, EC_FLAG_ECDSA, \ ec_public_key_bip341_tweak @@ -46,7 +47,8 @@ def decrypt_response_payload(self, encrypted, hmac): # Verify hmac received hmac_calculated = hmac_sha256(self.response_hmac_key, encrypted) - assert compare_digest(hmac, hmac_calculated) + if not compare_digest(hmac, hmac_calculated): + raise SecurityError() # Return decrypted data return decrypt(self.response_encryption_key, encrypted) diff --git a/flaskserver.py b/flaskserver.py index 94eafd5..a5cf7d1 100644 --- a/flaskserver.py +++ b/flaskserver.py @@ -2,9 +2,11 @@ import json import base64 import time +import collections from flask import Flask, request, jsonify from .server import PINServerECDH, PINServerECDHv1, PINServerECDHv2 from .pindb import PINDb +from werkzeug.exceptions import BadRequest from wallycore import AES_KEY_LEN_256, AES_BLOCK_LEN, HMAC_SHA256_LEN from dotenv import load_dotenv @@ -54,12 +56,17 @@ def start_handshake_route(): # NOTE: explicit fields in protocol v1 def _complete_server_call_v1(pin_func, udata): + if udata.keys() != {'cke', 'ske', 'encrypted_data', 'hmac_encrypted_data'}: + raise BadRequest() + ske = udata['ske'] - assert 'replay_counter' not in udata # Get associated session (ensuring not stale) _cleanup_expired_sessions() - e_ecdh_server = sessions[ske] + + e_ecdh_server = sessions.get(ske) + if not e_ecdh_server: + raise BadRequest() # get/set pin and get response data encrypted_key, hmac = e_ecdh_server.call_with_payload( @@ -82,9 +89,14 @@ def _complete_server_call_v1(pin_func, udata): # NOTE: v2 is one concatentated field, base64-encoded def _complete_server_call_v2(pin_func, udata): - assert 'data' in udata - data = base64.b64decode(udata['data'].encode()) - assert len(data) > 37 # cke and counter and some encrypted payload + if udata.keys() != {'data'}: + raise BadRequest() + + try: + data = base64.b64decode(udata['data'].encode()) + assert len(data) > 37 # cke and counter and some encrypted payload + except Exception as e: + raise BadRequest(e) cke = data[:33] replay_counter = data[33:37] @@ -104,7 +116,12 @@ def _complete_server_call_v2(pin_func, udata): def _complete_server_call(pin_func): try: # Get request data - udata = json.loads(request.data) + try: + udata = json.loads(request.data) + assert isinstance(udata, collections.abc.Mapping) + except Exception as e: + raise BadRequest(e) + if 'data' in udata: return _complete_server_call_v2(pin_func, udata) return _complete_server_call_v1(pin_func, udata) diff --git a/pindb.py b/pindb.py index e8c8bd7..4dce38c 100644 --- a/pindb.py +++ b/pindb.py @@ -5,6 +5,7 @@ from .lib import decrypt, encrypt from pathlib import Path from hmac import compare_digest +from werkzeug.exceptions import BadRequest from wallycore import ec_sig_to_public_key, sha256, hmac_sha256, \ AES_KEY_LEN_256, EC_SIGNATURE_RECOVERABLE_LEN, SHA256_LEN from dotenv import load_dotenv @@ -102,17 +103,19 @@ class PINDb(object): @classmethod def _extract_fields(cls, cke, data, replay_counter=None): - assert len(data) > SHA256_LEN + if len(data) <= SHA256_LEN: + raise BadRequest() # secret + (optional)entropy + sig pin_secret = data[:SHA256_LEN] if len(data) == SHA256_LEN + SHA256_LEN + EC_SIGNATURE_RECOVERABLE_LEN: entropy = data[SHA256_LEN: SHA256_LEN + SHA256_LEN] sig = data[SHA256_LEN + SHA256_LEN:] - else: - assert len(data) == SHA256_LEN + EC_SIGNATURE_RECOVERABLE_LEN + elif len(data) == SHA256_LEN + EC_SIGNATURE_RECOVERABLE_LEN: entropy = b'' sig = data[SHA256_LEN:] + else: + raise BadRequest() # The client_public_key also signs over any replay counter if replay_counter is not None: @@ -133,7 +136,8 @@ def _check_v2_anti_replay(cls, server_counter, client_counter): if server_counter is not None and client_counter is not None: server_counter = int.from_bytes(server_counter, byteorder='little', signed=False) client_counter = int.from_bytes(client_counter, byteorder='little', signed=False) - assert client_counter > server_counter + if client_counter <= server_counter: + raise BadRequest() @classmethod def _save_pin_fields(cls, pin_pubkey_hash, hash_pin_secret, aes_key, @@ -269,7 +273,8 @@ def set_pin(cls, cke, payload, aes_pin_data_key, replay_counter=None): # NOTE: we require client-passed entropy at this point pin_secret, entropy, pin_pubkey = cls._extract_fields(cke, payload, replay_counter) pin_pubkey_hash = bytes(sha256(pin_pubkey)) - assert entropy + if not entropy: + raise BadRequest() # Load any existing replay counter for the pubkey # and if found check the anti-replay counter diff --git a/server.py b/server.py index bb1a0a6..2e36354 100644 --- a/server.py +++ b/server.py @@ -2,6 +2,7 @@ from hmac import compare_digest import os from .lib import decrypt, encrypt, E_ECDH +from werkzeug.exceptions import BadRequest, SecurityError from wallycore import ec_private_key_verify, ec_sig_from_bytes, sha256, \ hmac_sha256, EC_FLAG_ECDSA, ec_private_key_bip341_tweak, ec_public_key_from_private_key @@ -72,7 +73,8 @@ def get_signed_public_key(self): def decrypt_request_payload(self, cke, encrypted, hmac): # Verify hmac received hmac_calculated = hmac_sha256(self.request_hmac_key, cke + encrypted) - assert compare_digest(hmac, hmac_calculated) + if not compare_digest(hmac, hmac_calculated): + raise SecurityError() # Return decrypted data return decrypt(self.request_encryption_key, encrypted) @@ -86,8 +88,11 @@ def encrypt_response_payload(self, payload): # Calls passed function with unwrapped payload, and wraps response before # returning. Separates payload handler func from wrapper encryption. def call_with_payload(self, cke, encrypted, hmac, func): - self.generate_shared_secrets(cke) - payload = self.decrypt_request_payload(cke, encrypted, hmac) + try: + self.generate_shared_secrets(cke) + payload = self.decrypt_request_payload(cke, encrypted, hmac) + except Exception as e: + raise BadRequest(e) # Call the passed function with the decrypted payload response = func(cke, payload, self._get_aes_pin_data_key()) @@ -126,6 +131,10 @@ def encrypt_response_payload(self, cke, payload): # Calls passed function with unwrapped payload, and wraps response before # returning. Separates payload handler func from wrapper encryption. def call_with_payload(self, cke, encrypted, func): - payload = self.decrypt_request_payload(cke, encrypted) + try: + payload = self.decrypt_request_payload(cke, encrypted) + except Exception as e: + raise BadRequest(e) + response = func(cke, payload, self._get_aes_pin_data_key(), self.replay_counter) return self.encrypt_response_payload(cke, response) diff --git a/test/test_ecdh_v1.py b/test/test_ecdh_v1.py index 31aa8c1..28b448c 100644 --- a/test/test_ecdh_v1.py +++ b/test/test_ecdh_v1.py @@ -1,6 +1,7 @@ import unittest import os +from werkzeug.exceptions import BadRequest from ..client import PINClientECDHv1 from ..server import PINServerECDHv1 @@ -162,14 +163,14 @@ def test_bad_request_cke_throws(self): server.decrypt_request_payload(cke, encrypted, hmac) # no error server.generate_shared_secrets(bad_cke) - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: server.decrypt_request_payload(bad_cke, encrypted, hmac) # error # Ensure call_with_payload() throws before it calls the handler fn def _func(client_key, payload, aes_pin_data_key): self.fail('should-never-get-here') - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: server.call_with_payload(bad_cke, encrypted, hmac, _func) def test_bad_request_hmac_throws(self): @@ -189,14 +190,14 @@ def test_bad_request_hmac_throws(self): # Ensure decrypt_request() throws server.generate_shared_secrets(cke) server.decrypt_request_payload(cke, encrypted, hmac) # no error - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: server.decrypt_request_payload(cke, encrypted, bad_hmac) # error # Ensure call_with_payload() throws before it calls the handler fn def _func(client_key, payload, aes_pin_data_key): self.fail('should-never-get-here') - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: server.call_with_payload(cke, encrypted, bad_hmac, _func) def test_bad_response_hmac_throws(self): @@ -221,7 +222,7 @@ def _func(client_key, payload, pin_data_aes_key): self.assertNotEqual(hmac, bad_hmac) client.decrypt_response_payload(encrypted, hmac) # No error - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: client.decrypt_response_payload(encrypted, bad_hmac) # error diff --git a/test/test_ecdh_v2.py b/test/test_ecdh_v2.py index bf39a88..cb6be28 100644 --- a/test/test_ecdh_v2.py +++ b/test/test_ecdh_v2.py @@ -1,6 +1,7 @@ import unittest import os +from werkzeug.exceptions import BadRequest from ..client import PINClientECDHv2 from ..server import PINServerECDHv2 @@ -134,7 +135,7 @@ def test_bad_request_cke_throws(self): def _func(client_key, payload, aes_pin_data_key): self.fail('should-never-get-here') - with self.assertRaises(ValueError) as cm: + with self.assertRaises(BadRequest) as cm: server.call_with_payload(bad_cke, encrypted, _func) def test_bad_request_counter_throws(self): @@ -159,7 +160,7 @@ def test_bad_request_counter_throws(self): def _func(client_key, payload, aes_pin_data_key): self.fail('should-never-get-here') - with self.assertRaises(ValueError) as cm: + with self.assertRaises(BadRequest) as cm: server.call_with_payload(cke, encrypted, _func) def test_bad_request_hmac_throws(self): @@ -186,7 +187,7 @@ def test_bad_request_hmac_throws(self): def _func(client_key, payload, aes_pin_data_key, replay_counter): self.fail('should-never-get-here') - with self.assertRaises(ValueError) as cm: + with self.assertRaises(BadRequest) as cm: server.call_with_payload(cke, bad_encrypted, _func) def test_bad_response_hmac_throws(self): diff --git a/test/test_pindb.py b/test/test_pindb.py index 2c6793b..e2f68d8 100644 --- a/test/test_pindb.py +++ b/test/test_pindb.py @@ -2,6 +2,7 @@ import os from hmac import compare_digest +from werkzeug.exceptions import BadRequest from ..pindb import PINDb from ..lib import E_ECDH @@ -428,12 +429,12 @@ def test_bad_v2_counter_breaks_set_pin(self): # Set-pin must also respect the counter v2_replay_counter = b'\x05\x00\x00\x00' payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) v2_replay_counter = b'\x00\x00\x00\x00' payload = self.make_payload(privkey, cke, pin_secret, entropy, v2_replay_counter) - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: aeskey_s = PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) # Key still present and readable as set failed @@ -513,7 +514,7 @@ def _test_client_entropy_impl(self, use_v2_protocol): payload = self.make_payload(sig_priv, cke, secret, b'', v2_replay_counter) # Verify trying to set-pin without entropy fails - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) # Get-pin should be fine without entropy @@ -525,10 +526,10 @@ def _test_client_entropy_impl(self, use_v2_protocol): for entropy in [self.new_entropy()[:-1], self.new_entropy() + b'\xab']: payload = self.make_payload(sig_priv, cke, secret, entropy, v2_replay_counter) - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: PINDb.set_pin(cke, payload, pin_aes_key, v2_replay_counter) - with self.assertRaises(AssertionError) as cm: + with self.assertRaises(BadRequest) as cm: PINDb.get_aes_key(cke, payload, pin_aes_key, v2_replay_counter) def test_client_entropy(self): diff --git a/test/test_pinserver.py b/test/test_pinserver.py index f3dd59b..cc62abf 100644 --- a/test/test_pinserver.py +++ b/test/test_pinserver.py @@ -405,11 +405,11 @@ def test_rejects_bad_payload_not_json(self): with self.assertRaises(ValueError) as cm: self.post('set_pin', urldata) - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) with self.assertRaises(ValueError) as cm: self.post('get_pin', urldata) - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) def _test_rejects_on_bad_json_impl(self, use_v2_protocol): # Make ourselves a static key pair for this logical client @@ -467,7 +467,7 @@ def _fn(d): self.make_server_call(priv_key, endpoint, pin_secret, self.new_entropy(), use_v2_protocol, mangler) - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) def test_rejects_on_bad_json(self): for use_v2_protocol in [False, True]: @@ -483,7 +483,7 @@ def _test_client_entropy_impl(self, use_v2_protocol): with self.assertRaises(ValueError) as cm: self.set_pin(priv_key, pin_secret, b'', use_v2_protocol) - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) # Set pin with client entropy - fine aeskey_s = self.set_pin(priv_key, pin_secret, self.new_entropy(), use_v2_protocol=False) @@ -494,7 +494,7 @@ def _test_client_entropy_impl(self, use_v2_protocol): aeskey_g = self.get_pin(priv_key, pin_secret, b'', use_v2_protocol) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) def test_client_entropy(self): for use_v2_protocol in [False, True]: @@ -513,14 +513,14 @@ def test_delayed_interaction_v1(self): aeskey_g = self.get_pin(priv_key, pin_secret, b'', use_v2_protocol=False) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - # If we delay in the server interaction it will fail with a 500 error + # If we delay in the server interaction it will fail with a 400 error client = self.new_client_v1() time.sleep(SESSION_LIFETIME + 1) # Sufficiently long delay with self.assertRaises(ValueError) as cm: self.server_call_v1(priv_key, client, 'get_pin', pin_secret, b'') - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) def test_cannot_reuse_client_session_v1(self): # Make ourselves a static key pair for this logical client @@ -538,12 +538,12 @@ def test_cannot_reuse_client_session_v1(self): self.new_entropy()) self.assertTrue(compare_digest(aeskey_g, aeskey_s)) - # Trying to reuse the session should fail with a 500 error + # Trying to reuse the session should fail with a 400 error # because the server has closed that ephemeral encryption session with self.assertRaises(ValueError) as cm: self.server_call_v1(priv_key, client, 'get_pin', pin_secret, b'') - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) # Not great, but we could reuse the client if we re-initiate handshake # (But that would use same cke which is not ideal/recommended.) @@ -575,7 +575,7 @@ def test_cannot_reuse_client_session_v2(self): with self.assertRaises(ValueError) as cm: aeskey_g = self.server_call_v2(priv_key, client, 'set_pin', self.new_pin_secret(), self.new_entropy()) - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) def test_set_pin_counter_v2(self): # Make ourselves a static key pair for this logical client @@ -596,14 +596,14 @@ def test_set_pin_counter_v2(self): with self.assertRaises(ValueError) as cm: aeskey_g = self.server_call_v2(priv_key, client, 'set_pin', self.new_pin_secret(), self.new_entropy()) - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) # Trying to set-pin with zero counter should fail client = self.new_client_v2(True) with self.assertRaises(ValueError) as cm: aeskey_g = self.server_call_v2(priv_key, client, 'set_pin', self.new_pin_secret(), self.new_entropy()) - self.assertEqual('500', str(cm.exception.args[0])) + self.assertEqual('400', str(cm.exception.args[0])) # Existing saved PIN undamaged as set attempt failed aeskey_g = self.get_pin(priv_key, pin_secret, b'', use_v2_protocol=True)