From 7a84bec13d27a89306b7b52cdf1cfbca4df2df8a Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 3 Nov 2024 20:31:59 +0700 Subject: [PATCH 01/13] Implement bulletproofs IPA and range proof --- Cargo.lock | 35 ++- Cargo.toml | 3 +- examples/example_range_proof.py | 18 ++ examples/example_rsa.py | 42 --- python/zksnake/bulletproofs/__init__.py | 0 python/zksnake/bulletproofs/ipa.py | 204 +++++++++++++++ python/zksnake/bulletproofs/range_proof.py | 286 +++++++++++++++++++++ python/zksnake/ecc.py | 11 +- python/zksnake/groth16/prover.py | 13 +- python/zksnake/{ => groth16}/qap.py | 24 +- python/zksnake/groth16/setup.py | 10 +- python/zksnake/parser.py | 18 +- python/zksnake/r1cs.py | 62 +++-- python/zksnake/transcript.py | 61 +++++ python/zksnake/utils.py | 5 +- src/bls12_381/curve.rs | 112 +++++--- src/bn254/curve.rs | 97 ++++--- tests/test_bulletproofs.py | 58 +++++ tests/test_r1cs_qap.py | 19 +- 19 files changed, 907 insertions(+), 171 deletions(-) create mode 100644 examples/example_range_proof.py delete mode 100644 examples/example_rsa.py create mode 100644 python/zksnake/bulletproofs/__init__.py create mode 100644 python/zksnake/bulletproofs/ipa.py create mode 100644 python/zksnake/bulletproofs/range_proof.py rename python/zksnake/{ => groth16}/qap.py (78%) create mode 100644 python/zksnake/transcript.py create mode 100644 tests/test_bulletproofs.py diff --git a/Cargo.lock b/Cargo.lock index f6cb124..466535c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -168,12 +168,30 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cpufeatures" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +dependencies = [ + "libc", +] + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -226,6 +244,7 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ + "block-buffer", "crypto-common", ] @@ -277,9 +296,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.154" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "lock_api" @@ -553,6 +572,17 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "smallvec" version = "1.13.2" @@ -731,4 +761,5 @@ dependencies = [ "pyo3", "rayon", "serde", + "sha2", ] diff --git a/Cargo.toml b/Cargo.toml index 6e421a2..e60e85b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ rayon = "1.10.0" serde = {version="1.0.200", features = ["derive"]} ark-serialize = { version = "0.4", features = ["derive"] } ark-bls12-381 = "0.4.0" +sha2 = "0.10.8" [features] -parallel = ["ark-ff/parallel", "ark-poly/parallel", "ark-ec/parallel", "ark-std/parallel"] \ No newline at end of file +parallel = ["ark-ff/parallel", "ark-poly/parallel", "ark-ec/parallel", "ark-std/parallel"] diff --git a/examples/example_range_proof.py b/examples/example_range_proof.py new file mode 100644 index 0000000..2f7e6a9 --- /dev/null +++ b/examples/example_range_proof.py @@ -0,0 +1,18 @@ +""" +Prove that v is in range of [0, 2^32-1] without revealing the value of v itself +using Inner Product Argument (Bulletproofs) +""" +from zksnake.bulletproofs.range_proof import Prover, Verifier + +bitsize = 32 +prover = Prover(bitsize, 'BN254') + +# secret value v +value = 133337 + +proof, commitment = prover.prove(value) +print("Proof:", proof.to_bytes().hex()) + +verifier = Verifier(bitsize, 'BN254') +assert verifier.verify(proof, commitment) +print("Proof is valid!") diff --git a/examples/example_rsa.py b/examples/example_rsa.py deleted file mode 100644 index f2dc880..0000000 --- a/examples/example_rsa.py +++ /dev/null @@ -1,42 +0,0 @@ -from zksnake.symbolic import Symbol -from zksnake.r1cs import ConstraintSystem - -from zksnake.groth16 import Setup, Prover, Verifier - -p = Symbol("p") -q = Symbol("q") -v0 = Symbol("v0") -v1 = Symbol("v1") -n = Symbol("n") - -# prove that we know p and q such that n == p*q -cs = ConstraintSystem(["p", "q"], ["n"]) -cs.add_constraint(v0 == 1 / (p - 1)) # make sure p != 1 -cs.add_constraint(v1 == 1 / (q - 1)) # make sure q != 1 -cs.add_constraint(n == p * q) - -cs.set_public(n) # value of n is public knowledge - -qap = cs.compile() - -pval = 64135289477071580278790190170577389084825014742943447208116859632024532344630238623598752668347708737661925585694639798853367 -qval = 33372027594978156556226010605355114227940760344767554666784520987023841729210037080257448673296881877565718986258036932062711 -nval = 2140324650240744961264423072839333563008614715144755017797754920881418023447140136643345519095804679610992851872470914587687396261921557363047454770520805119056493106687691590019759405693457452230589325976697471681738069364894699871578494975937497937 - -assert pval * qval == nval - -public_witness, private_witness = cs.solve({"p": pval, "q": qval}, {"n": nval}) - -setup = Setup(qap) - -pkey, vkey = setup.generate() - -prover = Prover(qap, pkey) -verifier = Verifier(vkey) - -proof = prover.prove(public_witness, private_witness) - -print("Proof:", proof.to_hex()) - -assert verifier.verify(proof, public_witness) -print("Proof is valid!") diff --git a/python/zksnake/bulletproofs/__init__.py b/python/zksnake/bulletproofs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/zksnake/bulletproofs/ipa.py b/python/zksnake/bulletproofs/ipa.py new file mode 100644 index 0000000..4bb0afe --- /dev/null +++ b/python/zksnake/bulletproofs/ipa.py @@ -0,0 +1,204 @@ +from ..utils import next_power_of_two, split_list +from ..transcript import FiatShamirTranscript, hash_to_curve, hash_to_scalar +from ..ecc import CurvePointSize, EllipticCurve + + +class InnerProductProof: + + def __init__(self, a: int, b: int, L: list, R: list): + self.a = a + self.b = b + self.L = L + self.R = R + + def to_bytes(self) -> bytes: + s = b"" + for _, (L, R) in enumerate(zip(self.L, self.R)): + s += bytes(L.to_bytes()) + s += bytes(R.to_bytes()) + + s += self.a.to_bytes(32, 'little') + s += self.b.to_bytes(32, 'little') + + return bytes(s) + + @classmethod + def from_bytes(cls, s: bytes, crv="BN254"): + + E = EllipticCurve(crv) + n = CurvePointSize[crv].value // 2 + + assert len(s) % n == 0, "Invalid proof length" + + Ls = [] + Rs = [] + + field_s = split_list(s[-64:], 32) + s = split_list(s[:-64], n) + + for i in range(0, len(s), 2): + Ls.append(E.from_hex(s[i].hex())) + Rs.append(E.from_hex(s[i+1].hex())) + + a = int.from_bytes(field_s[0], 'little') + b = int.from_bytes(field_s[1], 'little') + + return InnerProductProof(a, b, Ls, Rs) + + +class Prover: + + def __init__(self, size, curve, transcript: FiatShamirTranscript = None, seed=b'InnerProductProof', Q=None): + self.n = next_power_of_two(size) + self.E = EllipticCurve(curve) + self.G = hash_to_curve(seed, b'G', curve, self.n) + self.H = hash_to_curve(seed, b'H', curve, self.n) + self.Q = Q or hash_to_curve(seed, b'Q', curve, 1) + + self.transcript = transcript or FiatShamirTranscript( + self.n.to_bytes(32, 'big')) + + def __inner_product(self, a, b): + return sum(a * b for a, b in zip(a, b)) % self.E.order + + def __split_half(self, data: list): + if len(data) > 2: + mid_index = len(data) // 2 + return data[:mid_index], data[mid_index:] + elif len(data) == 2: + return [data[0]], [data[1]] + else: + return [data[0]], [] + + def prove(self, a: list, b: list): + + self.transcript.reset() + + # pad a and b to the size + a = a + [0 for _ in range(self.n - len(a))] + b = b + [0 for _ in range(self.n - len(b))] + + for g in self.G: + self.transcript.append(g.to_bytes()) + for h in self.H: + self.transcript.append(h.to_bytes()) + + ab = self.__inner_product(a, b) + + # vector commitment of Cp = + + * Q + Cp = self.E.multiexp(self.G, a) + \ + self.E.multiexp(self.H, b) + ab * self.Q + + L_list = [] + R_list = [] + u_list = [] + + n = self.n + G = self.G + H = self.H + + while n != 1: + n //= 2 + + a_low, a_hi = self.__split_half(a) + b_low, b_hi = self.__split_half(b) + G_low, G_hi = self.__split_half(G) + H_low, H_hi = self.__split_half(H) + + L = self.E.multiexp(G_hi, a_low) + \ + self.E.multiexp(H_low, b_hi) + \ + self.__inner_product(a_low, b_hi) * self.Q + R = self.E.multiexp(G_low, a_hi) + \ + self.E.multiexp(H_hi, b_low) + \ + self.__inner_product(a_hi, b_low) * self.Q + + L_list.append(L) + R_list.append(R) + + self.transcript.append(L.to_bytes()) + self.transcript.append(R.to_bytes()) + + u = hash_to_scalar( + self.transcript.get_challenge(), b'u', self.E.order) + u_inv = pow(u, -1, self.E.order) + u_list.append(u) + + for i in range(n): + a_low[i] = (a_low[i] * u + a_hi[i] * u_inv) % self.E.order + b_low[i] = (b_low[i] * u_inv + b_hi[i] * u) % self.E.order + + G_low[i] = self.E.multiexp([G_low[i], G_hi[i]], [u_inv, u]) + H_low[i] = self.E.multiexp([H_low[i], H_hi[i]], [u, u_inv]) + + a = a_low + b = b_low + + G = G_low + H = H_low + + a = a[0] + b = b[0] + + return InnerProductProof(a, b, L_list, R_list), Cp + + +class Verifier: + + def __init__(self, size, curve, transcript: FiatShamirTranscript = None, seed=b'InnerProductProof'): + self.n = next_power_of_two(size) + self.E = EllipticCurve(curve) + self.G = hash_to_curve(seed, b'G', curve, self.n) + self.H = hash_to_curve(seed, b'H', curve, self.n) + self.Q = hash_to_curve(seed, b'Q', curve, 1) + + self.transcript = transcript or FiatShamirTranscript( + self.n.to_bytes(32, 'big')) + + def verify(self, proof: InnerProductProof, commitment): + + self.transcript.reset() + assert len(proof.L) < 32, "Argument size is too big" + + for g in self.G: + self.transcript.append(g.to_bytes()) + for h in self.H: + self.transcript.append(h.to_bytes()) + + k = len(proof.L) + challenges = [] + challenges_inv = [] + + all_inv = 1 + for i in range(k): + self.transcript.append(proof.L[i].to_bytes()) + self.transcript.append(proof.R[i].to_bytes()) + + u = hash_to_scalar( + self.transcript.get_challenge(), b'u', self.E.order) + + challenges.append(pow(u, 2, self.E.order)) + challenges_inv.append(pow(u, -2, self.E.order)) + all_inv *= pow(u, -1, self.E.order) + + s = [all_inv] + for i in range(1, self.n): + lg_i = (32 - 1 - (32 - i.bit_length())) + l = 1 << lg_i + + u_lg_i_sq = challenges[(k - 1) - lg_i] + s.append(s[i - l] * u_lg_i_sq) + + a_s = [proof.a * x % self.E.order for x in s] + b_s_inv = [proof.b * pow(x, -1, self.E.order) % + self.E.order for x in s] + + sum_LR = self.E.curve.PointG1.identity() + for j in range(k): + sum_LR += proof.L[j] * challenges[j] + \ + proof.R[j] * challenges_inv[j] + + rhs = self.E.multiexp(self.G, a_s) + \ + self.E.multiexp(self.H, b_s_inv) + \ + proof.a * proof.b * self.Q - sum_LR + + return commitment == rhs diff --git a/python/zksnake/bulletproofs/range_proof.py b/python/zksnake/bulletproofs/range_proof.py new file mode 100644 index 0000000..1533236 --- /dev/null +++ b/python/zksnake/bulletproofs/range_proof.py @@ -0,0 +1,286 @@ +from ..utils import get_random_int, next_power_of_two, split_list +from ..polynomial import PolynomialRing +from ..ecc import CurvePointSize, EllipticCurve +from ..transcript import FiatShamirTranscript, hash_to_curve, hash_to_scalar +from . import ipa + +class RangeProof: + + def __init__(self, A, S, T1, T2, t, t_blinding, e_blinding, ipa_proof: ipa.InnerProductProof): + self.A = A + self.S = S + self.T1 = T1 + self.T2 = T2 + self.t = t + self.t_blinding = t_blinding + self.e_blinding = e_blinding + self.ipa_proof = ipa_proof + + def to_bytes(self) -> bytes: + s = b"" + s += bytes(self.A.to_bytes()) + s += bytes(self.S.to_bytes()) + s += bytes(self.T1.to_bytes()) + s += bytes(self.T2.to_bytes()) + s += bytes(self.t.to_bytes(32, 'little')) + s += bytes(self.t_blinding.to_bytes(32, 'little')) + s += bytes(self.e_blinding.to_bytes(32, 'little')) + s += self.ipa_proof.to_bytes() + + return s + + @classmethod + def from_bytes(cls, s: bytes, crv="BN254"): + + E = EllipticCurve(crv) + n = CurvePointSize[crv].value // 2 + + assert len(s) % n == 0, "Invalid proof length" + + point_s = split_list(s[:4*n], n) + field_s = split_list(s[4*n:4*n+32*3], 32) + ipa_s = s[4*n+32*3:] + + assert len(point_s) == 4 and len(field_s) == 3, "Malformed proof structure" + + A = E.from_hex(point_s[0].hex()) + S = E.from_hex(point_s[1].hex()) + T1 = E.from_hex(point_s[2].hex()) + T2 = E.from_hex(point_s[3].hex()) + t = int.from_bytes(field_s[0], 'little') + t_blinding = int.from_bytes(field_s[1], 'little') + e_blinding = int.from_bytes(field_s[2], 'little') + ipa_proof = ipa.InnerProductProof.from_bytes(ipa_s) + + return RangeProof(A, S, T1, T2, t, t_blinding, e_blinding, ipa_proof) + +class Prover: + + def __init__(self, bitsize: int, curve, transcript: FiatShamirTranscript = None, seed=b'RangeProof'): + assert bitsize < 2**32 + self.n = next_power_of_two(bitsize) + self.E = EllipticCurve(curve) + self.G = hash_to_curve(seed, b'G', curve, self.n) + self.H = hash_to_curve(seed, b'H', curve, self.n) + self.B = hash_to_curve(seed, b'B', curve, 1) + self.B_blinding = hash_to_curve(seed, b'Blinding', curve, 1) + + self.transcript = transcript or FiatShamirTranscript( + self.n.to_bytes(32, 'big')) + + def __inner_product(self, a, b): + return sum(a * b for a, b in zip(a, b)) % self.E.order + + def __split_lr(self, data: list): + l = [] + r = [] + for v in data: + l += [v] + r += [(v - 1) % self.E.order] + + return l, r + + def prove(self, v: int): + + # bit vectors of v + a = [(v >> i) & 1 for i in range(self.n)] + a_L, a_R = self.__split_lr(a) + + s_L = [get_random_int(self.E.order) for _ in range(self.n)] + s_R = [get_random_int(self.E.order) for _ in range(self.n)] + + a_blinding = get_random_int(self.E.order) + v_blinding = get_random_int(self.E.order) + s_blinding = get_random_int(self.E.order) + + V = v * self.B + v_blinding * self.B_blinding + A = self.E.multiexp(self.G, a_L) + self.E.multiexp(self.H, a_R) + a_blinding * self.B_blinding + S = self.E.multiexp(self.G, s_L) + self.E.multiexp(self.H, s_R) + s_blinding * self.B_blinding + + self.transcript.append(V.to_bytes()) + self.transcript.append(A.to_bytes()) + self.transcript.append(S.to_bytes()) + + y = hash_to_scalar(self.transcript.get_challenge(), b'y', self.E.order) + z = hash_to_scalar(self.transcript.get_challenge(), b'z', self.E.order) + + l_0 = [] + l_1 = [] + r_0 = [] + r_1 = [] + exp_2 = 1 + exp_y = 1 + for i in range(self.n): + l_0.append((a_L[i] - z) % self.E.order) + l_1.append(s_L[i]) + + r_0.append((exp_y * (a_R[i] + z) + z*z * exp_2) % self.E.order) + r_1.append(exp_y * s_R[i] % self.E.order) + + exp_y *= y + exp_2 += exp_2 + + l_vecpoly = [] + r_vecpoly = [] + + p = self.E.order + for i in range(self.n): + l_vecpoly += [PolynomialRing([l_0[i], l_1[i]], p)] + r_vecpoly += [PolynomialRing([r_0[i], r_1[i]], p)] + + t0 = self.__inner_product(l_0, r_0) + t2 = self.__inner_product(l_1, r_1) + + l0_plus_l1 = [(a + b) % p for a,b in zip(l_0, l_1)] + r0_plus_r1 = [(a + b) % p for a,b in zip(r_0, r_1)] + + t1 = (self.__inner_product(l0_plus_l1, r0_plus_r1) - t0 - t2) % p + + t_poly = PolynomialRing([t0, t1, t2], p) + + t1_blinding = get_random_int(p) + t2_blinding = get_random_int(p) + T1 = t1 * self.B + t1_blinding * self.B_blinding + T2 = t2 * self.B + t2_blinding * self.B_blinding + + self.transcript.append(T1.to_bytes()) + self.transcript.append(T2.to_bytes()) + + x = hash_to_scalar(self.transcript.get_challenge(), b'x', self.E.order) + + l_list = [poly(x) for poly in l_vecpoly] + r_list = [poly(x) for poly in r_vecpoly] + t = t_poly(x) + + t_blinding_poly = PolynomialRing([z*z * v_blinding, t1_blinding, t2_blinding], p) + t_blinding = t_blinding_poly(x) + e_blinding = (a_blinding + x * s_blinding) % p + + self.transcript.append(t) + self.transcript.append(t_blinding) + self.transcript.append(e_blinding) + + w = hash_to_scalar(self.transcript.get_challenge(), b'w', self.E.order) + + Q = w * self.B + + ipa_prover = ipa.Prover(self.n, self.E.name, self.transcript) + + ipa_prover.G = self.G + ipa_prover.H = [pow(y, -i, p) * self.H[i] for i in range(self.n)] + ipa_prover.Q = Q + + ipa_proof, _ = ipa_prover.prove(l_list, r_list) + + return RangeProof(A, S, T1, T2, t, t_blinding, e_blinding, ipa_proof), V + + +class Verifier: + + def __init__(self, bitsize: int, curve, transcript: FiatShamirTranscript = None, seed=b'RangeProof'): + assert bitsize < 2**32 + self.n = bitsize + self.E = EllipticCurve(curve) + self.G = hash_to_curve(seed, b'G', curve, self.n) + self.H = hash_to_curve(seed, b'H', curve, self.n) + self.B = hash_to_curve(seed, b'B', curve, 1) + self.B_blinding = hash_to_curve(seed, b'Blinding', curve, 1) + + self.transcript = transcript or FiatShamirTranscript( + self.n.to_bytes(32, 'big')) + + def __delta(self, y, z): + sum_pow_2_y = sum([pow(y, i, self.E.order) for i in range(self.n)]) % self.E.order + z_pow_3 = pow(z, 3, self.E.order) + sum_2 = sum([pow(2, i, self.E.order) for i in range(self.n)]) % self.E.order + return (((z - pow(z, 2, self.E.order)) * sum_pow_2_y) - (z_pow_3 * sum_2)) % self.E.order + + def verify(self, proof: RangeProof, commitment): + + self.transcript.append(commitment.to_bytes()) + self.transcript.append(proof.A.to_bytes()) + self.transcript.append(proof.S.to_bytes()) + + y = hash_to_scalar(self.transcript.get_challenge(), b'y', self.E.order) + z = hash_to_scalar(self.transcript.get_challenge(), b'z', self.E.order) + + self.transcript.append(proof.T1.to_bytes()) + self.transcript.append(proof.T2.to_bytes()) + + x = hash_to_scalar(self.transcript.get_challenge(), b'x', self.E.order) + + self.transcript.append(proof.t) + self.transcript.append(proof.t_blinding) + self.transcript.append(proof.e_blinding) + + w = hash_to_scalar(self.transcript.get_challenge(), b'w', self.E.order) + + self.transcript.reset() + + for g in self.G: + self.transcript.append(g.to_bytes()) + for i, h in enumerate(self.H): + hprime = pow(y, -i, self.E.order) * h + self.transcript.append(hprime.to_bytes()) + + c = get_random_int(self.E.order) + + k = len(proof.ipa_proof.L) + challenges = [] + challenges_inv = [] + + all_inv = 1 + for i in range(k): + self.transcript.append(proof.ipa_proof.L[i].to_bytes()) + self.transcript.append(proof.ipa_proof.R[i].to_bytes()) + + u = hash_to_scalar( + self.transcript.get_challenge(), b'u', self.E.order) + + challenges.append(pow(u, 2, self.E.order)) + challenges_inv.append(pow(u, -2, self.E.order)) + all_inv *= pow(u, -1, self.E.order) + + s = [all_inv] + for i in range(1, self.n): + lg_i = (32 - 1 - (32 - i.bit_length())) + l = 1 << lg_i + + u_lg_i_sq = challenges[(k - 1) - lg_i] + s.append(s[i - l] * u_lg_i_sq) + + a = proof.ipa_proof.a + b = proof.ipa_proof.b + + scalar_mul_g = [(-z - a*s[i]) % self.E.order for i in range(self.n)] + scalar_mul_h = [] + + for i in range(self.n): + s_inv = pow(s[i], -1, self.E.order) + rhs = z*z * pow(2, i, self.E.order) - b * s_inv + + scalar_mul_h += [(z + pow(y, -i, self.E.order) * rhs) % self.E.order] + + points = [ + proof.A, + proof.S, + commitment, + proof.T1, + proof.T2, + self.B, + self.B_blinding, + ] + self.G + self.H + proof.ipa_proof.L + proof.ipa_proof.R + + scalars = [ + 1, + x, + c * z*z % self.E.order, + c * x % self.E.order, + c * x*x % self.E.order, + (w*(proof.t - a*b) + c*(self.__delta(y, z) - proof.t)) % self.E.order, + (-proof.e_blinding - c*proof.t_blinding) % self.E.order, + ] + scalar_mul_g + scalar_mul_h + challenges + challenges_inv + + final_check = self.E.multiexp(points, scalars) + + return final_check.is_zero() \ No newline at end of file diff --git a/python/zksnake/ecc.py b/python/zksnake/ecc.py index a1f3108..778b756 100644 --- a/python/zksnake/ecc.py +++ b/python/zksnake/ecc.py @@ -9,12 +9,19 @@ class CurveType(Enum): ALT_BN128 = ec_bn254 BLS12_381 = ec_bls12_381 - +P_BN254 = 21888242871839275222246405745257275088696311157297823662689037894645226208583 Q_BN254 = 21888242871839275222246405745257275088548364400416034343698204186575808495617 + +P_BLS12_381 = 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 Q_BLS12_381 = ( 52435875175126190479447740508185965837690552500527637822603658699938581184513 ) +class CurveField(Enum): + BN128 = P_BN254 + BN254 = P_BN254 + ALT_BN128 = P_BN254 + BLS12_381 = P_BLS12_381 class CurveOrder(Enum): BN128 = Q_BN254 @@ -22,7 +29,6 @@ class CurveOrder(Enum): ALT_BN128 = Q_BN254 BLS12_381 = Q_BLS12_381 - class CurvePointSize(Enum): BN128 = 64 BN254 = 64 @@ -35,6 +41,7 @@ def __init__(self, curve: str): self.name = curve self.curve = CurveType[curve].value self.order = CurveOrder[curve].value + self.field_modulus = CurveField[curve].value def G1(self): return self.curve.g1() diff --git a/python/zksnake/groth16/prover.py b/python/zksnake/groth16/prover.py index 4a50436..e0548ad 100644 --- a/python/zksnake/groth16/prover.py +++ b/python/zksnake/groth16/prover.py @@ -1,7 +1,8 @@ """Proving module of Groth16 protocol""" +from zksnake.r1cs import R1CS from ..ecc import EllipticCurve, CurvePointSize -from ..qap import QAP +from .qap import QAP from ..utils import get_random_int, split_list @@ -167,24 +168,26 @@ class Prover: Prover object Args: - qap: QAP to be proved from + r1cs: R1CS to be proved from key: `ProvingKey` from trusted setup curve: `BN254` or `BLS12_381` """ - def __init__(self, qap: QAP, key: ProvingKey, curve: str = "BN254"): + def __init__(self, r1cs: R1CS, key: ProvingKey, curve: str = "BN254"): - self.qap = qap self.key = key self.E = EllipticCurve(curve) self.order = self.E.order + self.qap = QAP(self.order) + self.qap.from_r1cs(r1cs) + if key.delta_1.is_zero() or key.delta_2.is_zero(): raise ValueError("Key delta_1 or delta_2 is zero element!") def prove(self, public_witness: list, private_witness: list) -> Proof: """ - Prove statement from QAP by providing public and private witness + Prove statement from R1CS by providing public and private witness """ assert len(self.key.kdelta_1) == len( private_witness diff --git a/python/zksnake/qap.py b/python/zksnake/groth16/qap.py similarity index 78% rename from python/zksnake/qap.py rename to python/zksnake/groth16/qap.py index 8a839c5..bd300c6 100644 --- a/python/zksnake/qap.py +++ b/python/zksnake/groth16/qap.py @@ -1,5 +1,6 @@ -from .array import SparseArray -from .polynomial import ( +from zksnake.ecc import Q_BN254 +from zksnake.r1cs import R1CS +from ..polynomial import ( PolynomialRing, ifft, fft, @@ -9,29 +10,28 @@ class QAP: - def __init__(self, p): + def __init__(self, p=None): self.a = [] self.b = [] self.c = [] self.n_public = 0 - self.p = p + self.p = p or Q_BN254 - def from_r1cs(self, A: SparseArray, B: SparseArray, C: SparseArray, n_public: int): + def from_r1cs(self, r1cs: R1CS): """ Parse QAP from R1CS matrices Args: - A, B, C: matrix A,B,C from R1CS - n_public: number of public variables in R1CS + r1cs: R1CS object """ - self.n_public = n_public + self.n_public = r1cs.n_public - next_power_2 = 1 << (A.n_row - 1).bit_length() + next_power_2 = 1 << (r1cs.A.n_row - 1).bit_length() - self.a = A - self.b = B - self.c = C + self.a = r1cs.A + self.b = r1cs.B + self.c = r1cs.C self.a.n_row = next_power_2 self.b.n_row = next_power_2 diff --git a/python/zksnake/groth16/setup.py b/python/zksnake/groth16/setup.py index 660ab96..84b532a 100644 --- a/python/zksnake/groth16/setup.py +++ b/python/zksnake/groth16/setup.py @@ -1,8 +1,9 @@ """Trusted setup module of Groth16 protocol""" from joblib import Parallel, delayed +from zksnake.r1cs import R1CS -from ..qap import QAP +from .qap import QAP from ..ecc import EllipticCurve from ..polynomial import ( evaluate_vanishing_polynomial, @@ -15,17 +16,18 @@ class Setup: - def __init__(self, qap: QAP, curve: str = "BN254"): + def __init__(self, r1cs: R1CS, curve: str = "BN254"): """ Trusted setup object Args: - qap: QAP to be set up from + r1cs: R1CS to be set up from curve: `BN254` or `BLS12_381` """ - self.qap = qap self.E = EllipticCurve(curve) self.order = self.E.order + self.qap = QAP(self.order) + self.qap.from_r1cs(r1cs) def generate(self) -> tuple[ProvingKey, VerifyingKey]: """Generate `ProvingKey` and `VerifyingKey`""" diff --git a/python/zksnake/parser.py b/python/zksnake/parser.py index 5d3df36..9735b1f 100644 --- a/python/zksnake/parser.py +++ b/python/zksnake/parser.py @@ -97,7 +97,8 @@ def __read_constraint_section(self, content: BytesIO): n_a = int.from_bytes(content.read(4), "little") for _ in range(n_a): wire_id = int.from_bytes(content.read(4), "little") - factor = int.from_bytes(content.read(self.header["fs"]), "little") + factor = int.from_bytes( + content.read(self.header["fs"]), "little") sym = self.wires[wire_id] if a: @@ -108,7 +109,8 @@ def __read_constraint_section(self, content: BytesIO): n_b = int.from_bytes(content.read(4), "little") for _ in range(n_b): wire_id = int.from_bytes(content.read(4), "little") - factor = int.from_bytes(content.read(self.header["fs"]), "little") + factor = int.from_bytes( + content.read(self.header["fs"]), "little") sym = self.wires[wire_id] if b: @@ -119,7 +121,8 @@ def __read_constraint_section(self, content: BytesIO): n_c = int.from_bytes(content.read(4), "little") for _ in range(n_c): wire_id = int.from_bytes(content.read(4), "little") - factor = int.from_bytes(content.read(self.header["fs"]), "little") + factor = int.from_bytes( + content.read(self.header["fs"]), "little") sym = self.wires[wire_id] if rhs_c: @@ -180,7 +183,8 @@ def __construct_constraints(self): private_inputs = [ Symbol(f"priv{i+1}") for i in range(self.header["n_priv_in"]) ] - outputs = [Symbol(f"out{i+1}") for i in range(self.header["n_pub_out"])] + outputs = [Symbol(f"out{i+1}") + for i in range(self.header["n_pub_out"])] n_intermediate = self.header["n_wires"] - ( self.header["n_pub_in"] @@ -188,10 +192,12 @@ def __construct_constraints(self): + self.header["n_pub_out"] + 1 ) - intermediate_vars = [Symbol(f"v{i+1}") for i in range(n_intermediate)] + intermediate_vars = [Symbol(f"v{i+1}") + for i in range(n_intermediate)] self.wires = ( - [1] + outputs + public_inputs + private_inputs + intermediate_vars + [1] + outputs + public_inputs + + private_inputs + intermediate_vars ) for constraint in self.raw_constraints: diff --git a/python/zksnake/r1cs.py b/python/zksnake/r1cs.py index 99f389c..bdb5303 100644 --- a/python/zksnake/r1cs.py +++ b/python/zksnake/r1cs.py @@ -5,7 +5,6 @@ from .symbolic import Symbol, SymbolArray, Equation, symeval, get_unassigned_var from .array import SparseArray from .ecc import EllipticCurve -from .qap import QAP from .parser import R1CSReader from .utils import get_n_jobs @@ -333,7 +332,8 @@ def __add_var(self, eq: Symbol): self._BaseConstraint__add_var(eq) # pylint: disable=no-member def __get_witness_vector(self): - public_input = [v for v in self.vars if v in self.inputs and v in self.public] + public_input = [ + v for v in self.vars if v in self.inputs and v in self.public] private_input = [ v for v in self.vars @@ -411,13 +411,15 @@ def __consume_constraint_stack(self, constraints_stack: list): target = target_l coeff = coeff_l left = target - multiplier = pow(symeval(r, self.vars, self.p), -1, self.p) + multiplier = pow( + symeval(r, self.vars, self.p), -1, self.p) elif not target_l and target_r: target = target_r coeff = coeff_r left = target - multiplier = pow(symeval(l, self.vars, self.p), -1, self.p) + multiplier = pow( + symeval(l, self.vars, self.p), -1, self.p) else: raise ValueError() @@ -437,10 +439,12 @@ def __consume_constraint_stack(self, constraints_stack: list): # there will be 4 possible values in total: # [val_1, -val_1, val_2, -val2] val_1 = ( - (evaluated_right - diff) * inv_coeff * multiplier % self.p + (evaluated_right - diff) * + inv_coeff * multiplier % self.p ) val_2 = ( - (evaluated_right + diff) * inv_coeff * multiplier % self.p + (evaluated_right + diff) * + inv_coeff * multiplier % self.p ) for v in (val_1, -val_1, val_2, -val_2): @@ -476,10 +480,12 @@ def __consume_constraint_stack(self, constraints_stack: list): self.vars[target.name] = 0 eval_l = ( - l if isinstance(l, int) else symeval(l, self.vars, self.p) + l if isinstance(l, int) else symeval( + l, self.vars, self.p) ) eval_r = ( - r if isinstance(r, int) else symeval(r, self.vars, self.p) + r if isinstance(r, int) else symeval( + r, self.vars, self.p) ) if not target_l: @@ -546,7 +552,8 @@ def __consume_hint(self): if isinstance(arg, Symbol) and self.vars.get(arg.name) is None: break elif ( - isinstance(arg, Symbol) and self.vars.get(arg.name) is not None + isinstance(arg, Symbol) and self.vars.get( + arg.name) is not None ): evaluated_args.append(self.vars[arg.name]) else: @@ -559,7 +566,8 @@ def evaluate(self, input_values: dict, output_values: dict = None) -> bool: """Evaluate the constraint system with given inputs and output""" output_values = output_values or {} if len(input_values) != len(self.inputs): - raise ValueError("Length of input values differ with input variables") + raise ValueError( + "Length of input values differ with input variables") for k, _ in self.vars.items(): self.vars[k] = None @@ -599,12 +607,12 @@ def __add_dummy_constraints(self): eq = 0 == var * 0 self.add_constraint(eq) - def compile(self) -> QAP: + def compile(self) -> R1CS: """ - Compile R1CS into Quadratic Arithmetic Program (QAP) + Compile list of constraints into R1CS Returns: - qap: QAP object of the constraint system + r1cs: R1CS object """ self.__add_dummy_constraints() witness = self.__get_witness_vector() @@ -621,7 +629,7 @@ def compile(self) -> QAP: else: n_job = 1 - result = Parallel(n_jobs=n_job)( + result = Parallel(n_jobs=n_job, max_nbytes="100M")( delayed(consume_constraint)(row, constraint, witness, self.p) for row, constraint in enumerate(self.constraints) ) @@ -631,10 +639,7 @@ def compile(self) -> QAP: B.append(row[1]) C.append(row[2]) - qap = QAP(self.p) - qap.from_r1cs(A, B, C, len(self.public) + 1) - - return qap + return R1CS(A, B, C, len(self.public) + 1) def solve(self, input_values: dict, output_value: dict = None) -> list: """ @@ -652,11 +657,12 @@ def solve(self, input_values: dict, output_value: dict = None) -> list: witness = self.__get_witness_vector() if not self.evaluate(input_values, output_value): - raise ValueError("Evaluated constraints are not satisfied with given input") + raise ValueError( + "Evaluated constraints are not satisfied with given input") w = self.__evaluate_witness_vector(witness) - return w[: len(self.public) + 1], w[len(self.public) + 1 :] + return w[: len(self.public) + 1], w[len(self.public) + 1:] @classmethod def from_file(cls, r1csfile: str, symfile: str = None): @@ -685,3 +691,19 @@ def from_file(cls, r1csfile: str, symfile: str = None): def to_file(self, filepath): raise NotImplementedError + + +class R1CS: + + def __init__(self, A: SparseArray, B: SparseArray, C: SparseArray, n_public: int): + self.A = A + self.B = B + self.C = C + self.n_public = n_public + + def to_bytes(self): + raise NotImplementedError + + @classmethod + def from_bytes(cls, data): + raise NotImplementedError diff --git a/python/zksnake/transcript.py b/python/zksnake/transcript.py new file mode 100644 index 0000000..14ce93d --- /dev/null +++ b/python/zksnake/transcript.py @@ -0,0 +1,61 @@ +import hashlib +from .ecc import EllipticCurve + + +def hash_to_scalar(data: bytes, domain_separation_tag: bytes, modulus: int, alg: str = 'sha256'): + h = hashlib.new(alg) + h.update(domain_separation_tag) + h.update(data) + + return int.from_bytes(h.digest(), 'big') % modulus + + +def hash_to_curve(data: bytes, domain_separation_tag: bytes, curve: str = 'BN254', size: int = 1, alg: str = 'sha256'): + + E = EllipticCurve(curve) + + h = hashlib.new(alg) + h.update(domain_separation_tag) + h.update(data) + + points = [] + for _ in range(size): + while True: + digest = h.digest() + h.update(digest) + + try: + point = E.curve.PointG1.hash_to_curve(digest) + points.append(point) + break + except ValueError: + pass + + return points[0] if size == 1 else points + +class FiatShamirTranscript: + + def __init__(self, label: bytes, alg='sha256'): + self.alg = alg + self.label = label + self.hasher = hashlib.new(alg, label) + self.state = [] + + def reset(self): + self.hasher = hashlib.new(self.alg, self.label) + + def append(self, data): + + if isinstance(data, bytes): + self.hasher.update(data) + elif isinstance(data, str): + self.hasher.update(data.encode()) + elif isinstance(data, int): + data = int.to_bytes(data, data.bit_length(), 'big') + self.hasher.update(data) + elif data and isinstance(data, list) and isinstance(data[0], int): + self.hasher.update(bytes(data)) + + def get_challenge(self): + digest = self.hasher.digest() + return digest diff --git a/python/zksnake/utils.py b/python/zksnake/utils.py index fb28711..f523a2d 100644 --- a/python/zksnake/utils.py +++ b/python/zksnake/utils.py @@ -1,7 +1,7 @@ import os import random import time - +import hashlib def get_random_int(n_max): """Get random integer in [1, n_max] range""" @@ -22,6 +22,9 @@ def split_list(data, n): """Split data into n chunks""" return [data[i : i + n] for i in range(0, len(data), n)] +def next_power_of_two(n: int): + """Get next 2^x number from n""" + return 1 << (n - 1).bit_length() class Timer: def __init__(self, name): diff --git a/src/bls12_381/curve.rs b/src/bls12_381/curve.rs index 243228b..782b076 100644 --- a/src/bls12_381/curve.rs +++ b/src/bls12_381/curve.rs @@ -1,13 +1,18 @@ -use ark_bls12_381::{Bls12_381, Fr, G1Affine, G1Projective, G2Affine, G2Projective}; +use ark_bls12_381::{ g1::Config, Bls12_381, Fr, G1Affine, G1Projective, G2Affine, G2Projective }; use ark_ec::{ - pairing::{Pairing, PairingOutput}, - AffineRepr, CurveGroup, Group, VariableBaseMSM, + hashing::{ curve_maps::wb::WBMap, map_to_curve_hasher::MapToCurveBasedHasher, HashToCurve }, + pairing::{ Pairing, PairingOutput }, + short_weierstrass::Projective, + AffineRepr, + CurveGroup, + Group, + VariableBaseMSM, }; -use ark_ff::{QuadExtField, Zero}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_ff::{ field_hashers::DefaultFieldHasher, QuadExtField, Zero }; +use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize }; use num_bigint::BigUint; -use pyo3::{exceptions::PyValueError, prelude::*, types::PyType}; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use pyo3::{ exceptions::PyValueError, prelude::*, types::PyType }; +use rayon::iter::{ IntoParallelIterator, ParallelIterator }; #[pyclass] #[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] @@ -100,7 +105,10 @@ impl PointG1 { pub fn to_hex(&self) -> PyResult { let mut b = Vec::new(); let _ = self.point.serialize_compressed(&mut b); - let hex_string: String = b.iter().map(|byte| format!("{:02x}", byte)).collect(); + let hex_string: String = b + .iter() + .map(|byte| format!("{:02x}", byte)) + .collect(); Ok(hex_string) } @@ -114,15 +122,51 @@ impl PointG1 { #[classmethod] pub fn from_bytes(_cls: &PyType, hex: Vec) -> PyResult { match G1Affine::deserialize_compressed(&*hex) { - Err(e) => Err(PyValueError::new_err(format!( - "Cannot deserialize point: {}", - e.to_string() - ))), - Ok(point) => Ok(PointG1 { - point: point.into(), - }), + Err(e) => + Err(PyValueError::new_err(format!("Cannot deserialize point: {}", e.to_string()))), + Ok(point) => + Ok(PointG1 { + point: point.into(), + }), } } + + #[classmethod] + pub fn hash_to_curve(_cls: &PyType, data: Vec) -> PyResult { + use sha2::Sha256; + let hasher = MapToCurveBasedHasher::< + Projective, + DefaultFieldHasher, + WBMap + > + ::new(&[1]) + .unwrap(); + + let point = hasher.hash(&data).unwrap(); + Ok(PointG1 { + point: point.into(), + }) + } + + #[classmethod] + pub fn from_x(_cls: &PyType, x: BigUint) -> PyResult { + match G1Affine::get_point_from_x_unchecked(x.into(), true) { + Some(e) => { + if e.is_on_curve() && e.is_in_correct_subgroup_assuming_on_curve() { + return Ok(PointG1 { point: e.into() }); + } + Err(PyValueError::new_err(format!("Point is not on curve"))) + } + None => Err(PyValueError::new_err(format!("Cannot found point"))), + } + } + + #[classmethod] + pub fn identity(_cls: &PyType) -> PyResult { + Ok(PointG1 { + point: G1Affine::identity().into(), + }) + } } #[pyclass] @@ -225,7 +269,10 @@ impl PointG2 { pub fn to_hex(&self) -> PyResult { let mut b = Vec::new(); let _ = self.point.serialize_compressed(&mut b); - let hex_string: String = b.iter().map(|byte| format!("{:02x}", byte)).collect(); + let hex_string: String = b + .iter() + .map(|byte| format!("{:02x}", byte)) + .collect(); Ok(hex_string) } @@ -239,13 +286,12 @@ impl PointG2 { #[classmethod] pub fn from_bytes(_cls: &PyType, hex: Vec) -> PyResult { match G2Affine::deserialize_compressed(&*hex) { - Err(e) => Err(PyValueError::new_err(format!( - "Cannot deserialize point: {}", - e.to_string() - ))), - Ok(point) => Ok(PointG2 { - point: point.into(), - }), + Err(e) => + Err(PyValueError::new_err(format!("Cannot deserialize point: {}", e.to_string()))), + Ok(point) => + Ok(PointG2 { + point: point.into(), + }), } } } @@ -253,7 +299,7 @@ impl PointG2 { #[pyfunction] pub fn batch_multi_scalar_g1( points: Vec, - scalars: Vec, + scalars: Vec ) -> PyResult> { let result: Vec = (&points, &scalars) .into_par_iter() @@ -268,7 +314,7 @@ pub fn batch_multi_scalar_g1( #[pyfunction] pub fn batch_multi_scalar_g2( points: Vec, - scalars: Vec, + scalars: Vec ) -> PyResult> { let result: Vec = (&points, &scalars) .into_par_iter() @@ -284,7 +330,7 @@ pub fn batch_multi_scalar_g2( pub fn multiscalar_mul_g1(points: Vec, scalars: Vec) -> PyResult { let mut fr_scalars: Vec = vec![]; for scalar in scalars { - fr_scalars.push(Fr::from(scalar)) + fr_scalars.push(Fr::from(scalar)); } let mut affine_points: Vec = vec![]; for point in points { @@ -293,9 +339,7 @@ pub fn multiscalar_mul_g1(points: Vec, scalars: Vec) -> PyResu let r = G1Projective::msm(&affine_points, &fr_scalars); match r { Ok(r) => Ok(PointG1 { point: r }), - Err(_) => Err(PyValueError::new_err(format!( - "Number of points and scalars mismatch" - ))), + Err(_) => Err(PyValueError::new_err(format!("Number of points and scalars mismatch"))), } } @@ -303,7 +347,7 @@ pub fn multiscalar_mul_g1(points: Vec, scalars: Vec) -> PyResu pub fn multiscalar_mul_g2(points: Vec, scalars: Vec) -> PyResult { let mut fr_scalars: Vec = vec![]; for scalar in scalars { - fr_scalars.push(Fr::from(scalar)) + fr_scalars.push(Fr::from(scalar)); } let mut affine_points: Vec = vec![]; for point in points { @@ -312,9 +356,7 @@ pub fn multiscalar_mul_g2(points: Vec, scalars: Vec) -> PyResu let r = G2Projective::msm(&affine_points, &fr_scalars); match r { Ok(r) => Ok(PointG2 { point: r }), - Err(_) => Err(PyValueError::new_err(format!( - "Number of points and scalars mismatch" - ))), + Err(_) => Err(PyValueError::new_err(format!("Number of points and scalars mismatch"))), } } @@ -353,10 +395,10 @@ pub fn multi_pairing(a: Vec, b: Vec) -> PyResult { let mut point1: Vec = vec![]; let mut point2: Vec = vec![]; for p in a { - point1.push(p.point) + point1.push(p.point); } for p in b { - point2.push(p.point) + point2.push(p.point); } Ok(PointG12 { point: Bls12_381::multi_pairing(point1, point2), diff --git a/src/bn254/curve.rs b/src/bn254/curve.rs index 2b533dc..e3fa2ce 100644 --- a/src/bn254/curve.rs +++ b/src/bn254/curve.rs @@ -1,13 +1,10 @@ -use ark_bn254::{Bn254, Fr, G1Affine, G1Projective, G2Affine, G2Projective}; -use ark_ec::{ - pairing::{Pairing, PairingOutput}, - AffineRepr, CurveGroup, Group, VariableBaseMSM, -}; -use ark_ff::{QuadExtField, Zero}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_bn254::{ Bn254, Fr, G1Affine, G1Projective, G2Affine, G2Projective }; +use ark_ec::{ pairing::{ Pairing, PairingOutput }, AffineRepr, CurveGroup, Group, VariableBaseMSM }; +use ark_ff::{ QuadExtField, Zero }; +use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize }; use num_bigint::BigUint; -use pyo3::{exceptions::PyValueError, prelude::*, types::PyType}; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use pyo3::{ exceptions::PyValueError, prelude::*, types::PyType }; +use rayon::iter::{ IntoParallelIterator, ParallelIterator }; #[pyclass] #[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] @@ -107,7 +104,10 @@ impl PointG1 { pub fn to_hex(&self) -> PyResult { let mut b = Vec::new(); let _ = self.point.serialize_compressed(&mut b); - let hex_string: String = b.iter().map(|byte| format!("{:02x}", byte)).collect(); + let hex_string: String = b + .iter() + .map(|byte| format!("{:02x}", byte)) + .collect(); Ok(hex_string) } @@ -121,15 +121,40 @@ impl PointG1 { #[classmethod] pub fn from_bytes(_cls: &PyType, hex: Vec) -> PyResult { match G1Affine::deserialize_compressed(&*hex) { - Err(e) => Err(PyValueError::new_err(format!( - "Cannot deserialize point: {}", - e.to_string() - ))), - Ok(point) => Ok(PointG1 { - point: point.into(), - }), + Err(e) => + Err(PyValueError::new_err(format!("Cannot deserialize point: {}", e.to_string()))), + Ok(point) => + Ok(PointG1 { + point: point.into(), + }), } } + + #[classmethod] + pub fn hash_to_curve(_cls: &PyType, data: Vec) -> PyResult { + let x = BigUint::from_bytes_be(&data); + Self::from_x(_cls, x) + } + + #[classmethod] + pub fn from_x(_cls: &PyType, x: BigUint) -> PyResult { + match G1Affine::get_point_from_x_unchecked(x.into(), true) { + Some(e) => { + if e.is_on_curve() && e.is_in_correct_subgroup_assuming_on_curve() { + return Ok(PointG1 { point: e.into() }); + } + Err(PyValueError::new_err(format!("Point is not on curve"))) + } + None => Err(PyValueError::new_err(format!("Cannot found point"))), + } + } + + #[classmethod] + pub fn identity(_cls: &PyType) -> PyResult { + Ok(PointG1 { + point: G1Affine::identity().into(), + }) + } } #[pyclass] @@ -239,7 +264,10 @@ impl PointG2 { pub fn to_hex(&self) -> PyResult { let mut b = Vec::new(); let _ = self.point.serialize_compressed(&mut b); - let hex_string: String = b.iter().map(|byte| format!("{:02x}", byte)).collect(); + let hex_string: String = b + .iter() + .map(|byte| format!("{:02x}", byte)) + .collect(); Ok(hex_string) } @@ -253,13 +281,12 @@ impl PointG2 { #[classmethod] pub fn from_bytes(_cls: &PyType, hex: Vec) -> PyResult { match G2Affine::deserialize_compressed(&*hex) { - Err(e) => Err(PyValueError::new_err(format!( - "Cannot deserialize point: {}", - e.to_string() - ))), - Ok(point) => Ok(PointG2 { - point: point.into(), - }), + Err(e) => + Err(PyValueError::new_err(format!("Cannot deserialize point: {}", e.to_string()))), + Ok(point) => + Ok(PointG2 { + point: point.into(), + }), } } } @@ -267,7 +294,7 @@ impl PointG2 { #[pyfunction] pub fn batch_multi_scalar_g1( points: Vec, - scalars: Vec, + scalars: Vec ) -> PyResult> { let result: Vec = (&points, &scalars) .into_par_iter() @@ -282,7 +309,7 @@ pub fn batch_multi_scalar_g1( #[pyfunction] pub fn batch_multi_scalar_g2( points: Vec, - scalars: Vec, + scalars: Vec ) -> PyResult> { let result: Vec = (&points, &scalars) .into_par_iter() @@ -298,7 +325,7 @@ pub fn batch_multi_scalar_g2( pub fn multiscalar_mul_g1(points: Vec, scalars: Vec) -> PyResult { let mut fr_scalars: Vec = vec![]; for scalar in scalars { - fr_scalars.push(Fr::from(scalar)) + fr_scalars.push(Fr::from(scalar)); } let mut affine_points: Vec = vec![]; for point in points { @@ -307,9 +334,7 @@ pub fn multiscalar_mul_g1(points: Vec, scalars: Vec) -> PyResu let r = G1Projective::msm(&affine_points, &fr_scalars); match r { Ok(r) => Ok(PointG1 { point: r }), - Err(_) => Err(PyValueError::new_err(format!( - "Number of points and scalars mismatch" - ))), + Err(_) => Err(PyValueError::new_err(format!("Number of points and scalars mismatch"))), } } @@ -317,7 +342,7 @@ pub fn multiscalar_mul_g1(points: Vec, scalars: Vec) -> PyResu pub fn multiscalar_mul_g2(points: Vec, scalars: Vec) -> PyResult { let mut fr_scalars: Vec = vec![]; for scalar in scalars { - fr_scalars.push(Fr::from(scalar)) + fr_scalars.push(Fr::from(scalar)); } let mut affine_points: Vec = vec![]; for point in points { @@ -326,9 +351,7 @@ pub fn multiscalar_mul_g2(points: Vec, scalars: Vec) -> PyResu let r = G2Projective::msm(&affine_points, &fr_scalars); match r { Ok(r) => Ok(PointG2 { point: r }), - Err(_) => Err(PyValueError::new_err(format!( - "Number of points and scalars mismatch" - ))), + Err(_) => Err(PyValueError::new_err(format!("Number of points and scalars mismatch"))), } } @@ -367,10 +390,10 @@ pub fn multi_pairing(a: Vec, b: Vec) -> PyResult { let mut point1: Vec = vec![]; let mut point2: Vec = vec![]; for p in a { - point1.push(p.point) + point1.push(p.point); } for p in b { - point2.push(p.point) + point2.push(p.point); } Ok(PointG12 { point: Bn254::multi_pairing(point1, point2), diff --git a/tests/test_bulletproofs.py b/tests/test_bulletproofs.py new file mode 100644 index 0000000..4cdc304 --- /dev/null +++ b/tests/test_bulletproofs.py @@ -0,0 +1,58 @@ +import pytest + +from zksnake.ecc import EllipticCurve +from zksnake.symbolic import Symbol +from zksnake.r1cs import ConstraintSystem +from zksnake.bulletproofs import ipa, range_proof + + +def test_ipa_bn254(): + + a = [1,3,3,7] + b = [1,2,3,4] + + prover = ipa.Prover(8, 'BN254') + proof, comm = prover.prove(a, b) + + verifier = ipa.Verifier(8, 'BN254') + assert verifier.verify(proof, comm) + + +def test_ipa_bls12_381(): + + a = [1,3,3,7] + b = [1,2,3,4] + + prover = ipa.Prover(8, 'BLS12_381') + proof, comm = prover.prove(a, b) + + verifier = ipa.Verifier(8, 'BLS12_381') + assert verifier.verify(proof, comm) + +def test_range_proof_bn254(): + + prover = range_proof.Prover(32, 'BN254') + proof, comm = prover.prove(1337) + + verifier = range_proof.Verifier(32, 'BN254') + assert verifier.verify(proof, comm) + + prover = range_proof.Prover(8, 'BN254') + proof, comm = prover.prove(500) + + verifier = range_proof.Verifier(8, 'BN254') + assert not verifier.verify(proof, comm) + +def test_range_proof_bls12_381(): + + prover = range_proof.Prover(32, 'BLS12_381') + proof, comm = prover.prove(1337) + + verifier = range_proof.Verifier(32, 'BLS12_381') + assert verifier.verify(proof, comm) + + prover = range_proof.Prover(8, 'BLS12_381') + proof, comm = prover.prove(500) + + verifier = range_proof.Verifier(8, 'BLS12_381') + assert not verifier.verify(proof, comm) diff --git a/tests/test_r1cs_qap.py b/tests/test_r1cs_qap.py index 0999882..bc4dbdc 100644 --- a/tests/test_r1cs_qap.py +++ b/tests/test_r1cs_qap.py @@ -1,5 +1,6 @@ import pytest +from zksnake.groth16.qap import QAP from zksnake.symbolic import Symbol, SymbolArray from zksnake.r1cs import ConstraintSystem, ConstraintTemplate @@ -17,7 +18,10 @@ def test_basic_r1cs_bn254(): pub, priv = cs.solve({"x": 3}, {"y": 35}) - qap = cs.compile() + r1cs = cs.compile() + + qap = QAP() + qap.from_r1cs(r1cs) qap.evaluate_witness(pub + priv) @@ -33,7 +37,10 @@ def test_basic_r1cs_bls12_381(): cs.add_constraint(y - 5 - x == v1 * x) cs.set_public(y) - qap = cs.compile() + r1cs = cs.compile() + + qap = QAP() + qap.from_r1cs(r1cs) pub, priv = cs.solve({"x": 3}, {"y": 35}) @@ -60,7 +67,9 @@ def test_constraint_structure(): cs.set_public(y) - qap = cs.compile() + r1cs = cs.compile() + qap = QAP() + qap.from_r1cs(r1cs) pub, priv = cs.solve({"x": 3}) @@ -85,7 +94,9 @@ def test_r1cs_loop_constraint(): cs.add_constraint(out == v[n_power - 2]) cs.set_public(out) - qap = cs.compile() + r1cs = cs.compile() + qap = QAP() + qap.from_r1cs(r1cs) pub, priv = cs.solve({"inp": 2}, {"out": 2**n_power}) From 10ffb4efcc0187a9a15ec9c70002e72b489c962e Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 3 Nov 2024 20:43:38 +0700 Subject: [PATCH 02/13] Fix proof serialization --- python/zksnake/bulletproofs/ipa.py | 2 +- python/zksnake/bulletproofs/range_proof.py | 4 ++-- python/zksnake/utils.py | 1 - tests/test_bulletproofs.py | 26 ++++++++++++++++++++-- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/zksnake/bulletproofs/ipa.py b/python/zksnake/bulletproofs/ipa.py index 4bb0afe..f6cb321 100644 --- a/python/zksnake/bulletproofs/ipa.py +++ b/python/zksnake/bulletproofs/ipa.py @@ -28,7 +28,7 @@ def from_bytes(cls, s: bytes, crv="BN254"): E = EllipticCurve(crv) n = CurvePointSize[crv].value // 2 - assert len(s) % n == 0, "Invalid proof length" + assert (len(s)-64) % n == 0, "Invalid proof length" Ls = [] Rs = [] diff --git a/python/zksnake/bulletproofs/range_proof.py b/python/zksnake/bulletproofs/range_proof.py index 1533236..f9aacd9 100644 --- a/python/zksnake/bulletproofs/range_proof.py +++ b/python/zksnake/bulletproofs/range_proof.py @@ -35,7 +35,7 @@ def from_bytes(cls, s: bytes, crv="BN254"): E = EllipticCurve(crv) n = CurvePointSize[crv].value // 2 - assert len(s) % n == 0, "Invalid proof length" + assert (len(s)-160) % n == 0, "Invalid proof length" point_s = split_list(s[:4*n], n) field_s = split_list(s[4*n:4*n+32*3], 32) @@ -50,7 +50,7 @@ def from_bytes(cls, s: bytes, crv="BN254"): t = int.from_bytes(field_s[0], 'little') t_blinding = int.from_bytes(field_s[1], 'little') e_blinding = int.from_bytes(field_s[2], 'little') - ipa_proof = ipa.InnerProductProof.from_bytes(ipa_s) + ipa_proof = ipa.InnerProductProof.from_bytes(ipa_s, crv) return RangeProof(A, S, T1, T2, t, t_blinding, e_blinding, ipa_proof) diff --git a/python/zksnake/utils.py b/python/zksnake/utils.py index f523a2d..6ba2424 100644 --- a/python/zksnake/utils.py +++ b/python/zksnake/utils.py @@ -1,7 +1,6 @@ import os import random import time -import hashlib def get_random_int(n_max): """Get random integer in [1, n_max] range""" diff --git a/tests/test_bulletproofs.py b/tests/test_bulletproofs.py index 4cdc304..f0f7a43 100644 --- a/tests/test_bulletproofs.py +++ b/tests/test_bulletproofs.py @@ -14,8 +14,10 @@ def test_ipa_bn254(): prover = ipa.Prover(8, 'BN254') proof, comm = prover.prove(a, b) + proof = proof.to_bytes() + verifier = ipa.Verifier(8, 'BN254') - assert verifier.verify(proof, comm) + assert verifier.verify(ipa.InnerProductProof.from_bytes(proof), comm) def test_ipa_bls12_381(): @@ -26,8 +28,10 @@ def test_ipa_bls12_381(): prover = ipa.Prover(8, 'BLS12_381') proof, comm = prover.prove(a, b) + proof = proof.to_bytes() + verifier = ipa.Verifier(8, 'BLS12_381') - assert verifier.verify(proof, comm) + assert verifier.verify(ipa.InnerProductProof.from_bytes(proof, 'BLS12_381'), comm) def test_range_proof_bn254(): @@ -56,3 +60,21 @@ def test_range_proof_bls12_381(): verifier = range_proof.Verifier(8, 'BLS12_381') assert not verifier.verify(proof, comm) + +def test_range_proof_serialization(): + + prover = range_proof.Prover(32, 'BN254') + proof, comm = prover.prove(1337) + + proof = proof.to_bytes() + + verifier = range_proof.Verifier(32, 'BN254') + assert verifier.verify(range_proof.RangeProof.from_bytes(proof), comm) + + prover = range_proof.Prover(32, 'BLS12_381') + proof, comm = prover.prove(1337) + + proof = proof.to_bytes() + + verifier = range_proof.Verifier(32, 'BLS12_381') + assert verifier.verify(range_proof.RangeProof.from_bytes(proof, 'BLS12_381'), comm) \ No newline at end of file From 54bcb374aa244a4779e8bce65a1847c9e932a135 Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 3 Nov 2024 20:51:21 +0700 Subject: [PATCH 03/13] Update README.md --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index c14b99c..be55af0 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ zksnake currently only support **Groth16** proving scheme with `BN254` and `BLS1 ## Usage -### Build constraints into QAP +### Build constraints ```python from zksnake.symbolic import Symbol @@ -30,7 +30,7 @@ cs.add_constraint(v1 == x*x) cs.add_constraint(y - 5 - x == v1*x) cs.set_public(y) -qap = cs.compile() +r1cs = cs.compile() ``` Alternatively, you can import the constraints from [Circom](https://github.com/iden3/circom): @@ -39,7 +39,7 @@ Alternatively, you can import the constraints from [Circom](https://github.com/i from zksnake.r1cs import ConstraintSystem cs = ConstraintSystem.from_file("circuit.r1cs", "circuit.sym") -qap = cs.compile() +r1cs = cs.compile() ``` Note that some constraints that are complex or expensive (require off-circuit computation) cannot be imported directly and require you to add "hint" function to pre-define the variable value (see [Example](./examples/example_bitify_circom.py)). @@ -50,7 +50,7 @@ Note that some constraints that are complex or expensive (require off-circuit co from zksnake.groth16 import Setup # one time setup -setup = Setup(qap) +setup = Setup(r1cs) prover_key, verifier_key = setup.generate() ``` @@ -63,7 +63,7 @@ from zksnake.groth16 import Prover, Verifier public_witness, private_witness = cs.solve({'x': 3}, {'y': 35}) # proving -prover = Prover(qap, prover_key) +prover = Prover(r1cs, prover_key) proof = prover.prove(public_witness, private_witness) # verification From e24b611acae41225c7adce51f9c26213761f26ba Mon Sep 17 00:00:00 2001 From: Merricx Date: Mon, 4 Nov 2024 13:00:52 +0700 Subject: [PATCH 04/13] Update hash to curve function --- Cargo.lock | 87 ++++++++++++++++++++++ Cargo.toml | 1 + python/zksnake/bulletproofs/ipa.py | 4 +- python/zksnake/bulletproofs/range_proof.py | 18 ++--- python/zksnake/transcript.py | 28 ++----- src/bls12_381/curve.rs | 25 ++++++- src/bn254/curve.rs | 20 +++-- tests/test_bulletproofs.py | 4 - 8 files changed, 143 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 466535c..9fa7f42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -153,6 +153,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + [[package]] name = "bincode" version = "1.3.3" @@ -177,6 +183,24 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bn254_hash2curve" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4f26def0d8df9bf8661edd7992cfb9f83d65ce0e8f3f0877d63b99fc1644731" +dependencies = [ + "ark-bn254", + "ark-ec", + "ark-ff", + "digest", + "elliptic-curve", + "hex", + "num-bigint", + "num-integer", + "sha2", + "subtle", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -217,6 +241,18 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array", + "rand_core", + "subtle", + "zeroize", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -254,6 +290,32 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct", + "crypto-bigint", + "ff", + "generic-array", + "group", + "rand_core", + "subtle", + "zeroize", +] + +[[package]] +name = "ff" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +dependencies = [ + "rand_core", + "subtle", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -262,6 +324,18 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", + "zeroize", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core", + "subtle", ] [[package]] @@ -279,6 +353,12 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "indoc" version = "2.0.5" @@ -589,6 +669,12 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "1.0.109" @@ -757,6 +843,7 @@ dependencies = [ "ark-serialize", "ark-std", "bincode", + "bn254_hash2curve", "num-bigint", "pyo3", "rayon", diff --git a/Cargo.toml b/Cargo.toml index e60e85b..df5cadf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ serde = {version="1.0.200", features = ["derive"]} ark-serialize = { version = "0.4", features = ["derive"] } ark-bls12-381 = "0.4.0" sha2 = "0.10.8" +bn254_hash2curve = "0.1.2" [features] parallel = ["ark-ff/parallel", "ark-poly/parallel", "ark-ec/parallel", "ark-std/parallel"] diff --git a/python/zksnake/bulletproofs/ipa.py b/python/zksnake/bulletproofs/ipa.py index f6cb321..793f2ae 100644 --- a/python/zksnake/bulletproofs/ipa.py +++ b/python/zksnake/bulletproofs/ipa.py @@ -119,7 +119,7 @@ def prove(self, a: list, b: list): self.transcript.append(R.to_bytes()) u = hash_to_scalar( - self.transcript.get_challenge(), b'u', self.E.order) + self.transcript.get_challenge(), b'u', self.E.name) u_inv = pow(u, -1, self.E.order) u_list.append(u) @@ -174,7 +174,7 @@ def verify(self, proof: InnerProductProof, commitment): self.transcript.append(proof.R[i].to_bytes()) u = hash_to_scalar( - self.transcript.get_challenge(), b'u', self.E.order) + self.transcript.get_challenge(), b'u', self.E.name) challenges.append(pow(u, 2, self.E.order)) challenges_inv.append(pow(u, -2, self.E.order)) diff --git a/python/zksnake/bulletproofs/range_proof.py b/python/zksnake/bulletproofs/range_proof.py index f9aacd9..9ad56df 100644 --- a/python/zksnake/bulletproofs/range_proof.py +++ b/python/zksnake/bulletproofs/range_proof.py @@ -101,8 +101,8 @@ def prove(self, v: int): self.transcript.append(A.to_bytes()) self.transcript.append(S.to_bytes()) - y = hash_to_scalar(self.transcript.get_challenge(), b'y', self.E.order) - z = hash_to_scalar(self.transcript.get_challenge(), b'z', self.E.order) + y = hash_to_scalar(self.transcript.get_challenge(), b'y', self.E.name) + z = hash_to_scalar(self.transcript.get_challenge(), b'z', self.E.name) l_0 = [] l_1 = [] @@ -146,7 +146,7 @@ def prove(self, v: int): self.transcript.append(T1.to_bytes()) self.transcript.append(T2.to_bytes()) - x = hash_to_scalar(self.transcript.get_challenge(), b'x', self.E.order) + x = hash_to_scalar(self.transcript.get_challenge(), b'x', self.E.name) l_list = [poly(x) for poly in l_vecpoly] r_list = [poly(x) for poly in r_vecpoly] @@ -160,7 +160,7 @@ def prove(self, v: int): self.transcript.append(t_blinding) self.transcript.append(e_blinding) - w = hash_to_scalar(self.transcript.get_challenge(), b'w', self.E.order) + w = hash_to_scalar(self.transcript.get_challenge(), b'w', self.E.name) Q = w * self.B @@ -201,19 +201,19 @@ def verify(self, proof: RangeProof, commitment): self.transcript.append(proof.A.to_bytes()) self.transcript.append(proof.S.to_bytes()) - y = hash_to_scalar(self.transcript.get_challenge(), b'y', self.E.order) - z = hash_to_scalar(self.transcript.get_challenge(), b'z', self.E.order) + y = hash_to_scalar(self.transcript.get_challenge(), b'y', self.E.name) + z = hash_to_scalar(self.transcript.get_challenge(), b'z', self.E.name) self.transcript.append(proof.T1.to_bytes()) self.transcript.append(proof.T2.to_bytes()) - x = hash_to_scalar(self.transcript.get_challenge(), b'x', self.E.order) + x = hash_to_scalar(self.transcript.get_challenge(), b'x', self.E.name) self.transcript.append(proof.t) self.transcript.append(proof.t_blinding) self.transcript.append(proof.e_blinding) - w = hash_to_scalar(self.transcript.get_challenge(), b'w', self.E.order) + w = hash_to_scalar(self.transcript.get_challenge(), b'w', self.E.name) self.transcript.reset() @@ -235,7 +235,7 @@ def verify(self, proof: RangeProof, commitment): self.transcript.append(proof.ipa_proof.R[i].to_bytes()) u = hash_to_scalar( - self.transcript.get_challenge(), b'u', self.E.order) + self.transcript.get_challenge(), b'u', self.E.name) challenges.append(pow(u, 2, self.E.order)) challenges_inv.append(pow(u, -2, self.E.order)) diff --git a/python/zksnake/transcript.py b/python/zksnake/transcript.py index 14ce93d..c06d862 100644 --- a/python/zksnake/transcript.py +++ b/python/zksnake/transcript.py @@ -2,34 +2,22 @@ from .ecc import EllipticCurve -def hash_to_scalar(data: bytes, domain_separation_tag: bytes, modulus: int, alg: str = 'sha256'): - h = hashlib.new(alg) - h.update(domain_separation_tag) - h.update(data) - - return int.from_bytes(h.digest(), 'big') % modulus +def hash_to_scalar(data: bytes, domain_separation_tag: bytes, curve: str = 'BN254', alg: str = 'sha256'): + E = EllipticCurve(curve) + return E.curve.PointG1.hash_to_field(domain_separation_tag, data) def hash_to_curve(data: bytes, domain_separation_tag: bytes, curve: str = 'BN254', size: int = 1, alg: str = 'sha256'): E = EllipticCurve(curve) - h = hashlib.new(alg) - h.update(domain_separation_tag) - h.update(data) - points = [] for _ in range(size): - while True: - digest = h.digest() - h.update(digest) - - try: - point = E.curve.PointG1.hash_to_curve(digest) - points.append(point) - break - except ValueError: - pass + point = E.curve.PointG1.hash_to_curve(domain_separation_tag, data) + points.append(point) + + # TODO: might not be the best practice to chain hash + data = point.to_bytes() return points[0] if size == 1 else points diff --git a/src/bls12_381/curve.rs b/src/bls12_381/curve.rs index 782b076..82a9134 100644 --- a/src/bls12_381/curve.rs +++ b/src/bls12_381/curve.rs @@ -1,4 +1,13 @@ -use ark_bls12_381::{ g1::Config, Bls12_381, Fr, G1Affine, G1Projective, G2Affine, G2Projective }; +use ark_bls12_381::{ + g1::Config, + Bls12_381, + Fq, + Fr, + G1Affine, + G1Projective, + G2Affine, + G2Projective, +}; use ark_ec::{ hashing::{ curve_maps::wb::WBMap, map_to_curve_hasher::MapToCurveBasedHasher, HashToCurve }, pairing::{ Pairing, PairingOutput }, @@ -8,7 +17,7 @@ use ark_ec::{ Group, VariableBaseMSM, }; -use ark_ff::{ field_hashers::DefaultFieldHasher, QuadExtField, Zero }; +use ark_ff::{ field_hashers::{ DefaultFieldHasher, HashToField }, QuadExtField, Zero }; use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize }; use num_bigint::BigUint; use pyo3::{ exceptions::PyValueError, prelude::*, types::PyType }; @@ -132,14 +141,22 @@ impl PointG1 { } #[classmethod] - pub fn hash_to_curve(_cls: &PyType, data: Vec) -> PyResult { + pub fn hash_to_field(_cls: &PyType, dst: Vec, data: Vec) -> BigUint { + use sha2::Sha256; + let hasher = as HashToField>::new(&dst); + let x: Vec = hasher.hash_to_field(&data, 1); + x[0].into() + } + + #[classmethod] + pub fn hash_to_curve(_cls: &PyType, dst: Vec, data: Vec) -> PyResult { use sha2::Sha256; let hasher = MapToCurveBasedHasher::< Projective, DefaultFieldHasher, WBMap > - ::new(&[1]) + ::new(&dst) .unwrap(); let point = hasher.hash(&data).unwrap(); diff --git a/src/bn254/curve.rs b/src/bn254/curve.rs index e3fa2ce..4586075 100644 --- a/src/bn254/curve.rs +++ b/src/bn254/curve.rs @@ -1,10 +1,12 @@ -use ark_bn254::{ Bn254, Fr, G1Affine, G1Projective, G2Affine, G2Projective }; +use ark_bn254::{ Bn254, Fq, Fr, G1Affine, G1Projective, G2Affine, G2Projective }; use ark_ec::{ pairing::{ Pairing, PairingOutput }, AffineRepr, CurveGroup, Group, VariableBaseMSM }; -use ark_ff::{ QuadExtField, Zero }; +use ark_ff::{ field_hashers::{ DefaultFieldHasher, HashToField }, QuadExtField, Zero }; use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize }; +use bn254_hash2curve::hash2g1::HashToG1; use num_bigint::BigUint; use pyo3::{ exceptions::PyValueError, prelude::*, types::PyType }; use rayon::iter::{ IntoParallelIterator, ParallelIterator }; +use sha2::Sha256; #[pyclass] #[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)] @@ -131,9 +133,17 @@ impl PointG1 { } #[classmethod] - pub fn hash_to_curve(_cls: &PyType, data: Vec) -> PyResult { - let x = BigUint::from_bytes_be(&data); - Self::from_x(_cls, x) + pub fn hash_to_field(_cls: &PyType, dst: Vec, data: Vec) -> BigUint { + let hasher = as HashToField>::new(&dst); + let x: Vec = hasher.hash_to_field(&data, 1); + x[0].into() + } + + #[classmethod] + pub fn hash_to_curve(_cls: &PyType, dst: Vec, data: Vec) -> PyResult { + let point = HashToG1(&data, &dst); + + Ok(PointG1 { point: point.into() }) } #[classmethod] diff --git a/tests/test_bulletproofs.py b/tests/test_bulletproofs.py index f0f7a43..75393fe 100644 --- a/tests/test_bulletproofs.py +++ b/tests/test_bulletproofs.py @@ -1,8 +1,4 @@ import pytest - -from zksnake.ecc import EllipticCurve -from zksnake.symbolic import Symbol -from zksnake.r1cs import ConstraintSystem from zksnake.bulletproofs import ipa, range_proof From 8273fd1d684bb18f1cb16260452195ea28578bbc Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 17 Nov 2024 09:45:40 +0700 Subject: [PATCH 05/13] Update transcript type error --- python/zksnake/transcript.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/zksnake/transcript.py b/python/zksnake/transcript.py index c06d862..1d785c8 100644 --- a/python/zksnake/transcript.py +++ b/python/zksnake/transcript.py @@ -2,12 +2,12 @@ from .ecc import EllipticCurve -def hash_to_scalar(data: bytes, domain_separation_tag: bytes, curve: str = 'BN254', alg: str = 'sha256'): +def hash_to_scalar(data: bytes, domain_separation_tag: bytes, curve: str = 'BN254'): E = EllipticCurve(curve) return E.curve.PointG1.hash_to_field(domain_separation_tag, data) -def hash_to_curve(data: bytes, domain_separation_tag: bytes, curve: str = 'BN254', size: int = 1, alg: str = 'sha256'): +def hash_to_curve(data: bytes, domain_separation_tag: bytes, curve: str = 'BN254', size: int = 1): E = EllipticCurve(curve) @@ -43,6 +43,8 @@ def append(self, data): self.hasher.update(data) elif data and isinstance(data, list) and isinstance(data[0], int): self.hasher.update(bytes(data)) + else: + raise TypeError(f"Type of {type(data)} is not supported as transcript") def get_challenge(self): digest = self.hasher.digest() From 1533529e36d0f4bd32cf98b53272d72f3edba61c Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 17 Nov 2024 10:00:43 +0700 Subject: [PATCH 06/13] Update unittest to use Ubuntu 20.04 --- .github/workflows/CI.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 65c9a59..6e32f32 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -59,10 +59,10 @@ jobs: pytest - name: pytest if: ${{ !startsWith(matrix.platform.target, 'x86') && matrix.platform.target != 'ppc64' }} - uses: uraimo/run-on-arch-action@v2.5.0 + uses: uraimo/run-on-arch-action@v2 with: arch: ${{ matrix.platform.target }} - distro: ubuntu22.04 + distro: ubuntu20.04 githubToken: ${{ github.token }} install: | apt-get update From 252b437ca77f2492bfa2106ae4b2e577bf26ceb3 Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 17 Nov 2024 10:45:58 +0700 Subject: [PATCH 07/13] Add manual modification in aarch64 CI --- .github/workflows/CI.yml | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6e32f32..446e393 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -35,7 +35,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.x" - name: Build wheels uses: PyO3/maturin-action@v1 with: @@ -64,10 +64,17 @@ jobs: arch: ${{ matrix.platform.target }} distro: ubuntu20.04 githubToken: ${{ github.token }} + # Copied from https://github.com/codecov/codecov-rs/blob/main/.github/workflows/publish.yml install: | apt-get update - apt-get install -y --no-install-recommends python3 python3-pip - pip3 install -U pip pytest pylint + apt-get install -y gnupg ca-certificates + echo "deb https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy main" >> /etc/apt/sources.list.d/deadsnakes.list + apt-key adv --keyserver keyserver.ubuntu.com --recv-keys F23C5A6CF475977595C89F51BA6932366A755776 + apt-get update + apt-get install -y --no-install-recommends python3.12 python3.12-venv python3-pip + python3.12 -m venv /venv + source /venv/bin/activate + pip install -U pip pytest pylint run: | set -e pip3 install zksnake --find-links dist --force-reinstall @@ -87,7 +94,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.x" architecture: ${{ matrix.platform.target }} - name: Build wheels uses: PyO3/maturin-action@v1 @@ -123,7 +130,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.x" - name: Build wheels uses: PyO3/maturin-action@v1 with: From 181b3e97f4c9239c80be531ba41c08d40da4e967 Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 17 Nov 2024 10:53:36 +0700 Subject: [PATCH 08/13] Revert to ubuntu 22 --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 446e393..26f4dec 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -62,7 +62,7 @@ jobs: uses: uraimo/run-on-arch-action@v2 with: arch: ${{ matrix.platform.target }} - distro: ubuntu20.04 + distro: ubuntu22.04 githubToken: ${{ github.token }} # Copied from https://github.com/codecov/codecov-rs/blob/main/.github/workflows/publish.yml install: | From 153205fcfacd20b8df8469c17f5d4e75e4366d6a Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 17 Nov 2024 11:01:57 +0700 Subject: [PATCH 09/13] Remove pylint in aarch64 --- .github/workflows/CI.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 26f4dec..88e74c5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -74,11 +74,10 @@ jobs: apt-get install -y --no-install-recommends python3.12 python3.12-venv python3-pip python3.12 -m venv /venv source /venv/bin/activate - pip install -U pip pytest pylint + pip install -U pip pytest run: | set -e pip3 install zksnake --find-links dist --force-reinstall - pylint --disable=R,C,fixme,import-error python pytest windows: From 9a3bfcad8bcce06bd78562cd685937783591c896 Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 17 Nov 2024 11:41:44 +0700 Subject: [PATCH 10/13] Change pip syntax --- .github/workflows/CI.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 88e74c5..172d204 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -74,10 +74,11 @@ jobs: apt-get install -y --no-install-recommends python3.12 python3.12-venv python3-pip python3.12 -m venv /venv source /venv/bin/activate - pip install -U pip pytest + pip3 install pytest pylint run: | set -e pip3 install zksnake --find-links dist --force-reinstall + pylint --disable=R,C,fixme,import-error python pytest windows: From 2ea57c53f3c8015bc6f623317192f13fbdf6eda8 Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 17 Nov 2024 12:22:16 +0700 Subject: [PATCH 11/13] Fix python syntax --- .github/workflows/CI.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 172d204..6439176 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -78,8 +78,8 @@ jobs: run: | set -e pip3 install zksnake --find-links dist --force-reinstall - pylint --disable=R,C,fixme,import-error python - pytest + python3 -m pylint --disable=R,C,fixme,import-error python + python3 -m pytest windows: runs-on: ${{ matrix.platform.runner }} From 5c7268f788488350327fbbfd9dc4565b6f8f4f0e Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 17 Nov 2024 12:32:08 +0700 Subject: [PATCH 12/13] Add missing venv --- .github/workflows/CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6439176..0887a01 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -77,6 +77,7 @@ jobs: pip3 install pytest pylint run: | set -e + source /venv/bin/activate pip3 install zksnake --find-links dist --force-reinstall python3 -m pylint --disable=R,C,fixme,import-error python python3 -m pytest From baaa848a50d73efa7bb41e05aceae8b2524285f4 Mon Sep 17 00:00:00 2001 From: Merricx Date: Sun, 17 Nov 2024 12:49:42 +0700 Subject: [PATCH 13/13] Reduce n_bit in test to avoid recursion limit --- tests/test_r1cs_qap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_r1cs_qap.py b/tests/test_r1cs_qap.py index bc4dbdc..ebcf4f6 100644 --- a/tests/test_r1cs_qap.py +++ b/tests/test_r1cs_qap.py @@ -204,7 +204,7 @@ def main(self, *args): f = lambda x, i: (x >> i) & 1 self.add_hint(f, b, (inp, i)) - n_bit = 256 + n_bit = 128 inp = Symbol("i") bits = [] out = SymbolArray("bit", n_bit)