diff --git a/src/spake2/spake2.py b/src/spake2/spake2.py index 5071b5e..5758d99 100644 --- a/src/spake2/spake2.py +++ b/src/spake2/spake2.py @@ -25,6 +25,8 @@ class WrongGroupError(SPAKEError): pass class ReflectionThwarted(SPAKEError): """Someone tried to reflect our message back to us.""" +class KeyDegradationThwarted(SPAKEError): + """Someone tried to degrade our key material.""" SideA = b"A" SideB = b"B" @@ -113,6 +115,8 @@ def finish(self, inbound_side_and_message): # ) * self.xy_scalar pw_unblinding = self.my_unblinding().scalarmult(-self.pw_scalar) K_elem = inbound_elem.add(pw_unblinding).scalarmult(self.xy_scalar) + if K_elem is g.Zero: + raise KeyDegradationThwarted K_bytes = K_elem.to_bytes() key = self._finalize(K_bytes) return key diff --git a/src/spake2/test/test_spake2.py b/src/spake2/test/test_spake2.py index d0925fb..fc4736a 100644 --- a/src/spake2/test/test_spake2.py +++ b/src/spake2/test/test_spake2.py @@ -58,6 +58,15 @@ def test_reflect(self): reflected = b"B" + m1[1:] self.assertRaises(spake2.ReflectionThwarted, s1.finish, reflected) + def test_keydegradation(self): + pw = b"password" + s1 = SPAKE2_A(pw) + m1 = s1.start() + pw_scalar = s1.params.group.password_to_scalar(pw) + pw_blinding = s1.params.N.scalarmult(pw_scalar) + manipulatedmsg = b"B" + pw_blinding.to_bytes() + self.assertRaises(spake2.KeyDegradationThwarted, s1.finish, manipulatedmsg) + class OtherEntropy(unittest.TestCase): def test_entropy(self):