From d8c6b6dcc807eb5cf8dc5583eccd3656b49b3985 Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 4 Feb 2022 11:54:15 +0100 Subject: [PATCH 01/37] perf(std/tEd): first bit in ScalarMul handled separately --- std/algebra/twistededwards/point.go | 12 ++++++++++-- std/algebra/twistededwards/point_test.go | 19 +++++++++++++++++++ std/signature/eddsa/eddsa_test.go | 8 ++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index 7faf85faf4..d6ae374da7 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -133,7 +133,11 @@ func (p *Point) ScalarMulNonFixedBase(api frontend.API, p1 *Point, scalar fronte 1, } - for i := len(b) - 1; i >= 0; i-- { + n := len(b) - 1 + res.X = api.Select(b[n], p1.X, res.X) + res.Y = api.Select(b[n], p1.Y, res.Y) + + for i := len(b) - 2; i >= 0; i-- { res.Double(api, &res, curve) tmp := Point{} tmp.AddGeneric(api, &res, p1, curve) @@ -161,7 +165,11 @@ func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar fr 1, } - for i := len(b) - 1; i >= 0; i-- { + n := len(b) - 1 + res.X = api.Select(b[n], x, res.X) + res.Y = api.Select(b[n], y, res.Y) + + for i := len(b) - 2; i >= 0; i-- { res.Double(api, &res, curve) tmp := Point{} tmp.AddFixedPoint(api, &res, x, y, curve) diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index b8f9d1959b..b46a025104 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -551,3 +551,22 @@ func TestNeg(t *testing.T) { assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254)) } + +// Bench +func BenchmarkDouble(b *testing.B) { + var c double + ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} + +func BenchmarkAddGeneric(b *testing.B) { + var c addGeneric + ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} + +func BenchmarkAddFixedPoint(b *testing.B) { + var c add + ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} diff --git a/std/signature/eddsa/eddsa_test.go b/std/signature/eddsa/eddsa_test.go index 84636c0a4b..ecacd505e4 100644 --- a/std/signature/eddsa/eddsa_test.go +++ b/std/signature/eddsa/eddsa_test.go @@ -36,6 +36,7 @@ import ( eddsabw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards/eddsa" "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark-crypto/signature" + "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/twistededwards" "github.com/consensys/gnark/test" @@ -252,3 +253,10 @@ func TestEddsa(t *testing.T) { } } + +// Bench +func BenchmarkEdDSA(b *testing.B) { + var c eddsaCircuit + ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} From 77ba09bd94c40d097f81e1bdac14b7d0eb3b383c Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 4 Feb 2022 12:34:15 +0100 Subject: [PATCH 02/37] perf(std/tEd): rearrange Double --> less constraints --- std/algebra/twistededwards/point.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index d6ae374da7..722aeb8298 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -103,14 +103,12 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { u := api.Mul(p1.X, p1.Y) v := api.Mul(p1.X, p1.X) w := api.Mul(p1.Y, p1.Y) - z := api.Mul(v, w) n1 := api.Mul(2, u) av := api.Mul(v, &curve.A) n2 := api.Sub(w, av) - d := api.Mul(z, curve.D) - d1 := api.Add(1, d) - d2 := api.Sub(1, d) + d1 := api.Add(w, av) + d2 := api.Sub(2, d1) p.X = api.DivUnchecked(n1, d1) p.Y = api.DivUnchecked(n2, d2) From a3c37ae0384b4de02276a6b2cf4085d7de9f6a0d Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 4 Feb 2022 13:16:11 +0100 Subject: [PATCH 03/37] perf(std/EdDSA): rearrange eddsa verify (-1 addtion, -1 MustBeOnCurve) --- std/algebra/twistededwards/point_test.go | 6 ++++++ std/signature/eddsa/eddsa.go | 17 ++++++++--------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index b46a025104..f36305abad 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -570,3 +570,9 @@ func BenchmarkAddFixedPoint(b *testing.B) { ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) b.Log("groth16", ccsBench.GetNbConstraints()) } + +func BenchmarkMustBeOnCurve(b *testing.B) { + var c mustBeOnCurve + ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go index c46b0273c3..7cffff0b8b 100644 --- a/std/signature/eddsa/eddsa.go +++ b/std/signature/eddsa/eddsa.go @@ -58,7 +58,6 @@ func Verify(api frontend.API, sig Signature, msg frontend.Variable, pubKey Publi return err } hash.Write(data...) - //hramConstant := hash.Sum(data...) hramConstant := hash.Sum() // lhs = [S]G @@ -71,24 +70,24 @@ func Verify(api frontend.API, sig Signature, msg frontend.Variable, pubKey Publi rhs := twistededwards.Point{} rhs.ScalarMulNonFixedBase(api, &pubKey.A, hramConstant, pubKey.Curve). AddGeneric(api, &rhs, &sig.R, pubKey.Curve) - rhs.MustBeOnCurve(api, pubKey.Curve) + // rhs.MustBeOnCurve(api, pubKey.Curve) - // lhs-rhs - rhs.Neg(api, &rhs).AddGeneric(api, &lhs, &rhs, pubKey.Curve) - - // [cofactor](lhs-rhs) + // [cofactor]*lhs and [cofactor]*rhs switch cofactor { case 4: rhs.Double(api, &rhs, pubKey.Curve). Double(api, &rhs, pubKey.Curve) + lhs.Double(api, &lhs, pubKey.Curve). + Double(api, &lhs, pubKey.Curve) case 8: rhs.Double(api, &rhs, pubKey.Curve). Double(api, &rhs, pubKey.Curve).Double(api, &rhs, pubKey.Curve) + lhs.Double(api, &lhs, pubKey.Curve). + Double(api, &lhs, pubKey.Curve).Double(api, &lhs, pubKey.Curve) } - //rhs.MustBeOnCurve(api, pubKey.Curve) - api.AssertIsEqual(rhs.X, 0) - api.AssertIsEqual(rhs.Y, 1) + api.AssertIsEqual(rhs.X, lhs.X) + api.AssertIsEqual(rhs.Y, lhs.Y) return nil } From 936bd06c0d8be480ffea83caa691ed4fe7d9a635 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Fri, 4 Feb 2022 14:38:38 +0100 Subject: [PATCH 04/37] feat: groth16 prover adapted to new fft OK --- go.mod | 2 +- go.sum | 4 ++ .../backend/bls12-377/groth16/marshal_test.go | 2 +- internal/backend/bls12-377/groth16/prove.go | 24 +++++----- internal/backend/bls12-377/groth16/setup.go | 4 +- .../backend/bls12-377/plonk/marshal_test.go | 4 +- internal/backend/bls12-377/plonk/prove.go | 46 +++++------------- internal/backend/bls12-377/plonk/setup.go | 30 ++++++------ .../backend/bls12-381/groth16/marshal_test.go | 2 +- internal/backend/bls12-381/groth16/prove.go | 24 +++++----- internal/backend/bls12-381/groth16/setup.go | 4 +- .../backend/bls12-381/plonk/marshal_test.go | 4 +- internal/backend/bls12-381/plonk/prove.go | 46 +++++------------- internal/backend/bls12-381/plonk/setup.go | 30 ++++++------ .../backend/bls24-315/groth16/marshal_test.go | 2 +- internal/backend/bls24-315/groth16/prove.go | 24 +++++----- internal/backend/bls24-315/groth16/setup.go | 4 +- .../backend/bls24-315/plonk/marshal_test.go | 4 +- internal/backend/bls24-315/plonk/prove.go | 46 +++++------------- internal/backend/bls24-315/plonk/setup.go | 30 ++++++------ .../backend/bn254/groth16/marshal_test.go | 2 +- internal/backend/bn254/groth16/prove.go | 24 +++++----- internal/backend/bn254/groth16/setup.go | 4 +- internal/backend/bn254/plonk/marshal_test.go | 4 +- internal/backend/bn254/plonk/prove.go | 46 +++++------------- internal/backend/bn254/plonk/setup.go | 30 ++++++------ .../backend/bw6-633/groth16/marshal_test.go | 2 +- internal/backend/bw6-633/groth16/prove.go | 24 +++++----- internal/backend/bw6-633/groth16/setup.go | 4 +- .../backend/bw6-633/plonk/marshal_test.go | 4 +- internal/backend/bw6-633/plonk/prove.go | 46 +++++------------- internal/backend/bw6-633/plonk/setup.go | 30 ++++++------ .../backend/bw6-761/groth16/marshal_test.go | 2 +- internal/backend/bw6-761/groth16/prove.go | 24 +++++----- internal/backend/bw6-761/groth16/setup.go | 4 +- .../backend/bw6-761/plonk/marshal_test.go | 4 +- internal/backend/bw6-761/plonk/prove.go | 46 +++++------------- internal/backend/bw6-761/plonk/setup.go | 30 ++++++------ .../zkpschemes/groth16/groth16.prove.go.tmpl | 24 +++++----- .../zkpschemes/groth16/groth16.setup.go.tmpl | 4 +- .../groth16/tests/groth16.marshal.go.tmpl | 2 +- .../zkpschemes/plonk/plonk.prove.go.tmpl | 47 ++++++------------- .../zkpschemes/plonk/plonk.setup.go.tmpl | 30 ++++++------ .../zkpschemes/plonk/tests/marshal.go.tmpl | 4 +- 44 files changed, 321 insertions(+), 456 deletions(-) diff --git a/go.mod b/go.mod index 8ae200dc61..56ce699366 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a - github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969 + github.com/consensys/gnark-crypto v0.6.1-0.20220204095423-2fb0ec48a36f github.com/fxamacker/cbor/v2 v2.2.0 github.com/leanovate/gopter v0.2.9 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 6b59c42b75..f6c42236c3 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,10 @@ github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a h1:AEpwbXTjBGKo github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969 h1:SPKwScbSTdl2p+QvJulHumGa2+5FO6RPh857TCPxda0= github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= +github.com/consensys/gnark-crypto v0.6.1-0.20220203135532-a5667210247a h1:Jfr3vYmkw4xxWvNAnavhGiN0pVyhmpPer5sq1zFJFAk= +github.com/consensys/gnark-crypto v0.6.1-0.20220203135532-a5667210247a/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= +github.com/consensys/gnark-crypto v0.6.1-0.20220204095423-2fb0ec48a36f h1:55DRDYCFD64OIJh/Yz1Bch9Va14lwKgA/xk0n8JUIjE= +github.com/consensys/gnark-crypto v0.6.1-0.20220204095423-2fb0ec48a36f/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/backend/bls12-377/groth16/marshal_test.go b/internal/backend/bls12-377/groth16/marshal_test.go index 0d00cba75e..901d6f885f 100644 --- a/internal/backend/bls12-377/groth16/marshal_test.go +++ b/internal/backend/bls12-377/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bls12-377/groth16/prove.go b/internal/backend/bls12-377/groth16/prove.go index 16a1b7ac32..634c4b58e7 100644 --- a/internal/backend/bls12-377/groth16/prove.go +++ b/internal/backend/bls12-377/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bls12-377/groth16/setup.go b/internal/backend/bls12-377/groth16/setup.go index 8db0c6ac8c..21fe78713c 100644 --- a/internal/backend/bls12-377/groth16/setup.go +++ b/internal/backend/bls12-377/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bls12-377/plonk/marshal_test.go b/internal/backend/bls12-377/plonk/marshal_test.go index 04933096c5..c7179f96cc 100644 --- a/internal/backend/bls12-377/plonk/marshal_test.go +++ b/internal/backend/bls12-377/plonk/marshal_test.go @@ -48,8 +48,8 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) + pk.DomainNum = *fft.NewDomain(42) + pk.DomainH = *fft.NewDomain(4 * 42) pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go index 7c8cb49ce7..667deea130 100644 --- a/internal/backend/bls12-377/plonk/prove.go +++ b/internal/backend/bls12-377/plonk/prove.go @@ -166,7 +166,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) copy(qk, fullWitness[:spr.NbPublicVariables]) copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) + pk.DomainNum.FFTInverse(qk, fft.DIF) fft.BitReverse(qk) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -401,7 +401,7 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +409,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -552,7 +552,7 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.DomainNum.FFTInverse(z, fft.DIF) fft.BitReverse(z) return blindPoly(z, pk.DomainNum.Cardinality, 2) @@ -608,16 +608,17 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) +// evalIDCosets id, uid, u**2id on (Z/4mZ) func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { id = make([]fr.Element, pk.DomainH.Cardinality) + // TODO doing an expo per chunk is useless utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var acc fr.Element acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) + id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) acc.Mul(&acc, &pk.DomainH.Generator) } }) @@ -707,18 +708,10 @@ func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomia // // Puts the result in res of size n. // Warning: result is in bit reversed order, we do a bit reverse operation only once in computeH +// TODO remove this function func evaluateHDomain(poly []fr.Element, domainH *fft.Domain) []fr.Element { - res := make([]fr.Element, domainH.Cardinality) - - // we copy poly in res and scale by coset here - // to avoid FFT scaling on domainH.Cardinality (res is very sparse) - utils.Parallelize(len(poly), func(start, end int) { - for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainH.CosetTable[0][i]) - } - }, runtime.NumCPU()/2) - domainH.FFT(res, fft.DIF, 0) + domainH.FFT(res, fft.DIF, true) return res } @@ -742,7 +735,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom var one fr.Element one.SetOne() uu.Set(&pk.DomainH.Generator) - u[0].Set(&pk.DomainH.FinerGenerator) + u[0].Set(&pk.DomainH.FrMultiplicativeGen) u[1].Mul(&u[0], &uu) u[2].Mul(&u[1], &uu) u[3].Mul(&u[2], &uu) @@ -768,15 +761,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computes L1 (canonical form) startsAtOne := make(polynomial.Polynomial, pk.DomainH.Cardinality) - utils.Parallelize(int(pk.DomainNum.Cardinality), func(start, end int) { - for i := start; i < end; i++ { - startsAtOne[i].Mul(&pk.DomainNum.CardinalityInv, &pk.DomainH.CosetTable[0][i]) - } - }) - - // evaluates L1 on the odd cosets of (Z/8mZ)/(Z/mZ) - // / ! \ note that we scaled by the coset in the previous loop, hence we pass 0 as coset here. - pk.DomainH.FFT(startsAtOne, fft.DIF, 0) + pk.DomainH.FFT(startsAtOne, fft.DIF, true) // evaluate qlL+qrR+qmL.R+qoO+k + alpha.(zu*g1*g2*g3*l-z*f1*f2*f3*l) + alpha**2*L1(X)(Z(X)-1) // on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -802,12 +787,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.DomainH.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding h1 := h[:pk.DomainNum.Cardinality+2] diff --git a/internal/backend/bls12-377/plonk/setup.go b/internal/backend/bls12-377/plonk/setup.go index 774e1be88a..5f5652e3b2 100644 --- a/internal/backend/bls12-377/plonk/setup.go +++ b/internal/backend/bls12-377/plonk/setup.go @@ -96,15 +96,15 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.DomainNum = *fft.NewDomain(sizeSystem) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(4 * sizeSystem) } vk.Size = pk.DomainNum.Cardinality @@ -113,8 +113,8 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) vk.NbPublicVariables = uint64(spr.NbPublicVariables) // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) + vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) + vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) if err := pk.InitKZG(srs); err != nil { return nil, nil, err @@ -148,11 +148,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) + pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -274,8 +274,8 @@ func computeLDE(pk *ProvingKey) { // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] sID := make([]fr.Element, 3*nbElmt) sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) + sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) + sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) for i := 1; i < nbElmt; i++ { sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 @@ -300,9 +300,9 @@ func computeLDE(pk *ProvingKey) { copy(pk.CS1, pk.LS1) copy(pk.CS2, pk.LS2) copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) fft.BitReverse(pk.CS1) fft.BitReverse(pk.CS2) fft.BitReverse(pk.CS3) diff --git a/internal/backend/bls12-381/groth16/marshal_test.go b/internal/backend/bls12-381/groth16/marshal_test.go index 38c1e8b038..990e3ee6b1 100644 --- a/internal/backend/bls12-381/groth16/marshal_test.go +++ b/internal/backend/bls12-381/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bls12-381/groth16/prove.go b/internal/backend/bls12-381/groth16/prove.go index fe52aafb2d..a8e7767d4e 100644 --- a/internal/backend/bls12-381/groth16/prove.go +++ b/internal/backend/bls12-381/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bls12-381/groth16/setup.go b/internal/backend/bls12-381/groth16/setup.go index 1daf2c2f69..d481f8ad0d 100644 --- a/internal/backend/bls12-381/groth16/setup.go +++ b/internal/backend/bls12-381/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bls12-381/plonk/marshal_test.go b/internal/backend/bls12-381/plonk/marshal_test.go index 9076e2280c..3671c52294 100644 --- a/internal/backend/bls12-381/plonk/marshal_test.go +++ b/internal/backend/bls12-381/plonk/marshal_test.go @@ -48,8 +48,8 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) + pk.DomainNum = *fft.NewDomain(42) + pk.DomainH = *fft.NewDomain(4 * 42) pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go index 5f9dadb7bb..a26d234bfe 100644 --- a/internal/backend/bls12-381/plonk/prove.go +++ b/internal/backend/bls12-381/plonk/prove.go @@ -166,7 +166,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) copy(qk, fullWitness[:spr.NbPublicVariables]) copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) + pk.DomainNum.FFTInverse(qk, fft.DIF) fft.BitReverse(qk) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -401,7 +401,7 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +409,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -552,7 +552,7 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.DomainNum.FFTInverse(z, fft.DIF) fft.BitReverse(z) return blindPoly(z, pk.DomainNum.Cardinality, 2) @@ -608,16 +608,17 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) +// evalIDCosets id, uid, u**2id on (Z/4mZ) func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { id = make([]fr.Element, pk.DomainH.Cardinality) + // TODO doing an expo per chunk is useless utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var acc fr.Element acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) + id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) acc.Mul(&acc, &pk.DomainH.Generator) } }) @@ -707,18 +708,10 @@ func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomia // // Puts the result in res of size n. // Warning: result is in bit reversed order, we do a bit reverse operation only once in computeH +// TODO remove this function func evaluateHDomain(poly []fr.Element, domainH *fft.Domain) []fr.Element { - res := make([]fr.Element, domainH.Cardinality) - - // we copy poly in res and scale by coset here - // to avoid FFT scaling on domainH.Cardinality (res is very sparse) - utils.Parallelize(len(poly), func(start, end int) { - for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainH.CosetTable[0][i]) - } - }, runtime.NumCPU()/2) - domainH.FFT(res, fft.DIF, 0) + domainH.FFT(res, fft.DIF, true) return res } @@ -742,7 +735,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom var one fr.Element one.SetOne() uu.Set(&pk.DomainH.Generator) - u[0].Set(&pk.DomainH.FinerGenerator) + u[0].Set(&pk.DomainH.FrMultiplicativeGen) u[1].Mul(&u[0], &uu) u[2].Mul(&u[1], &uu) u[3].Mul(&u[2], &uu) @@ -768,15 +761,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computes L1 (canonical form) startsAtOne := make(polynomial.Polynomial, pk.DomainH.Cardinality) - utils.Parallelize(int(pk.DomainNum.Cardinality), func(start, end int) { - for i := start; i < end; i++ { - startsAtOne[i].Mul(&pk.DomainNum.CardinalityInv, &pk.DomainH.CosetTable[0][i]) - } - }) - - // evaluates L1 on the odd cosets of (Z/8mZ)/(Z/mZ) - // / ! \ note that we scaled by the coset in the previous loop, hence we pass 0 as coset here. - pk.DomainH.FFT(startsAtOne, fft.DIF, 0) + pk.DomainH.FFT(startsAtOne, fft.DIF, true) // evaluate qlL+qrR+qmL.R+qoO+k + alpha.(zu*g1*g2*g3*l-z*f1*f2*f3*l) + alpha**2*L1(X)(Z(X)-1) // on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -802,12 +787,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.DomainH.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding h1 := h[:pk.DomainNum.Cardinality+2] diff --git a/internal/backend/bls12-381/plonk/setup.go b/internal/backend/bls12-381/plonk/setup.go index 27dfd6868d..10a59218fe 100644 --- a/internal/backend/bls12-381/plonk/setup.go +++ b/internal/backend/bls12-381/plonk/setup.go @@ -96,15 +96,15 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.DomainNum = *fft.NewDomain(sizeSystem) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(4 * sizeSystem) } vk.Size = pk.DomainNum.Cardinality @@ -113,8 +113,8 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) vk.NbPublicVariables = uint64(spr.NbPublicVariables) // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) + vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) + vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) if err := pk.InitKZG(srs); err != nil { return nil, nil, err @@ -148,11 +148,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) + pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -274,8 +274,8 @@ func computeLDE(pk *ProvingKey) { // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] sID := make([]fr.Element, 3*nbElmt) sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) + sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) + sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) for i := 1; i < nbElmt; i++ { sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 @@ -300,9 +300,9 @@ func computeLDE(pk *ProvingKey) { copy(pk.CS1, pk.LS1) copy(pk.CS2, pk.LS2) copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) fft.BitReverse(pk.CS1) fft.BitReverse(pk.CS2) fft.BitReverse(pk.CS3) diff --git a/internal/backend/bls24-315/groth16/marshal_test.go b/internal/backend/bls24-315/groth16/marshal_test.go index aa2ceda799..06bf1ec9de 100644 --- a/internal/backend/bls24-315/groth16/marshal_test.go +++ b/internal/backend/bls24-315/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bls24-315/groth16/prove.go b/internal/backend/bls24-315/groth16/prove.go index 606f9e07f9..fbc67eb021 100644 --- a/internal/backend/bls24-315/groth16/prove.go +++ b/internal/backend/bls24-315/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bls24-315/groth16/setup.go b/internal/backend/bls24-315/groth16/setup.go index e34b8d752b..596e7786fc 100644 --- a/internal/backend/bls24-315/groth16/setup.go +++ b/internal/backend/bls24-315/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bls24-315/plonk/marshal_test.go b/internal/backend/bls24-315/plonk/marshal_test.go index c24928ccf8..eb449b3d2b 100644 --- a/internal/backend/bls24-315/plonk/marshal_test.go +++ b/internal/backend/bls24-315/plonk/marshal_test.go @@ -48,8 +48,8 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) + pk.DomainNum = *fft.NewDomain(42) + pk.DomainH = *fft.NewDomain(4 * 42) pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go index 6f951143ba..15dfe7ea24 100644 --- a/internal/backend/bls24-315/plonk/prove.go +++ b/internal/backend/bls24-315/plonk/prove.go @@ -166,7 +166,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) copy(qk, fullWitness[:spr.NbPublicVariables]) copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) + pk.DomainNum.FFTInverse(qk, fft.DIF) fft.BitReverse(qk) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -401,7 +401,7 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +409,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -552,7 +552,7 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.DomainNum.FFTInverse(z, fft.DIF) fft.BitReverse(z) return blindPoly(z, pk.DomainNum.Cardinality, 2) @@ -608,16 +608,17 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) +// evalIDCosets id, uid, u**2id on (Z/4mZ) func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { id = make([]fr.Element, pk.DomainH.Cardinality) + // TODO doing an expo per chunk is useless utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var acc fr.Element acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) + id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) acc.Mul(&acc, &pk.DomainH.Generator) } }) @@ -707,18 +708,10 @@ func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomia // // Puts the result in res of size n. // Warning: result is in bit reversed order, we do a bit reverse operation only once in computeH +// TODO remove this function func evaluateHDomain(poly []fr.Element, domainH *fft.Domain) []fr.Element { - res := make([]fr.Element, domainH.Cardinality) - - // we copy poly in res and scale by coset here - // to avoid FFT scaling on domainH.Cardinality (res is very sparse) - utils.Parallelize(len(poly), func(start, end int) { - for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainH.CosetTable[0][i]) - } - }, runtime.NumCPU()/2) - domainH.FFT(res, fft.DIF, 0) + domainH.FFT(res, fft.DIF, true) return res } @@ -742,7 +735,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom var one fr.Element one.SetOne() uu.Set(&pk.DomainH.Generator) - u[0].Set(&pk.DomainH.FinerGenerator) + u[0].Set(&pk.DomainH.FrMultiplicativeGen) u[1].Mul(&u[0], &uu) u[2].Mul(&u[1], &uu) u[3].Mul(&u[2], &uu) @@ -768,15 +761,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computes L1 (canonical form) startsAtOne := make(polynomial.Polynomial, pk.DomainH.Cardinality) - utils.Parallelize(int(pk.DomainNum.Cardinality), func(start, end int) { - for i := start; i < end; i++ { - startsAtOne[i].Mul(&pk.DomainNum.CardinalityInv, &pk.DomainH.CosetTable[0][i]) - } - }) - - // evaluates L1 on the odd cosets of (Z/8mZ)/(Z/mZ) - // / ! \ note that we scaled by the coset in the previous loop, hence we pass 0 as coset here. - pk.DomainH.FFT(startsAtOne, fft.DIF, 0) + pk.DomainH.FFT(startsAtOne, fft.DIF, true) // evaluate qlL+qrR+qmL.R+qoO+k + alpha.(zu*g1*g2*g3*l-z*f1*f2*f3*l) + alpha**2*L1(X)(Z(X)-1) // on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -802,12 +787,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.DomainH.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding h1 := h[:pk.DomainNum.Cardinality+2] diff --git a/internal/backend/bls24-315/plonk/setup.go b/internal/backend/bls24-315/plonk/setup.go index 8327805f3a..184c3a5677 100644 --- a/internal/backend/bls24-315/plonk/setup.go +++ b/internal/backend/bls24-315/plonk/setup.go @@ -96,15 +96,15 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.DomainNum = *fft.NewDomain(sizeSystem) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(4 * sizeSystem) } vk.Size = pk.DomainNum.Cardinality @@ -113,8 +113,8 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) vk.NbPublicVariables = uint64(spr.NbPublicVariables) // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) + vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) + vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) if err := pk.InitKZG(srs); err != nil { return nil, nil, err @@ -148,11 +148,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) + pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -274,8 +274,8 @@ func computeLDE(pk *ProvingKey) { // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] sID := make([]fr.Element, 3*nbElmt) sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) + sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) + sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) for i := 1; i < nbElmt; i++ { sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 @@ -300,9 +300,9 @@ func computeLDE(pk *ProvingKey) { copy(pk.CS1, pk.LS1) copy(pk.CS2, pk.LS2) copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) fft.BitReverse(pk.CS1) fft.BitReverse(pk.CS2) fft.BitReverse(pk.CS3) diff --git a/internal/backend/bn254/groth16/marshal_test.go b/internal/backend/bn254/groth16/marshal_test.go index d8f8c38496..2db9e45415 100644 --- a/internal/backend/bn254/groth16/marshal_test.go +++ b/internal/backend/bn254/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bn254/groth16/prove.go b/internal/backend/bn254/groth16/prove.go index 87bb23c383..f4a929f426 100644 --- a/internal/backend/bn254/groth16/prove.go +++ b/internal/backend/bn254/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bn254/groth16/setup.go b/internal/backend/bn254/groth16/setup.go index 95cccddf80..334461b25b 100644 --- a/internal/backend/bn254/groth16/setup.go +++ b/internal/backend/bn254/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bn254/plonk/marshal_test.go b/internal/backend/bn254/plonk/marshal_test.go index f17f8ca756..9a7a714d3d 100644 --- a/internal/backend/bn254/plonk/marshal_test.go +++ b/internal/backend/bn254/plonk/marshal_test.go @@ -48,8 +48,8 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) + pk.DomainNum = *fft.NewDomain(42) + pk.DomainH = *fft.NewDomain(4 * 42) pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 25b24887ed..4b8a62637c 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -166,7 +166,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) copy(qk, fullWitness[:spr.NbPublicVariables]) copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) + pk.DomainNum.FFTInverse(qk, fft.DIF) fft.BitReverse(qk) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -401,7 +401,7 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +409,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -552,7 +552,7 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.DomainNum.FFTInverse(z, fft.DIF) fft.BitReverse(z) return blindPoly(z, pk.DomainNum.Cardinality, 2) @@ -608,16 +608,17 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) +// evalIDCosets id, uid, u**2id on (Z/4mZ) func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { id = make([]fr.Element, pk.DomainH.Cardinality) + // TODO doing an expo per chunk is useless utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var acc fr.Element acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) + id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) acc.Mul(&acc, &pk.DomainH.Generator) } }) @@ -707,18 +708,10 @@ func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomia // // Puts the result in res of size n. // Warning: result is in bit reversed order, we do a bit reverse operation only once in computeH +// TODO remove this function func evaluateHDomain(poly []fr.Element, domainH *fft.Domain) []fr.Element { - res := make([]fr.Element, domainH.Cardinality) - - // we copy poly in res and scale by coset here - // to avoid FFT scaling on domainH.Cardinality (res is very sparse) - utils.Parallelize(len(poly), func(start, end int) { - for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainH.CosetTable[0][i]) - } - }, runtime.NumCPU()/2) - domainH.FFT(res, fft.DIF, 0) + domainH.FFT(res, fft.DIF, true) return res } @@ -742,7 +735,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom var one fr.Element one.SetOne() uu.Set(&pk.DomainH.Generator) - u[0].Set(&pk.DomainH.FinerGenerator) + u[0].Set(&pk.DomainH.FrMultiplicativeGen) u[1].Mul(&u[0], &uu) u[2].Mul(&u[1], &uu) u[3].Mul(&u[2], &uu) @@ -768,15 +761,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computes L1 (canonical form) startsAtOne := make(polynomial.Polynomial, pk.DomainH.Cardinality) - utils.Parallelize(int(pk.DomainNum.Cardinality), func(start, end int) { - for i := start; i < end; i++ { - startsAtOne[i].Mul(&pk.DomainNum.CardinalityInv, &pk.DomainH.CosetTable[0][i]) - } - }) - - // evaluates L1 on the odd cosets of (Z/8mZ)/(Z/mZ) - // / ! \ note that we scaled by the coset in the previous loop, hence we pass 0 as coset here. - pk.DomainH.FFT(startsAtOne, fft.DIF, 0) + pk.DomainH.FFT(startsAtOne, fft.DIF, true) // evaluate qlL+qrR+qmL.R+qoO+k + alpha.(zu*g1*g2*g3*l-z*f1*f2*f3*l) + alpha**2*L1(X)(Z(X)-1) // on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -802,12 +787,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.DomainH.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding h1 := h[:pk.DomainNum.Cardinality+2] diff --git a/internal/backend/bn254/plonk/setup.go b/internal/backend/bn254/plonk/setup.go index b7b8d41869..012df10020 100644 --- a/internal/backend/bn254/plonk/setup.go +++ b/internal/backend/bn254/plonk/setup.go @@ -96,15 +96,15 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.DomainNum = *fft.NewDomain(sizeSystem) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(4 * sizeSystem) } vk.Size = pk.DomainNum.Cardinality @@ -113,8 +113,8 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) vk.NbPublicVariables = uint64(spr.NbPublicVariables) // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) + vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) + vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) if err := pk.InitKZG(srs); err != nil { return nil, nil, err @@ -148,11 +148,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) + pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -274,8 +274,8 @@ func computeLDE(pk *ProvingKey) { // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] sID := make([]fr.Element, 3*nbElmt) sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) + sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) + sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) for i := 1; i < nbElmt; i++ { sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 @@ -300,9 +300,9 @@ func computeLDE(pk *ProvingKey) { copy(pk.CS1, pk.LS1) copy(pk.CS2, pk.LS2) copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) fft.BitReverse(pk.CS1) fft.BitReverse(pk.CS2) fft.BitReverse(pk.CS3) diff --git a/internal/backend/bw6-633/groth16/marshal_test.go b/internal/backend/bw6-633/groth16/marshal_test.go index 1fc47f35de..b2a3a63d4b 100644 --- a/internal/backend/bw6-633/groth16/marshal_test.go +++ b/internal/backend/bw6-633/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bw6-633/groth16/prove.go b/internal/backend/bw6-633/groth16/prove.go index 45c4688959..b72cf7b41f 100644 --- a/internal/backend/bw6-633/groth16/prove.go +++ b/internal/backend/bw6-633/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bw6-633/groth16/setup.go b/internal/backend/bw6-633/groth16/setup.go index 17caeecb6f..b26489abbc 100644 --- a/internal/backend/bw6-633/groth16/setup.go +++ b/internal/backend/bw6-633/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bw6-633/plonk/marshal_test.go b/internal/backend/bw6-633/plonk/marshal_test.go index ebf373a463..dd04a1b03b 100644 --- a/internal/backend/bw6-633/plonk/marshal_test.go +++ b/internal/backend/bw6-633/plonk/marshal_test.go @@ -48,8 +48,8 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) + pk.DomainNum = *fft.NewDomain(42) + pk.DomainH = *fft.NewDomain(4 * 42) pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) diff --git a/internal/backend/bw6-633/plonk/prove.go b/internal/backend/bw6-633/plonk/prove.go index e633795585..cb85a15cb5 100644 --- a/internal/backend/bw6-633/plonk/prove.go +++ b/internal/backend/bw6-633/plonk/prove.go @@ -166,7 +166,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) copy(qk, fullWitness[:spr.NbPublicVariables]) copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) + pk.DomainNum.FFTInverse(qk, fft.DIF) fft.BitReverse(qk) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -401,7 +401,7 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +409,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -552,7 +552,7 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.DomainNum.FFTInverse(z, fft.DIF) fft.BitReverse(z) return blindPoly(z, pk.DomainNum.Cardinality, 2) @@ -608,16 +608,17 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) +// evalIDCosets id, uid, u**2id on (Z/4mZ) func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { id = make([]fr.Element, pk.DomainH.Cardinality) + // TODO doing an expo per chunk is useless utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var acc fr.Element acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) + id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) acc.Mul(&acc, &pk.DomainH.Generator) } }) @@ -707,18 +708,10 @@ func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomia // // Puts the result in res of size n. // Warning: result is in bit reversed order, we do a bit reverse operation only once in computeH +// TODO remove this function func evaluateHDomain(poly []fr.Element, domainH *fft.Domain) []fr.Element { - res := make([]fr.Element, domainH.Cardinality) - - // we copy poly in res and scale by coset here - // to avoid FFT scaling on domainH.Cardinality (res is very sparse) - utils.Parallelize(len(poly), func(start, end int) { - for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainH.CosetTable[0][i]) - } - }, runtime.NumCPU()/2) - domainH.FFT(res, fft.DIF, 0) + domainH.FFT(res, fft.DIF, true) return res } @@ -742,7 +735,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom var one fr.Element one.SetOne() uu.Set(&pk.DomainH.Generator) - u[0].Set(&pk.DomainH.FinerGenerator) + u[0].Set(&pk.DomainH.FrMultiplicativeGen) u[1].Mul(&u[0], &uu) u[2].Mul(&u[1], &uu) u[3].Mul(&u[2], &uu) @@ -768,15 +761,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computes L1 (canonical form) startsAtOne := make(polynomial.Polynomial, pk.DomainH.Cardinality) - utils.Parallelize(int(pk.DomainNum.Cardinality), func(start, end int) { - for i := start; i < end; i++ { - startsAtOne[i].Mul(&pk.DomainNum.CardinalityInv, &pk.DomainH.CosetTable[0][i]) - } - }) - - // evaluates L1 on the odd cosets of (Z/8mZ)/(Z/mZ) - // / ! \ note that we scaled by the coset in the previous loop, hence we pass 0 as coset here. - pk.DomainH.FFT(startsAtOne, fft.DIF, 0) + pk.DomainH.FFT(startsAtOne, fft.DIF, true) // evaluate qlL+qrR+qmL.R+qoO+k + alpha.(zu*g1*g2*g3*l-z*f1*f2*f3*l) + alpha**2*L1(X)(Z(X)-1) // on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -802,12 +787,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.DomainH.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding h1 := h[:pk.DomainNum.Cardinality+2] diff --git a/internal/backend/bw6-633/plonk/setup.go b/internal/backend/bw6-633/plonk/setup.go index 06cfe86c1b..383674aff1 100644 --- a/internal/backend/bw6-633/plonk/setup.go +++ b/internal/backend/bw6-633/plonk/setup.go @@ -96,15 +96,15 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.DomainNum = *fft.NewDomain(sizeSystem) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(4 * sizeSystem) } vk.Size = pk.DomainNum.Cardinality @@ -113,8 +113,8 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) vk.NbPublicVariables = uint64(spr.NbPublicVariables) // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) + vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) + vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) if err := pk.InitKZG(srs); err != nil { return nil, nil, err @@ -148,11 +148,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) + pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -274,8 +274,8 @@ func computeLDE(pk *ProvingKey) { // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] sID := make([]fr.Element, 3*nbElmt) sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) + sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) + sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) for i := 1; i < nbElmt; i++ { sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 @@ -300,9 +300,9 @@ func computeLDE(pk *ProvingKey) { copy(pk.CS1, pk.LS1) copy(pk.CS2, pk.LS2) copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) fft.BitReverse(pk.CS1) fft.BitReverse(pk.CS2) fft.BitReverse(pk.CS3) diff --git a/internal/backend/bw6-761/groth16/marshal_test.go b/internal/backend/bw6-761/groth16/marshal_test.go index 73054efb74..bccf077d51 100644 --- a/internal/backend/bw6-761/groth16/marshal_test.go +++ b/internal/backend/bw6-761/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bw6-761/groth16/prove.go b/internal/backend/bw6-761/groth16/prove.go index 8848e71c05..663bc522b1 100644 --- a/internal/backend/bw6-761/groth16/prove.go +++ b/internal/backend/bw6-761/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bw6-761/groth16/setup.go b/internal/backend/bw6-761/groth16/setup.go index 137b4df3fe..0e100d3319 100644 --- a/internal/backend/bw6-761/groth16/setup.go +++ b/internal/backend/bw6-761/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bw6-761/plonk/marshal_test.go b/internal/backend/bw6-761/plonk/marshal_test.go index 4a18f7651f..7587dc2021 100644 --- a/internal/backend/bw6-761/plonk/marshal_test.go +++ b/internal/backend/bw6-761/plonk/marshal_test.go @@ -48,8 +48,8 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) + pk.DomainNum = *fft.NewDomain(42) + pk.DomainH = *fft.NewDomain(4 * 42) pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go index 3f2981607b..d646ec81cf 100644 --- a/internal/backend/bw6-761/plonk/prove.go +++ b/internal/backend/bw6-761/plonk/prove.go @@ -166,7 +166,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) copy(qk, fullWitness[:spr.NbPublicVariables]) copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) + pk.DomainNum.FFTInverse(qk, fft.DIF) fft.BitReverse(qk) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -401,7 +401,7 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +409,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -552,7 +552,7 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.DomainNum.FFTInverse(z, fft.DIF) fft.BitReverse(z) return blindPoly(z, pk.DomainNum.Cardinality, 2) @@ -608,16 +608,17 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) +// evalIDCosets id, uid, u**2id on (Z/4mZ) func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { id = make([]fr.Element, pk.DomainH.Cardinality) + // TODO doing an expo per chunk is useless utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var acc fr.Element acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) + id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) acc.Mul(&acc, &pk.DomainH.Generator) } }) @@ -707,18 +708,10 @@ func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomia // // Puts the result in res of size n. // Warning: result is in bit reversed order, we do a bit reverse operation only once in computeH +// TODO remove this function func evaluateHDomain(poly []fr.Element, domainH *fft.Domain) []fr.Element { - res := make([]fr.Element, domainH.Cardinality) - - // we copy poly in res and scale by coset here - // to avoid FFT scaling on domainH.Cardinality (res is very sparse) - utils.Parallelize(len(poly), func(start, end int) { - for i := start; i < end; i++ { - res[i].Mul(&poly[i], &domainH.CosetTable[0][i]) - } - }, runtime.NumCPU()/2) - domainH.FFT(res, fft.DIF, 0) + domainH.FFT(res, fft.DIF, true) return res } @@ -742,7 +735,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom var one fr.Element one.SetOne() uu.Set(&pk.DomainH.Generator) - u[0].Set(&pk.DomainH.FinerGenerator) + u[0].Set(&pk.DomainH.FrMultiplicativeGen) u[1].Mul(&u[0], &uu) u[2].Mul(&u[1], &uu) u[3].Mul(&u[2], &uu) @@ -768,15 +761,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computes L1 (canonical form) startsAtOne := make(polynomial.Polynomial, pk.DomainH.Cardinality) - utils.Parallelize(int(pk.DomainNum.Cardinality), func(start, end int) { - for i := start; i < end; i++ { - startsAtOne[i].Mul(&pk.DomainNum.CardinalityInv, &pk.DomainH.CosetTable[0][i]) - } - }) - - // evaluates L1 on the odd cosets of (Z/8mZ)/(Z/mZ) - // / ! \ note that we scaled by the coset in the previous loop, hence we pass 0 as coset here. - pk.DomainH.FFT(startsAtOne, fft.DIF, 0) + pk.DomainH.FFT(startsAtOne, fft.DIF, true) // evaluate qlL+qrR+qmL.R+qoO+k + alpha.(zu*g1*g2*g3*l-z*f1*f2*f3*l) + alpha**2*L1(X)(Z(X)-1) // on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -802,12 +787,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.DomainH.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding h1 := h[:pk.DomainNum.Cardinality+2] diff --git a/internal/backend/bw6-761/plonk/setup.go b/internal/backend/bw6-761/plonk/setup.go index 5a0741b31e..80bde803bf 100644 --- a/internal/backend/bw6-761/plonk/setup.go +++ b/internal/backend/bw6-761/plonk/setup.go @@ -96,15 +96,15 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.DomainNum = *fft.NewDomain(sizeSystem) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.DomainH = *fft.NewDomain(4 * sizeSystem) } vk.Size = pk.DomainNum.Cardinality @@ -113,8 +113,8 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) vk.NbPublicVariables = uint64(spr.NbPublicVariables) // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) + vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) + vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) if err := pk.InitKZG(srs); err != nil { return nil, nil, err @@ -148,11 +148,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) + pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) + pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -274,8 +274,8 @@ func computeLDE(pk *ProvingKey) { // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] sID := make([]fr.Element, 3*nbElmt) sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) + sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) + sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) for i := 1; i < nbElmt; i++ { sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 @@ -300,9 +300,9 @@ func computeLDE(pk *ProvingKey) { copy(pk.CS1, pk.LS1) copy(pk.CS2, pk.LS2) copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) fft.BitReverse(pk.CS1) fft.BitReverse(pk.CS2) fft.BitReverse(pk.CS3) diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl index 58f9cf6dd7..485fcf437a 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl @@ -259,18 +259,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -278,12 +278,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl index 867e2aee83..14933f3931 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl @@ -74,7 +74,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -394,7 +394,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl index 67ed2b1e93..e5ad79d370 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl @@ -166,7 +166,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index a48210d151..e4472d8a7a 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -142,7 +142,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) copy(qk, fullWitness[:spr.NbPublicVariables]) copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) + pk.DomainNum.FFTInverse(qk, fft.DIF) fft.BitReverse(qk) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -377,7 +377,7 @@ func computeBlindedLRO(ll,lr,lo polynomial.Polynomial, domain *fft.Domain) (bcl, go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -385,13 +385,13 @@ func computeBlindedLRO(ll,lr,lo polynomial.Polynomial, domain *fft.Domain) (bcl, go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -528,7 +528,7 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.DomainNum.FFTInverse(z, fft.DIF) fft.BitReverse(z) return blindPoly(z, pk.DomainNum.Cardinality, 2) @@ -584,16 +584,17 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) +// evalIDCosets id, uid, u**2id on (Z/4mZ) func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { id = make([]fr.Element, pk.DomainH.Cardinality) + // TODO doing an expo per chunk is useless utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var acc fr.Element acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) + id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) acc.Mul(&acc, &pk.DomainH.Generator) } }) @@ -678,23 +679,16 @@ func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomia return res } + // evaluateHDomain evaluates poly (canonical form) of degree m z**i+1 @@ -282,9 +282,9 @@ func computeLDE(pk *ProvingKey) { copy(pk.CS1, pk.LS1) copy(pk.CS2, pk.LS2) copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) + pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) + pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) fft.BitReverse(pk.CS1) fft.BitReverse(pk.CS2) fft.BitReverse(pk.CS3) diff --git a/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl index 99e19a1319..cacef9b5ba 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl @@ -30,8 +30,8 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) + pk.DomainNum = *fft.NewDomain(42) + pk.DomainH = *fft.NewDomain(4*42) pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) From 0d7649fee98d9dfaf213a65fdefd8bfadc6f1ee4 Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 4 Feb 2022 16:49:08 +0100 Subject: [PATCH 05/37] perf(std/tEd): Lookup2 for first 2 bits in ScalarMulFixedBase --- std/algebra/twistededwards/point.go | 30 ++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index 722aeb8298..5ed9a10e5b 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -97,6 +97,25 @@ func (p *Point) AddGeneric(api frontend.API, p1, p2 *Point, curve EdCurve) *Poin return p } +// DoubleFixedPoint doubles a points in SNARK coordinates +func (p *Point) DoubleFixedPoint(api frontend.API, x, y interface{}, curve EdCurve) *Point { + + u := api.Mul(x, y) + v := api.Mul(x, x) + w := api.Mul(y, y) + + n1 := api.Mul(2, u) + av := api.Mul(v, &curve.A) + n2 := api.Sub(w, av) + d1 := api.Add(w, av) + d2 := api.Sub(2, d1) + + p.X = api.DivUnchecked(n1, d1) + p.Y = api.DivUnchecked(n2, d2) + + return p +} + // Double doubles a points in SNARK coordinates func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { @@ -163,11 +182,16 @@ func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar fr 1, } + pp := Point{} + ppp := Point{} + pp.DoubleFixedPoint(api, x, y, curve) + ppp.AddFixedPoint(api, &pp, x, y, curve) + n := len(b) - 1 - res.X = api.Select(b[n], x, res.X) - res.Y = api.Select(b[n], y, res.Y) + res.X = api.Lookup2(b[n], b[n-1], res.X, pp.X, x, ppp.X) + res.Y = api.Lookup2(b[n], b[n-1], res.Y, pp.Y, y, ppp.Y) - for i := len(b) - 2; i >= 0; i-- { + for i := len(b) - 3; i >= 0; i-- { res.Double(api, &res, curve) tmp := Point{} tmp.AddFixedPoint(api, &res, x, y, curve) From 208235ea5cc9993908c20ee2ca2808804ecbdd06 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Fri, 4 Feb 2022 17:51:14 +0100 Subject: [PATCH 06/37] feat: polynomial --> []frElement --- internal/backend/bn254/plonk/marshal.go | 14 +- internal/backend/bn254/plonk/marshal_test.go | 18 +- internal/backend/bn254/plonk/prove.go | 262 ++++++++++--------- internal/backend/bn254/plonk/setup.go | 85 +++--- 4 files changed, 193 insertions(+), 186 deletions(-) diff --git a/internal/backend/bn254/plonk/marshal.go b/internal/backend/bn254/plonk/marshal.go index 6012cb114b..3dc3291430 100644 --- a/internal/backend/bn254/plonk/marshal.go +++ b/internal/backend/bn254/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainNum.WriteTo(w) + n2, err := pk.DomainSmall.WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainH.WriteTo(w) + n2, err = pk.DomainBig.WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.DomainSmall.Cardinality) + if len(pk.Permutation) != (3 * int(pk.DomainSmall.Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -143,19 +143,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainNum.ReadFrom(r) + n2, err := pk.DomainSmall.ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainH.ReadFrom(r) + n2, err = pk.DomainBig.ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ diff --git a/internal/backend/bn254/plonk/marshal_test.go b/internal/backend/bn254/plonk/marshal_test.go index 9a7a714d3d..ef9085a96a 100644 --- a/internal/backend/bn254/plonk/marshal_test.go +++ b/internal/backend/bn254/plonk/marshal_test.go @@ -48,14 +48,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42) - pk.DomainH = *fft.NewDomain(4 * 42) - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.DomainSmall = *fft.NewDomain(42) + pk.DomainBig = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -63,7 +63,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 4b8a62637c..51deac1b2b 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by gnark DO NOT EDIT - package plonk import ( @@ -27,9 +25,8 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" @@ -37,7 +34,7 @@ import ( "github.com/consensys/gnark/internal/backend/bn254/cs" - "github.com/consensys/gnark-crypto/fiat-shamir" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/internal/utils" ) @@ -89,11 +86,11 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := computeLRO(spr, pk, solution) + ll, lr, lo := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum) + bcl, bcr, bco, err := computeBlindedLROCanonical(ll, lr, lo, &pk.DomainSmall) if err != nil { return nil, err } @@ -111,12 +108,12 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz polynomial.Polynomial + var bz []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { var err error - bz, err = computeBlindedZ(ll, lr, lo, pk, gamma) + bz, err = computeBlindedZCanonical(ll, lr, lo, pk, gamma) if err != nil { chZ <- err close(chZ) @@ -142,31 +139,31 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // evaluation of the blinded versions of l, r, o and bz // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial + var evalBL, evalBR, evalBO, evalBZ []fr.Element chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateHDomain(bcl, &pk.DomainH) + evalBL = evaluateDomainBigBitReversed(bcl, &pk.DomainBig) close(chEvalBL) }() go func() { - evalBR = evaluateHDomain(bcr, &pk.DomainH) + evalBR = evaluateDomainBigBitReversed(bcr, &pk.DomainBig) close(chEvalBR) }() go func() { - evalBO = evaluateHDomain(bco, &pk.DomainH) + evalBO = evaluateDomainBigBitReversed(bco, &pk.DomainBig) close(chEvalBO) }() - var constraintsInd, constraintsOrdering polynomial.Polynomial + var constraintsInd, constraintsOrdering []fr.Element chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) + qk := make([]fr.Element, pk.DomainSmall.Cardinality) copy(qk, fullWitness[:spr.NbPublicVariables]) copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF) + pk.DomainSmall.FFTInverse(qk, fft.DIF) fft.BitReverse(qk) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -174,7 +171,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed(pk, evalBL, evalBR, evalBO, qk) close(chConstraintInd) }() @@ -184,13 +181,13 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) + evalBZ = evaluateDomainBigBitReversed(bz, &pk.DomainBig) // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed(pk, evalBZ, evalBL, evalBR, evalBO, gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() @@ -200,10 +197,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, } <-chConstraintInd // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +215,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(bcl, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(bcr, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(bco, zeta) wgZetaEvals.Done() }() @@ -236,7 +233,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, proof.ZShiftedOpening, err = kzg.Open( bz, &zetaShifted, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -247,7 +244,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial + linearizedPolynomial []fr.Element linearizedPolynomialDigest curve.G1Affine errLPoly error ) @@ -276,7 +273,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) @@ -324,7 +321,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, }, &zeta, hFunc, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +332,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +368,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,13 +394,13 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) @@ -438,7 +444,7 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // WARNING: // pre condition degree(cp) <= rou + bo // pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) totalDegree := rou + bo @@ -447,7 +453,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, res := cp[:totalDegree+1] // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) + blindingPoly := make([]fr.Element, bo+1) for i := uint64(0); i < bo+1; i++ { if _, err := blindingPoly[i].SetRandom(); err != nil { return nil, err @@ -463,13 +469,13 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, return res, nil } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.DomainSmall.Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -507,12 +513,12 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // (l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) // // * l, r, o are the solution in Lagrange basis -func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Element) (polynomial.Polynomial, error) { +func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, gamma fr.Element) ([]fr.Element, error) { // note that z has more capacity has its memory is reused for blinded z later on - z := make(polynomial.Polynomial, pk.DomainNum.Cardinality, pk.DomainNum.Cardinality+3) - nbElmts := int(pk.DomainNum.Cardinality) - gInv := make(polynomial.Polynomial, pk.DomainNum.Cardinality) + z := make([]fr.Element, pk.DomainSmall.Cardinality, pk.DomainSmall.Cardinality+3) + nbElmts := int(pk.DomainSmall.Cardinality) + gInv := make([]fr.Element, pk.DomainSmall.Cardinality) z[0].SetOne() gInv[0].SetOne() @@ -521,7 +527,7 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele var f [3]fr.Element var g [3]fr.Element var u [3]fr.Element - u[0].Exp(pk.DomainNum.Generator, new(big.Int).SetInt64(int64(start))) + u[0].Exp(pk.DomainSmall.Generator, new(big.Int).SetInt64(int64(start))) u[1].Mul(&u[0], &pk.Vk.Shifter[0]) u[2].Mul(&u[0], &pk.Vk.Shifter[1]) @@ -540,9 +546,9 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele gInv[i+1] = g[0] z[i+1] = f[0] - u[0].Mul(&u[0], &pk.DomainNum.Generator) // z**i -> z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 + u[0].Mul(&u[0], &pk.DomainSmall.Generator) // z**i -> z**i+1 + u[1].Mul(&u[1], &pk.DomainSmall.Generator) // u*z**i -> u*z**i+1 + u[2].Mul(&u[2], &pk.DomainSmall.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,40 +558,40 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF) + pk.DomainSmall.FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.DomainSmall.Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on // the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) wg.Wait() // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets // of (Z/8mZ)/(Z/mZ) @@ -608,48 +614,48 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on (Z/4mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { +// evaluationIdDomainBigCoset id, uid, u**2id on (Z/4mZ) +func evaluationIdDomainBigCoset(pk *ProvingKey) (id []fr.Element) { - id = make([]fr.Element, pk.DomainH.Cardinality) + id = make([]fr.Element, pk.DomainBig.Cardinality) // TODO doing an expo per chunk is useless - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) + acc.Exp(pk.DomainBig.Generator, new(big.Int).SetInt64(int64(start))) for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) - acc.Mul(&acc, &pk.DomainH.Generator) + id[i].Mul(&acc, &pk.DomainBig.FrMultiplicativeGen) + acc.Mul(&acc, &pk.DomainBig.Generator) } }) return id } -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd // cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. // // * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets // * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, evalZ, evalL, evalR, evalO []fr.Element, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + // evalutation of ID on domainBig shifted + evalID := evaluationIdDomainBigCoset(pk) // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) var wg sync.WaitGroup wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial + var evalS1, evalS2, evalS3 []fr.Element go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) + evalS1 = evaluateDomainBigBitReversed(pk.CS1, &pk.DomainBig) wg.Done() }() go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) + evalS2 = evaluateDomainBigBitReversed(pk.CS2, &pk.DomainBig) wg.Done() }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) + evalS3 = evaluateDomainBigBitReversed(pk.CS3, &pk.DomainBig) wg.Wait() // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) @@ -658,9 +664,9 @@ func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomia nn := uint64(64 - bits.TrailingZeros64(uint64(s))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element var eID fr.Element @@ -703,71 +709,73 @@ func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomia return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + h[i].Mul(&h[i], &evaluationXnMinusOneInverse[irev%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, true) + pk.DomainBig.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.DomainSmall.Cardinality+2] + h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] + h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] return h1, h2, h3 @@ -803,7 +810,7 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // * a, b, c are the evaluation of l, r, o at zeta // * z is the permutation polynomial, zu is Z(uX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element @@ -813,11 +820,11 @@ func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z p var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) + s1 = eval(pk.CS1, zeta) s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) close(chS1) }() - t := pk.CS2.Eval(&zeta) + t := eval(pk.CS2, zeta) t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) <-chS1 s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) @@ -833,7 +840,7 @@ func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z p // third part L1(zeta)*alpha**2**Z var lagrange, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) + nbElmt := int64(pk.DomainSmall.Cardinality) lagrange.Set(&zeta). Exp(lagrange, big.NewInt(nbElmt)). Sub(&lagrange, &one) @@ -845,7 +852,8 @@ func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z p Mul(&lagrange, &alpha). Mul(&lagrange, &alpha) // alpha**2*L_0 - linPol := z.Clone() + linPol := make([]fr.Element, len(z)) + copy(linPol, z) utils.Parallelize(len(linPol), func(start, end int) { var t0, t1 fr.Element diff --git a/internal/backend/bn254/plonk/setup.go b/internal/backend/bn254/plonk/setup.go index 012df10020..e19555a9c2 100644 --- a/internal/backend/bn254/plonk/setup.go +++ b/internal/backend/bn254/plonk/setup.go @@ -12,16 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by gnark DO NOT EDIT - package plonk import ( "errors" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" "github.com/consensys/gnark/internal/backend/bn254/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +38,19 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element // Domains used for the FFTs - DomainNum, DomainH fft.Domain + DomainSmall, DomainBig fft.Domain // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + LsID []fr.Element + LS1, LS2, LS3 []fr.Element + CS1, CS2, CS3 []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -96,37 +95,37 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem) + pk.DomainSmall = *fft.NewDomain(sizeSystem) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8 * sizeSystem) + pk.DomainBig = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4 * sizeSystem) + pk.DomainBig = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.DomainSmall.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.DomainSmall.Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) // shifters - vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) - vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) + vk.Shifter[0].Set(&pk.DomainSmall.FrMultiplicativeGen) + vk.Shifter[1].Square(&pk.DomainSmall.FrMultiplicativeGen) if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -148,11 +147,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -211,7 +210,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.DomainSmall.Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -269,24 +268,24 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { // s1 (LDE) s2 (LDE) s3 (LDE) func computeLDE(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmt := int(pk.DomainSmall.Cardinality) // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] sID := make([]fr.Element, 3*nbElmt) sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) - sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) + sID[nbElmt].Set(&pk.DomainSmall.FrMultiplicativeGen) + sID[2*nbElmt].Square(&pk.DomainSmall.FrMultiplicativeGen) for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 + sID[i].Mul(&sID[i-1], &pk.DomainSmall.Generator) // z**i -> z**i+1 + sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainSmall.Generator) // u*z**i -> u*z**i+1 + sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainSmall.Generator) // u**2*z**i -> u**2*z**i+1 } // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) + pk.LS1 = make([]fr.Element, nbElmt) + pk.LS2 = make([]fr.Element, nbElmt) + pk.LS3 = make([]fr.Element, nbElmt) for i := 0; i < nbElmt; i++ { pk.LS1[i].Set(&sID[pk.Permutation[i]]) pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) @@ -294,22 +293,22 @@ func computeLDE(pk *ProvingKey) { } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) + pk.CS1 = make([]fr.Element, nbElmt) + pk.CS2 = make([]fr.Element, nbElmt) + pk.CS3 = make([]fr.Element, nbElmt) copy(pk.CS1, pk.LS1) copy(pk.CS2, pk.LS2) copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CS1, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CS2, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CS3, fft.DIF) fft.BitReverse(pk.CS1) fft.BitReverse(pk.CS2) fft.BitReverse(pk.CS3) } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized From 80807eceb7a8337bb8793dfa0969e1be0ff10322 Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 4 Feb 2022 17:55:03 +0100 Subject: [PATCH 07/37] perf(std/tEd): FixedPoint should be hidden by the API --- std/algebra/twistededwards/curve.go | 58 ++++++++----- std/algebra/twistededwards/point.go | 94 ++------------------- std/algebra/twistededwards/point_test.go | 103 ++++++++++++----------- std/signature/eddsa/eddsa.go | 10 ++- 4 files changed, 108 insertions(+), 157 deletions(-) diff --git a/std/algebra/twistededwards/curve.go b/std/algebra/twistededwards/curve.go index 798b01271f..432210c588 100644 --- a/std/algebra/twistededwards/curve.go +++ b/std/algebra/twistededwards/curve.go @@ -30,11 +30,17 @@ import ( "github.com/consensys/gnark/internal/utils" ) +// Coordinates of a point on a twisted Edwards curve +type Coord struct { + X, Y big.Int +} + // EdCurve stores the info on the chosen edwards curve // note that all curves implemented in gnark-crypto have A = -1 type EdCurve struct { - A, D, Cofactor, Order, BaseX, BaseY big.Int - ID ecc.ID + A, D, Cofactor, Order big.Int + Base Coord + ID ecc.ID } var constructors map[ecc.ID]func() EdCurve @@ -71,9 +77,11 @@ func newEdBN254() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BN254, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BN254, } } @@ -88,9 +96,11 @@ func newEdBLS381() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BLS12_381, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BLS12_381, } } @@ -105,9 +115,11 @@ func newEdBLS377() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BLS12_377, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BLS12_377, } } @@ -122,9 +134,11 @@ func newEdBW633() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BW6_633, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BW6_633, } } @@ -139,9 +153,11 @@ func newEdBW761() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BW6_761, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BW6_761, } } @@ -156,9 +172,11 @@ func newEdBLS315() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BLS24_315, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BLS24_315, } } diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index 5ed9a10e5b..573ffe88cb 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -46,34 +46,9 @@ func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { } -// AddFixedPoint Adds two points, among which is one fixed point (the base), on a twisted edwards curve (eg jubjub) -// p1, base, ecurve are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) AddFixedPoint(api frontend.API, p1 *Point /*basex*/, x /*basey*/, y interface{}, curve EdCurve) *Point { - - // https://eprint.iacr.org/2008/013.pdf - - n11 := api.Mul(p1.X, y) - n12 := api.Mul(p1.Y, x) - n1 := api.Add(n11, n12) - - n21 := api.Mul(p1.Y, y) - n22 := api.Mul(p1.X, x) - an22 := api.Mul(n22, &curve.A) - n2 := api.Sub(n21, an22) - - d11 := api.Mul(curve.D, n11, n12) - d1 := api.Add(1, d11) - d2 := api.Sub(1, d11) - - p.X = api.DivUnchecked(n1, d1) - p.Y = api.DivUnchecked(n2, d2) - - return p -} - -// AddGeneric Adds two points on a twisted edwards curve (eg jubjub) +// Add Adds two points on a twisted edwards curve (eg jubjub) // p1, p2, c are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) AddGeneric(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { +func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { // https://eprint.iacr.org/2008/013.pdf @@ -97,25 +72,6 @@ func (p *Point) AddGeneric(api frontend.API, p1, p2 *Point, curve EdCurve) *Poin return p } -// DoubleFixedPoint doubles a points in SNARK coordinates -func (p *Point) DoubleFixedPoint(api frontend.API, x, y interface{}, curve EdCurve) *Point { - - u := api.Mul(x, y) - v := api.Mul(x, x) - w := api.Mul(y, y) - - n1 := api.Mul(2, u) - av := api.Mul(v, &curve.A) - n2 := api.Sub(w, av) - d1 := api.Add(w, av) - d2 := api.Sub(2, d1) - - p.X = api.DivUnchecked(n1, d1) - p.Y = api.DivUnchecked(n2, d2) - - return p -} - // Double doubles a points in SNARK coordinates func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { @@ -135,44 +91,12 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { return p } -// ScalarMulNonFixedBase computes the scalar multiplication of a point on a twisted Edwards curve +// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve // p1: base point (as snark point) // curve: parameters of the Edwards curve // scal: scalar as a SNARK constraint // Standard left to right double and add -func (p *Point) ScalarMulNonFixedBase(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { - - // first unpack the scalar - b := api.ToBinary(scalar) - - res := Point{ - 0, - 1, - } - - n := len(b) - 1 - res.X = api.Select(b[n], p1.X, res.X) - res.Y = api.Select(b[n], p1.Y, res.Y) - - for i := len(b) - 2; i >= 0; i-- { - res.Double(api, &res, curve) - tmp := Point{} - tmp.AddGeneric(api, &res, p1, curve) - res.X = api.Select(b[i], tmp.X, res.X) - res.Y = api.Select(b[i], tmp.Y, res.Y) - } - - p.X = res.X - p.Y = res.Y - return p -} - -// ScalarMulFixedBase computes the scalar multiplication of a point on a twisted Edwards curve -// x, y: coordinates of the base point -// curve: parameters of the Edwards curve -// scal: scalar as a SNARK constraint -// Standard left to right double and add -func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar frontend.Variable, curve EdCurve) *Point { +func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { // first unpack the scalar b := api.ToBinary(scalar) @@ -184,17 +108,17 @@ func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar fr pp := Point{} ppp := Point{} - pp.DoubleFixedPoint(api, x, y, curve) - ppp.AddFixedPoint(api, &pp, x, y, curve) + pp.Double(api, p1, curve) + ppp.Add(api, &pp, p1, curve) n := len(b) - 1 - res.X = api.Lookup2(b[n], b[n-1], res.X, pp.X, x, ppp.X) - res.Y = api.Lookup2(b[n], b[n-1], res.Y, pp.Y, y, ppp.Y) + res.X = api.Lookup2(b[n], b[n-1], res.X, pp.X, p1.X, ppp.X) + res.Y = api.Lookup2(b[n], b[n-1], res.Y, pp.Y, p1.Y, ppp.Y) for i := len(b) - 3; i >= 0; i-- { res.Double(api, &res, curve) tmp := Point{} - tmp.AddFixedPoint(api, &res, x, y, curve) + tmp.Add(api, &res, p1, curve) res.X = api.Select(b[i], tmp.X, res.X) res.Y = api.Select(b[i], tmp.Y, res.Y) } diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index f36305abad..7b7a696a28 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -61,8 +61,8 @@ func TestIsOnCurve(t *testing.T) { t.Fatal(err) } - witness.P.X = (params.BaseX) - witness.P.Y = (params.BaseY) + witness.P.X = (params.Base.X) + witness.P.Y = (params.Base.Y) assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254)) @@ -80,7 +80,10 @@ func (circuit *add) Define(api frontend.API) error { return err } - res := circuit.P.AddFixedPoint(api, &circuit.P, params.BaseX, params.BaseY, params) + p := Point{} + p.X = params.Base.X + p.Y = params.Base.Y + res := circuit.P.Add(api, &circuit.P, &p, params) api.AssertIsEqual(res.X, circuit.E.X) api.AssertIsEqual(res.Y, circuit.E.Y) @@ -100,8 +103,8 @@ func TestAddFixedPoint(t *testing.T) { t.Fatal(err) } var base, point, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) point.Set(&base) r := big.NewInt(5) point.ScalarMul(&point, r) @@ -133,7 +136,7 @@ func (circuit *addGeneric) Define(api frontend.API) error { return err } - res := circuit.P1.AddGeneric(api, &circuit.P1, &circuit.P2, params) + res := circuit.P1.Add(api, &circuit.P1, &circuit.P2, params) api.AssertIsEqual(res.X, circuit.E.X) api.AssertIsEqual(res.Y, circuit.E.Y) @@ -157,8 +160,8 @@ func TestAddGeneric(t *testing.T) { switch id { case ecc.BN254: var op1, op2, expected tbn254.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -173,8 +176,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS12_381: var op1, op2, expected tbls12381.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -189,8 +192,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS12_377: var op1, op2, expected tbls12377.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -205,8 +208,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS24_315: var op1, op2, expected tbls24315.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -221,8 +224,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BW6_633: var op1, op2, expected tbw6633.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -237,8 +240,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BW6_761: var op1, op2, expected tbw6761.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -299,8 +302,8 @@ func TestDouble(t *testing.T) { switch id { case ecc.BN254: var base, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -308,8 +311,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS12_381: var base, expected tbls12381.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -317,8 +320,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS12_377: var base, expected tbls12377.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -326,8 +329,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS24_315: var base, expected tbls24315.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -335,8 +338,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BW6_633: var base, expected tbw6633.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -344,8 +347,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BW6_761: var base, expected tbw6761.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -375,8 +378,10 @@ func (circuit *scalarMulFixed) Define(api frontend.API) error { return err } - var resFixed Point - resFixed.ScalarMulFixedBase(api, params.BaseX, params.BaseY, circuit.S, params) + var resFixed, p Point + p.X = params.Base.X + p.Y = params.Base.Y + resFixed.ScalarMul(api, &p, circuit.S, params) api.AssertIsEqual(resFixed.X, circuit.E.X) api.AssertIsEqual(resFixed.Y, circuit.E.Y) @@ -402,8 +407,8 @@ func TestScalarMulFixed(t *testing.T) { switch id { case ecc.BN254: var base, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -411,8 +416,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BLS12_381: var base, expected tbls12381.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -420,8 +425,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BLS12_377: var base, expected tbls12377.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -429,8 +434,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BLS24_315: var base, expected tbls24315.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -438,8 +443,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BW6_633: var base, expected tbw6633.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -447,8 +452,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BW6_761: var base, expected tbw6761.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -475,7 +480,7 @@ func (circuit *scalarMulGeneric) Define(api frontend.API) error { return err } - resGeneric := circuit.P.ScalarMulNonFixedBase(api, &circuit.P, circuit.S, params) + resGeneric := circuit.P.ScalarMul(api, &circuit.P, circuit.S, params) api.AssertIsEqual(resGeneric.X, circuit.E.X) api.AssertIsEqual(resGeneric.Y, circuit.E.Y) @@ -495,8 +500,8 @@ func TestScalarMulGeneric(t *testing.T) { t.Fatal(err) } var base, point, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) s := big.NewInt(902) point.ScalarMul(&base, s) // random point r := big.NewInt(230928302) @@ -537,8 +542,8 @@ func TestNeg(t *testing.T) { t.Fatal(err) } var base, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Neg(&base) // generate witness diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go index 7cffff0b8b..94b218023a 100644 --- a/std/signature/eddsa/eddsa.go +++ b/std/signature/eddsa/eddsa.go @@ -60,16 +60,20 @@ func Verify(api frontend.API, sig Signature, msg frontend.Variable, pubKey Publi hash.Write(data...) hramConstant := hash.Sum() + base := twistededwards.Point{} + base.X = pubKey.Curve.Base.X + base.Y = pubKey.Curve.Base.Y + // lhs = [S]G cofactor := pubKey.Curve.Cofactor.Uint64() lhs := twistededwards.Point{} - lhs.ScalarMulFixedBase(api, pubKey.Curve.BaseX, pubKey.Curve.BaseY, sig.S, pubKey.Curve) + lhs.ScalarMul(api, &base, sig.S, pubKey.Curve) lhs.MustBeOnCurve(api, pubKey.Curve) // rhs = R+[H(R,A,M)]*A rhs := twistededwards.Point{} - rhs.ScalarMulNonFixedBase(api, &pubKey.A, hramConstant, pubKey.Curve). - AddGeneric(api, &rhs, &sig.R, pubKey.Curve) + rhs.ScalarMul(api, &pubKey.A, hramConstant, pubKey.Curve). + Add(api, &rhs, &sig.R, pubKey.Curve) // rhs.MustBeOnCurve(api, pubKey.Curve) // [cofactor]*lhs and [cofactor]*rhs From 0c69bf1b41efc9e7c09e73e2f8d2484db8f1a3da Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Tue, 8 Feb 2022 17:33:20 +0100 Subject: [PATCH 08/37] fix: fixed plonk up to permutation polynomial --- internal/backend/bn254/plonk/prove.go | 437 ++++++++++++++----------- internal/backend/bn254/plonk/setup.go | 140 +++++--- internal/backend/bn254/plonk/verify.go | 72 ++-- 3 files changed, 367 insertions(+), 282 deletions(-) diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 51deac1b2b..e74de5b1a0 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -16,6 +16,7 @@ package plonk import ( "crypto/sha256" + "fmt" "math/big" "math/bits" "runtime" @@ -40,6 +41,7 @@ import ( ) type Proof struct { + // Commitments to the solution vectors LRO [3]kzg.Digest @@ -59,6 +61,11 @@ type Proof struct { // Prove from the public data func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, opt backend.ProverConfig) (*Proof, error) { + // printPoly("cql", pk.Ql) + // printPoly("cqr", pk.Qr) + // printPoly("cqm", pk.Qm) + // printPoly("cqo", pk.Qo) + // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -86,17 +93,25 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := evaluateLROSmallDomain(spr, pk, solution) + evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLROCanonical(ll, lr, lo, &pk.DomainSmall) + blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + &pk.DomainSmall) if err != nil { return nil, err } + // printPoly("cl", blindedLCanonical) + // printPoly("cr", blindedRCanonical) + // printPoly("co", blindedOCanonical) + // compute kzg commitments of bcl, bcr and bco - if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -106,14 +121,22 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, return nil, err } + // Fiat Shamir this + var beta fr.Element + beta.SetUint64(10) + // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz []fr.Element + var blindedZCanonical []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { var err error - bz, err = computeBlindedZCanonical(ll, lr, lo, pk, gamma) + blindedZCanonical, err = computeBlindedZCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + pk, beta, gamma) if err != nil { chZ <- err close(chZ) @@ -125,7 +148,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // this may add additional arithmetic operations, but with smaller tasks // we ensure that this commitment is well parallelized, without having a "unbalanced task" making // the rest of the code wait too long. - if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { + if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { chZ <- err close(chZ) return @@ -138,21 +161,26 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, }() // evaluation of the blinded versions of l, r, o and bz - // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ []fr.Element + // on the coset of the big domain + var ( + evaluationBlindedLDomainBigBitReversed []fr.Element + evaluationBlindedRDomainBigBitReversed []fr.Element + evaluationBlindedODomainBigBitReversed []fr.Element + evaluationBlindedZDomainBigBitReversed []fr.Element + ) chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateDomainBigBitReversed(bcl, &pk.DomainBig) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) // CORRECT close(chEvalBL) }() go func() { - evalBR = evaluateDomainBigBitReversed(bcr, &pk.DomainBig) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) // CORRECT close(chEvalBR) }() go func() { - evalBO = evaluateDomainBigBitReversed(bco, &pk.DomainBig) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) // CORRECT close(chEvalBO) }() @@ -160,18 +188,23 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make([]fr.Element, pk.DomainSmall.Cardinality) - copy(qk, fullWitness[:spr.NbPublicVariables]) - copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainSmall.FFTInverse(qk, fft.DIF) - fft.BitReverse(qk) + qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) - // --> uses the blinded version of l, r, o + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evaluateConstraintsDomainBigBitReversed(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -181,13 +214,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, chConstraintOrdering <- err return } - evalBZ = evaluateDomainBigBitReversed(bz, &pk.DomainBig) + printPoly("z", blindedZCanonical) + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evaluateOrderingDomainBigBitReversed(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() @@ -195,9 +236,16 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, if err := <-chConstraintOrdering; err != nil { return nil, err } - <-chConstraintInd + + check := make([]fr.Element, len(constraintsOrdering)) + copy(check, constraintsOrdering) + fft.BitReverse(constraintsOrdering) + // printVector("gordering", constraintsOrdering, true) + + <-chConstraintInd // CORRECT + // compute h in canonical form - h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { @@ -210,20 +258,25 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, return nil, err } + fmt.Printf("beta = Fr(%s)\n", beta.String()) + fmt.Printf("gamma = Fr(%s)\n", gamma.String()) + fmt.Printf("alpha = Fr(%s)\n", alpha.String()) + fmt.Printf("zeta = Fr(%s)\n", zeta.String()) + // compute evaluations of (blinded version of) l, r, o, z at zeta var blzeta, brzeta, bozeta fr.Element var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = eval(bcl, zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = eval(bcr, zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = eval(bco, zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -231,7 +284,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, + blindedZCanonical, &zetaShifted, &pk.DomainBig, pk.Vk.KZGSRS, @@ -258,10 +311,11 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) @@ -304,11 +358,11 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, []polynomial.Polynomial{ foldedH, linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -442,31 +496,34 @@ func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - totalDegree := rou + bo + // totalDegree := rou + bo - // re-use cp - res := cp[:totalDegree+1] + // // re-use cp + // res := cp[:totalDegree+1] - // random polynomial - blindingPoly := make([]fr.Element, bo+1) - for i := uint64(0); i < bo+1; i++ { - if _, err := blindingPoly[i].SetRandom(); err != nil { - return nil, err - } - } + // // random polynomial + // blindingPoly := make([]fr.Element, bo+1) + // for i := uint64(0); i < bo+1; i++ { + // if _, err := blindingPoly[i].SetRandom(); err != nil { + // return nil, err + // } + // } - // blinding - for i := uint64(0); i < bo+1; i++ { - res[i].Sub(&res[i], &blindingPoly[i]) - res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - } + // // blinding + // for i := uint64(0); i < bo+1; i++ { + // res[i].Sub(&res[i], &blindingPoly[i]) + // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + // } + + // return res, nil - return res, nil + // TODO reactivate blinding + return cp, nil } // evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. @@ -508,12 +565,12 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainSmall.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainSmall.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -593,6 +646,7 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO }() evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) wg.Wait() + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets // of (Z/8mZ)/(Z/mZ) utils.Parallelize(len(evalQk), func(start, end int) { @@ -614,95 +668,75 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO return evalQk } -// evaluationIdDomainBigCoset id, uid, u**2id on (Z/4mZ) -func evaluationIdDomainBigCoset(pk *ProvingKey) (id []fr.Element) { - - id = make([]fr.Element, pk.DomainBig.Cardinality) - - // TODO doing an expo per chunk is useless - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainBig.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainBig.FrMultiplicativeGen) - acc.Mul(&acc, &pk.DomainBig.Generator) - } - }) - - return id -} - // evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, evalZ, evalL, evalR, evalO []fr.Element, gamma fr.Element) []fr.Element { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID on domainBig shifted - evalID := evaluationIdDomainBigCoset(pk) + nbElmts := int(pk.DomainBig.Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 []fr.Element - go func() { - evalS1 = evaluateDomainBigBitReversed(pk.CS1, &pk.DomainBig) - wg.Done() - }() - go func() { - evalS2 = evaluateDomainBigBitReversed(pk.CS2, &pk.DomainBig) - wg.Done() - }() - evalS3 = evaluateDomainBigBitReversed(pk.CS3, &pk.DomainBig) - wg.Wait() + // printVector("z", z, true) // CORRECT + // printVector("l", l, true) // CORRECT + // printVector("r", r, true) // CORRECT + // printVector("o", o, true) // CORRECT - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + // printVector("s1", pk.EvaluationPermutationBigDomainBitReversed[:nbElmts], true) + // printVector("s2", pk.EvaluationPermutationBigDomainBitReversed[nbElmts:2*nbElmts], true) + // printVector("s3", pk.EvaluationPermutationBigDomainBitReversed[2*nbElmts:], true) + + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.DomainBig.Cardinality) + + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality + toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + // fft.BitReverse(z) + // fft.BitReverse(l) + // fft.BitReverse(r) + // fft.BitReverse(o) + // fft.BitReverse(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality]) + // fft.BitReverse(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality : 2*pk.DomainBig.Cardinality]) + // fft.BitReverse(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:]) utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) + var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] - - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g } }) @@ -710,17 +744,18 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, evalZ, evalL, evalR, e } // evaluateDomainBigBitReversed evaluates poly (canonical form) of degree m> nn h[i].Mul(&h[i], &evaluationXnMinusOneInverse[irev%ratio]) } @@ -807,79 +838,89 @@ func computeQuotientCanonical(pk *ProvingKey, constraintsInd, constraintOrdering // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z []fr.Element, pk *ProvingKey) []fr.Element { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = eval(pk.CS1, zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := eval(pk.CS2, zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() nbElmt := int64(pk.DomainSmall.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (1/n)*(ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha) // α²*L₁(ζ) - linPol := make([]fr.Element, len(z)) - copy(linPol, z) + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + if i < len(pk.S3Canonical) { + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bn254/plonk/setup.go b/internal/backend/bn254/plonk/setup.go index e19555a9c2..cc1b53bf78 100644 --- a/internal/backend/bn254/plonk/setup.go +++ b/internal/backend/bn254/plonk/setup.go @@ -16,6 +16,7 @@ package plonk import ( "errors" + "fmt" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" @@ -47,10 +48,9 @@ type ProvingKey struct { // Domains used for the FFTs DomainSmall, DomainBig fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LsID []fr.Element - LS1, LS2, LS3 []fr.Element - CS1, CS2, CS3 []fr.Element + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -68,13 +68,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,6 +95,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints pk.DomainSmall = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases @@ -111,10 +111,6 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) vk.Generator.Set(&pk.DomainSmall.Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainSmall.FrMultiplicativeGen) - vk.Shifter[1].Square(&pk.DomainSmall.FrMultiplicativeGen) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } @@ -133,7 +129,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -162,7 +158,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -181,13 +177,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -199,13 +195,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { @@ -255,7 +251,7 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // // ex: z gen of Z/mZ, u gen of Z/8mZ, then @@ -266,46 +262,84 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainSmall.Cardinality) + nbElmts := int(pk.DomainSmall.Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainSmall.FrMultiplicativeGen) - sID[2*nbElmt].Square(&pk.DomainSmall.FrMultiplicativeGen) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainSmall.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainSmall.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainSmall.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) // Lagrange form of S1, S2, S3 - pk.LS1 = make([]fr.Element, nbElmt) - pk.LS2 = make([]fr.Element, nbElmt) - pk.LS3 = make([]fr.Element, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make([]fr.Element, nbElmt) - pk.CS2 = make([]fr.Element, nbElmt) - pk.CS3 = make([]fr.Element, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainSmall.FFTInverse(pk.CS1, fft.DIF) - pk.DomainSmall.FFTInverse(pk.CS2, fft.DIF) - pk.DomainSmall.FFTInverse(pk.CS3, fft.DIF) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res +} + +func printVector(name string, vector []fr.Element, reverse ...bool) { + + a := make([]fr.Element, len(vector)) + copy(a, vector) + if len(reverse) > 0 { + fft.BitReverse(a) + } + fmt.Printf("%s = [", name) + for i := 0; i < len(a); i++ { + fmt.Printf("%s, ", a[i].String()) + } + fmt.Println("]") +} + +func printPoly(name string, vector []fr.Element) { + fmt.Printf("%s = ", name) + for i := 0; i < len(vector); i++ { + fmt.Printf("%s*x**%d", vector[i].String(), i) + if i < len(vector)-1 { + fmt.Printf(" + ") + } + } + fmt.Println("") } // InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS diff --git a/internal/backend/bn254/plonk/verify.go b/internal/backend/bn254/plonk/verify.go index 0c5261a51c..26e5ab5c6a 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/internal/backend/bn254/plonk/verify.go @@ -63,7 +63,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +71,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i Date: Tue, 8 Feb 2022 17:49:41 +0100 Subject: [PATCH 09/37] test(tEd): test scalarMul for all curves and schemes --- std/algebra/twistededwards/point_test.go | 124 +++++++++++++++++++---- 1 file changed, 102 insertions(+), 22 deletions(-) diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index 7b7a696a28..99d1023092 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -396,8 +396,7 @@ func TestScalarMulFixed(t *testing.T) { var circuit, witness scalarMulFixed // generate witness data - //for _, id := range ecc.Implemented() { - for _, id := range []ecc.ID{ecc.BLS12_377} { + for _, id := range ecc.Implemented() { params, err := NewEdCurve(id) if err != nil { @@ -462,7 +461,7 @@ func TestScalarMulFixed(t *testing.T) { } // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithBackends(backend.PLONK), test.WithCurves(id)) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) } } @@ -495,28 +494,109 @@ func TestScalarMulGeneric(t *testing.T) { var circuit, witness scalarMulGeneric // generate witness data - params, err := NewEdCurve(ecc.BN254) - if err != nil { - t.Fatal(err) - } - var base, point, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.Base.X) - base.Y.SetBigInt(¶ms.Base.Y) - s := big.NewInt(902) - point.ScalarMul(&base, s) // random point - r := big.NewInt(230928302) - expected.ScalarMul(&point, r) + for _, id := range ecc.Implemented() { - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) + params, err := NewEdCurve(id) + if err != nil { + t.Fatal(err) + } - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254)) + switch id { + case ecc.BN254: + var base, point, expected tbn254.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BLS12_377: + var base, point, expected tbls12377.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BLS12_381: + var base, point, expected tbls12381.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BLS24_315: + var base, point, expected tbls24315.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BW6_761: + var base, point, expected tbw6761.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BW6_633: + var base, point, expected tbw6633.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + } + // creates r1cs + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) + } } type neg struct { From a336a73e5f930fbab848a33157d8e18270e33f65 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Tue, 8 Feb 2022 22:49:34 +0100 Subject: [PATCH 10/37] fix: correct up to quotient --- internal/backend/bn254/plonk/marshal.go | 27 ++--- internal/backend/bn254/plonk/marshal_test.go | 2 - internal/backend/bn254/plonk/prove.go | 103 +++++++++---------- 3 files changed, 55 insertions(+), 77 deletions(-) diff --git a/internal/backend/bn254/plonk/marshal.go b/internal/backend/bn254/plonk/marshal.go index 3dc3291430..92a7dde042 100644 --- a/internal/backend/bn254/plonk/marshal.go +++ b/internal/backend/bn254/plonk/marshal.go @@ -12,16 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by gnark DO NOT EDIT - package plonk import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" "errors" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" "io" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" ) // WriteTo writes binary encoding of Proof to w @@ -117,12 +116,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { ([]fr.Element)(pk.Qo), ([]fr.Element)(pk.CQk), ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.LS1), - ([]fr.Element)(pk.LS2), - ([]fr.Element)(pk.LS3), - ([]fr.Element)(pk.CS1), - ([]fr.Element)(pk.CS2), - ([]fr.Element)(pk.CS3), + ([]fr.Element)(pk.S1Canonical), + ([]fr.Element)(pk.S2Canonical), + ([]fr.Element)(pk.S3Canonical), pk.Permutation, } @@ -165,12 +161,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { (*[]fr.Element)(&pk.Qo), (*[]fr.Element)(&pk.CQk), (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.LS1), - (*[]fr.Element)(&pk.LS2), - (*[]fr.Element)(&pk.LS3), - (*[]fr.Element)(&pk.CS1), - (*[]fr.Element)(&pk.CS2), - (*[]fr.Element)(&pk.CS3), + (*[]fr.Element)(&pk.S1Canonical), + (*[]fr.Element)(&pk.S2Canonical), + (*[]fr.Element)(&pk.S3Canonical), &pk.Permutation, } @@ -193,8 +186,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.SizeInv, &vk.Generator, vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -222,8 +213,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.SizeInv, &vk.Generator, &vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], diff --git a/internal/backend/bn254/plonk/marshal_test.go b/internal/backend/bn254/plonk/marshal_test.go index ef9085a96a..b508906177 100644 --- a/internal/backend/bn254/plonk/marshal_test.go +++ b/internal/backend/bn254/plonk/marshal_test.go @@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen @@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index e74de5b1a0..9721acbeef 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -233,19 +233,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { + if err := <-chConstraintOrdering; err != nil { // CORRECT return nil, err } - check := make([]fr.Element, len(constraintsOrdering)) - copy(check, constraintsOrdering) - fft.BitReverse(constraintsOrdering) - // printVector("gordering", constraintsOrdering, true) - <-chConstraintInd // CORRECT // compute h in canonical form - h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT // compute kzg commitments of h1, h2 and h3 if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { @@ -297,16 +292,16 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial []fr.Element - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, @@ -321,11 +316,11 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + \zeta^{m+2}*Comm(h2) + \zeta^{2(m+2)}*Comm(h3) var bZetaPowerm, bSize big.Int bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element @@ -333,18 +328,18 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // \zeta^{m+2}*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // \zeta^{2(m+2)}*Comm(h3) + \zeta^{m+2}*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // \zeta^{2(m+2)}*Comm(h3) + \zeta^{m+2}*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + \zeta*h2 + \zeta^{2}*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // \zeta^{m+2}*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // \zeta^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // \zeta^{2(m+2)}*h3+h2*\zeta^{m+2} + foldedH[i].Add(&foldedH[i], &h1[i]) // \zeta^{2(m+2)*h3+\zeta^{m+2}*h2 + h1 } }) @@ -357,7 +352,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, proof.BatchedProof, err = kzg.BatchOpenSinglePoint( []polynomial.Polynomial{ foldedH, - linearizedPolynomial, + linearizedPolynomialCanonical, blindedLCanonical, blindedRCanonical, blindedOCanonical, @@ -619,7 +614,7 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma } // evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version @@ -700,14 +695,6 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Elemen cosetShift.Set(&pk.Vk.CosetShift) cosetShiftSquare.Square(&pk.Vk.CosetShift) - // fft.BitReverse(z) - // fft.BitReverse(l) - // fft.BitReverse(r) - // fft.BitReverse(o) - // fft.BitReverse(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality]) - // fft.BitReverse(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality : 2*pk.DomainBig.Cardinality]) - // fft.BitReverse(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:]) - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { var evaluationIDBigDomain fr.Element @@ -762,16 +749,18 @@ func evaluateXnMinusOneDomainBigCoset(domainBig, domainSmall *fft.Domain) []fr.E res := make([]fr.Element, ratio) - var g fr.Element expo := big.NewInt(int64(domainSmall.Cardinality)) - g.Exp(domainBig.Generator, expo) + res[0].Exp(domainBig.FrMultiplicativeGen, expo) + + var t fr.Element + t.Exp(domainBig.Generator, big.NewInt(int64(domainSmall.Cardinality))) - res[0].Set(&domainBig.FrMultiplicativeGen) for i := 1; i < int(ratio); i++ { - res[i].Mul(&res[i-1], &g) + res[i].Mul(&res[i-1], &t) } var one fr.Element + one.SetOne() for i := 0; i < int(ratio); i++ { res[i].Sub(&res[i], &one) } @@ -784,19 +773,20 @@ func evaluateXnMinusOneDomainBigCoset(domainBig, domainSmall *fft.Domain) []fr.E // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α²*L₁(X)*(Z(X)-1)= h(X)Z(X) // // constraintInd, constraintOrdering are evaluated on the big domain (coset). -func computeQuotientCanonical(pk *ProvingKey, constraintsInd, constraintOrdering, evaluationBlindedZDomainBigBitReversed []fr.Element, alpha fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { +func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReversed, evaluationConstraintOrderingBitReversed, evaluationBlindedZDomainBigBitReversed []fr.Element, alpha fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { h := make([]fr.Element, pk.DomainBig.Cardinality) // evaluate Z = Xᵐ-1 on a coset of the big domain - var bExpo big.Int - bExpo.SetUint64(pk.DomainSmall.Cardinality) evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) - evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) + evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // CORRECT // computes L_{1} (canonical form) startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) - pk.DomainBig.FFT(startsAtOne, fft.DIF, true) + for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { + startsAtOne[i].SetOne() + } + pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L_{1}(X)(Z(X)-1) // on a coset of the big domain @@ -810,16 +800,15 @@ func computeQuotientCanonical(pk *ProvingKey, constraintsInd, constraintOrdering utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { var t fr.Element for i := uint64(start); i < uint64(end); i++ { - t.Sub(&evaluationBlindedZDomainBigBitReversed[i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain - h[i].Mul(&startsAtOne[i], &alpha).Mul(&h[i], &t). - Add(&h[i], &constraintOrdering[i]). - Mul(&h[i], &alpha). - Add(&h[i], &constraintsInd[i]) - - // evaluate qlL+qrR+qmL.R+qoO+k + α.(zu*g1*g2*g3*l-z*f1*f2*f3*l)/Z - // on the big domain (coset) - irev := bits.Reverse64(i) >> nn - h[i].Mul(&h[i], &evaluationXnMinusOneInverse[irev%ratio]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) @@ -832,7 +821,7 @@ func computeQuotientCanonical(pk *ProvingKey, constraintsInd, constraintOrdering h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] - return h1, h2, h3 + return h1, h2, h3 // CORRECT } @@ -845,7 +834,7 @@ func computeQuotientCanonical(pk *ProvingKey, constraintsInd, constraintOrdering // The Linearized polynomial is: // // α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + α*( (l(ζ)+β*s1(\zeta)+γ)*(r(ζ)+β*s2(\zeta)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { @@ -853,19 +842,22 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var rl fr.Element rl.Mul(&rZeta, &lZeta) + fmt.Printf("Z(μζ) = %s\n", zu.String()) + // second part: // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = eval(pk.S1Canonical, zeta) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - tmp := eval(pk.S2Canonical, zeta) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 s1.Mul(&s1, &tmp).Mul(&s1, &zu) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*Z(μζ) + fmt.Printf("l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*Z(μζ) = %s\n", s1.String()) var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -887,9 +879,8 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (1/n)*(ζⁿ⁻¹)/(ζ-1) + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha) // α²*L₁(ζ) From 5c04c285c2bd3c069964246f14eb69fd15f8596e Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Tue, 8 Feb 2022 23:03:22 +0100 Subject: [PATCH 11/37] fix: linearized polynomial OK --- internal/backend/bn254/plonk/prove.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 9721acbeef..e7f7afc8e8 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -842,8 +842,6 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, var rl fr.Element rl.Mul(&rZeta, &lZeta) - fmt.Printf("Z(μζ) = %s\n", zu.String()) - // second part: // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element @@ -856,8 +854,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*Z(μζ) - fmt.Printf("l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*Z(μζ) = %s\n", s1.String()) + s1.Mul(&s1, &tmp).Mul(&s1, &zu) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*Z(μζ) // CORRECT var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -868,7 +865,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element @@ -882,7 +879,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, Inverse(&den) lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha) // α²*L₁(ζ) + Mul(&lagrangeZeta, &alpha) // α²*L₁(ζ) // CORRECT linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) From 793934028c5722f82f1b9d624fdba65c46541c72 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Wed, 9 Feb 2022 00:01:35 +0100 Subject: [PATCH 12/37] fix: missing beta in linearized polynomial --- internal/backend/bn254/plonk/prove.go | 13 ++++++-- internal/backend/bn254/plonk/verify.go | 45 ++++++++++++++------------ 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index e7f7afc8e8..23801a9749 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -879,23 +879,32 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, Inverse(&den) lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &alpha) // α²*L₁(ζ) // CORRECT + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + if i < len(pk.S3Canonical) { - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) + + t0.Mul(&pk.S3Canonical[i], &s1). + Mul(&t0, &beta) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*\beta*s3(X) + linPol[i].Add(&linPol[i], &t0) } linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) diff --git a/internal/backend/bn254/plonk/verify.go b/internal/backend/bn254/plonk/verify.go index 26e5ab5c6a..e27a702a1f 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/internal/backend/bn254/plonk/verify.go @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by gnark DO NOT EDIT - package plonk import ( "crypto/sha256" "errors" + "fmt" "math/big" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -30,7 +29,7 @@ import ( bn254witness "github.com/consensys/gnark/internal/backend/bn254/witness" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/fiat-shamir" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" ) var ( @@ -84,7 +83,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) xiLi.Mul(&lagrange, &publicWitness[i]) pi.Add(&pi, &xiLi) - // use Lᵢ₊₁ = w*L_i*(X-z^i)/(X-zⁱ⁺¹) + // use Lᵢ₊₁ = w*L_i*(X-z^{i})/(X-zⁱ⁺¹) lagrange.Mul(&lagrange, &vk.Generator). Mul(&lagrange, &den) acc.Mul(&acc, &vk.Generator) @@ -92,27 +91,34 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) lagrange.Div(&lagrange, &den) } - // linearizedpolynomial + pi(ζ) + α*(Z(μζ))*(l(ζ)+s1(ζ)+γ)*(r(ζ)+s2(ζ)+γ)*(o(ζ)+γ) - α²*L₁(ζ) + // linearizedpolynomial + pi(ζ) + α*(Z(μζ))*(l(ζ)+\beta*s1(ζ)+γ)*(r(ζ)+\beta*s2(ζ)+γ)*(o(ζ)+γ) - α²*L₁(ζ) var _s1, _s2, _o, alphaSquareLagrange fr.Element zu := proof.ZShiftedOpening.ClaimedValue - claimedQuotient := proof.BatchedProof.ClaimedValues[0] - linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] - l := proof.BatchedProof.ClaimedValues[2] - r := proof.BatchedProof.ClaimedValues[3] - o := proof.BatchedProof.ClaimedValues[4] - s1 := proof.BatchedProof.ClaimedValues[5] - s2 := proof.BatchedProof.ClaimedValues[6] + claimedQuotient := proof.BatchedProof.ClaimedValues[0] // CORRECT + linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] // CORRECT + l := proof.BatchedProof.ClaimedValues[2] // CORRECT + r := proof.BatchedProof.ClaimedValues[3] // CORRECT + o := proof.BatchedProof.ClaimedValues[4] // CORRECT + s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT + s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT + + fmt.Printf("h(zeta) = %s\n", claimedQuotient.String()) + + var beta fr.Element + beta.SetUint64(10) - _s1.Add(&l, &s1).Add(&_s1, &gamma) // (l(ζ)+s1(ζ)+γ) - _s2.Add(&r, &s2).Add(&_s2, &gamma) // (r(ζ)+s2(ζ)+γ) - _o.Add(&o, &gamma) // (o(ζ)+γ) + _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+\beta*s1(ζ)+γ) + _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+\beta*s2(ζ)+γ) + _o.Add(&o, &gamma) // (o(ζ)+γ) _s1.Mul(&_s1, &_s2). Mul(&_s1, &_o). Mul(&_s1, &alpha). - Mul(&_s1, &zu) // α*(Z(μζ))*(l(ζ)+s1(ζ)+γ)*(r(ζ)+s2(ζ)+γ)*(o(ζ)+γ) + Mul(&_s1, &zu) // α*(Z(μζ))*(l(ζ)+\beta*s1(ζ)+γ)*(r(ζ)+\beta*s2(ζ)+γ)*(o(ζ)+γ) + + fmt.Printf("α*(Z(μζ))*(l(ζ)+s1(ζ)+γ)*(r(ζ)+s2(ζ)+γ)*(o(ζ)+γ) = %s\n", _s1.String()) alphaSquareLagrange.Mul(&lagrangeOne, &alpha). Mul(&alphaSquareLagrange, &alpha) // α²*L₁(ζ) @@ -122,7 +128,9 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) Add(&linearizedPolynomialZeta, &_s1). // linearizedpolynomial+pi(zeta)+α*(Z(μζ))*(l(ζ)+s1(ζ)+γ)*(r(ζ)+s2(ζ)+γ)*(o(ζ)+γ) Sub(&linearizedPolynomialZeta, &alphaSquareLagrange) // linearizedpolynomial+pi(zeta)+α*(Z(μζ))*(l(ζ)+s1(ζ)+γ)*(r(ζ)+s2(ζ)+γ)*(o(ζ)+γ)-α²*L₁(ζ) - // Compute H(ζ) using the previous result: H(ζ) = prev_result/(ζⁿ⁻¹) + fmt.Printf("linpolcompleted(zeta) = %s\n", linearizedPolynomialZeta.String()) + + // Compute H(ζ) using the previous result: H(ζ) = prev_result/(ζⁿ-1) var zetaPowerMMinusOne fr.Element zetaPowerMMinusOne.Sub(&zetaPowerM, &one) linearizedPolynomialZeta.Div(&linearizedPolynomialZeta, &zetaPowerMMinusOne) @@ -155,9 +163,6 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) var linearizedPolynomialDigest curve.G1Affine - var beta fr.Element - beta.SetUint64(10) - // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) var t fr.Element _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) From 1890b52fbbc4e3bd9c8acbf70fd07b701328f27b Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Wed, 9 Feb 2022 00:11:53 +0100 Subject: [PATCH 13/37] fix: verifier obtains correct quotient --- internal/backend/bn254/plonk/prove.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 23801a9749..671f53b0a8 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -784,7 +784,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // computes L_{1} (canonical form) startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { - startsAtOne[i].SetOne() + startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) } pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT From c85e67c7e34fd7ee1a7f811d0fb9c88979d97457 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Wed, 9 Feb 2022 01:06:43 +0100 Subject: [PATCH 14/37] fix: fixed verifier --- internal/backend/bn254/plonk/prove.go | 7 +++-- internal/backend/bn254/plonk/verify.go | 37 ++++++++++++-------------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 671f53b0a8..7f94b69fa1 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -843,7 +843,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, rl.Mul(&rZeta, &lZeta) // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*\beta*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { @@ -854,7 +854,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*Z(μζ) // CORRECT + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*\beta*Z(μζ) // CORRECT var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -895,8 +895,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, if i < len(pk.S3Canonical) { - t0.Mul(&pk.S3Canonical[i], &s1). - Mul(&t0, &beta) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*\beta*s3(X) + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*\beta*s3(X) linPol[i].Add(&linPol[i], &t0) } diff --git a/internal/backend/bn254/plonk/verify.go b/internal/backend/bn254/plonk/verify.go index e27a702a1f..2b6b05da5d 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/internal/backend/bn254/plonk/verify.go @@ -140,7 +140,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) return errWrongClaimedQuotient } - // compute the folded commitment to H: Comm(h₁) + ζᵐ*Comm(h₂) + ζ²ᵐ*Comm(h₃) + // compute the folded commitment to H: Comm(h₁) + ζ^{m+2}*Comm(h₂) + ζ^{2(m+2)}*Comm(h₃) mPlusTwo := big.NewInt(int64(vk.Size) + 2) var zetaMPlusTwo fr.Element zetaMPlusTwo.Exp(zeta, mPlusTwo) @@ -163,27 +163,24 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) var linearizedPolynomialDigest curve.G1Affine - // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) - var t fr.Element - _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) - t.Mul(&s2, &beta).Add(&t, &t).Add(&t, &gamma) - _s1.Mul(&_s1, &t). - Mul(&_s1, &zu). - Mul(&_s1, &alpha) // α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ) - - var cosetShift, cosetShiftSquare fr.Element - cosetShift.Set(&vk.CosetShift) - cosetShiftSquare.Square(&cosetShift) - _s2.Mul(&beta, &zeta).Add(&_s2, &l).Add(&_s2, &gamma) // (l(ζ)+β*ζ+γ) - t.Mul(&zeta, &cosetShift).Mul(&t, &zeta).Add(&t, &r).Add(&t, &gamma) // (r(ζ)+β*u*ζ+γ) - _s2.Mul(&_s2, &t) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) - t.Mul(&t, &cosetShiftSquare).Mul(&t, &zeta).Add(&t, &o).Add(&t, &gamma) // (o(ζ)+β*u²*ζ+γ) - _s2.Mul(&_s2, &t) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - _s2.Mul(&_s2, &alpha) - _s2.Sub(&alphaSquareLagrange, &_s2) + // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*\beta*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) + + // CORRECT + var u, v, w, cosetsquare fr.Element + u.Mul(&zu, &beta) + v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) + w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) + _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // \alpha*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*\beta + + // CORRECT + cosetsquare.Square(&vk.CosetShift) + u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) + v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*\mu*ζ+γ) + w.Mul(&beta, &zeta).Mul(&w, &cosetsquare).Add(&w, &o).Add(&w, &gamma) // (o(ζ)+β*\mu^{2}*ζ+γ) + _s2.Mul(&u, &v).Mul(&_s2, &w).Neg(&_s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // note since third part = α²*L₁(ζ)*Z - // we add alphaSquareLagrange to _s2 + _s2.Mul(&_s2, &alpha).Add(&_s2, &alphaSquareLagrange) // -\alpha*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + α²*L₁(ζ) points := []curve.G1Affine{ vk.Ql, vk.Qr, vk.Qm, vk.Qo, vk.Qk, // first part From cd0abd9f91261defd81805c45178d3ba8aaa8420 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Wed, 9 Feb 2022 09:52:03 +0100 Subject: [PATCH 15/37] style: removed debug comments --- internal/backend/bn254/plonk/prove.go | 56 ++++++++------------------ internal/backend/bn254/plonk/setup.go | 27 ------------- internal/backend/bn254/plonk/verify.go | 29 +++++-------- 3 files changed, 27 insertions(+), 85 deletions(-) diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 7f94b69fa1..7b6552f273 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -16,7 +16,6 @@ package plonk import ( "crypto/sha256" - "fmt" "math/big" "math/bits" "runtime" @@ -61,11 +60,6 @@ type Proof struct { // Prove from the public data func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, opt backend.ProverConfig) (*Proof, error) { - // printPoly("cql", pk.Ql) - // printPoly("cqr", pk.Qr) - // printPoly("cqm", pk.Qm) - // printPoly("cqo", pk.Qo) - // pick a hash function that will be used to derive the challenges hFunc := sha256.New() @@ -106,10 +100,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, return nil, err } - // printPoly("cl", blindedLCanonical) - // printPoly("cr", blindedRCanonical) - // printPoly("co", blindedOCanonical) - // compute kzg commitments of bcl, bcr and bco if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err @@ -214,7 +204,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, chConstraintOrdering <- err return } - printPoly("z", blindedZCanonical) + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. @@ -253,11 +243,6 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, return nil, err } - fmt.Printf("beta = Fr(%s)\n", beta.String()) - fmt.Printf("gamma = Fr(%s)\n", gamma.String()) - fmt.Printf("alpha = Fr(%s)\n", alpha.String()) - fmt.Printf("zeta = Fr(%s)\n", zeta.String()) - // compute evaluations of (blinded version of) l, r, o, z at zeta var blzeta, brzeta, bozeta fr.Element var wgZetaEvals sync.WaitGroup @@ -320,7 +305,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, close(chLpoly) }() - // foldedHDigest = Comm(h1) + \zeta^{m+2}*Comm(h2) + \zeta^{2(m+2)}*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element @@ -328,18 +313,18 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // \zeta^{m+2}*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // \zeta^{2(m+2)}*Comm(h3) + \zeta^{m+2}*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // \zeta^{2(m+2)}*Comm(h3) + \zeta^{m+2}*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + \zeta*h2 + \zeta^{2}*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // \zeta^{m+2}*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // \zeta^{m+2)*h3+h2 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // \zeta^{2(m+2)}*h3+h2*\zeta^{m+2} - foldedH[i].Add(&foldedH[i], &h1[i]) // \zeta^{2(m+2)*h3+\zeta^{m+2}*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -673,15 +658,6 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Elemen nbElmts := int(pk.DomainBig.Cardinality) - // printVector("z", z, true) // CORRECT - // printVector("l", l, true) // CORRECT - // printVector("r", r, true) // CORRECT - // printVector("o", o, true) // CORRECT - - // printVector("s1", pk.EvaluationPermutationBigDomainBitReversed[:nbElmts], true) - // printVector("s2", pk.EvaluationPermutationBigDomainBitReversed[nbElmts:2*nbElmts], true) - // printVector("s3", pk.EvaluationPermutationBigDomainBitReversed[2*nbElmts:], true) - // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) // on the big domain (coset). res := make([]fr.Element, pk.DomainBig.Cardinality) @@ -781,14 +757,14 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // CORRECT - // computes L_{1} (canonical form) + // computes L₁ (canonical form) startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) } pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT - // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L_{1}(X)(Z(X)-1) + // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain nn := uint64(64 - bits.TrailingZeros64(pk.DomainBig.Cardinality)) @@ -834,7 +810,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // The Linearized polynomial is: // // α²*L₁(ζ)*Z(X) -// + α*( (l(ζ)+β*s1(\zeta)+γ)*(r(ζ)+β*s2(\zeta)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) // + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { @@ -843,7 +819,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, rl.Mul(&rZeta, &lZeta) // second part: - // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*\beta*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { @@ -854,7 +830,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*\beta*Z(μζ) // CORRECT + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -895,7 +871,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, if i < len(pk.S3Canonical) { - t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*\beta*s3(X) + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) linPol[i].Add(&linPol[i], &t0) } diff --git a/internal/backend/bn254/plonk/setup.go b/internal/backend/bn254/plonk/setup.go index cc1b53bf78..1e9cefa08c 100644 --- a/internal/backend/bn254/plonk/setup.go +++ b/internal/backend/bn254/plonk/setup.go @@ -16,7 +16,6 @@ package plonk import ( "errors" - "fmt" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" @@ -316,32 +315,6 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { return res } -func printVector(name string, vector []fr.Element, reverse ...bool) { - - a := make([]fr.Element, len(vector)) - copy(a, vector) - if len(reverse) > 0 { - fft.BitReverse(a) - } - - fmt.Printf("%s = [", name) - for i := 0; i < len(a); i++ { - fmt.Printf("%s, ", a[i].String()) - } - fmt.Println("]") -} - -func printPoly(name string, vector []fr.Element) { - fmt.Printf("%s = ", name) - for i := 0; i < len(vector); i++ { - fmt.Printf("%s*x**%d", vector[i].String(), i) - if i < len(vector)-1 { - fmt.Printf(" + ") - } - } - fmt.Println("") -} - // InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS // // This should be used after deserializing a ProvingKey diff --git a/internal/backend/bn254/plonk/verify.go b/internal/backend/bn254/plonk/verify.go index 2b6b05da5d..eb38ff56c1 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/internal/backend/bn254/plonk/verify.go @@ -17,7 +17,6 @@ package plonk import ( "crypto/sha256" "errors" - "fmt" "math/big" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -83,7 +82,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) xiLi.Mul(&lagrange, &publicWitness[i]) pi.Add(&pi, &xiLi) - // use Lᵢ₊₁ = w*L_i*(X-z^{i})/(X-zⁱ⁺¹) + // use Lᵢ₊₁ = w*L_i*(X-zⁱ)/(X-zⁱ⁺¹) lagrange.Mul(&lagrange, &vk.Generator). Mul(&lagrange, &den) acc.Mul(&acc, &vk.Generator) @@ -91,7 +90,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) lagrange.Div(&lagrange, &den) } - // linearizedpolynomial + pi(ζ) + α*(Z(μζ))*(l(ζ)+\beta*s1(ζ)+γ)*(r(ζ)+\beta*s2(ζ)+γ)*(o(ζ)+γ) - α²*L₁(ζ) + // linearizedpolynomial + pi(ζ) + α*(Z(μζ))*(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*(o(ζ)+γ) - α²*L₁(ζ) var _s1, _s2, _o, alphaSquareLagrange fr.Element zu := proof.ZShiftedOpening.ClaimedValue @@ -104,21 +103,17 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT - fmt.Printf("h(zeta) = %s\n", claimedQuotient.String()) - var beta fr.Element beta.SetUint64(10) - _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+\beta*s1(ζ)+γ) - _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+\beta*s2(ζ)+γ) + _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) + _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+β*s2(ζ)+γ) _o.Add(&o, &gamma) // (o(ζ)+γ) _s1.Mul(&_s1, &_s2). Mul(&_s1, &_o). Mul(&_s1, &alpha). - Mul(&_s1, &zu) // α*(Z(μζ))*(l(ζ)+\beta*s1(ζ)+γ)*(r(ζ)+\beta*s2(ζ)+γ)*(o(ζ)+γ) - - fmt.Printf("α*(Z(μζ))*(l(ζ)+s1(ζ)+γ)*(r(ζ)+s2(ζ)+γ)*(o(ζ)+γ) = %s\n", _s1.String()) + Mul(&_s1, &zu) // α*(Z(μζ))*(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*(o(ζ)+γ) alphaSquareLagrange.Mul(&lagrangeOne, &alpha). Mul(&alphaSquareLagrange, &alpha) // α²*L₁(ζ) @@ -128,8 +123,6 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) Add(&linearizedPolynomialZeta, &_s1). // linearizedpolynomial+pi(zeta)+α*(Z(μζ))*(l(ζ)+s1(ζ)+γ)*(r(ζ)+s2(ζ)+γ)*(o(ζ)+γ) Sub(&linearizedPolynomialZeta, &alphaSquareLagrange) // linearizedpolynomial+pi(zeta)+α*(Z(μζ))*(l(ζ)+s1(ζ)+γ)*(r(ζ)+s2(ζ)+γ)*(o(ζ)+γ)-α²*L₁(ζ) - fmt.Printf("linpolcompleted(zeta) = %s\n", linearizedPolynomialZeta.String()) - // Compute H(ζ) using the previous result: H(ζ) = prev_result/(ζⁿ-1) var zetaPowerMMinusOne fr.Element zetaPowerMMinusOne.Sub(&zetaPowerM, &one) @@ -140,7 +133,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) return errWrongClaimedQuotient } - // compute the folded commitment to H: Comm(h₁) + ζ^{m+2}*Comm(h₂) + ζ^{2(m+2)}*Comm(h₃) + // compute the folded commitment to H: Comm(h₁) + ζᵐ⁺²*Comm(h₂) + ζ²⁽ᵐ⁺²⁾*Comm(h₃) mPlusTwo := big.NewInt(int64(vk.Size) + 2) var zetaMPlusTwo fr.Element zetaMPlusTwo.Exp(zeta, mPlusTwo) @@ -163,24 +156,24 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) var linearizedPolynomialDigest curve.G1Affine - // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*\beta*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) + // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) // CORRECT var u, v, w, cosetsquare fr.Element u.Mul(&zu, &beta) v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) - _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // \alpha*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*\beta + _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β // CORRECT cosetsquare.Square(&vk.CosetShift) u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) - v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*\mu*ζ+γ) - w.Mul(&beta, &zeta).Mul(&w, &cosetsquare).Add(&w, &o).Add(&w, &gamma) // (o(ζ)+β*\mu^{2}*ζ+γ) + v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*μ*ζ+γ) + w.Mul(&beta, &zeta).Mul(&w, &cosetsquare).Add(&w, &o).Add(&w, &gamma) // (o(ζ)+β*μ²*ζ+γ) _s2.Mul(&u, &v).Mul(&_s2, &w).Neg(&_s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // note since third part = α²*L₁(ζ)*Z - _s2.Mul(&_s2, &alpha).Add(&_s2, &alphaSquareLagrange) // -\alpha*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + α²*L₁(ζ) + _s2.Mul(&_s2, &alpha).Add(&_s2, &alphaSquareLagrange) // -α*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + α²*L₁(ζ) points := []curve.G1Affine{ vk.Ql, vk.Qr, vk.Qm, vk.Qo, vk.Qk, // first part From 50459e8e584827780a4ec0b496d8900658e34c14 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Wed, 9 Feb 2022 10:14:00 +0100 Subject: [PATCH 16/37] feat: code gen for plonk --- examples/rollup/account.go | 4 +- examples/rollup/circuit.go | 2 +- examples/rollup/operator.go | 6 +- examples/rollup/transfer.go | 4 +- frontend/api.go | 2 +- frontend/compile.go | 4 +- frontend/cs/plonk/assertions.go | 6 +- frontend/cs/plonk/sparse_r1cs.go | 4 +- frontend/cs/r1cs/assertions.go | 8 +- frontend/cs/r1cs/r1cs.go | 4 +- internal/backend/bls12-377/plonk/marshal.go | 36 +- .../backend/bls12-377/plonk/marshal_test.go | 20 +- internal/backend/bls12-377/plonk/prove.go | 623 +++++++++-------- internal/backend/bls12-377/plonk/setup.go | 157 +++-- internal/backend/bls12-377/plonk/verify.go | 93 +-- internal/backend/bls12-381/plonk/marshal.go | 36 +- .../backend/bls12-381/plonk/marshal_test.go | 20 +- internal/backend/bls12-381/plonk/prove.go | 623 +++++++++-------- internal/backend/bls12-381/plonk/setup.go | 157 +++-- internal/backend/bls12-381/plonk/verify.go | 93 +-- internal/backend/bls24-315/plonk/marshal.go | 36 +- .../backend/bls24-315/plonk/marshal_test.go | 20 +- internal/backend/bls24-315/plonk/prove.go | 623 +++++++++-------- internal/backend/bls24-315/plonk/setup.go | 157 +++-- internal/backend/bls24-315/plonk/verify.go | 93 +-- internal/backend/bn254/plonk/marshal.go | 5 +- internal/backend/bn254/plonk/prove.go | 14 +- internal/backend/bn254/plonk/setup.go | 5 +- internal/backend/bn254/plonk/verify.go | 4 +- internal/backend/bw6-633/plonk/marshal.go | 36 +- .../backend/bw6-633/plonk/marshal_test.go | 20 +- internal/backend/bw6-633/plonk/prove.go | 623 +++++++++-------- internal/backend/bw6-633/plonk/setup.go | 157 +++-- internal/backend/bw6-633/plonk/verify.go | 93 +-- internal/backend/bw6-761/cs/to_delete.go | 81 +++ internal/backend/bw6-761/plonk/marshal.go | 36 +- .../backend/bw6-761/plonk/marshal_test.go | 20 +- internal/backend/bw6-761/plonk/prove.go | 623 +++++++++-------- internal/backend/bw6-761/plonk/setup.go | 157 +++-- internal/backend/bw6-761/plonk/verify.go | 93 +-- .../zkpschemes/plonk/plonk.marshal.go.tmpl | 97 ++- .../zkpschemes/plonk/plonk.prove.go.tmpl | 651 +++++++++--------- .../zkpschemes/plonk/plonk.setup.go.tmpl | 159 ++--- .../zkpschemes/plonk/plonk.verify.go.tmpl | 97 +-- .../zkpschemes/plonk/tests/marshal.go.tmpl | 180 +++-- internal/utils/circuit.go | 2 +- std/algebra/fields_bls12377/e12.go | 52 +- std/algebra/fields_bls12377/e6.go | 2 +- std/algebra/fields_bls24315/e12.go | 2 +- std/algebra/fields_bls24315/e24.go | 48 +- std/algebra/sw_bls12377/g1.go | 2 +- std/algebra/sw_bls12377/g2.go | 2 +- std/algebra/sw_bls24315/g1.go | 2 +- std/algebra/sw_bls24315/g2.go | 2 +- .../twistededwards/bandersnatch/point.go | 2 +- std/algebra/twistededwards/point.go | 2 +- std/fiat-shamir/transcript.go | 4 +- test/kzg_srs.go | 2 +- 58 files changed, 3161 insertions(+), 2945 deletions(-) create mode 100644 internal/backend/bw6-761/cs/to_delete.go diff --git a/examples/rollup/account.go b/examples/rollup/account.go index f381a968fd..ad6b92b48c 100644 --- a/examples/rollup/account.go +++ b/examples/rollup/account.go @@ -25,7 +25,7 @@ import ( var ( // SizeAccount byte size of a serialized account (5*32bytes) - // index || nonce || balance || pubkeyX || pubkeyY, each chunk is 32 bytes + // index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY, each chunk is 32 bytes SizeAccount = 160 ) @@ -48,7 +48,7 @@ func (ac *Account) Reset() { // Serialize serializes the account as a concatenation of 5 chunks of 256 bits // one chunk per field (pubKey has 2 chunks), except index and nonce that are concatenated in a single 256 bits chunk -// index || nonce || balance || pubkeyX || pubkeyY, each chunk is 256 bits +// index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY, each chunk is 256 bits func (ac *Account) Serialize() []byte { //var buffer bytes.Buffer diff --git a/examples/rollup/circuit.go b/examples/rollup/circuit.go index 2907f7fc19..83f554c2b9 100644 --- a/examples/rollup/circuit.go +++ b/examples/rollup/circuit.go @@ -159,7 +159,7 @@ func (circuit *Circuit) Define(api frontend.API) error { // verifySignatureTransfer ensures that the signature of the transfer is valid func verifyTransferSignature(api frontend.API, t TransferConstraints, hFunc mimc.MiMC) error { - // the signature is on h(nonce || amount || senderpubKey (x&y) || receiverPubkey(x&y)) + // the signature is on h(nonce ∥ amount ∥ senderpubKey (x&y) ∥ receiverPubkey(x&y)) hFunc.Write(t.Nonce, t.Amount, t.SenderPubKey.A.X, t.SenderPubKey.A.Y, t.ReceiverPubKey.A.X, t.ReceiverPubKey.A.Y) htransfer := hFunc.Sum() diff --git a/examples/rollup/operator.go b/examples/rollup/operator.go index a02dc0c07d..56c2304182 100644 --- a/examples/rollup/operator.go +++ b/examples/rollup/operator.go @@ -46,8 +46,8 @@ func NewQueue(batchSize int) Queue { // Operator represents a rollup operator type Operator struct { - State []byte // list of accounts: index || nonce || balance || pubkeyX || pubkeyY, each chunk is 256 bits - HashState []byte // Hashed version of the state, each chunk is 256bits: ... || H(index || nonce || balance || pubkeyX || pubkeyY)) || ... + State []byte // list of accounts: index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY, each chunk is 256 bits + HashState []byte // Hashed version of the state, each chunk is 256bits: ... ∥ H(index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY)) ∥ ... AccountMap map[string]uint64 // hashmap of all available accounts (the key is the account.pubkey.X), the value is the index of the account in the state nbAccounts int // number of accounts managed by this operator h hash.Hash // hash function used to build the Merkle Tree @@ -178,7 +178,7 @@ func (o *Operator) updateState(t Transfer, numTransfer int) error { o.witnesses.Transfers[numTransfer].Signature.S = t.signature.S[:] // verifying the signature. The msg is the hash (o.h) of the transfer - // nonce || amount || senderpubKey(x&y) || receiverPubkey(x&y) + // nonce ∥ amount ∥ senderpubKey(x&y) ∥ receiverPubkey(x&y) resSig, err := t.Verify(o.h) if err != nil { return err diff --git a/examples/rollup/transfer.go b/examples/rollup/transfer.go index e12776cbf8..7507b357d3 100644 --- a/examples/rollup/transfer.go +++ b/examples/rollup/transfer.go @@ -52,7 +52,7 @@ func (t *Transfer) Sign(priv eddsa.PrivateKey, h hash.Hash) (eddsa.Signature, er //var frNonce, msg fr.Element var frNonce fr.Element - // serializing transfer. The signature is on h(nonce || amount || senderpubKey (x&y) || receiverPubkey(x&y)) + // serializing transfer. The signature is on h(nonce ∥ amount ∥ senderpubKey (x&y) ∥ receiverPubkey(x&y)) // (each pubkey consist of 2 chunks of 256bits) frNonce.SetUint64(t.nonce) b := frNonce.Bytes() @@ -91,7 +91,7 @@ func (t *Transfer) Verify(h hash.Hash) (bool, error) { var frNonce fr.Element // serializing transfer. The msg to sign is - // nonce || amount || senderpubKey(x&y) || receiverPubkey(x&y) + // nonce ∥ amount ∥ senderpubKey(x&y) ∥ receiverPubkey(x&y) // (each pubkey consist of 2 chunks of 256bits) frNonce.SetUint64(t.nonce) b := frNonce.Bytes() diff --git a/frontend/api.go b/frontend/api.go index bc34ea6202..3e7fde87c0 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -101,7 +101,7 @@ type API interface { // AssertIsDifferent fails if i1 == i2 AssertIsDifferent(i1, i2 Variable) - // AssertIsBoolean fails if v != 0 || v != 1 + // AssertIsBoolean fails if v != 0 ∥ v != 1 AssertIsBoolean(i1 Variable) // AssertIsLessOrEqual fails if v > bound diff --git a/frontend/compile.go b/frontend/compile.go index 0fff033589..4565063d0f 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -42,8 +42,8 @@ type NewBuilder func(ecc.ID) (Builder, error) // from the declarative code // // 3. finally, it converts that to a ConstraintSystem. -// if zkpID == backend.GROTH16 --> R1CS -// if zkpID == backend.PLONK --> SparseR1CS +// if zkpID == backend.GROTH16 → R1CS +// if zkpID == backend.PLONK → SparseR1CS // // initialCapacity is an optional parameter that reserves memory in slices // it should be set to the estimated number of constraints in the circuit, if known. diff --git a/frontend/cs/plonk/assertions.go b/frontend/cs/plonk/assertions.go index 5f9cebdd33..efe58e8627 100644 --- a/frontend/cs/plonk/assertions.go +++ b/frontend/cs/plonk/assertions.go @@ -63,7 +63,7 @@ func (system *sparseR1CS) AssertIsDifferent(i1, i2 frontend.Variable) { system.Inverse(system.Sub(i1, i2)) } -// AssertIsBoolean fails if v != 0 || v != 1 +// AssertIsBoolean fails if v != 0 ∥ v != 1 func (system *sparseR1CS) AssertIsBoolean(i1 frontend.Variable) { if system.IsConstant(i1) { c := utils.FromInterface(i1) @@ -125,7 +125,7 @@ func (system *sparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term l := system.Sub(1, t, aBits[i]) // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 - // --> this is a boolean constraint + // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too system.markBoolean(aBits[i].(compiled.Term)) // this does not create a constraint @@ -172,7 +172,7 @@ func (system *sparseR1CS) mustBeLessOrEqCst(a compiled.Term, bound big.Int) { } p := make([]frontend.Variable, nbBits+1) - // p[i] == 1 --> a[j] == c[j] for all j >= i + // p[i] == 1 → a[j] == c[j] for all j ⩾ i p[nbBits] = 1 for i := nbBits - 1; i >= t; i-- { diff --git a/frontend/cs/plonk/sparse_r1cs.go b/frontend/cs/plonk/sparse_r1cs.go index 959d28c55d..2e4beac548 100644 --- a/frontend/cs/plonk/sparse_r1cs.go +++ b/frontend/cs/plonk/sparse_r1cs.go @@ -130,7 +130,7 @@ func (system *sparseR1CS) NewSecretVariable(name string) frontend.Variable { // reduces redundancy in linear expression // It factorizes Variable that appears multiple times with != coeff Ids -// To ensure the determinism in the compile process, Variables are stored as public||secret||internal||unset +// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset // for each visibility, the Variables are sorted from lowest ID to highest ID func (system *sparseR1CS) reduce(l compiled.LinearExpression) compiled.LinearExpression { @@ -268,7 +268,7 @@ func (system *sparseR1CS) CheckVariables() error { sbb.WriteString(strconv.Itoa(cptHints)) sbb.WriteString(" unconstrained hints") sbb.WriteByte('\n') - // TODO we may add more debug info here --> idea, in NewHint, take the debug stack, and store in the hint map some + // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some // debugInfo to find where a hint was declared (and not constrained) } return errors.New(sbb.String()) diff --git a/frontend/cs/r1cs/assertions.go b/frontend/cs/r1cs/assertions.go index 22b3873a42..aba87440c0 100644 --- a/frontend/cs/r1cs/assertions.go +++ b/frontend/cs/r1cs/assertions.go @@ -41,7 +41,7 @@ func (system *r1CS) AssertIsDifferent(i1, i2 frontend.Variable) { system.Inverse(system.Sub(i1, i2)) } -// AssertIsBoolean adds an assertion in the constraint system (v == 0 || v == 1) +// AssertIsBoolean adds an assertion in the constraint system (v == 0 ∥ v == 1) func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { vars, _ := system.toVariables(i1) @@ -69,7 +69,7 @@ func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { system.addConstraint(newR1C(v, _v, o), debug) } -// AssertIsLessOrEqual adds assertion in constraint system (v <= bound) +// AssertIsLessOrEqual adds assertion in constraint system (v ⩽ bound) // // bound can be a constant or a Variable // @@ -120,7 +120,7 @@ func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { l = system.Sub(l, t, aBits[i]) // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 - // --> this is a boolean constraint + // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too system.markBoolean(aBits[i].(compiled.Variable)) // this does not create a constraint @@ -158,7 +158,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { } p := make([]frontend.Variable, nbBits+1) - // p[i] == 1 --> a[j] == c[j] for all j >= i + // p[i] == 1 → a[j] == c[j] for all j ⩾ i p[nbBits] = system.constant(1) for i := nbBits - 1; i >= t; i-- { diff --git a/frontend/cs/r1cs/r1cs.go b/frontend/cs/r1cs/r1cs.go index def4e35901..eb22b890b8 100644 --- a/frontend/cs/r1cs/r1cs.go +++ b/frontend/cs/r1cs/r1cs.go @@ -146,7 +146,7 @@ func (system *r1CS) one() compiled.Variable { // reduces redundancy in linear expression // It factorizes Variable that appears multiple times with != coeff Ids -// To ensure the determinism in the compile process, Variables are stored as public||secret||internal||unset +// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset // for each visibility, the Variables are sorted from lowest ID to highest ID func (system *r1CS) reduce(l compiled.Variable) compiled.Variable { // ensure our linear expression is sorted, by visibility and by Variable ID @@ -311,7 +311,7 @@ func (system *r1CS) CheckVariables() error { sbb.WriteString(strconv.Itoa(cptHints)) sbb.WriteString(" unconstrained hints") sbb.WriteByte('\n') - // TODO we may add more debug info here --> idea, in NewHint, take the debug stack, and store in the hint map some + // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some // debugInfo to find where a hint was declared (and not constrained) } return errors.New(sbb.String()) diff --git a/internal/backend/bls12-377/plonk/marshal.go b/internal/backend/bls12-377/plonk/marshal.go index 411a2be9e7..54de931064 100644 --- a/internal/backend/bls12-377/plonk/marshal.go +++ b/internal/backend/bls12-377/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainNum.WriteTo(w) + n2, err := pk.DomainSmall.WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainH.WriteTo(w) + n2, err = pk.DomainBig.WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.DomainSmall.Cardinality) + if len(pk.Permutation) != (3 * int(pk.DomainSmall.Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { ([]fr.Element)(pk.Qo), ([]fr.Element)(pk.CQk), ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.LS1), - ([]fr.Element)(pk.LS2), - ([]fr.Element)(pk.LS3), - ([]fr.Element)(pk.CS1), - ([]fr.Element)(pk.CS2), - ([]fr.Element)(pk.CS3), + ([]fr.Element)(pk.S1Canonical), + ([]fr.Element)(pk.S2Canonical), + ([]fr.Element)(pk.S3Canonical), pk.Permutation, } @@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainNum.ReadFrom(r) + n2, err := pk.DomainSmall.ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainH.ReadFrom(r) + n2, err = pk.DomainBig.ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ @@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { (*[]fr.Element)(&pk.Qo), (*[]fr.Element)(&pk.CQk), (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.LS1), - (*[]fr.Element)(&pk.LS2), - (*[]fr.Element)(&pk.LS3), - (*[]fr.Element)(&pk.CS1), - (*[]fr.Element)(&pk.CS2), - (*[]fr.Element)(&pk.CS3), + (*[]fr.Element)(&pk.S1Canonical), + (*[]fr.Element)(&pk.S2Canonical), + (*[]fr.Element)(&pk.S3Canonical), &pk.Permutation, } @@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.SizeInv, &vk.Generator, vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.SizeInv, &vk.Generator, &vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], diff --git a/internal/backend/bls12-377/plonk/marshal_test.go b/internal/backend/bls12-377/plonk/marshal_test.go index c7179f96cc..0e8e2e90b4 100644 --- a/internal/backend/bls12-377/plonk/marshal_test.go +++ b/internal/backend/bls12-377/plonk/marshal_test.go @@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen @@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42) - pk.DomainH = *fft.NewDomain(4 * 42) - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.DomainSmall = *fft.NewDomain(42) + pk.DomainBig = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 @@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go index 667deea130..1cd205a3a2 100644 --- a/internal/backend/bls12-377/plonk/prove.go +++ b/internal/backend/bls12-377/plonk/prove.go @@ -43,6 +43,7 @@ import ( ) type Proof struct { + // Commitments to the solution vectors LRO [3]kzg.Digest @@ -89,17 +90,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := computeLRO(spr, pk, solution) + evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum) + blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + &pk.DomainSmall) if err != nil { return nil, err } // compute kzg commitments of bcl, bcr and bco - if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -109,14 +114,22 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn return nil, err } + // Fiat Shamir this + var beta fr.Element + beta.SetUint64(10) + // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz polynomial.Polynomial + var blindedZCanonical []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { var err error - bz, err = computeBlindedZ(ll, lr, lo, pk, gamma) + blindedZCanonical, err = computeBlindedZCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + pk, beta, gamma) if err != nil { chZ <- err close(chZ) @@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn // this may add additional arithmetic operations, but with smaller tasks // we ensure that this commitment is well parallelized, without having a "unbalanced task" making // the rest of the code wait too long. - if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { + if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { chZ <- err close(chZ) return @@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn }() // evaluation of the blinded versions of l, r, o and bz - // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial + // on the coset of the big domain + var ( + evaluationBlindedLDomainBigBitReversed []fr.Element + evaluationBlindedRDomainBigBitReversed []fr.Element + evaluationBlindedODomainBigBitReversed []fr.Element + evaluationBlindedZDomainBigBitReversed []fr.Element + ) chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateHDomain(bcl, &pk.DomainH) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) // CORRECT close(chEvalBL) }() go func() { - evalBR = evaluateHDomain(bcr, &pk.DomainH) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) // CORRECT close(chEvalBR) }() go func() { - evalBO = evaluateHDomain(bco, &pk.DomainH) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) // CORRECT close(chEvalBO) }() - var constraintsInd, constraintsOrdering polynomial.Polynomial + var constraintsInd, constraintsOrdering []fr.Element chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) - copy(qk, fullWitness[:spr.NbPublicVariables]) - copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF) - fft.BitReverse(qk) - - // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) - // --> uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,26 +207,36 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { + if err := <-chConstraintOrdering; err != nil { // CORRECT return nil, err } - <-chConstraintInd + + <-chConstraintInd // CORRECT + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,9 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, + blindedZCanonical, &zetaShifted, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +280,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -306,12 +340,12 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn proof.BatchedProof, err = kzg.BatchOpenSinglePoint( []polynomial.Polynomial{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -324,7 +358,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn }, &zeta, hFunc, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +369,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +405,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,13 +431,13 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) @@ -436,40 +479,43 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - totalDegree := rou + bo + // totalDegree := rou + bo - // re-use cp - res := cp[:totalDegree+1] + // // re-use cp + // res := cp[:totalDegree+1] - // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) - for i := uint64(0); i < bo+1; i++ { - if _, err := blindingPoly[i].SetRandom(); err != nil { - return nil, err - } - } + // // random polynomial + // blindingPoly := make([]fr.Element, bo+1) + // for i := uint64(0); i < bo+1; i++ { + // if _, err := blindingPoly[i].SetRandom(); err != nil { + // return nil, err + // } + // } - // blinding - for i := uint64(0); i < bo+1; i++ { - res[i].Sub(&res[i], &blindingPoly[i]) - res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - } + // // blinding + // for i := uint64(0); i < bo+1; i++ { + // res[i].Sub(&res[i], &blindingPoly[i]) + // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + // } + + // return res, nil - return res, nil + // TODO reactivate blinding + return cp, nil } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.DomainSmall.Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +548,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +594,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF) + pk.DomainSmall.FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.DomainSmall.Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,270 +650,251 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on (Z/4mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - // TODO doing an expo per chunk is useless - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.DomainBig.Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.DomainBig.Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] - - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, true) + pk.DomainBig.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.DomainSmall.Cardinality+2] + h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] + h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] - return h1, h2, h3 + return h1, h2, h3 // CORRECT } // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.DomainSmall.Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bls12-377/plonk/setup.go b/internal/backend/bls12-377/plonk/setup.go index 5f5652e3b2..87f8e8e406 100644 --- a/internal/backend/bls12-377/plonk/setup.go +++ b/internal/backend/bls12-377/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" "github.com/consensys/gnark/internal/backend/bls12-377/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,18 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element // Domains used for the FFTs - DomainNum, DomainH fft.Domain + DomainSmall, DomainBig fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +68,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +94,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem) + pk.DomainSmall = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8 * sizeSystem) + pk.DomainBig = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4 * sizeSystem) + pk.DomainBig = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.DomainSmall.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.DomainSmall.Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) - vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +129,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +143,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +158,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +177,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +195,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.DomainSmall.Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +251,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.DomainSmall.Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) - sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bls12-377/plonk/verify.go b/internal/backend/bls12-377/plonk/verify.go index 17ecfd94ee..9432e539ec 100644 --- a/internal/backend/bls12-377/plonk/verify.go +++ b/internal/backend/bls12-377/plonk/verify.go @@ -63,7 +63,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +71,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,26 +207,36 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { + if err := <-chConstraintOrdering; err != nil { // CORRECT return nil, err } - <-chConstraintInd + + <-chConstraintInd // CORRECT + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,9 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, + blindedZCanonical, &zetaShifted, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +280,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -306,12 +340,12 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn proof.BatchedProof, err = kzg.BatchOpenSinglePoint( []polynomial.Polynomial{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -324,7 +358,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn }, &zeta, hFunc, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +369,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +405,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,13 +431,13 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) @@ -436,40 +479,43 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - totalDegree := rou + bo + // totalDegree := rou + bo - // re-use cp - res := cp[:totalDegree+1] + // // re-use cp + // res := cp[:totalDegree+1] - // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) - for i := uint64(0); i < bo+1; i++ { - if _, err := blindingPoly[i].SetRandom(); err != nil { - return nil, err - } - } + // // random polynomial + // blindingPoly := make([]fr.Element, bo+1) + // for i := uint64(0); i < bo+1; i++ { + // if _, err := blindingPoly[i].SetRandom(); err != nil { + // return nil, err + // } + // } - // blinding - for i := uint64(0); i < bo+1; i++ { - res[i].Sub(&res[i], &blindingPoly[i]) - res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - } + // // blinding + // for i := uint64(0); i < bo+1; i++ { + // res[i].Sub(&res[i], &blindingPoly[i]) + // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + // } + + // return res, nil - return res, nil + // TODO reactivate blinding + return cp, nil } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.DomainSmall.Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +548,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +594,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF) + pk.DomainSmall.FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.DomainSmall.Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,270 +650,251 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on (Z/4mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - // TODO doing an expo per chunk is useless - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.DomainBig.Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.DomainBig.Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] - - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, true) + pk.DomainBig.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.DomainSmall.Cardinality+2] + h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] + h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] - return h1, h2, h3 + return h1, h2, h3 // CORRECT } // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.DomainSmall.Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bls12-381/plonk/setup.go b/internal/backend/bls12-381/plonk/setup.go index 10a59218fe..057d2aea04 100644 --- a/internal/backend/bls12-381/plonk/setup.go +++ b/internal/backend/bls12-381/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" "github.com/consensys/gnark/internal/backend/bls12-381/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,18 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element // Domains used for the FFTs - DomainNum, DomainH fft.Domain + DomainSmall, DomainBig fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +68,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +94,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem) + pk.DomainSmall = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8 * sizeSystem) + pk.DomainBig = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4 * sizeSystem) + pk.DomainBig = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.DomainSmall.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.DomainSmall.Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) - vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +129,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +143,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +158,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +177,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +195,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.DomainSmall.Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +251,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.DomainSmall.Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) - sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bls12-381/plonk/verify.go b/internal/backend/bls12-381/plonk/verify.go index 53283ee4a9..060659fc42 100644 --- a/internal/backend/bls12-381/plonk/verify.go +++ b/internal/backend/bls12-381/plonk/verify.go @@ -63,7 +63,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +71,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,26 +207,36 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { + if err := <-chConstraintOrdering; err != nil { // CORRECT return nil, err } - <-chConstraintInd + + <-chConstraintInd // CORRECT + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,9 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, + blindedZCanonical, &zetaShifted, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +280,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -306,12 +340,12 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn proof.BatchedProof, err = kzg.BatchOpenSinglePoint( []polynomial.Polynomial{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -324,7 +358,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn }, &zeta, hFunc, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +369,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +405,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,13 +431,13 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) @@ -436,40 +479,43 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - totalDegree := rou + bo + // totalDegree := rou + bo - // re-use cp - res := cp[:totalDegree+1] + // // re-use cp + // res := cp[:totalDegree+1] - // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) - for i := uint64(0); i < bo+1; i++ { - if _, err := blindingPoly[i].SetRandom(); err != nil { - return nil, err - } - } + // // random polynomial + // blindingPoly := make([]fr.Element, bo+1) + // for i := uint64(0); i < bo+1; i++ { + // if _, err := blindingPoly[i].SetRandom(); err != nil { + // return nil, err + // } + // } - // blinding - for i := uint64(0); i < bo+1; i++ { - res[i].Sub(&res[i], &blindingPoly[i]) - res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - } + // // blinding + // for i := uint64(0); i < bo+1; i++ { + // res[i].Sub(&res[i], &blindingPoly[i]) + // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + // } + + // return res, nil - return res, nil + // TODO reactivate blinding + return cp, nil } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.DomainSmall.Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +548,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +594,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF) + pk.DomainSmall.FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.DomainSmall.Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,270 +650,251 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on (Z/4mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - // TODO doing an expo per chunk is useless - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.DomainBig.Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.DomainBig.Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] - - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, true) + pk.DomainBig.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.DomainSmall.Cardinality+2] + h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] + h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] - return h1, h2, h3 + return h1, h2, h3 // CORRECT } // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.DomainSmall.Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bls24-315/plonk/setup.go b/internal/backend/bls24-315/plonk/setup.go index 184c3a5677..c37c7d7a60 100644 --- a/internal/backend/bls24-315/plonk/setup.go +++ b/internal/backend/bls24-315/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" "github.com/consensys/gnark/internal/backend/bls24-315/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,18 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element // Domains used for the FFTs - DomainNum, DomainH fft.Domain + DomainSmall, DomainBig fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +68,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +94,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem) + pk.DomainSmall = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8 * sizeSystem) + pk.DomainBig = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4 * sizeSystem) + pk.DomainBig = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.DomainSmall.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.DomainSmall.Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) - vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +129,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +143,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +158,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +177,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +195,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.DomainSmall.Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +251,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.DomainSmall.Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) - sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bls24-315/plonk/verify.go b/internal/backend/bls24-315/plonk/verify.go index e8d9fa52b5..9e32a48ca8 100644 --- a/internal/backend/bls24-315/plonk/verify.go +++ b/internal/backend/bls24-315/plonk/verify.go @@ -63,7 +63,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +71,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,26 +207,36 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { + if err := <-chConstraintOrdering; err != nil { // CORRECT return nil, err } - <-chConstraintInd + + <-chConstraintInd // CORRECT + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,9 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, + blindedZCanonical, &zetaShifted, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +280,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -306,12 +340,12 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes proof.BatchedProof, err = kzg.BatchOpenSinglePoint( []polynomial.Polynomial{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -324,7 +358,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes }, &zeta, hFunc, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +369,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +405,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,13 +431,13 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) @@ -436,40 +479,43 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - totalDegree := rou + bo + // totalDegree := rou + bo - // re-use cp - res := cp[:totalDegree+1] + // // re-use cp + // res := cp[:totalDegree+1] - // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) - for i := uint64(0); i < bo+1; i++ { - if _, err := blindingPoly[i].SetRandom(); err != nil { - return nil, err - } - } + // // random polynomial + // blindingPoly := make([]fr.Element, bo+1) + // for i := uint64(0); i < bo+1; i++ { + // if _, err := blindingPoly[i].SetRandom(); err != nil { + // return nil, err + // } + // } - // blinding - for i := uint64(0); i < bo+1; i++ { - res[i].Sub(&res[i], &blindingPoly[i]) - res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - } + // // blinding + // for i := uint64(0); i < bo+1; i++ { + // res[i].Sub(&res[i], &blindingPoly[i]) + // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + // } + + // return res, nil - return res, nil + // TODO reactivate blinding + return cp, nil } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.DomainSmall.Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +548,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +594,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF) + pk.DomainSmall.FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.DomainSmall.Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,270 +650,251 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on (Z/4mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - // TODO doing an expo per chunk is useless - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.DomainBig.Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.DomainBig.Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] - - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, true) + pk.DomainBig.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.DomainSmall.Cardinality+2] + h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] + h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] - return h1, h2, h3 + return h1, h2, h3 // CORRECT } // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.DomainSmall.Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bw6-633/plonk/setup.go b/internal/backend/bw6-633/plonk/setup.go index 383674aff1..6ac99c4cae 100644 --- a/internal/backend/bw6-633/plonk/setup.go +++ b/internal/backend/bw6-633/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" "github.com/consensys/gnark/internal/backend/bw6-633/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,18 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element // Domains used for the FFTs - DomainNum, DomainH fft.Domain + DomainSmall, DomainBig fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +68,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +94,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem) + pk.DomainSmall = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8 * sizeSystem) + pk.DomainBig = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4 * sizeSystem) + pk.DomainBig = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.DomainSmall.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.DomainSmall.Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) - vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +129,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +143,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +158,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +177,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +195,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.DomainSmall.Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +251,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.DomainSmall.Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) - sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bw6-633/plonk/verify.go b/internal/backend/bw6-633/plonk/verify.go index 49e6684f2c..cf7ee19b7b 100644 --- a/internal/backend/bw6-633/plonk/verify.go +++ b/internal/backend/bw6-633/plonk/verify.go @@ -63,7 +63,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +71,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,26 +207,36 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { + if err := <-chConstraintOrdering; err != nil { // CORRECT return nil, err } - <-chConstraintInd + + <-chConstraintInd // CORRECT + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,9 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, + blindedZCanonical, &zetaShifted, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +280,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -306,12 +340,12 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes proof.BatchedProof, err = kzg.BatchOpenSinglePoint( []polynomial.Polynomial{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -324,7 +358,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes }, &zeta, hFunc, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +369,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +405,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,13 +431,13 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) @@ -436,40 +479,43 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - totalDegree := rou + bo + // totalDegree := rou + bo - // re-use cp - res := cp[:totalDegree+1] + // // re-use cp + // res := cp[:totalDegree+1] - // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) - for i := uint64(0); i < bo+1; i++ { - if _, err := blindingPoly[i].SetRandom(); err != nil { - return nil, err - } - } + // // random polynomial + // blindingPoly := make([]fr.Element, bo+1) + // for i := uint64(0); i < bo+1; i++ { + // if _, err := blindingPoly[i].SetRandom(); err != nil { + // return nil, err + // } + // } - // blinding - for i := uint64(0); i < bo+1; i++ { - res[i].Sub(&res[i], &blindingPoly[i]) - res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - } + // // blinding + // for i := uint64(0); i < bo+1; i++ { + // res[i].Sub(&res[i], &blindingPoly[i]) + // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + // } + + // return res, nil - return res, nil + // TODO reactivate blinding + return cp, nil } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.DomainSmall.Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +548,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +594,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF) + pk.DomainSmall.FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.DomainSmall.Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,270 +650,251 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on (Z/4mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - // TODO doing an expo per chunk is useless - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.DomainBig.Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.DomainBig.Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] - - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, true) + pk.DomainBig.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.DomainSmall.Cardinality+2] + h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] + h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] - return h1, h2, h3 + return h1, h2, h3 // CORRECT } // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.DomainSmall.Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bw6-761/plonk/setup.go b/internal/backend/bw6-761/plonk/setup.go index 80bde803bf..36a18a3ba8 100644 --- a/internal/backend/bw6-761/plonk/setup.go +++ b/internal/backend/bw6-761/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" "github.com/consensys/gnark/internal/backend/bw6-761/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,18 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element // Domains used for the FFTs - DomainNum, DomainH fft.Domain + DomainSmall, DomainBig fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +68,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +94,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem) + pk.DomainSmall = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8 * sizeSystem) + pk.DomainBig = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4 * sizeSystem) + pk.DomainBig = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.DomainSmall.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.DomainSmall.Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) - vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +129,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +143,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +158,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +177,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +195,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.DomainSmall.Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +251,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.DomainSmall.Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) - sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bw6-761/plonk/verify.go b/internal/backend/bw6-761/plonk/verify.go index 9cba6efb9a..8937310621 100644 --- a/internal/backend/bw6-761/plonk/verify.go +++ b/internal/backend/bw6-761/plonk/verify.go @@ -63,7 +63,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +71,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -160,26 +184,36 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { + if err := <-chConstraintOrdering; err != nil { // CORRECT return nil, err } - <-chConstraintInd + + <-chConstraintInd // CORRECT + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -194,15 +228,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -210,9 +244,9 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, + blindedZCanonical, &zetaShifted, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -223,53 +257,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -282,12 +317,12 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } proof.BatchedProof, err = kzg.BatchOpenSinglePoint( []polynomial.Polynomial{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -300,7 +335,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } }, &zeta, hFunc, - &pk.DomainH, + &pk.DomainBig, pk.Vk.KZGSRS, ) if err != nil { @@ -311,8 +346,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -338,7 +382,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -364,23 +408,23 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll,lr,lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { - +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { + // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality + 2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality + 2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality + 2) - + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + chDone := make(chan error, 2) go func() { - var err error + var err error copy(cl, ll) domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) - chDone <- err + chDone <- err }() go func() { var err error @@ -388,20 +432,20 @@ func computeBlindedLRO(ll,lr,lo polynomial.Polynomial, domain *fft.Domain) (bcl, domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) - chDone <- err + chDone <- err }() copy(co, lo) domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { - return + return } - err = <-chDone + err = <-chDone if err != nil { return } - err = <-chDone - return + err = <-chDone + return } @@ -412,40 +456,43 @@ func computeBlindedLRO(ll,lr,lo polynomial.Polynomial, domain *fft.Domain) (bcl, // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - totalDegree := rou + bo + // totalDegree := rou + bo - // re-use cp - res := cp[:totalDegree+1] + // // re-use cp + // res := cp[:totalDegree+1] - // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) - for i := uint64(0); i < bo+1; i++ { - if _, err := blindingPoly[i].SetRandom(); err != nil { - return nil, err - } - } + // // random polynomial + // blindingPoly := make([]fr.Element, bo+1) + // for i := uint64(0); i < bo+1; i++ { + // if _, err := blindingPoly[i].SetRandom(); err != nil { + // return nil, err + // } + // } - // blinding - for i := uint64(0); i < bo+1; i++ { - res[i].Sub(&res[i], &blindingPoly[i]) - res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - } + // // blinding + // for i := uint64(0); i < bo+1; i++ { + // res[i].Sub(&res[i], &blindingPoly[i]) + // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + // } + + // return res, nil - return res, nil + // TODO reactivate blinding + return cp, nil } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.DomainSmall.Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -478,47 +525,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -528,43 +571,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF) + pk.DomainSmall.FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.DomainSmall.Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -584,271 +627,251 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on (Z/4mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - // TODO doing an expo per chunk is useless - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FrMultiplicativeGen) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.DomainBig.Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.DomainBig.Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] - - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g } }) return res } - -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, true) + pk.DomainBig.FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.DomainSmall.Cardinality+2] + h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] + h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] - return h1, h2, h3 + return h1, h2, h3 // CORRECT } // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.DomainSmall.Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl index 255930ee03..bf514c2e13 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl @@ -1,6 +1,5 @@ import ( "errors" - {{- template "import_polynomial" . }} {{- template "import_kzg" . }} {{- template "import_fr" . }} {{- template "import_fft" . }} @@ -22,18 +21,18 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element // Domains used for the FFTs - DomainNum, DomainH fft.Domain + DomainSmall, DomainBig fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -51,13 +50,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -78,37 +76,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem) + pk.DomainSmall = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem) + pk.DomainBig = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem) + pk.DomainBig = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.DomainSmall.Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.DomainSmall.Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FrMultiplicativeGen) - vk.Shifter[1].Square(&pk.DomainNum.FrMultiplicativeGen) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -116,7 +111,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -130,11 +125,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) + pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) + pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -145,7 +140,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -164,13 +159,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -182,18 +177,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.DomainSmall.Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -238,60 +233,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.DomainSmall.Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FrMultiplicativeGen) - sID[2*nbElmt].Square(&pk.DomainNum.FrMultiplicativeGen) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) + pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) + pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized @@ -324,4 +329,4 @@ func (vk *VerifyingKey) NbPublicWitness() int { // VerifyingKey returns pk.Vk func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk -} \ No newline at end of file +} diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl index a38645edc3..1deb18c706 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl @@ -42,7 +42,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -50,20 +50,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i if the circuits contains maps or slices +// this is actually a shallow copy → if the circuits contains maps or slices // only the reference is copied. func ShallowClone(circuit frontend.Circuit) frontend.Circuit { diff --git a/std/algebra/fields_bls12377/e12.go b/std/algebra/fields_bls12377/e12.go index d63f963f64..af387437e3 100644 --- a/std/algebra/fields_bls12377/e12.go +++ b/std/algebra/fields_bls12377/e12.go @@ -163,25 +163,25 @@ func (e *E12) CyclotomicSquareCompressed(api frontend.API, x E12, ext Extension) var t [7]E2 - // t0 = g1^2 + // t0 = g1² t[0].Square(api, x.C0.B1, ext) - // t1 = g5^2 + // t1 = g5² t[1].Square(api, x.C1.B2, ext) // t5 = g1 + g5 t[5].Add(api, x.C0.B1, x.C1.B2) - // t2 = (g1 + g5)^2 + // t2 = (g1 + g5)² t[2].Square(api, t[5], ext) - // t3 = g1^2 + g5^2 + // t3 = g1² + g5² t[3].Add(api, t[0], t[1]) // t5 = 2 * g1 * g5 t[5].Sub(api, t[2], t[3]) // t6 = g3 + g2 t[6].Add(api, x.C1.B0, x.C0.B2) - // t3 = (g3 + g2)^2 + // t3 = (g3 + g2)² t[3].Square(api, t[6], ext) - // t2 = g3^2 + // t2 = g3² t[2].Square(api, x.C1.B0, ext) // t6 = 2 * nr * g1 * g5 @@ -192,33 +192,33 @@ func (e *E12) CyclotomicSquareCompressed(api frontend.API, x E12, ext Extension) // z3 = 6 * nr * g1 * g5 + 2 * g3 e.C1.B0.Add(api, t[5], t[6]) - // t4 = nr * g5^2 + // t4 = nr * g5² t[4].MulByNonResidue(api, t[1], ext) - // t5 = nr * g5^2 + g1^2 + // t5 = nr * g5² + g1² t[5].Add(api, t[0], t[4]) - // t6 = nr * g5^2 + g1^2 - g2 + // t6 = nr * g5² + g1² - g2 t[6].Sub(api, t[5], x.C0.B2) - // t1 = g2^2 + // t1 = g2² t[1].Square(api, x.C0.B2, ext) - // t6 = 2 * nr * g5^2 + 2 * g1^2 - 2*g2 + // t6 = 2 * nr * g5² + 2 * g1² - 2*g2 t[6].Double(api, t[6]) - // z2 = 3 * nr * g5^2 + 3 * g1^2 - 2*g2 + // z2 = 3 * nr * g5² + 3 * g1² - 2*g2 e.C0.B2.Add(api, t[6], t[5]) - // t4 = nr * g2^2 + // t4 = nr * g2² t[4].MulByNonResidue(api, t[1], ext) - // t5 = g3^2 + nr * g2^2 + // t5 = g3² + nr * g2² t[5].Add(api, t[2], t[4]) - // t6 = g3^2 + nr * g2^2 - g1 + // t6 = g3² + nr * g2² - g1 t[6].Sub(api, t[5], x.C0.B1) - // t6 = 2 * g3^2 + 2 * nr * g2^2 - 2 * g1 + // t6 = 2 * g3² + 2 * nr * g2² - 2 * g1 t[6].Double(api, t[6]) - // z1 = 3 * g3^2 + 3 * nr * g2^2 - 2 * g1 + // z1 = 3 * g3² + 3 * nr * g2² - 2 * g1 e.C0.B1.Add(api, t[6], t[5]) - // t0 = g2^2 + g3^2 + // t0 = g2² + g3² t[0].Add(api, t[2], t[1]) // t5 = 2 * g3 * g2 t[5].Sub(api, t[3], t[0]) @@ -239,13 +239,13 @@ func (e *E12) Decompress(api frontend.API, x E12, ext Extension) *E12 { var one E2 one.SetOne(api) - // t0 = g1^2 + // t0 = g1² t[0].Square(api, x.C0.B1, ext) - // t1 = 3 * g1^2 - 2 * g2 + // t1 = 3 * g1² - 2 * g2 t[1].Sub(api, t[0], x.C0.B2). Double(api, t[1]). Add(api, t[1], t[0]) - // t0 = E * g5^2 + t1 + // t0 = E * g5² + t1 t[2].Square(api, x.C1.B2, ext) t[0].MulByNonResidue(api, t[2], ext). Add(api, t[0], t[1]) @@ -258,14 +258,14 @@ func (e *E12) Decompress(api frontend.API, x E12, ext Extension) *E12 { // t1 = g2 * g1 t[1].Mul(api, x.C0.B2, x.C0.B1, ext) - // t2 = 2 * g4^2 - 3 * g2 * g1 + // t2 = 2 * g4² - 3 * g2 * g1 t[2].Square(api, e.C1.B1, ext). Sub(api, t[2], t[1]). Double(api, t[2]). Sub(api, t[2], t[1]) // t1 = g3 * g5 t[1].Mul(api, x.C1.B0, x.C1.B2, ext) - // c_0 = E * (2 * g4^2 + g3 * g5 - 3 * g2 * g1) + 1 + // c₀ = E * (2 * g4² + g3 * g5 - 3 * g2 * g1) + 1 t[2].Add(api, t[2], t[1]) e.C0.B0.MulByNonResidue(api, t[2], ext). Add(api, e.C0.B0, one) @@ -296,9 +296,9 @@ func (e *E12) CyclotomicSquare(api frontend.API, x E12, ext Extension) *E12 { t[5].Square(api, x.C0.B1, ext) t[8].Add(api, x.C1.B2, x.C0.B1).Square(api, t[8], ext).Sub(api, t[8], t[4]).Sub(api, t[8], t[5]).MulByNonResidue(api, t[8], ext) // 2*x5*x1*u - t[0].MulByNonResidue(api, t[0], ext).Add(api, t[0], t[1]) // x4^2*u + x0^2 - t[2].MulByNonResidue(api, t[2], ext).Add(api, t[2], t[3]) // x2^2*u + x3^2 - t[4].MulByNonResidue(api, t[4], ext).Add(api, t[4], t[5]) // x5^2*u + x1^2 + t[0].MulByNonResidue(api, t[0], ext).Add(api, t[0], t[1]) // x4²*u + x0² + t[2].MulByNonResidue(api, t[2], ext).Add(api, t[2], t[3]) // x2²*u + x3² + t[4].MulByNonResidue(api, t[4], ext).Add(api, t[4], t[5]) // x5²*u + x1² e.C0.B0.Sub(api, t[0], x.C0.B0).Add(api, e.C0.B0, e.C0.B0).Add(api, e.C0.B0, t[0]) e.C0.B1.Sub(api, t[2], x.C0.B1).Add(api, e.C0.B1, e.C0.B1).Add(api, e.C0.B1, t[2]) diff --git a/std/algebra/fields_bls12377/e6.go b/std/algebra/fields_bls12377/e6.go index 610caec82e..caa171bbbb 100644 --- a/std/algebra/fields_bls12377/e6.go +++ b/std/algebra/fields_bls12377/e6.go @@ -109,7 +109,7 @@ func (e *E6) MulByFp2(api frontend.API, e1 E6, e2 E2, ext Extension) *E6 { return e } -// MulByNonResidue multiplies e by the imaginary elmt of Fp6 (noted a+bV+cV where V**3 in F^2) +// MulByNonResidue multiplies e by the imaginary elmt of Fp6 (noted a+bV+cV where V**3 in F²) func (e *E6) MulByNonResidue(api frontend.API, e1 E6, ext Extension) *E6 { res := E6{} res.B0.MulByNonResidue(api, e1.B2, ext) diff --git a/std/algebra/fields_bls24315/e12.go b/std/algebra/fields_bls24315/e12.go index a0d087eafe..08b8ea75a7 100644 --- a/std/algebra/fields_bls24315/e12.go +++ b/std/algebra/fields_bls24315/e12.go @@ -109,7 +109,7 @@ func (e *E12) MulByFp2(api frontend.API, e1 E12, e2 E4, ext Extension) *E12 { return e } -// MulByNonResidue multiplies e by the imaginary elmt of Fp12 (noted a+bV+cV where V**3 in F^2) +// MulByNonResidue multiplies e by the imaginary elmt of Fp12 (noted a+bV+cV where V**3 in F²) func (e *E12) MulByNonResidue(api frontend.API, e1 E12, ext Extension) *E12 { res := E12{} res.C0.MulByNonResidue(api, e1.C2, ext) diff --git a/std/algebra/fields_bls24315/e24.go b/std/algebra/fields_bls24315/e24.go index 6b12d58c6b..aa13c3ddde 100644 --- a/std/algebra/fields_bls24315/e24.go +++ b/std/algebra/fields_bls24315/e24.go @@ -163,25 +163,25 @@ func (e *E24) Square(api frontend.API, x E24, ext Extension) *E24 { func (e *E24) CyclotomicSquareCompressed(api frontend.API, x E24, ext Extension) *E24 { var t [7]E4 - // t0 = g1^2 + // t0 = g1² t[0].Square(api, x.D0.C1, ext) - // t1 = g5^2 + // t1 = g5² t[1].Square(api, x.D1.C2, ext) // t5 = g1 + g5 t[5].Add(api, x.D0.C1, x.D1.C2) - // t2 = (g1 + g5)^2 + // t2 = (g1 + g5)² t[2].Square(api, t[5], ext) - // t3 = g1^2 + g5^2 + // t3 = g1² + g5² t[3].Add(api, t[0], t[1]) // t5 = 2 * g1 * g5 t[5].Sub(api, t[2], t[3]) // t6 = g3 + g2 t[6].Add(api, x.D1.C0, x.D0.C2) - // t3 = (g3 + g2)^2 + // t3 = (g3 + g2)² t[3].Square(api, t[6], ext) - // t2 = g3^2 + // t2 = g3² t[2].Square(api, x.D1.C0, ext) // t6 = 2 * nr * g1 * g5 @@ -192,33 +192,33 @@ func (e *E24) CyclotomicSquareCompressed(api frontend.API, x E24, ext Extension) // z3 = 6 * nr * g1 * g5 + 2 * g3 e.D1.C0.Add(api, t[5], t[6]) - // t4 = nr * g5^2 + // t4 = nr * g5² t[4].MulByNonResidue(api, t[1], ext) - // t5 = nr * g5^2 + g1^2 + // t5 = nr * g5² + g1² t[5].Add(api, t[0], t[4]) - // t6 = nr * g5^2 + g1^2 - g2 + // t6 = nr * g5² + g1² - g2 t[6].Sub(api, t[5], x.D0.C2) - // t1 = g2^2 + // t1 = g2² t[1].Square(api, x.D0.C2, ext) - // t6 = 2 * nr * g5^2 + 2 * g1^2 - 2*g2 + // t6 = 2 * nr * g5² + 2 * g1² - 2*g2 t[6].Double(api, t[6]) - // z2 = 3 * nr * g5^2 + 3 * g1^2 - 2*g2 + // z2 = 3 * nr * g5² + 3 * g1² - 2*g2 e.D0.C2.Add(api, t[6], t[5]) - // t4 = nr * g2^2 + // t4 = nr * g2² t[4].MulByNonResidue(api, t[1], ext) - // t5 = g3^2 + nr * g2^2 + // t5 = g3² + nr * g2² t[5].Add(api, t[2], t[4]) - // t6 = g3^2 + nr * g2^2 - g1 + // t6 = g3² + nr * g2² - g1 t[6].Sub(api, t[5], x.D0.C1) - // t6 = 2 * g3^2 + 2 * nr * g2^2 - 2 * g1 + // t6 = 2 * g3² + 2 * nr * g2² - 2 * g1 t[6].Double(api, t[6]) - // z1 = 3 * g3^2 + 3 * nr * g2^2 - 2 * g1 + // z1 = 3 * g3² + 3 * nr * g2² - 2 * g1 e.D0.C1.Add(api, t[6], t[5]) - // t0 = g2^2 + g3^2 + // t0 = g2² + g3² t[0].Add(api, t[2], t[1]) // t5 = 2 * g3 * g2 t[5].Sub(api, t[3], t[0]) @@ -239,13 +239,13 @@ func (e *E24) Decompress(api frontend.API, x E24, ext Extension) *E24 { var one E4 one.SetOne(api) - // t0 = g1^2 + // t0 = g1² t[0].Square(api, x.D0.C1, ext) - // t1 = 3 * g1^2 - 2 * g2 + // t1 = 3 * g1² - 2 * g2 t[1].Sub(api, t[0], x.D0.C2). Double(api, t[1]). Add(api, t[1], t[0]) - // t0 = E * g5^2 + t1 + // t0 = E * g5² + t1 t[2].Square(api, x.D1.C2, ext) t[0].MulByNonResidue(api, t[2], ext). Add(api, t[0], t[1]) @@ -258,14 +258,14 @@ func (e *E24) Decompress(api frontend.API, x E24, ext Extension) *E24 { // t1 = g2 * g1 t[1].Mul(api, x.D0.C2, x.D0.C1, ext) - // t2 = 2 * g4^2 - 3 * g2 * g1 + // t2 = 2 * g4² - 3 * g2 * g1 t[2].Square(api, e.D1.C1, ext). Sub(api, t[2], t[1]). Double(api, t[2]). Sub(api, t[2], t[1]) // t1 = g3 * g5 t[1].Mul(api, x.D1.C0, x.D1.C2, ext) - // c_0 = E * (2 * g4^2 + g3 * g5 - 3 * g2 * g1) + 1 + // c₀ = E * (2 * g4² + g3 * g5 - 3 * g2 * g1) + 1 t[2].Add(api, t[2], t[1]) e.D0.C0.MulByNonResidue(api, t[2], ext). Add(api, e.D0.C0, one) @@ -474,7 +474,7 @@ func (e *E24) FinalExponentiation(api frontend.API, e1 E24, genT uint64, ext Ext // Daiki Hayashida and Kenichiro Hayasaka // and Tadanori Teruya // https://eprint.iacr.org/2020/875.pdf - // 3*Phi_24(api, p)/r = (api, u-1)^2 * (api, u+p) * (api, u^2+p^2) * (api, u^4+p^4-1) + 3 + // 3*Phi_24(api, p)/r = (api, u-1)² * (api, u+p) * (api, u²+p²) * (api, u⁴+p⁴-1) + 3 t[0].CyclotomicSquare(api, result, ext) t[1].Expt(api, result, genT, ext) t[2].Conjugate(api, result) diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/sw_bls12377/g1.go index b38735f94d..0e33212b96 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/sw_bls12377/g1.go @@ -137,7 +137,7 @@ func (p *G1Jac) DoubleAssign(api frontend.API) *G1Jac { S = api.Sub(S, XX) S = api.Sub(S, YYYY) S = api.Add(S, S) - M = api.Mul(XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0) + M = api.Mul(XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0) p.Z = api.Add(p.Z, p.Y) p.Z = api.Mul(p.Z, p.Z) p.Z = api.Sub(p.Z, YY) diff --git a/std/algebra/sw_bls12377/g2.go b/std/algebra/sw_bls12377/g2.go index 0aaa4ac03f..8014f9494c 100644 --- a/std/algebra/sw_bls12377/g2.go +++ b/std/algebra/sw_bls12377/g2.go @@ -183,7 +183,7 @@ func (p *G2Jac) Double(api frontend.API, p1 *G2Jac, ext fields_bls12377.Extensio S.Sub(api, S, XX) S.Sub(api, S, YYYY) S.Add(api, S, S) - M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0) + M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0) p.Z.Add(api, p.Z, p.Y) p.Z.Square(api, p.Z, ext) p.Z.Sub(api, p.Z, YY) diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/sw_bls24315/g1.go index 10bb09cdab..27560ee675 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/sw_bls24315/g1.go @@ -137,7 +137,7 @@ func (p *G1Jac) DoubleAssign(api frontend.API) *G1Jac { S = api.Sub(S, XX) S = api.Sub(S, YYYY) S = api.Add(S, S) - M = api.Mul(XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0) + M = api.Mul(XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0) p.Z = api.Add(p.Z, p.Y) p.Z = api.Mul(p.Z, p.Z) p.Z = api.Sub(p.Z, YY) diff --git a/std/algebra/sw_bls24315/g2.go b/std/algebra/sw_bls24315/g2.go index 1dd7154686..07e020e1b4 100644 --- a/std/algebra/sw_bls24315/g2.go +++ b/std/algebra/sw_bls24315/g2.go @@ -183,7 +183,7 @@ func (p *G2Jac) Double(api frontend.API, p1 *G2Jac, ext fields_bls24315.Extensio S.Sub(api, S, XX) S.Sub(api, S, YYYY) S.Add(api, S, S) - M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0) + M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0) p.Z.Add(api, p.Z, p.Y) p.Z.Square(api, p.Z, ext) p.Z.Sub(api, p.Z, YY) diff --git a/std/algebra/twistededwards/bandersnatch/point.go b/std/algebra/twistededwards/bandersnatch/point.go index ddfb8ac52a..dbba4e3ae6 100644 --- a/std/algebra/twistededwards/bandersnatch/point.go +++ b/std/algebra/twistededwards/bandersnatch/point.go @@ -28,7 +28,7 @@ type Point struct { } // MustBeOnCurve checks if a point is on the reduced twisted Edwards curve -// a*x^2 + y^2 = 1 + d*x^2*y^2. +// a*x² + y² = 1 + d*x²*y². func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { one := big.NewInt(1) diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index 7faf85faf4..bc826b5ca2 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -28,7 +28,7 @@ type Point struct { } // MustBeOnCurve checks if a point is on the reduced twisted Edwards curve -// a*x^2 + y^2 = 1 + d*x^2*y^2. +// a*x² + y² = 1 + d*x²*y². func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { one := big.NewInt(1) diff --git a/std/fiat-shamir/transcript.go b/std/fiat-shamir/transcript.go index 3491a548a5..97a35ca62a 100644 --- a/std/fiat-shamir/transcript.go +++ b/std/fiat-shamir/transcript.go @@ -91,8 +91,8 @@ func (t *Transcript) Bind(challengeID string, values []frontend.Variable) error // ComputeChallenge computes the challenge corresponding to the given name. // The resulting variable is: -// * H(name || previous_challenge || binded_values...) if the challenge is not the first one -// * H(name || binded_values... ) if it's is the first challenge +// * H(name ∥ previous_challenge ∥ binded_values...) if the challenge is not the first one +// * H(name ∥ binded_values... ) if it's is the first challenge func (t *Transcript) ComputeChallenge(challengeID string) (frontend.Variable, error) { challenge, ok := t.challenges[challengeID] diff --git a/test/kzg_srs.go b/test/kzg_srs.go index 7f96d5bb03..4e8225ce80 100644 --- a/test/kzg_srs.go +++ b/test/kzg_srs.go @@ -34,7 +34,7 @@ import ( const srsCachedSize = (1 << 14) + 3 // NewKZGSRS uses ccs nb variables and nb constraints to initialize a kzg srs -// for sizes < 2^15, returns a pre-computed cached SRS +// for sizes < 2¹⁵, returns a pre-computed cached SRS // // /!\ warning /!\: this method is here for convenience only: in production, a SRS generated through MPC should be used. func NewKZGSRS(ccs frontend.CompiledConstraintSystem) (kzg.SRS, error) { From 6926710c0864f7f5f4fac1d62feb36921186283b Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Wed, 9 Feb 2022 11:38:48 +0100 Subject: [PATCH 17/37] feat: udpate gnark-crypto --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 56ce699366..535811db35 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a - github.com/consensys/gnark-crypto v0.6.1-0.20220204095423-2fb0ec48a36f + github.com/consensys/gnark-crypto v0.6.1-0.20220209103408-f71b1fc783da github.com/fxamacker/cbor/v2 v2.2.0 github.com/leanovate/gopter v0.2.9 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index f6c42236c3..6ff55dc1f3 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/consensys/gnark-crypto v0.6.1-0.20220203135532-a5667210247a h1:Jfr3vY github.com/consensys/gnark-crypto v0.6.1-0.20220203135532-a5667210247a/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= github.com/consensys/gnark-crypto v0.6.1-0.20220204095423-2fb0ec48a36f h1:55DRDYCFD64OIJh/Yz1Bch9Va14lwKgA/xk0n8JUIjE= github.com/consensys/gnark-crypto v0.6.1-0.20220204095423-2fb0ec48a36f/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= +github.com/consensys/gnark-crypto v0.6.1-0.20220209103408-f71b1fc783da h1:dfeAHW2Yx/ceM+ft3UP/f5fufoU0LyxqF8655Cp88TI= +github.com/consensys/gnark-crypto v0.6.1-0.20220209103408-f71b1fc783da/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From 83623ee7c84330cc93be4a018e76d92e75f5320a Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Wed, 9 Feb 2022 12:21:01 +0100 Subject: [PATCH 18/37] fix(tEd): case when scalar size is odd --- std/algebra/twistededwards/point.go | 37 +++++++++++++++-------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index 573ffe88cb..8789743ac9 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -101,26 +101,27 @@ func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, // first unpack the scalar b := api.ToBinary(scalar) - res := Point{ - 0, - 1, + n := len(b) - 1 + res := Point{} + res.X = api.Select(b[n], p1.X, 0) + res.Y = api.Select(b[n], p1.Y, 1) + + tmp := Point{} + A := Point{} + A.Double(api, p1, curve) + B := Point{} + B.Add(api, &A, p1, curve) + + if n%2 == 0 { + n += 1 } - pp := Point{} - ppp := Point{} - pp.Double(api, p1, curve) - ppp.Add(api, &pp, p1, curve) - - n := len(b) - 1 - res.X = api.Lookup2(b[n], b[n-1], res.X, pp.X, p1.X, ppp.X) - res.Y = api.Lookup2(b[n], b[n-1], res.Y, pp.Y, p1.Y, ppp.Y) - - for i := len(b) - 3; i >= 0; i-- { - res.Double(api, &res, curve) - tmp := Point{} - tmp.Add(api, &res, p1, curve) - res.X = api.Select(b[i], tmp.X, res.X) - res.Y = api.Select(b[i], tmp.Y, res.Y) + for i := n - 2; i >= 1; i -= 2 { + res.Double(api, &res, curve). + Double(api, &res, curve) + tmp.X = api.Lookup2(b[i], b[i-1], 0, A.X, p1.X, B.X) + tmp.Y = api.Lookup2(b[i], b[i-1], 1, A.Y, p1.Y, B.Y) + res.Add(api, &res, &tmp, curve) } p.X = res.X From 255d640e3f167f024f7c1ab9108260d22f28a5b5 Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Wed, 9 Feb 2022 13:33:34 +0100 Subject: [PATCH 19/37] fix(tEd): case when scalar size is odd --- std/algebra/twistededwards/point.go | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index 8789743ac9..b5fd4ca2e0 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -101,20 +101,17 @@ func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, // first unpack the scalar b := api.ToBinary(scalar) - n := len(b) - 1 res := Point{} - res.X = api.Select(b[n], p1.X, 0) - res.Y = api.Select(b[n], p1.Y, 1) - tmp := Point{} A := Point{} - A.Double(api, p1, curve) B := Point{} + + A.Double(api, p1, curve) B.Add(api, &A, p1, curve) - if n%2 == 0 { - n += 1 - } + n := len(b) - 1 + res.X = api.Lookup2(b[n], b[n-1], 0, A.X, p1.X, B.X) + res.Y = api.Lookup2(b[n], b[n-1], 1, A.Y, p1.Y, B.Y) for i := n - 2; i >= 1; i -= 2 { res.Double(api, &res, curve). @@ -124,6 +121,13 @@ func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, res.Add(api, &res, &tmp, curve) } + if n%2 == 0 { + res.Double(api, &res, curve) + tmp.Add(api, &res, p1, curve) + res.X = api.Select(b[0], tmp.X, res.X) + res.Y = api.Select(b[0], tmp.Y, res.Y) + } + p.X = res.X p.Y = res.Y From e8cdcda98858e8ba24d601e2ccfabe0f50d42c66 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Wed, 9 Feb 2022 15:15:25 +0100 Subject: [PATCH 20/37] feat(plonk): beta is dervied using Fiat Shamir --- internal/backend/bls12-377/plonk/prove.go | 8 +++++--- internal/backend/bls12-377/plonk/verify.go | 12 +++++++++--- internal/backend/bls12-381/plonk/prove.go | 8 +++++--- internal/backend/bls12-381/plonk/verify.go | 12 +++++++++--- internal/backend/bls24-315/plonk/prove.go | 8 +++++--- internal/backend/bls24-315/plonk/verify.go | 12 +++++++++--- internal/backend/bn254/plonk/prove.go | 8 +++++--- internal/backend/bn254/plonk/verify.go | 12 +++++++++--- internal/backend/bw6-633/plonk/prove.go | 8 +++++--- internal/backend/bw6-633/plonk/verify.go | 12 +++++++++--- internal/backend/bw6-761/plonk/prove.go | 8 +++++--- internal/backend/bw6-761/plonk/verify.go | 12 +++++++++--- .../template/zkpschemes/plonk/plonk.prove.go.tmpl | 8 +++++--- .../template/zkpschemes/plonk/plonk.verify.go.tmpl | 12 +++++++++--- 14 files changed, 98 insertions(+), 42 deletions(-) diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go index 1cd205a3a2..4fda660783 100644 --- a/internal/backend/bls12-377/plonk/prove.go +++ b/internal/backend/bls12-377/plonk/prove.go @@ -67,7 +67,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -115,8 +115,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn } // Fiat Shamir this - var beta fr.Element - beta.SetUint64(10) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded diff --git a/internal/backend/bls12-377/plonk/verify.go b/internal/backend/bls12-377/plonk/verify.go index 9432e539ec..3693a98cf1 100644 --- a/internal/backend/bls12-377/plonk/verify.go +++ b/internal/backend/bls12-377/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -105,8 +111,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT - var beta fr.Element - beta.SetUint64(10) + // var beta fr.Element + // beta.SetUint64(10) _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+β*s2(ζ)+γ) diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go index c8c26cf826..17a847347d 100644 --- a/internal/backend/bls12-381/plonk/prove.go +++ b/internal/backend/bls12-381/plonk/prove.go @@ -67,7 +67,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -115,8 +115,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn } // Fiat Shamir this - var beta fr.Element - beta.SetUint64(10) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded diff --git a/internal/backend/bls12-381/plonk/verify.go b/internal/backend/bls12-381/plonk/verify.go index 060659fc42..8c77d0cdb1 100644 --- a/internal/backend/bls12-381/plonk/verify.go +++ b/internal/backend/bls12-381/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -105,8 +111,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT - var beta fr.Element - beta.SetUint64(10) + // var beta fr.Element + // beta.SetUint64(10) _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+β*s2(ζ)+γ) diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go index c153401141..51eb4bc40a 100644 --- a/internal/backend/bls24-315/plonk/prove.go +++ b/internal/backend/bls24-315/plonk/prove.go @@ -67,7 +67,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -115,8 +115,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn } // Fiat Shamir this - var beta fr.Element - beta.SetUint64(10) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded diff --git a/internal/backend/bls24-315/plonk/verify.go b/internal/backend/bls24-315/plonk/verify.go index 9e32a48ca8..136239291a 100644 --- a/internal/backend/bls24-315/plonk/verify.go +++ b/internal/backend/bls24-315/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -105,8 +111,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT - var beta fr.Element - beta.SetUint64(10) + // var beta fr.Element + // beta.SetUint64(10) _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+β*s2(ζ)+γ) diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 7da7acad13..a25cde4816 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -67,7 +67,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -115,8 +115,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, } // Fiat Shamir this - var beta fr.Element - beta.SetUint64(10) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded diff --git a/internal/backend/bn254/plonk/verify.go b/internal/backend/bn254/plonk/verify.go index 36f4faa795..7ff483b226 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/internal/backend/bn254/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -105,8 +111,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT - var beta fr.Element - beta.SetUint64(10) + // var beta fr.Element + // beta.SetUint64(10) _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+β*s2(ζ)+γ) diff --git a/internal/backend/bw6-633/plonk/prove.go b/internal/backend/bw6-633/plonk/prove.go index 35d52e077c..ebf3f3068c 100644 --- a/internal/backend/bw6-633/plonk/prove.go +++ b/internal/backend/bw6-633/plonk/prove.go @@ -67,7 +67,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -115,8 +115,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes } // Fiat Shamir this - var beta fr.Element - beta.SetUint64(10) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded diff --git a/internal/backend/bw6-633/plonk/verify.go b/internal/backend/bw6-633/plonk/verify.go index cf7ee19b7b..a0365c311e 100644 --- a/internal/backend/bw6-633/plonk/verify.go +++ b/internal/backend/bw6-633/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -105,8 +111,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT - var beta fr.Element - beta.SetUint64(10) + // var beta fr.Element + // beta.SetUint64(10) _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+β*s2(ζ)+γ) diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go index c483d610e9..36e92fd481 100644 --- a/internal/backend/bw6-761/plonk/prove.go +++ b/internal/backend/bw6-761/plonk/prove.go @@ -67,7 +67,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -115,8 +115,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes } // Fiat Shamir this - var beta fr.Element - beta.SetUint64(10) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded diff --git a/internal/backend/bw6-761/plonk/verify.go b/internal/backend/bw6-761/plonk/verify.go index 8937310621..b08c4faa47 100644 --- a/internal/backend/bw6-761/plonk/verify.go +++ b/internal/backend/bw6-761/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -105,8 +111,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT - var beta fr.Element - beta.SetUint64(10) + // var beta fr.Element + // beta.SetUint64(10) _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+β*s2(ζ)+γ) diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index 1206e36736..b08177c580 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -44,7 +44,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -92,8 +92,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } } // Fiat Shamir this - var beta fr.Element - beta.SetUint64(10) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl index 1deb18c706..d7c4dfb7a8 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl @@ -22,7 +22,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -30,6 +30,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -84,8 +90,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT - var beta fr.Element - beta.SetUint64(10) + // var beta fr.Element + // beta.SetUint64(10) _s1.Mul(&s1, &beta).Add(&_s1, &l).Add(&_s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) _s2.Mul(&s2, &beta).Add(&_s2, &r).Add(&_s2, &gamma) // (r(ζ)+β*s2(ζ)+γ) From e568d4d8cbbd10e5b61c43f486ce5c4f0bb9865f Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 11 Feb 2022 10:29:23 +0100 Subject: [PATCH 21/37] refactor(eddsa): rearrange eddsa verif as cofactor clearing counts --- std/signature/eddsa/eddsa.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go index 94b218023a..127e82f5f7 100644 --- a/std/signature/eddsa/eddsa.go +++ b/std/signature/eddsa/eddsa.go @@ -68,26 +68,24 @@ func Verify(api frontend.API, sig Signature, msg frontend.Variable, pubKey Publi cofactor := pubKey.Curve.Cofactor.Uint64() lhs := twistededwards.Point{} lhs.ScalarMul(api, &base, sig.S, pubKey.Curve) - lhs.MustBeOnCurve(api, pubKey.Curve) // rhs = R+[H(R,A,M)]*A rhs := twistededwards.Point{} rhs.ScalarMul(api, &pubKey.A, hramConstant, pubKey.Curve). Add(api, &rhs, &sig.R, pubKey.Curve) - // rhs.MustBeOnCurve(api, pubKey.Curve) + rhs.MustBeOnCurve(api, pubKey.Curve) - // [cofactor]*lhs and [cofactor]*rhs + // lhs-rhs + rhs.Neg(api, &rhs).Add(api, &lhs, &rhs, pubKey.Curve) + + // [cofactor]*(lhs-rhs) switch cofactor { case 4: rhs.Double(api, &rhs, pubKey.Curve). Double(api, &rhs, pubKey.Curve) - lhs.Double(api, &lhs, pubKey.Curve). - Double(api, &lhs, pubKey.Curve) case 8: rhs.Double(api, &rhs, pubKey.Curve). Double(api, &rhs, pubKey.Curve).Double(api, &rhs, pubKey.Curve) - lhs.Double(api, &lhs, pubKey.Curve). - Double(api, &lhs, pubKey.Curve).Double(api, &lhs, pubKey.Curve) } api.AssertIsEqual(rhs.X, lhs.X) From 66b3452b59c6abc08efc3ba43948ad373356d50a Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 11 Feb 2022 16:48:29 +0100 Subject: [PATCH 22/37] feat(tEd): implements double-base scalar mul --- std/algebra/twistededwards/point.go | 39 ++++- std/algebra/twistededwards/point_test.go | 183 +++++++++++++++++++++++ 2 files changed, 218 insertions(+), 4 deletions(-) diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index b5fd4ca2e0..42d468ab92 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -27,6 +27,13 @@ type Point struct { X, Y frontend.Variable } +// Neg computes the negative of a point in SNARK coordinates +func (p *Point) Neg(api frontend.API, p1 *Point) *Point { + p.X = api.Neg(p1.X) + p.Y = p1.Y + return p +} + // MustBeOnCurve checks if a point is on the reduced twisted Edwards curve // a*x^2 + y^2 = 1 + d*x^2*y^2. func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { @@ -134,9 +141,33 @@ func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, return p } -// Neg computes the negative of a point in SNARK coordinates -func (p *Point) Neg(api frontend.API, p1 *Point) *Point { - p.X = api.Neg(p1.X) - p.Y = p1.Y +// DoubleBaseScalarMul computes s1*P1+s2*P2 +// where P1 and P2 are points on a twisted Edwards curve +// and s1, s2 scalars. +func (p *Point) DoubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve EdCurve) *Point { + + // first unpack the scalars + b1 := api.ToBinary(s1) + b2 := api.ToBinary(s2) + + res := Point{} + tmp := Point{} + sum := Point{} + sum.Add(api, p1, p2, curve) + + n := len(b1) + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, sum.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, sum.Y) + + for i := n - 2; i >= 0; i-- { + res.Double(api, &res, curve) + tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, sum.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, sum.Y) + res.Add(api, &res, &tmp, curve) + } + + p.X = res.X + p.Y = res.Y + return p } diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index 99d1023092..12aebf8de2 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -599,6 +599,189 @@ func TestScalarMulGeneric(t *testing.T) { } } +// + +type doubleScalarMulGeneric struct { + P1, P2, E Point + S1, S2 frontend.Variable +} + +func (circuit *doubleScalarMulGeneric) Define(api frontend.API) error { + + // get edwards curve params + params, err := NewEdCurve(api.Curve()) + if err != nil { + return err + } + + resGeneric := circuit.P1.DoubleBaseScalarMul(api, &circuit.P1, &circuit.P2, circuit.S1, circuit.S2, params) + + api.AssertIsEqual(resGeneric.X, circuit.E.X) + api.AssertIsEqual(resGeneric.Y, circuit.E.Y) + + return nil +} + +func TestDoubleScalarMulGeneric(t *testing.T) { + + assert := test.NewAssert(t) + + var circuit, witness doubleScalarMulGeneric + + // generate witness data + for _, id := range ecc.Implemented() { + + params, err := NewEdCurve(id) + if err != nil { + t.Fatal(err) + } + + switch id { + case ecc.BN254: + var base, point1, point2, tmp, expected tbn254.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BLS12_377: + var base, point1, point2, tmp, expected tbls12377.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BLS12_381: + var base, point1, point2, tmp, expected tbls12381.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BLS24_315: + var base, point1, point2, tmp, expected tbls24315.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BW6_761: + var base, point1, point2, tmp, expected tbw6761.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BW6_633: + var base, point1, point2, tmp, expected tbw6633.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + } + + // creates r1cs + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) + } +} + type neg struct { P, E Point } From 39b4e0ffa19b258ed7d899851763f8e8cd326949 Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 11 Feb 2022 17:14:22 +0100 Subject: [PATCH 23/37] perf(EdDSA): eddsa gadget using double-base scalar mul --- std/signature/eddsa/eddsa.go | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go index 127e82f5f7..c266b33a7e 100644 --- a/std/signature/eddsa/eddsa.go +++ b/std/signature/eddsa/eddsa.go @@ -64,32 +64,29 @@ func Verify(api frontend.API, sig Signature, msg frontend.Variable, pubKey Publi base.X = pubKey.Curve.Base.X base.Y = pubKey.Curve.Base.Y - // lhs = [S]G + //[S]G-[H(R,A,M)]*A cofactor := pubKey.Curve.Cofactor.Uint64() - lhs := twistededwards.Point{} - lhs.ScalarMul(api, &base, sig.S, pubKey.Curve) + Q := twistededwards.Point{} + _A := twistededwards.Point{} + _A.Neg(api, &pubKey.A) + Q.DoubleBaseScalarMul(api, &base, &_A, sig.S, hramConstant, pubKey.Curve) + Q.MustBeOnCurve(api, pubKey.Curve) - // rhs = R+[H(R,A,M)]*A - rhs := twistededwards.Point{} - rhs.ScalarMul(api, &pubKey.A, hramConstant, pubKey.Curve). - Add(api, &rhs, &sig.R, pubKey.Curve) - rhs.MustBeOnCurve(api, pubKey.Curve) - - // lhs-rhs - rhs.Neg(api, &rhs).Add(api, &lhs, &rhs, pubKey.Curve) + //[S]G-[H(R,A,M)]*A-R + Q.Neg(api, &Q).Add(api, &Q, &sig.R, pubKey.Curve) // [cofactor]*(lhs-rhs) switch cofactor { case 4: - rhs.Double(api, &rhs, pubKey.Curve). - Double(api, &rhs, pubKey.Curve) + Q.Double(api, &Q, pubKey.Curve). + Double(api, &Q, pubKey.Curve) case 8: - rhs.Double(api, &rhs, pubKey.Curve). - Double(api, &rhs, pubKey.Curve).Double(api, &rhs, pubKey.Curve) + Q.Double(api, &Q, pubKey.Curve). + Double(api, &Q, pubKey.Curve).Double(api, &Q, pubKey.Curve) } - api.AssertIsEqual(rhs.X, lhs.X) - api.AssertIsEqual(rhs.Y, lhs.Y) + api.AssertIsEqual(Q.X, 0) + api.AssertIsEqual(Q.Y, 1) return nil } From f541b195c72aae2aa835794f45f878c7def02501 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Mon, 14 Feb 2022 11:27:55 +0100 Subject: [PATCH 24/37] fix: restored commented code blinding polynomial --- internal/backend/bw6-761/cs/to_delete.go | 81 ------------------- .../zkpschemes/plonk/plonk.prove.go.tmpl | 60 +++++++------- .../zkpschemes/plonk/plonk.verify.go.tmpl | 18 ++--- 3 files changed, 38 insertions(+), 121 deletions(-) delete mode 100644 internal/backend/bw6-761/cs/to_delete.go diff --git a/internal/backend/bw6-761/cs/to_delete.go b/internal/backend/bw6-761/cs/to_delete.go deleted file mode 100644 index 8844b9dc3f..0000000000 --- a/internal/backend/bw6-761/cs/to_delete.go +++ /dev/null @@ -1,81 +0,0 @@ -package cs - -import ( - "fmt" - "strings" - - "github.com/consensys/gnark/internal/backend/compiled" -) - -// r1cs -func (cs *R1CS) printTerm(t compiled.Term) string { - coefID, varID, _ := t.Unpack() - coef := cs.Coefficients[coefID] - return fmt.Sprintf("%s*%d", coef.String(), varID) -} - -func (cs *R1CS) printLinExp(l compiled.Variable) string { - var sbb strings.Builder - for i, t := range l.LinExp { - sbb.WriteString(cs.printTerm(t)) - if i < len(l.LinExp)-1 { - sbb.WriteString(" + ") - } - } - return sbb.String() -} - -// func (cs *R1CS) printLinExp(l compiled.LinearExpression) string { -// var sbb strings.Builder -// for i, t := range l { -// sbb.WriteString(cs.printTerm(t)) -// if i < len(l)-1 { -// sbb.WriteString(" + ") -// } -// } -// return sbb.String() -// } - -func (cs *R1CS) Print() string { - var sbb strings.Builder - for i := 0; i < len(cs.Constraints); i++ { - sbb.WriteString("(") - sbb.WriteString(cs.printLinExp(cs.Constraints[i].L)) - sbb.WriteString(") * (") - sbb.WriteString(cs.printLinExp(cs.Constraints[i].R)) - sbb.WriteString(" ) = ") - sbb.WriteString(cs.printLinExp(cs.Constraints[i].O)) - sbb.WriteString("\n") - } - return sbb.String() -} - -// sparse r1cs -func (cs *SparseR1CS) printTerm(t compiled.Term) string { - coefID, varID, _ := t.Unpack() - coef := cs.Coefficients[coefID] - return fmt.Sprintf("%s*%d", coef.String(), varID) - -} - -func (cs *SparseR1CS) Print() string { - var sbb strings.Builder - for i := 0; i < len(cs.Constraints); i++ { - c := cs.Constraints[i] - sbb.WriteString(fmt.Sprintf("%d: ", i)) - sbb.WriteString(cs.printTerm(c.L)) - sbb.WriteString(" + ") - sbb.WriteString(cs.printTerm(c.R)) - sbb.WriteString(" + ( ") - sbb.WriteString(cs.printTerm(c.M[0])) - sbb.WriteString(" * ") - sbb.WriteString(cs.printTerm(c.M[1])) - sbb.WriteString(" ) + ") - sbb.WriteString(cs.printTerm(c.O)) - sbb.WriteString(" + ") - k := cs.Coefficients[c.K] - sbb.WriteString(k.String()) - sbb.WriteString("\n") - } - return sbb.String() -} diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index b08177c580..0d3e728755 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -144,15 +144,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) close(chEvalBL) }() go func() { - evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) close(chEvalBR) }() go func() { - evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) close(chEvalBO) }() @@ -187,7 +187,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } return } - evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL @@ -205,14 +205,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { // CORRECT + if err := <-chConstraintOrdering; err != nil { return nil, err } - <-chConstraintInd // CORRECT + <-chConstraintInd // compute h in canonical form - h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { @@ -463,29 +463,27 @@ func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bc func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - // totalDegree := rou + bo + totalDegree := rou + bo - // // re-use cp - // res := cp[:totalDegree+1] + // re-use cp + res := cp[:totalDegree+1] - // // random polynomial - // blindingPoly := make([]fr.Element, bo+1) - // for i := uint64(0); i < bo+1; i++ { - // if _, err := blindingPoly[i].SetRandom(); err != nil { - // return nil, err - // } - // } + // random polynomial + blindingPoly := make([]fr.Element, bo+1) + for i := uint64(0); i < bo+1; i++ { + if _, err := blindingPoly[i].SetRandom(); err != nil { + return nil, err + } + } - // // blinding - // for i := uint64(0); i < bo+1; i++ { - // res[i].Sub(&res[i], &blindingPoly[i]) - // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - // } + // blinding + for i := uint64(0); i < bo+1; i++ { + res[i].Sub(&res[i], &blindingPoly[i]) + res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + } - // return res, nil + return res, nil - // TODO reactivate blinding - return cp, nil } // evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. @@ -736,14 +734,14 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // evaluate Z = Xᵐ-1 on a coset of the big domain evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) - evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // CORRECT + evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // computes L₁ (canonical form) startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) } - pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT + pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain @@ -778,7 +776,7 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] - return h1, h2, h3 // CORRECT + return h1, h2, h3 } @@ -811,7 +809,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -822,7 +820,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element @@ -837,7 +835,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT + Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl index d7c4dfb7a8..1c64d0f118 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl @@ -82,13 +82,13 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} zu := proof.ZShiftedOpening.ClaimedValue - claimedQuotient := proof.BatchedProof.ClaimedValues[0] // CORRECT - linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] // CORRECT - l := proof.BatchedProof.ClaimedValues[2] // CORRECT - r := proof.BatchedProof.ClaimedValues[3] // CORRECT - o := proof.BatchedProof.ClaimedValues[4] // CORRECT - s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT - s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT + claimedQuotient := proof.BatchedProof.ClaimedValues[0] + linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] + l := proof.BatchedProof.ClaimedValues[2] + r := proof.BatchedProof.ClaimedValues[3] + o := proof.BatchedProof.ClaimedValues[4] + s1 := proof.BatchedProof.ClaimedValues[5] + s2 := proof.BatchedProof.ClaimedValues[6] // var beta fr.Element // beta.SetUint64(10) @@ -145,14 +145,14 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) - // CORRECT + var u, v, w, cosetsquare fr.Element u.Mul(&zu, &beta) v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β - // CORRECT + cosetsquare.Square(&vk.CosetShift) u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*μ*ζ+γ) From 7c128171d1c9835a59f8c988058a9e5afa5018a3 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Mon, 14 Feb 2022 11:35:54 +0100 Subject: [PATCH 25/37] fix: resolve comments --- internal/backend/bls12-377/plonk/marshal.go | 14 +- .../backend/bls12-377/plonk/marshal_test.go | 18 +-- internal/backend/bls12-377/plonk/prove.go | 136 +++++++++--------- internal/backend/bls12-377/plonk/setup.go | 67 ++++----- internal/backend/bls12-377/plonk/verify.go | 16 +-- internal/backend/bls12-381/plonk/marshal.go | 14 +- .../backend/bls12-381/plonk/marshal_test.go | 18 +-- internal/backend/bls12-381/plonk/prove.go | 136 +++++++++--------- internal/backend/bls12-381/plonk/setup.go | 67 ++++----- internal/backend/bls12-381/plonk/verify.go | 16 +-- internal/backend/bls24-315/plonk/marshal.go | 14 +- .../backend/bls24-315/plonk/marshal_test.go | 18 +-- internal/backend/bls24-315/plonk/prove.go | 136 +++++++++--------- internal/backend/bls24-315/plonk/setup.go | 67 ++++----- internal/backend/bls24-315/plonk/verify.go | 16 +-- internal/backend/bn254/plonk/marshal.go | 14 +- internal/backend/bn254/plonk/marshal_test.go | 18 +-- internal/backend/bn254/plonk/prove.go | 136 +++++++++--------- internal/backend/bn254/plonk/setup.go | 67 ++++----- internal/backend/bn254/plonk/verify.go | 16 +-- internal/backend/bw6-633/plonk/marshal.go | 14 +- .../backend/bw6-633/plonk/marshal_test.go | 18 +-- internal/backend/bw6-633/plonk/prove.go | 136 +++++++++--------- internal/backend/bw6-633/plonk/setup.go | 67 ++++----- internal/backend/bw6-633/plonk/verify.go | 16 +-- internal/backend/bw6-761/plonk/marshal.go | 14 +- .../backend/bw6-761/plonk/marshal_test.go | 18 +-- internal/backend/bw6-761/plonk/prove.go | 136 +++++++++--------- internal/backend/bw6-761/plonk/setup.go | 67 ++++----- internal/backend/bw6-761/plonk/verify.go | 16 +-- .../zkpschemes/plonk/plonk.marshal.go.tmpl | 14 +- .../zkpschemes/plonk/plonk.prove.go.tmpl | 88 ++++++------ .../zkpschemes/plonk/plonk.setup.go.tmpl | 67 ++++----- .../zkpschemes/plonk/plonk.verify.go.tmpl | 16 +-- .../zkpschemes/plonk/tests/marshal.go.tmpl | 18 +-- 35 files changed, 852 insertions(+), 857 deletions(-) diff --git a/internal/backend/bls12-377/plonk/marshal.go b/internal/backend/bls12-377/plonk/marshal.go index 54de931064..2bb46eabd3 100644 --- a/internal/backend/bls12-377/plonk/marshal.go +++ b/internal/backend/bls12-377/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainSmall.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainBig.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainSmall.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainSmall.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -140,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainSmall.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainBig.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ diff --git a/internal/backend/bls12-377/plonk/marshal_test.go b/internal/backend/bls12-377/plonk/marshal_test.go index 0e8e2e90b4..f4bea8c379 100644 --- a/internal/backend/bls12-377/plonk/marshal_test.go +++ b/internal/backend/bls12-377/plonk/marshal_test.go @@ -47,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainSmall = *fft.NewDomain(42) - pk.DomainBig = *fft.NewDomain(4 * 42) - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -62,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go index 4fda660783..b8a821b6f4 100644 --- a/internal/backend/bls12-377/plonk/prove.go +++ b/internal/backend/bls12-377/plonk/prove.go @@ -98,7 +98,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall, - &pk.DomainSmall) + &pk.Domain[0]) if err != nil { return nil, err } @@ -167,15 +167,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() @@ -183,10 +183,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) fft.BitReverse(qkCompletedCanonical) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain @@ -210,7 +210,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn return } - evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL @@ -228,14 +228,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { // CORRECT + if err := <-chConstraintOrdering; err != nil { return nil, err } - <-chConstraintInd // CORRECT + <-chConstraintInd // compute h in canonical form - h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { @@ -271,7 +271,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, &zetaShifted, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -312,7 +312,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) @@ -360,7 +360,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn }, &zeta, hFunc, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -486,36 +486,34 @@ func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bc func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - // totalDegree := rou + bo + totalDegree := rou + bo - // // re-use cp - // res := cp[:totalDegree+1] + // re-use cp + res := cp[:totalDegree+1] - // // random polynomial - // blindingPoly := make([]fr.Element, bo+1) - // for i := uint64(0); i < bo+1; i++ { - // if _, err := blindingPoly[i].SetRandom(); err != nil { - // return nil, err - // } - // } + // random polynomial + blindingPoly := make([]fr.Element, bo+1) + for i := uint64(0); i < bo+1; i++ { + if _, err := blindingPoly[i].SetRandom(); err != nil { + return nil, err + } + } - // // blinding - // for i := uint64(0); i < bo+1; i++ { - // res[i].Sub(&res[i], &blindingPoly[i]) - // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - // } + // blinding + for i := uint64(0); i < bo+1; i++ { + res[i].Sub(&res[i], &blindingPoly[i]) + res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + } - // return res, nil + return res, nil - // TODO reactivate blinding - return cp, nil } // evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainSmall.Cardinality) + s := int(pk.Domain[0].Cardinality) var l, r, o []fr.Element l = make([]fr.Element, s) @@ -558,14 +556,14 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma fr.Element) ([]fr.Element, error) { // note that z has more capacity has its memory is reused for blinded z later on - z := make([]fr.Element, pk.DomainSmall.Cardinality, pk.DomainSmall.Cardinality+3) - nbElmts := int(pk.DomainSmall.Cardinality) - gInv := make([]fr.Element, pk.DomainSmall.Cardinality) + z := make([]fr.Element, pk.Domain[0].Cardinality, pk.Domain[0].Cardinality+3) + nbElmts := int(pk.Domain[0].Cardinality) + gInv := make([]fr.Element, pk.Domain[0].Cardinality) z[0].SetOne() gInv[0].SetOne() - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) utils.Parallelize(nbElmts-1, func(start, end int) { @@ -596,10 +594,10 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma Mul(&z[i], &gInv[i]) } - pk.DomainSmall.FFTInverse(z, fft.DIF) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainSmall.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } @@ -614,22 +612,22 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO wg.Add(4) go func() { - evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain @@ -660,26 +658,26 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO // * gamma randomization func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - nbElmts := int(pk.DomainBig.Cardinality) + nbElmts := int(pk.Domain[1].Cardinality) // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) // on the big domain (coset). - res := make([]fr.Element, pk.DomainBig.Cardinality) + res := make([]fr.Element, pk.Domain[1].Cardinality) nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) var cosetShift, cosetShiftSquare fr.Element cosetShift.Set(&pk.Vk.CosetShift) cosetShiftSquare.Square(&pk.Vk.CosetShift) - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var evaluationIDBigDomain fr.Element - evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). - Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) var f [3]fr.Element var g [3]fr.Element @@ -703,7 +701,7 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Elemen res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) @@ -755,29 +753,29 @@ func evaluateXnMinusOneDomainBigCoset(domainBig, domainSmall *fft.Domain) []fr.E // constraintInd, constraintOrdering are evaluated on the big domain (coset). func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReversed, evaluationConstraintOrderingBitReversed, evaluationBlindedZDomainBigBitReversed []fr.Element, alpha fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - h := make([]fr.Element, pk.DomainBig.Cardinality) + h := make([]fr.Element, pk.Domain[1].Cardinality) // evaluate Z = Xᵐ-1 on a coset of the big domain - evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) - evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // CORRECT + evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.Domain[1], &pk.Domain[0]) + evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // computes L₁ (canonical form) - startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) - for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { - startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) + startsAtOne := make([]fr.Element, pk.Domain[1].Cardinality) + for i := 0; i < int(pk.Domain[0].Cardinality); i++ { + startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT + pk.Domain[1].FFT(startsAtOne, fft.DIF, true) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain - nn := uint64(64 - bits.TrailingZeros64(pk.DomainBig.Cardinality)) + nn := uint64(64 - bits.TrailingZeros64(pk.Domain[1].Cardinality)) var one fr.Element one.SetOne() - ratio := pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality + ratio := pk.Domain[1].Cardinality / pk.Domain[0].Cardinality - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var t fr.Element for i := uint64(start); i < uint64(end); i++ { @@ -794,14 +792,14 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainBig.FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainSmall.Cardinality+2] - h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] - h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - return h1, h2, h3 // CORRECT + return h1, h2, h3 } @@ -834,7 +832,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -845,12 +843,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainSmall.Cardinality) + nbElmt := int64(pk.Domain[0].Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -860,7 +858,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) diff --git a/internal/backend/bls12-377/plonk/setup.go b/internal/backend/bls12-377/plonk/setup.go index 87f8e8e406..259f76e6f1 100644 --- a/internal/backend/bls12-377/plonk/setup.go +++ b/internal/backend/bls12-377/plonk/setup.go @@ -45,8 +45,11 @@ type ProvingKey struct { // Storing LQk in Lagrange basis saves a fft... CQk, LQk []fr.Element - // Domains used for the FFTs - DomainSmall, DomainBig fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain // Permutation polynomials EvaluationPermutationBigDomainBitReversed []fr.Element @@ -94,21 +97,21 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainSmall = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainBig = *fft.NewDomain(8 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainBig = *fft.NewDomain(4 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainSmall.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainSmall.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) if err := pk.InitKZG(srs); err != nil { @@ -116,12 +119,12 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -143,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) - pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -206,7 +209,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainSmall.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -262,10 +265,10 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { // s1 (LDE) s2 (LDE) s3 (LDE) func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmts := int(pk.DomainSmall.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 pk.S1Canonical = make([]fr.Element, nbElmts) @@ -278,21 +281,21 @@ func ccomputePermutationPolynomials(pk *ProvingKey) { } // Canonical form of S1, S2, S3 - pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) fft.BitReverse(pk.S1Canonical) fft.BitReverse(pk.S2Canonical) fft.BitReverse(pk.S3Canonical) // evaluation of permutation on the big domain - pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) } @@ -314,7 +317,7 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { return res } -// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bls12-377/plonk/verify.go b/internal/backend/bls12-377/plonk/verify.go index 3693a98cf1..dee48705ee 100644 --- a/internal/backend/bls12-377/plonk/verify.go +++ b/internal/backend/bls12-377/plonk/verify.go @@ -103,13 +103,13 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne zu := proof.ZShiftedOpening.ClaimedValue - claimedQuotient := proof.BatchedProof.ClaimedValues[0] // CORRECT - linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] // CORRECT - l := proof.BatchedProof.ClaimedValues[2] // CORRECT - r := proof.BatchedProof.ClaimedValues[3] // CORRECT - o := proof.BatchedProof.ClaimedValues[4] // CORRECT - s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT - s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT + claimedQuotient := proof.BatchedProof.ClaimedValues[0] + linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] + l := proof.BatchedProof.ClaimedValues[2] + r := proof.BatchedProof.ClaimedValues[3] + o := proof.BatchedProof.ClaimedValues[4] + s1 := proof.BatchedProof.ClaimedValues[5] + s2 := proof.BatchedProof.ClaimedValues[6] // var beta fr.Element // beta.SetUint64(10) @@ -166,14 +166,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) - // CORRECT var u, v, w, cosetsquare fr.Element u.Mul(&zu, &beta) v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β - // CORRECT cosetsquare.Square(&vk.CosetShift) u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*μ*ζ+γ) diff --git a/internal/backend/bls12-381/plonk/marshal.go b/internal/backend/bls12-381/plonk/marshal.go index edf3c93c23..d03d1be5fa 100644 --- a/internal/backend/bls12-381/plonk/marshal.go +++ b/internal/backend/bls12-381/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainSmall.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainBig.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainSmall.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainSmall.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -140,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainSmall.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainBig.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ diff --git a/internal/backend/bls12-381/plonk/marshal_test.go b/internal/backend/bls12-381/plonk/marshal_test.go index cbb4b9a537..e30d108e8b 100644 --- a/internal/backend/bls12-381/plonk/marshal_test.go +++ b/internal/backend/bls12-381/plonk/marshal_test.go @@ -47,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainSmall = *fft.NewDomain(42) - pk.DomainBig = *fft.NewDomain(4 * 42) - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -62,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go index 17a847347d..7b434fd24e 100644 --- a/internal/backend/bls12-381/plonk/prove.go +++ b/internal/backend/bls12-381/plonk/prove.go @@ -98,7 +98,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall, - &pk.DomainSmall) + &pk.Domain[0]) if err != nil { return nil, err } @@ -167,15 +167,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() @@ -183,10 +183,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) fft.BitReverse(qkCompletedCanonical) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain @@ -210,7 +210,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn return } - evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL @@ -228,14 +228,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { // CORRECT + if err := <-chConstraintOrdering; err != nil { return nil, err } - <-chConstraintInd // CORRECT + <-chConstraintInd // compute h in canonical form - h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { @@ -271,7 +271,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, &zetaShifted, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -312,7 +312,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) @@ -360,7 +360,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn }, &zeta, hFunc, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -486,36 +486,34 @@ func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bc func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - // totalDegree := rou + bo + totalDegree := rou + bo - // // re-use cp - // res := cp[:totalDegree+1] + // re-use cp + res := cp[:totalDegree+1] - // // random polynomial - // blindingPoly := make([]fr.Element, bo+1) - // for i := uint64(0); i < bo+1; i++ { - // if _, err := blindingPoly[i].SetRandom(); err != nil { - // return nil, err - // } - // } + // random polynomial + blindingPoly := make([]fr.Element, bo+1) + for i := uint64(0); i < bo+1; i++ { + if _, err := blindingPoly[i].SetRandom(); err != nil { + return nil, err + } + } - // // blinding - // for i := uint64(0); i < bo+1; i++ { - // res[i].Sub(&res[i], &blindingPoly[i]) - // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - // } + // blinding + for i := uint64(0); i < bo+1; i++ { + res[i].Sub(&res[i], &blindingPoly[i]) + res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + } - // return res, nil + return res, nil - // TODO reactivate blinding - return cp, nil } // evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainSmall.Cardinality) + s := int(pk.Domain[0].Cardinality) var l, r, o []fr.Element l = make([]fr.Element, s) @@ -558,14 +556,14 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma fr.Element) ([]fr.Element, error) { // note that z has more capacity has its memory is reused for blinded z later on - z := make([]fr.Element, pk.DomainSmall.Cardinality, pk.DomainSmall.Cardinality+3) - nbElmts := int(pk.DomainSmall.Cardinality) - gInv := make([]fr.Element, pk.DomainSmall.Cardinality) + z := make([]fr.Element, pk.Domain[0].Cardinality, pk.Domain[0].Cardinality+3) + nbElmts := int(pk.Domain[0].Cardinality) + gInv := make([]fr.Element, pk.Domain[0].Cardinality) z[0].SetOne() gInv[0].SetOne() - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) utils.Parallelize(nbElmts-1, func(start, end int) { @@ -596,10 +594,10 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma Mul(&z[i], &gInv[i]) } - pk.DomainSmall.FFTInverse(z, fft.DIF) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainSmall.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } @@ -614,22 +612,22 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO wg.Add(4) go func() { - evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain @@ -660,26 +658,26 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO // * gamma randomization func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - nbElmts := int(pk.DomainBig.Cardinality) + nbElmts := int(pk.Domain[1].Cardinality) // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) // on the big domain (coset). - res := make([]fr.Element, pk.DomainBig.Cardinality) + res := make([]fr.Element, pk.Domain[1].Cardinality) nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) var cosetShift, cosetShiftSquare fr.Element cosetShift.Set(&pk.Vk.CosetShift) cosetShiftSquare.Square(&pk.Vk.CosetShift) - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var evaluationIDBigDomain fr.Element - evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). - Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) var f [3]fr.Element var g [3]fr.Element @@ -703,7 +701,7 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Elemen res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) @@ -755,29 +753,29 @@ func evaluateXnMinusOneDomainBigCoset(domainBig, domainSmall *fft.Domain) []fr.E // constraintInd, constraintOrdering are evaluated on the big domain (coset). func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReversed, evaluationConstraintOrderingBitReversed, evaluationBlindedZDomainBigBitReversed []fr.Element, alpha fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - h := make([]fr.Element, pk.DomainBig.Cardinality) + h := make([]fr.Element, pk.Domain[1].Cardinality) // evaluate Z = Xᵐ-1 on a coset of the big domain - evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) - evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // CORRECT + evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.Domain[1], &pk.Domain[0]) + evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // computes L₁ (canonical form) - startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) - for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { - startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) + startsAtOne := make([]fr.Element, pk.Domain[1].Cardinality) + for i := 0; i < int(pk.Domain[0].Cardinality); i++ { + startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT + pk.Domain[1].FFT(startsAtOne, fft.DIF, true) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain - nn := uint64(64 - bits.TrailingZeros64(pk.DomainBig.Cardinality)) + nn := uint64(64 - bits.TrailingZeros64(pk.Domain[1].Cardinality)) var one fr.Element one.SetOne() - ratio := pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality + ratio := pk.Domain[1].Cardinality / pk.Domain[0].Cardinality - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var t fr.Element for i := uint64(start); i < uint64(end); i++ { @@ -794,14 +792,14 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainBig.FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainSmall.Cardinality+2] - h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] - h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - return h1, h2, h3 // CORRECT + return h1, h2, h3 } @@ -834,7 +832,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -845,12 +843,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainSmall.Cardinality) + nbElmt := int64(pk.Domain[0].Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -860,7 +858,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) diff --git a/internal/backend/bls12-381/plonk/setup.go b/internal/backend/bls12-381/plonk/setup.go index 057d2aea04..823cc25d8b 100644 --- a/internal/backend/bls12-381/plonk/setup.go +++ b/internal/backend/bls12-381/plonk/setup.go @@ -45,8 +45,11 @@ type ProvingKey struct { // Storing LQk in Lagrange basis saves a fft... CQk, LQk []fr.Element - // Domains used for the FFTs - DomainSmall, DomainBig fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain // Permutation polynomials EvaluationPermutationBigDomainBitReversed []fr.Element @@ -94,21 +97,21 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainSmall = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainBig = *fft.NewDomain(8 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainBig = *fft.NewDomain(4 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainSmall.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainSmall.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) if err := pk.InitKZG(srs); err != nil { @@ -116,12 +119,12 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -143,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) - pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -206,7 +209,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainSmall.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -262,10 +265,10 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { // s1 (LDE) s2 (LDE) s3 (LDE) func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmts := int(pk.DomainSmall.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 pk.S1Canonical = make([]fr.Element, nbElmts) @@ -278,21 +281,21 @@ func ccomputePermutationPolynomials(pk *ProvingKey) { } // Canonical form of S1, S2, S3 - pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) fft.BitReverse(pk.S1Canonical) fft.BitReverse(pk.S2Canonical) fft.BitReverse(pk.S3Canonical) // evaluation of permutation on the big domain - pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) } @@ -314,7 +317,7 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { return res } -// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bls12-381/plonk/verify.go b/internal/backend/bls12-381/plonk/verify.go index 8c77d0cdb1..bf514edae9 100644 --- a/internal/backend/bls12-381/plonk/verify.go +++ b/internal/backend/bls12-381/plonk/verify.go @@ -103,13 +103,13 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne zu := proof.ZShiftedOpening.ClaimedValue - claimedQuotient := proof.BatchedProof.ClaimedValues[0] // CORRECT - linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] // CORRECT - l := proof.BatchedProof.ClaimedValues[2] // CORRECT - r := proof.BatchedProof.ClaimedValues[3] // CORRECT - o := proof.BatchedProof.ClaimedValues[4] // CORRECT - s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT - s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT + claimedQuotient := proof.BatchedProof.ClaimedValues[0] + linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] + l := proof.BatchedProof.ClaimedValues[2] + r := proof.BatchedProof.ClaimedValues[3] + o := proof.BatchedProof.ClaimedValues[4] + s1 := proof.BatchedProof.ClaimedValues[5] + s2 := proof.BatchedProof.ClaimedValues[6] // var beta fr.Element // beta.SetUint64(10) @@ -166,14 +166,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) - // CORRECT var u, v, w, cosetsquare fr.Element u.Mul(&zu, &beta) v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β - // CORRECT cosetsquare.Square(&vk.CosetShift) u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*μ*ζ+γ) diff --git a/internal/backend/bls24-315/plonk/marshal.go b/internal/backend/bls24-315/plonk/marshal.go index 3c5d5c9b2b..05f24f4f36 100644 --- a/internal/backend/bls24-315/plonk/marshal.go +++ b/internal/backend/bls24-315/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainSmall.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainBig.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainSmall.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainSmall.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -140,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainSmall.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainBig.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ diff --git a/internal/backend/bls24-315/plonk/marshal_test.go b/internal/backend/bls24-315/plonk/marshal_test.go index d9fe4ec39f..99763b07e3 100644 --- a/internal/backend/bls24-315/plonk/marshal_test.go +++ b/internal/backend/bls24-315/plonk/marshal_test.go @@ -47,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainSmall = *fft.NewDomain(42) - pk.DomainBig = *fft.NewDomain(4 * 42) - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -62,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go index 51eb4bc40a..ee92a6436c 100644 --- a/internal/backend/bls24-315/plonk/prove.go +++ b/internal/backend/bls24-315/plonk/prove.go @@ -98,7 +98,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall, - &pk.DomainSmall) + &pk.Domain[0]) if err != nil { return nil, err } @@ -167,15 +167,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() @@ -183,10 +183,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) fft.BitReverse(qkCompletedCanonical) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain @@ -210,7 +210,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn return } - evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL @@ -228,14 +228,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { // CORRECT + if err := <-chConstraintOrdering; err != nil { return nil, err } - <-chConstraintInd // CORRECT + <-chConstraintInd // compute h in canonical form - h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { @@ -271,7 +271,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, &zetaShifted, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -312,7 +312,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) @@ -360,7 +360,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn }, &zeta, hFunc, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -486,36 +486,34 @@ func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bc func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - // totalDegree := rou + bo + totalDegree := rou + bo - // // re-use cp - // res := cp[:totalDegree+1] + // re-use cp + res := cp[:totalDegree+1] - // // random polynomial - // blindingPoly := make([]fr.Element, bo+1) - // for i := uint64(0); i < bo+1; i++ { - // if _, err := blindingPoly[i].SetRandom(); err != nil { - // return nil, err - // } - // } + // random polynomial + blindingPoly := make([]fr.Element, bo+1) + for i := uint64(0); i < bo+1; i++ { + if _, err := blindingPoly[i].SetRandom(); err != nil { + return nil, err + } + } - // // blinding - // for i := uint64(0); i < bo+1; i++ { - // res[i].Sub(&res[i], &blindingPoly[i]) - // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - // } + // blinding + for i := uint64(0); i < bo+1; i++ { + res[i].Sub(&res[i], &blindingPoly[i]) + res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + } - // return res, nil + return res, nil - // TODO reactivate blinding - return cp, nil } // evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainSmall.Cardinality) + s := int(pk.Domain[0].Cardinality) var l, r, o []fr.Element l = make([]fr.Element, s) @@ -558,14 +556,14 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma fr.Element) ([]fr.Element, error) { // note that z has more capacity has its memory is reused for blinded z later on - z := make([]fr.Element, pk.DomainSmall.Cardinality, pk.DomainSmall.Cardinality+3) - nbElmts := int(pk.DomainSmall.Cardinality) - gInv := make([]fr.Element, pk.DomainSmall.Cardinality) + z := make([]fr.Element, pk.Domain[0].Cardinality, pk.Domain[0].Cardinality+3) + nbElmts := int(pk.Domain[0].Cardinality) + gInv := make([]fr.Element, pk.Domain[0].Cardinality) z[0].SetOne() gInv[0].SetOne() - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) utils.Parallelize(nbElmts-1, func(start, end int) { @@ -596,10 +594,10 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma Mul(&z[i], &gInv[i]) } - pk.DomainSmall.FFTInverse(z, fft.DIF) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainSmall.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } @@ -614,22 +612,22 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO wg.Add(4) go func() { - evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain @@ -660,26 +658,26 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO // * gamma randomization func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - nbElmts := int(pk.DomainBig.Cardinality) + nbElmts := int(pk.Domain[1].Cardinality) // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) // on the big domain (coset). - res := make([]fr.Element, pk.DomainBig.Cardinality) + res := make([]fr.Element, pk.Domain[1].Cardinality) nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) var cosetShift, cosetShiftSquare fr.Element cosetShift.Set(&pk.Vk.CosetShift) cosetShiftSquare.Square(&pk.Vk.CosetShift) - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var evaluationIDBigDomain fr.Element - evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). - Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) var f [3]fr.Element var g [3]fr.Element @@ -703,7 +701,7 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Elemen res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) @@ -755,29 +753,29 @@ func evaluateXnMinusOneDomainBigCoset(domainBig, domainSmall *fft.Domain) []fr.E // constraintInd, constraintOrdering are evaluated on the big domain (coset). func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReversed, evaluationConstraintOrderingBitReversed, evaluationBlindedZDomainBigBitReversed []fr.Element, alpha fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - h := make([]fr.Element, pk.DomainBig.Cardinality) + h := make([]fr.Element, pk.Domain[1].Cardinality) // evaluate Z = Xᵐ-1 on a coset of the big domain - evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) - evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // CORRECT + evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.Domain[1], &pk.Domain[0]) + evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // computes L₁ (canonical form) - startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) - for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { - startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) + startsAtOne := make([]fr.Element, pk.Domain[1].Cardinality) + for i := 0; i < int(pk.Domain[0].Cardinality); i++ { + startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT + pk.Domain[1].FFT(startsAtOne, fft.DIF, true) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain - nn := uint64(64 - bits.TrailingZeros64(pk.DomainBig.Cardinality)) + nn := uint64(64 - bits.TrailingZeros64(pk.Domain[1].Cardinality)) var one fr.Element one.SetOne() - ratio := pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality + ratio := pk.Domain[1].Cardinality / pk.Domain[0].Cardinality - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var t fr.Element for i := uint64(start); i < uint64(end); i++ { @@ -794,14 +792,14 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainBig.FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainSmall.Cardinality+2] - h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] - h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - return h1, h2, h3 // CORRECT + return h1, h2, h3 } @@ -834,7 +832,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -845,12 +843,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainSmall.Cardinality) + nbElmt := int64(pk.Domain[0].Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -860,7 +858,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) diff --git a/internal/backend/bls24-315/plonk/setup.go b/internal/backend/bls24-315/plonk/setup.go index c37c7d7a60..c0fd45c8b2 100644 --- a/internal/backend/bls24-315/plonk/setup.go +++ b/internal/backend/bls24-315/plonk/setup.go @@ -45,8 +45,11 @@ type ProvingKey struct { // Storing LQk in Lagrange basis saves a fft... CQk, LQk []fr.Element - // Domains used for the FFTs - DomainSmall, DomainBig fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain // Permutation polynomials EvaluationPermutationBigDomainBitReversed []fr.Element @@ -94,21 +97,21 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainSmall = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainBig = *fft.NewDomain(8 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainBig = *fft.NewDomain(4 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainSmall.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainSmall.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) if err := pk.InitKZG(srs); err != nil { @@ -116,12 +119,12 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -143,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) - pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -206,7 +209,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainSmall.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -262,10 +265,10 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { // s1 (LDE) s2 (LDE) s3 (LDE) func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmts := int(pk.DomainSmall.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 pk.S1Canonical = make([]fr.Element, nbElmts) @@ -278,21 +281,21 @@ func ccomputePermutationPolynomials(pk *ProvingKey) { } // Canonical form of S1, S2, S3 - pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) fft.BitReverse(pk.S1Canonical) fft.BitReverse(pk.S2Canonical) fft.BitReverse(pk.S3Canonical) // evaluation of permutation on the big domain - pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) } @@ -314,7 +317,7 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { return res } -// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bls24-315/plonk/verify.go b/internal/backend/bls24-315/plonk/verify.go index 136239291a..1cdcb5ea06 100644 --- a/internal/backend/bls24-315/plonk/verify.go +++ b/internal/backend/bls24-315/plonk/verify.go @@ -103,13 +103,13 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne zu := proof.ZShiftedOpening.ClaimedValue - claimedQuotient := proof.BatchedProof.ClaimedValues[0] // CORRECT - linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] // CORRECT - l := proof.BatchedProof.ClaimedValues[2] // CORRECT - r := proof.BatchedProof.ClaimedValues[3] // CORRECT - o := proof.BatchedProof.ClaimedValues[4] // CORRECT - s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT - s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT + claimedQuotient := proof.BatchedProof.ClaimedValues[0] + linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] + l := proof.BatchedProof.ClaimedValues[2] + r := proof.BatchedProof.ClaimedValues[3] + o := proof.BatchedProof.ClaimedValues[4] + s1 := proof.BatchedProof.ClaimedValues[5] + s2 := proof.BatchedProof.ClaimedValues[6] // var beta fr.Element // beta.SetUint64(10) @@ -166,14 +166,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) - // CORRECT var u, v, w, cosetsquare fr.Element u.Mul(&zu, &beta) v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β - // CORRECT cosetsquare.Square(&vk.CosetShift) u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*μ*ζ+γ) diff --git a/internal/backend/bn254/plonk/marshal.go b/internal/backend/bn254/plonk/marshal.go index c51dd9c698..e4c4d7059f 100644 --- a/internal/backend/bn254/plonk/marshal.go +++ b/internal/backend/bn254/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainSmall.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainBig.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainSmall.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainSmall.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -140,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainSmall.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainBig.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ diff --git a/internal/backend/bn254/plonk/marshal_test.go b/internal/backend/bn254/plonk/marshal_test.go index b508906177..ceec7305b0 100644 --- a/internal/backend/bn254/plonk/marshal_test.go +++ b/internal/backend/bn254/plonk/marshal_test.go @@ -47,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainSmall = *fft.NewDomain(42) - pk.DomainBig = *fft.NewDomain(4 * 42) - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -62,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index a25cde4816..e672397a91 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -98,7 +98,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall, - &pk.DomainSmall) + &pk.Domain[0]) if err != nil { return nil, err } @@ -167,15 +167,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() @@ -183,10 +183,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) fft.BitReverse(qkCompletedCanonical) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain @@ -210,7 +210,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, return } - evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL @@ -228,14 +228,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { // CORRECT + if err := <-chConstraintOrdering; err != nil { return nil, err } - <-chConstraintInd // CORRECT + <-chConstraintInd // compute h in canonical form - h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { @@ -271,7 +271,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, &zetaShifted, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -312,7 +312,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) @@ -360,7 +360,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, }, &zeta, hFunc, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -486,36 +486,34 @@ func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bc func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - // totalDegree := rou + bo + totalDegree := rou + bo - // // re-use cp - // res := cp[:totalDegree+1] + // re-use cp + res := cp[:totalDegree+1] - // // random polynomial - // blindingPoly := make([]fr.Element, bo+1) - // for i := uint64(0); i < bo+1; i++ { - // if _, err := blindingPoly[i].SetRandom(); err != nil { - // return nil, err - // } - // } + // random polynomial + blindingPoly := make([]fr.Element, bo+1) + for i := uint64(0); i < bo+1; i++ { + if _, err := blindingPoly[i].SetRandom(); err != nil { + return nil, err + } + } - // // blinding - // for i := uint64(0); i < bo+1; i++ { - // res[i].Sub(&res[i], &blindingPoly[i]) - // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - // } + // blinding + for i := uint64(0); i < bo+1; i++ { + res[i].Sub(&res[i], &blindingPoly[i]) + res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + } - // return res, nil + return res, nil - // TODO reactivate blinding - return cp, nil } // evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainSmall.Cardinality) + s := int(pk.Domain[0].Cardinality) var l, r, o []fr.Element l = make([]fr.Element, s) @@ -558,14 +556,14 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma fr.Element) ([]fr.Element, error) { // note that z has more capacity has its memory is reused for blinded z later on - z := make([]fr.Element, pk.DomainSmall.Cardinality, pk.DomainSmall.Cardinality+3) - nbElmts := int(pk.DomainSmall.Cardinality) - gInv := make([]fr.Element, pk.DomainSmall.Cardinality) + z := make([]fr.Element, pk.Domain[0].Cardinality, pk.Domain[0].Cardinality+3) + nbElmts := int(pk.Domain[0].Cardinality) + gInv := make([]fr.Element, pk.Domain[0].Cardinality) z[0].SetOne() gInv[0].SetOne() - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) utils.Parallelize(nbElmts-1, func(start, end int) { @@ -596,10 +594,10 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma Mul(&z[i], &gInv[i]) } - pk.DomainSmall.FFTInverse(z, fft.DIF) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainSmall.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } @@ -614,22 +612,22 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO wg.Add(4) go func() { - evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain @@ -660,26 +658,26 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO // * gamma randomization func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - nbElmts := int(pk.DomainBig.Cardinality) + nbElmts := int(pk.Domain[1].Cardinality) // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) // on the big domain (coset). - res := make([]fr.Element, pk.DomainBig.Cardinality) + res := make([]fr.Element, pk.Domain[1].Cardinality) nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) var cosetShift, cosetShiftSquare fr.Element cosetShift.Set(&pk.Vk.CosetShift) cosetShiftSquare.Square(&pk.Vk.CosetShift) - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var evaluationIDBigDomain fr.Element - evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). - Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) var f [3]fr.Element var g [3]fr.Element @@ -703,7 +701,7 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Elemen res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) @@ -755,29 +753,29 @@ func evaluateXnMinusOneDomainBigCoset(domainBig, domainSmall *fft.Domain) []fr.E // constraintInd, constraintOrdering are evaluated on the big domain (coset). func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReversed, evaluationConstraintOrderingBitReversed, evaluationBlindedZDomainBigBitReversed []fr.Element, alpha fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - h := make([]fr.Element, pk.DomainBig.Cardinality) + h := make([]fr.Element, pk.Domain[1].Cardinality) // evaluate Z = Xᵐ-1 on a coset of the big domain - evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) - evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // CORRECT + evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.Domain[1], &pk.Domain[0]) + evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // computes L₁ (canonical form) - startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) - for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { - startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) + startsAtOne := make([]fr.Element, pk.Domain[1].Cardinality) + for i := 0; i < int(pk.Domain[0].Cardinality); i++ { + startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT + pk.Domain[1].FFT(startsAtOne, fft.DIF, true) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain - nn := uint64(64 - bits.TrailingZeros64(pk.DomainBig.Cardinality)) + nn := uint64(64 - bits.TrailingZeros64(pk.Domain[1].Cardinality)) var one fr.Element one.SetOne() - ratio := pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality + ratio := pk.Domain[1].Cardinality / pk.Domain[0].Cardinality - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var t fr.Element for i := uint64(start); i < uint64(end); i++ { @@ -794,14 +792,14 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainBig.FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainSmall.Cardinality+2] - h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] - h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - return h1, h2, h3 // CORRECT + return h1, h2, h3 } @@ -834,7 +832,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -845,12 +843,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainSmall.Cardinality) + nbElmt := int64(pk.Domain[0].Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -860,7 +858,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) diff --git a/internal/backend/bn254/plonk/setup.go b/internal/backend/bn254/plonk/setup.go index 1af1d30756..421f035bd0 100644 --- a/internal/backend/bn254/plonk/setup.go +++ b/internal/backend/bn254/plonk/setup.go @@ -45,8 +45,11 @@ type ProvingKey struct { // Storing LQk in Lagrange basis saves a fft... CQk, LQk []fr.Element - // Domains used for the FFTs - DomainSmall, DomainBig fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain // Permutation polynomials EvaluationPermutationBigDomainBitReversed []fr.Element @@ -94,21 +97,21 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainSmall = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainBig = *fft.NewDomain(8 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainBig = *fft.NewDomain(4 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainSmall.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainSmall.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) if err := pk.InitKZG(srs); err != nil { @@ -116,12 +119,12 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -143,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) - pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -206,7 +209,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainSmall.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -262,10 +265,10 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { // s1 (LDE) s2 (LDE) s3 (LDE) func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmts := int(pk.DomainSmall.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 pk.S1Canonical = make([]fr.Element, nbElmts) @@ -278,21 +281,21 @@ func ccomputePermutationPolynomials(pk *ProvingKey) { } // Canonical form of S1, S2, S3 - pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) fft.BitReverse(pk.S1Canonical) fft.BitReverse(pk.S2Canonical) fft.BitReverse(pk.S3Canonical) // evaluation of permutation on the big domain - pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) } @@ -314,7 +317,7 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { return res } -// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bn254/plonk/verify.go b/internal/backend/bn254/plonk/verify.go index 7ff483b226..c4ca3c638a 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/internal/backend/bn254/plonk/verify.go @@ -103,13 +103,13 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) zu := proof.ZShiftedOpening.ClaimedValue - claimedQuotient := proof.BatchedProof.ClaimedValues[0] // CORRECT - linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] // CORRECT - l := proof.BatchedProof.ClaimedValues[2] // CORRECT - r := proof.BatchedProof.ClaimedValues[3] // CORRECT - o := proof.BatchedProof.ClaimedValues[4] // CORRECT - s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT - s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT + claimedQuotient := proof.BatchedProof.ClaimedValues[0] + linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] + l := proof.BatchedProof.ClaimedValues[2] + r := proof.BatchedProof.ClaimedValues[3] + o := proof.BatchedProof.ClaimedValues[4] + s1 := proof.BatchedProof.ClaimedValues[5] + s2 := proof.BatchedProof.ClaimedValues[6] // var beta fr.Element // beta.SetUint64(10) @@ -166,14 +166,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) - // CORRECT var u, v, w, cosetsquare fr.Element u.Mul(&zu, &beta) v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β - // CORRECT cosetsquare.Square(&vk.CosetShift) u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*μ*ζ+γ) diff --git a/internal/backend/bw6-633/plonk/marshal.go b/internal/backend/bw6-633/plonk/marshal.go index fb13f653e6..8f2ad728cc 100644 --- a/internal/backend/bw6-633/plonk/marshal.go +++ b/internal/backend/bw6-633/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainSmall.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainBig.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainSmall.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainSmall.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -140,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainSmall.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainBig.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ diff --git a/internal/backend/bw6-633/plonk/marshal_test.go b/internal/backend/bw6-633/plonk/marshal_test.go index 0996615b2d..7ace49aef7 100644 --- a/internal/backend/bw6-633/plonk/marshal_test.go +++ b/internal/backend/bw6-633/plonk/marshal_test.go @@ -47,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainSmall = *fft.NewDomain(42) - pk.DomainBig = *fft.NewDomain(4 * 42) - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -62,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 diff --git a/internal/backend/bw6-633/plonk/prove.go b/internal/backend/bw6-633/plonk/prove.go index ebf3f3068c..6210f2c40b 100644 --- a/internal/backend/bw6-633/plonk/prove.go +++ b/internal/backend/bw6-633/plonk/prove.go @@ -98,7 +98,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall, - &pk.DomainSmall) + &pk.Domain[0]) if err != nil { return nil, err } @@ -167,15 +167,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() @@ -183,10 +183,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) fft.BitReverse(qkCompletedCanonical) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain @@ -210,7 +210,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes return } - evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL @@ -228,14 +228,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { // CORRECT + if err := <-chConstraintOrdering; err != nil { return nil, err } - <-chConstraintInd // CORRECT + <-chConstraintInd // compute h in canonical form - h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { @@ -271,7 +271,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, &zetaShifted, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -312,7 +312,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) @@ -360,7 +360,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes }, &zeta, hFunc, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -486,36 +486,34 @@ func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bc func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - // totalDegree := rou + bo + totalDegree := rou + bo - // // re-use cp - // res := cp[:totalDegree+1] + // re-use cp + res := cp[:totalDegree+1] - // // random polynomial - // blindingPoly := make([]fr.Element, bo+1) - // for i := uint64(0); i < bo+1; i++ { - // if _, err := blindingPoly[i].SetRandom(); err != nil { - // return nil, err - // } - // } + // random polynomial + blindingPoly := make([]fr.Element, bo+1) + for i := uint64(0); i < bo+1; i++ { + if _, err := blindingPoly[i].SetRandom(); err != nil { + return nil, err + } + } - // // blinding - // for i := uint64(0); i < bo+1; i++ { - // res[i].Sub(&res[i], &blindingPoly[i]) - // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - // } + // blinding + for i := uint64(0); i < bo+1; i++ { + res[i].Sub(&res[i], &blindingPoly[i]) + res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + } - // return res, nil + return res, nil - // TODO reactivate blinding - return cp, nil } // evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainSmall.Cardinality) + s := int(pk.Domain[0].Cardinality) var l, r, o []fr.Element l = make([]fr.Element, s) @@ -558,14 +556,14 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma fr.Element) ([]fr.Element, error) { // note that z has more capacity has its memory is reused for blinded z later on - z := make([]fr.Element, pk.DomainSmall.Cardinality, pk.DomainSmall.Cardinality+3) - nbElmts := int(pk.DomainSmall.Cardinality) - gInv := make([]fr.Element, pk.DomainSmall.Cardinality) + z := make([]fr.Element, pk.Domain[0].Cardinality, pk.Domain[0].Cardinality+3) + nbElmts := int(pk.Domain[0].Cardinality) + gInv := make([]fr.Element, pk.Domain[0].Cardinality) z[0].SetOne() gInv[0].SetOne() - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) utils.Parallelize(nbElmts-1, func(start, end int) { @@ -596,10 +594,10 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma Mul(&z[i], &gInv[i]) } - pk.DomainSmall.FFTInverse(z, fft.DIF) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainSmall.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } @@ -614,22 +612,22 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO wg.Add(4) go func() { - evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain @@ -660,26 +658,26 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO // * gamma randomization func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - nbElmts := int(pk.DomainBig.Cardinality) + nbElmts := int(pk.Domain[1].Cardinality) // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) // on the big domain (coset). - res := make([]fr.Element, pk.DomainBig.Cardinality) + res := make([]fr.Element, pk.Domain[1].Cardinality) nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) var cosetShift, cosetShiftSquare fr.Element cosetShift.Set(&pk.Vk.CosetShift) cosetShiftSquare.Square(&pk.Vk.CosetShift) - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var evaluationIDBigDomain fr.Element - evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). - Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) var f [3]fr.Element var g [3]fr.Element @@ -703,7 +701,7 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Elemen res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) @@ -755,29 +753,29 @@ func evaluateXnMinusOneDomainBigCoset(domainBig, domainSmall *fft.Domain) []fr.E // constraintInd, constraintOrdering are evaluated on the big domain (coset). func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReversed, evaluationConstraintOrderingBitReversed, evaluationBlindedZDomainBigBitReversed []fr.Element, alpha fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - h := make([]fr.Element, pk.DomainBig.Cardinality) + h := make([]fr.Element, pk.Domain[1].Cardinality) // evaluate Z = Xᵐ-1 on a coset of the big domain - evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) - evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // CORRECT + evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.Domain[1], &pk.Domain[0]) + evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // computes L₁ (canonical form) - startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) - for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { - startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) + startsAtOne := make([]fr.Element, pk.Domain[1].Cardinality) + for i := 0; i < int(pk.Domain[0].Cardinality); i++ { + startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT + pk.Domain[1].FFT(startsAtOne, fft.DIF, true) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain - nn := uint64(64 - bits.TrailingZeros64(pk.DomainBig.Cardinality)) + nn := uint64(64 - bits.TrailingZeros64(pk.Domain[1].Cardinality)) var one fr.Element one.SetOne() - ratio := pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality + ratio := pk.Domain[1].Cardinality / pk.Domain[0].Cardinality - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var t fr.Element for i := uint64(start); i < uint64(end); i++ { @@ -794,14 +792,14 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainBig.FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainSmall.Cardinality+2] - h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] - h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - return h1, h2, h3 // CORRECT + return h1, h2, h3 } @@ -834,7 +832,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -845,12 +843,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainSmall.Cardinality) + nbElmt := int64(pk.Domain[0].Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -860,7 +858,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) diff --git a/internal/backend/bw6-633/plonk/setup.go b/internal/backend/bw6-633/plonk/setup.go index 6ac99c4cae..f7f6fa581c 100644 --- a/internal/backend/bw6-633/plonk/setup.go +++ b/internal/backend/bw6-633/plonk/setup.go @@ -45,8 +45,11 @@ type ProvingKey struct { // Storing LQk in Lagrange basis saves a fft... CQk, LQk []fr.Element - // Domains used for the FFTs - DomainSmall, DomainBig fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain // Permutation polynomials EvaluationPermutationBigDomainBitReversed []fr.Element @@ -94,21 +97,21 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainSmall = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainBig = *fft.NewDomain(8 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainBig = *fft.NewDomain(4 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainSmall.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainSmall.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) if err := pk.InitKZG(srs); err != nil { @@ -116,12 +119,12 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -143,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) - pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -206,7 +209,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainSmall.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -262,10 +265,10 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { // s1 (LDE) s2 (LDE) s3 (LDE) func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmts := int(pk.DomainSmall.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 pk.S1Canonical = make([]fr.Element, nbElmts) @@ -278,21 +281,21 @@ func ccomputePermutationPolynomials(pk *ProvingKey) { } // Canonical form of S1, S2, S3 - pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) fft.BitReverse(pk.S1Canonical) fft.BitReverse(pk.S2Canonical) fft.BitReverse(pk.S3Canonical) // evaluation of permutation on the big domain - pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) } @@ -314,7 +317,7 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { return res } -// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bw6-633/plonk/verify.go b/internal/backend/bw6-633/plonk/verify.go index a0365c311e..c15c433778 100644 --- a/internal/backend/bw6-633/plonk/verify.go +++ b/internal/backend/bw6-633/plonk/verify.go @@ -103,13 +103,13 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness zu := proof.ZShiftedOpening.ClaimedValue - claimedQuotient := proof.BatchedProof.ClaimedValues[0] // CORRECT - linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] // CORRECT - l := proof.BatchedProof.ClaimedValues[2] // CORRECT - r := proof.BatchedProof.ClaimedValues[3] // CORRECT - o := proof.BatchedProof.ClaimedValues[4] // CORRECT - s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT - s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT + claimedQuotient := proof.BatchedProof.ClaimedValues[0] + linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] + l := proof.BatchedProof.ClaimedValues[2] + r := proof.BatchedProof.ClaimedValues[3] + o := proof.BatchedProof.ClaimedValues[4] + s1 := proof.BatchedProof.ClaimedValues[5] + s2 := proof.BatchedProof.ClaimedValues[6] // var beta fr.Element // beta.SetUint64(10) @@ -166,14 +166,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) - // CORRECT var u, v, w, cosetsquare fr.Element u.Mul(&zu, &beta) v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β - // CORRECT cosetsquare.Square(&vk.CosetShift) u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*μ*ζ+γ) diff --git a/internal/backend/bw6-761/plonk/marshal.go b/internal/backend/bw6-761/plonk/marshal.go index fa4161a746..fe81cc4e0d 100644 --- a/internal/backend/bw6-761/plonk/marshal.go +++ b/internal/backend/bw6-761/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainSmall.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainBig.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainSmall.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainSmall.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -140,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainSmall.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainBig.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ diff --git a/internal/backend/bw6-761/plonk/marshal_test.go b/internal/backend/bw6-761/plonk/marshal_test.go index a5d505856a..2b47f5e50a 100644 --- a/internal/backend/bw6-761/plonk/marshal_test.go +++ b/internal/backend/bw6-761/plonk/marshal_test.go @@ -47,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainSmall = *fft.NewDomain(42) - pk.DomainBig = *fft.NewDomain(4 * 42) - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -62,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go index 36e92fd481..e04a6c6a02 100644 --- a/internal/backend/bw6-761/plonk/prove.go +++ b/internal/backend/bw6-761/plonk/prove.go @@ -98,7 +98,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall, - &pk.DomainSmall) + &pk.Domain[0]) if err != nil { return nil, err } @@ -167,15 +167,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() @@ -183,10 +183,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) fft.BitReverse(qkCompletedCanonical) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain @@ -210,7 +210,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes return } - evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) // CORRECT + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL @@ -228,14 +228,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes close(chConstraintOrdering) }() - if err := <-chConstraintOrdering; err != nil { // CORRECT + if err := <-chConstraintOrdering; err != nil { return nil, err } - <-chConstraintInd // CORRECT + <-chConstraintInd // compute h in canonical form - h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // CORRECT + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { @@ -271,7 +271,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, &zetaShifted, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -312,7 +312,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) @@ -360,7 +360,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes }, &zeta, hFunc, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -486,36 +486,34 @@ func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bc func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) - // totalDegree := rou + bo + totalDegree := rou + bo - // // re-use cp - // res := cp[:totalDegree+1] + // re-use cp + res := cp[:totalDegree+1] - // // random polynomial - // blindingPoly := make([]fr.Element, bo+1) - // for i := uint64(0); i < bo+1; i++ { - // if _, err := blindingPoly[i].SetRandom(); err != nil { - // return nil, err - // } - // } + // random polynomial + blindingPoly := make([]fr.Element, bo+1) + for i := uint64(0); i < bo+1; i++ { + if _, err := blindingPoly[i].SetRandom(); err != nil { + return nil, err + } + } - // // blinding - // for i := uint64(0); i < bo+1; i++ { - // res[i].Sub(&res[i], &blindingPoly[i]) - // res[rou+i].Add(&res[rou+i], &blindingPoly[i]) - // } + // blinding + for i := uint64(0); i < bo+1; i++ { + res[i].Sub(&res[i], &blindingPoly[i]) + res[rou+i].Add(&res[rou+i], &blindingPoly[i]) + } - // return res, nil + return res, nil - // TODO reactivate blinding - return cp, nil } // evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainSmall.Cardinality) + s := int(pk.Domain[0].Cardinality) var l, r, o []fr.Element l = make([]fr.Element, s) @@ -558,14 +556,14 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma fr.Element) ([]fr.Element, error) { // note that z has more capacity has its memory is reused for blinded z later on - z := make([]fr.Element, pk.DomainSmall.Cardinality, pk.DomainSmall.Cardinality+3) - nbElmts := int(pk.DomainSmall.Cardinality) - gInv := make([]fr.Element, pk.DomainSmall.Cardinality) + z := make([]fr.Element, pk.Domain[0].Cardinality, pk.Domain[0].Cardinality+3) + nbElmts := int(pk.Domain[0].Cardinality) + gInv := make([]fr.Element, pk.Domain[0].Cardinality) z[0].SetOne() gInv[0].SetOne() - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) utils.Parallelize(nbElmts-1, func(start, end int) { @@ -596,10 +594,10 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma Mul(&z[i], &gInv[i]) } - pk.DomainSmall.FFTInverse(z, fft.DIF) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainSmall.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } @@ -614,22 +612,22 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO wg.Add(4) go func() { - evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain @@ -660,26 +658,26 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO // * gamma randomization func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - nbElmts := int(pk.DomainBig.Cardinality) + nbElmts := int(pk.Domain[1].Cardinality) // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) // on the big domain (coset). - res := make([]fr.Element, pk.DomainBig.Cardinality) + res := make([]fr.Element, pk.Domain[1].Cardinality) nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) var cosetShift, cosetShiftSquare fr.Element cosetShift.Set(&pk.Vk.CosetShift) cosetShiftSquare.Square(&pk.Vk.CosetShift) - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var evaluationIDBigDomain fr.Element - evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). - Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) var f [3]fr.Element var g [3]fr.Element @@ -703,7 +701,7 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Elemen res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) @@ -755,29 +753,29 @@ func evaluateXnMinusOneDomainBigCoset(domainBig, domainSmall *fft.Domain) []fr.E // constraintInd, constraintOrdering are evaluated on the big domain (coset). func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReversed, evaluationConstraintOrderingBitReversed, evaluationBlindedZDomainBigBitReversed []fr.Element, alpha fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - h := make([]fr.Element, pk.DomainBig.Cardinality) + h := make([]fr.Element, pk.Domain[1].Cardinality) // evaluate Z = Xᵐ-1 on a coset of the big domain - evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) - evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // CORRECT + evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.Domain[1], &pk.Domain[0]) + evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // computes L₁ (canonical form) - startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) - for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { - startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) + startsAtOne := make([]fr.Element, pk.Domain[1].Cardinality) + for i := 0; i < int(pk.Domain[0].Cardinality); i++ { + startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.DomainBig.FFT(startsAtOne, fft.DIF, true) // CORRECT + pk.Domain[1].FFT(startsAtOne, fft.DIF, true) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain - nn := uint64(64 - bits.TrailingZeros64(pk.DomainBig.Cardinality)) + nn := uint64(64 - bits.TrailingZeros64(pk.Domain[1].Cardinality)) var one fr.Element one.SetOne() - ratio := pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality + ratio := pk.Domain[1].Cardinality / pk.Domain[0].Cardinality - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var t fr.Element for i := uint64(start); i < uint64(end); i++ { @@ -794,14 +792,14 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainBig.FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainSmall.Cardinality+2] - h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] - h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] - return h1, h2, h3 // CORRECT + return h1, h2, h3 } @@ -834,7 +832,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, tmp := eval(pk.S2Canonical, zeta) // s2(ζ) tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) // CORRECT + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) var uzeta, uuzeta fr.Element uzeta.Mul(&zeta, &pk.Vk.CosetShift) @@ -845,12 +843,12 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) - s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // CORRECT + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainSmall.Cardinality) + nbElmt := int64(pk.Domain[0].Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -860,7 +858,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) // CORRECT + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) diff --git a/internal/backend/bw6-761/plonk/setup.go b/internal/backend/bw6-761/plonk/setup.go index 36a18a3ba8..946d153c3f 100644 --- a/internal/backend/bw6-761/plonk/setup.go +++ b/internal/backend/bw6-761/plonk/setup.go @@ -45,8 +45,11 @@ type ProvingKey struct { // Storing LQk in Lagrange basis saves a fft... CQk, LQk []fr.Element - // Domains used for the FFTs - DomainSmall, DomainBig fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain // Permutation polynomials EvaluationPermutationBigDomainBitReversed []fr.Element @@ -94,21 +97,21 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainSmall = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainBig = *fft.NewDomain(8 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainBig = *fft.NewDomain(4 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainSmall.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainSmall.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) if err := pk.InitKZG(srs); err != nil { @@ -116,12 +119,12 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -143,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) - pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -206,7 +209,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainSmall.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -262,10 +265,10 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { // s1 (LDE) s2 (LDE) s3 (LDE) func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmts := int(pk.DomainSmall.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 pk.S1Canonical = make([]fr.Element, nbElmts) @@ -278,21 +281,21 @@ func ccomputePermutationPolynomials(pk *ProvingKey) { } // Canonical form of S1, S2, S3 - pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) fft.BitReverse(pk.S1Canonical) fft.BitReverse(pk.S2Canonical) fft.BitReverse(pk.S3Canonical) // evaluation of permutation on the big domain - pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) } @@ -314,7 +317,7 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { return res } -// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bw6-761/plonk/verify.go b/internal/backend/bw6-761/plonk/verify.go index b08c4faa47..139ca0e173 100644 --- a/internal/backend/bw6-761/plonk/verify.go +++ b/internal/backend/bw6-761/plonk/verify.go @@ -103,13 +103,13 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness zu := proof.ZShiftedOpening.ClaimedValue - claimedQuotient := proof.BatchedProof.ClaimedValues[0] // CORRECT - linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] // CORRECT - l := proof.BatchedProof.ClaimedValues[2] // CORRECT - r := proof.BatchedProof.ClaimedValues[3] // CORRECT - o := proof.BatchedProof.ClaimedValues[4] // CORRECT - s1 := proof.BatchedProof.ClaimedValues[5] // CORRECT - s2 := proof.BatchedProof.ClaimedValues[6] // CORRECT + claimedQuotient := proof.BatchedProof.ClaimedValues[0] + linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] + l := proof.BatchedProof.ClaimedValues[2] + r := proof.BatchedProof.ClaimedValues[3] + o := proof.BatchedProof.ClaimedValues[4] + s1 := proof.BatchedProof.ClaimedValues[5] + s2 := proof.BatchedProof.ClaimedValues[6] // var beta fr.Element // beta.SetUint64(10) @@ -166,14 +166,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) - // CORRECT var u, v, w, cosetsquare fr.Element u.Mul(&zu, &beta) v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β - // CORRECT cosetsquare.Square(&vk.CosetShift) u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*μ*ζ+γ) diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl index a3c70a5541..5e03ff6fb2 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl @@ -70,20 +70,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainSmall.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainBig.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainSmall.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainSmall.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -121,19 +121,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainSmall.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainBig.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index 0d3e728755..a5a31b8645 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -75,7 +75,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall, - &pk.DomainSmall) + &pk.Domain[0]) if err != nil { return nil, err } @@ -144,15 +144,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.DomainBig) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.DomainBig) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.DomainBig) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() @@ -160,10 +160,10 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qkCompletedCanonical := make([]fr.Element, pk.DomainSmall.Cardinality) + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainSmall.FFTInverse(qkCompletedCanonical, fft.DIF) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) fft.BitReverse(qkCompletedCanonical) // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain @@ -187,7 +187,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } return } - evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.DomainBig) + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL @@ -248,7 +248,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, &zetaShifted, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -289,7 +289,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainSmall.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) @@ -337,7 +337,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } }, &zeta, hFunc, - &pk.DomainBig, + &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { @@ -490,7 +490,7 @@ func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // solution = [ public | secret | internal ] func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainSmall.Cardinality) + s := int(pk.Domain[0].Cardinality) var l, r, o []fr.Element l = make([]fr.Element, s) @@ -533,14 +533,14 @@ func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.El func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma fr.Element) ([]fr.Element, error) { // note that z has more capacity has its memory is reused for blinded z later on - z := make([]fr.Element, pk.DomainSmall.Cardinality, pk.DomainSmall.Cardinality+3) - nbElmts := int(pk.DomainSmall.Cardinality) - gInv := make([]fr.Element, pk.DomainSmall.Cardinality) + z := make([]fr.Element, pk.Domain[0].Cardinality, pk.Domain[0].Cardinality+3) + nbElmts := int(pk.Domain[0].Cardinality) + gInv := make([]fr.Element, pk.Domain[0].Cardinality) z[0].SetOne() gInv[0].SetOne() - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) utils.Parallelize(nbElmts-1, func(start, end int) { @@ -571,10 +571,10 @@ func computeBlindedZCanonical(l, r, o []fr.Element, pk *ProvingKey, beta, gamma Mul(&z[i], &gInv[i]) } - pk.DomainSmall.FFTInverse(z, fft.DIF) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainSmall.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } @@ -589,22 +589,22 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO wg.Add(4) go func() { - evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.DomainBig) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.DomainBig) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.DomainBig) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.DomainBig) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateDomainBigBitReversed(qk, &pk.DomainBig) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain @@ -635,26 +635,26 @@ func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO // * gamma randomization func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - nbElmts := int(pk.DomainBig.Cardinality) + nbElmts := int(pk.Domain[1].Cardinality) // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) // on the big domain (coset). - res := make([]fr.Element, pk.DomainBig.Cardinality) + res := make([]fr.Element, pk.Domain[1].Cardinality) nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := int(pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality) + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) var cosetShift, cosetShiftSquare fr.Element cosetShift.Set(&pk.Vk.CosetShift) cosetShiftSquare.Square(&pk.Vk.CosetShift) - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var evaluationIDBigDomain fr.Element - evaluationIDBigDomain.Exp(pk.DomainBig.Generator, big.NewInt(int64(start))). - Mul(&evaluationIDBigDomain, &pk.DomainBig.FrMultiplicativeGen) + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) var f [3]fr.Element var g [3]fr.Element @@ -678,7 +678,7 @@ func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Elemen res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.DomainBig.Generator) // gⁱ*g + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) @@ -730,29 +730,29 @@ func evaluateXnMinusOneDomainBigCoset(domainBig, domainSmall *fft.Domain) []fr.E // constraintInd, constraintOrdering are evaluated on the big domain (coset). func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReversed, evaluationConstraintOrderingBitReversed, evaluationBlindedZDomainBigBitReversed []fr.Element, alpha fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - h := make([]fr.Element, pk.DomainBig.Cardinality) + h := make([]fr.Element, pk.Domain[1].Cardinality) // evaluate Z = Xᵐ-1 on a coset of the big domain - evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.DomainBig, &pk.DomainSmall) + evaluationXnMinusOneInverse := evaluateXnMinusOneDomainBigCoset(&pk.Domain[1], &pk.Domain[0]) evaluationXnMinusOneInverse = fr.BatchInvert(evaluationXnMinusOneInverse) // computes L₁ (canonical form) - startsAtOne := make([]fr.Element, pk.DomainBig.Cardinality) - for i := 0; i < int(pk.DomainSmall.Cardinality); i++ { - startsAtOne[i].Set(&pk.DomainSmall.CardinalityInv) + startsAtOne := make([]fr.Element, pk.Domain[1].Cardinality) + for i := 0; i < int(pk.Domain[0].Cardinality); i++ { + startsAtOne[i].Set(&pk.Domain[0].CardinalityInv) } - pk.DomainBig.FFT(startsAtOne, fft.DIF, true) + pk.Domain[1].FFT(startsAtOne, fft.DIF, true) // ql(X)L(X)+qr(X)R(X)+qm(X)L(X)R(X)+qo(X)O(X)+k(X) + α.(z(μX)*g₁(X)*g₂(X)*g₃(X)-z(X)*f₁(X)*f₂(X)*f₃(X)) + α**2*L₁(X)(Z(X)-1) // on a coset of the big domain - nn := uint64(64 - bits.TrailingZeros64(pk.DomainBig.Cardinality)) + nn := uint64(64 - bits.TrailingZeros64(pk.Domain[1].Cardinality)) var one fr.Element one.SetOne() - ratio := pk.DomainBig.Cardinality / pk.DomainSmall.Cardinality + ratio := pk.Domain[1].Cardinality / pk.Domain[0].Cardinality - utils.Parallelize(int(pk.DomainBig.Cardinality), func(start, end int) { + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { var t fr.Element for i := uint64(start); i < uint64(end); i++ { @@ -769,12 +769,12 @@ func computeQuotientCanonical(pk *ProvingKey, evaluationConstraintsIndBitReverse // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainBig.FFTInverse(h, fft.DIT, true) + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainSmall.Cardinality+2] - h2 := h[pk.DomainSmall.Cardinality+2 : 2*(pk.DomainSmall.Cardinality+2)] - h3 := h[2*(pk.DomainSmall.Cardinality+2) : 3*(pk.DomainSmall.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] return h1, h2, h3 @@ -825,7 +825,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, // third part L₁(ζ)*α²*Z var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainSmall.Cardinality) + nbElmt := int64(pk.Domain[0].Cardinality) lagrangeZeta.Set(&zeta). Exp(lagrangeZeta, big.NewInt(nbElmt)). Sub(&lagrangeZeta, &one) @@ -835,7 +835,7 @@ func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) Mul(&lagrangeZeta, &alpha). Mul(&lagrangeZeta, &alpha). - Mul(&lagrangeZeta, &pk.DomainSmall.CardinalityInv) // (1/n)*α²*L₁(ζ) + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) linPol := make([]fr.Element, len(blindedZCanonical)) copy(linPol, blindedZCanonical) diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl index bf514c2e13..25458538e9 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl @@ -27,8 +27,11 @@ type ProvingKey struct { // Storing LQk in Lagrange basis saves a fft... CQk, LQk []fr.Element - // Domains used for the FFTs - DomainSmall, DomainBig fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain // Permutation polynomials EvaluationPermutationBigDomainBitReversed []fr.Element @@ -76,21 +79,21 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainSmall = *fft.NewDomain(sizeSystem) - pk.Vk.CosetShift.Set(&pk.DomainSmall.FrMultiplicativeGen) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainBig = *fft.NewDomain(8 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainBig = *fft.NewDomain(4 * sizeSystem) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainSmall.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainSmall.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) if err := pk.InitKZG(srs); err != nil { @@ -98,12 +101,12 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -125,11 +128,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainSmall.FFTInverse(pk.Ql, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qr, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qm, fft.DIF) - pk.DomainSmall.FFTInverse(pk.Qo, fft.DIF) - pk.DomainSmall.FFTInverse(pk.CQk, fft.DIF) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -188,7 +191,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainSmall.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -244,10 +247,10 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { // s1 (LDE) s2 (LDE) s3 (LDE) func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmts := int(pk.DomainSmall.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) // Lagrange form of ID - evaluationIDSmallDomain := getIDSmallDomain(&pk.DomainSmall) + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 pk.S1Canonical = make([]fr.Element, nbElmts) @@ -260,21 +263,21 @@ func ccomputePermutationPolynomials(pk *ProvingKey) { } // Canonical form of S1, S2, S3 - pk.DomainSmall.FFTInverse(pk.S1Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S2Canonical, fft.DIF) - pk.DomainSmall.FFTInverse(pk.S3Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) fft.BitReverse(pk.S1Canonical) fft.BitReverse(pk.S2Canonical) fft.BitReverse(pk.S3Canonical) // evaluation of permutation on the big domain - pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.DomainBig.Cardinality) + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:], pk.S2Canonical) - copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], pk.S3Canonical) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.DomainBig.Cardinality:2*pk.DomainBig.Cardinality], fft.DIF, true) - pk.DomainBig.FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.DomainBig.Cardinality:], fft.DIF, true) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) } @@ -296,7 +299,7 @@ func getIDSmallDomain(domain *fft.Domain) []fr.Element { return res } -// InitKZG inits pk.Vk.KZG using pk.DomainSmall cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl index 1c64d0f118..f73acdb89c 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl @@ -82,13 +82,13 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} zu := proof.ZShiftedOpening.ClaimedValue - claimedQuotient := proof.BatchedProof.ClaimedValues[0] - linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] - l := proof.BatchedProof.ClaimedValues[2] - r := proof.BatchedProof.ClaimedValues[3] - o := proof.BatchedProof.ClaimedValues[4] - s1 := proof.BatchedProof.ClaimedValues[5] - s2 := proof.BatchedProof.ClaimedValues[6] + claimedQuotient := proof.BatchedProof.ClaimedValues[0] + linearizedPolynomialZeta := proof.BatchedProof.ClaimedValues[1] + l := proof.BatchedProof.ClaimedValues[2] + r := proof.BatchedProof.ClaimedValues[3] + o := proof.BatchedProof.ClaimedValues[4] + s1 := proof.BatchedProof.ClaimedValues[5] + s2 := proof.BatchedProof.ClaimedValues[6] // var beta fr.Element // beta.SetUint64(10) @@ -145,14 +145,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} // second part: α*( Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β*s₃(X)-Z(X)(l(ζ)+β*id_1(ζ)+γ)*(r(ζ)+β*id_2(ζ)+γ)*(o(ζ)+β*id_3(ζ)+γ) ) ) - var u, v, w, cosetsquare fr.Element u.Mul(&zu, &beta) v.Mul(&beta, &s1).Add(&v, &l).Add(&v, &gamma) w.Mul(&beta, &s2).Add(&w, &r).Add(&w, &gamma) _s1.Mul(&u, &v).Mul(&_s1, &w).Mul(&_s1, &alpha) // α*Z(μζ)(l(ζ)+β*s₁(ζ)+γ)*(r(ζ)+β*s₂(ζ)+γ)*β - cosetsquare.Square(&vk.CosetShift) u.Mul(&beta, &zeta).Add(&u, &l).Add(&u, &gamma) // (l(ζ)+β*ζ+γ) v.Mul(&beta, &zeta).Mul(&v, &vk.CosetShift).Add(&v, &r).Add(&v, &gamma) // (r(ζ)+β*μ*ζ+γ) diff --git a/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl index 7644e9e55c..a289d09f4e 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/tests/marshal.go.tmpl @@ -28,14 +28,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainSmall = *fft.NewDomain(42) - pk.DomainBig = *fft.NewDomain(4 * 42) - pk.Ql = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainSmall.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainSmall.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -43,7 +43,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainSmall.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 From 0c07a1ff551263a6a07b59e9693a3d448bceb6d4 Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Mon, 14 Feb 2022 12:39:55 +0100 Subject: [PATCH 26/37] perf(bandersnatch): apply tEd perf changes to Bandersnatch --- .../twistededwards/bandersnatch/curve.go | 18 ++- .../twistededwards/bandersnatch/point.go | 120 ++++++++---------- .../twistededwards/bandersnatch/point_test.go | 92 +++++++++----- 3 files changed, 126 insertions(+), 104 deletions(-) diff --git a/std/algebra/twistededwards/bandersnatch/curve.go b/std/algebra/twistededwards/bandersnatch/curve.go index 15caf30980..0cfb3c071e 100644 --- a/std/algebra/twistededwards/bandersnatch/curve.go +++ b/std/algebra/twistededwards/bandersnatch/curve.go @@ -25,10 +25,16 @@ import ( "github.com/consensys/gnark/internal/utils" ) +// Coordinates of a point on a twisted Edwards curve +type Coord struct { + X, Y big.Int +} + // EdCurve stores the info on the chosen edwards curve type EdCurve struct { - A, D, Cofactor, Order, BaseX, BaseY big.Int - ID ecc.ID + A, D, Cofactor, Order big.Int + Base Coord + ID ecc.ID } var constructors map[ecc.ID]func() EdCurve @@ -60,9 +66,11 @@ func newBandersnatch() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BLS12_381, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BLS12_381, } } diff --git a/std/algebra/twistededwards/bandersnatch/point.go b/std/algebra/twistededwards/bandersnatch/point.go index ddfb8ac52a..052dcd1580 100644 --- a/std/algebra/twistededwards/bandersnatch/point.go +++ b/std/algebra/twistededwards/bandersnatch/point.go @@ -27,6 +27,13 @@ type Point struct { X, Y frontend.Variable } +// Neg computes the negative of a point in SNARK coordinates +func (p *Point) Neg(api frontend.API, p1 *Point) *Point { + p.X = api.Neg(p1.X) + p.Y = p1.Y + return p +} + // MustBeOnCurve checks if a point is on the reduced twisted Edwards curve // a*x^2 + y^2 = 1 + d*x^2*y^2. func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { @@ -46,34 +53,9 @@ func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { } -// AddFixedPoint Adds two points, among which is one fixed point (the base), on a twisted edwards curve (eg jubjub) -// p1, base, ecurve are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) AddFixedPoint(api frontend.API, p1 *Point /*basex*/, x /*basey*/, y interface{}, curve EdCurve) *Point { - - // https://eprint.iacr.org/2008/013.pdf - - n11 := api.Mul(p1.X, y) - n12 := api.Mul(p1.Y, x) - n1 := api.Add(n11, n12) - - n21 := api.Mul(p1.Y, y) - n22 := api.Mul(p1.X, x) - an22 := api.Mul(n22, &curve.A) - n2 := api.Sub(n21, an22) - - d11 := api.Mul(curve.D, n11, n12) - d1 := api.Add(1, d11) - d2 := api.Sub(1, d11) - - p.X = api.DivUnchecked(n1, d1) - p.Y = api.DivUnchecked(n2, d2) - - return p -} - -// AddGeneric Adds two points on a twisted edwards curve (eg jubjub) +// Add Adds two points on a twisted edwards curve (eg jubjub) // p1, p2, c are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) AddGeneric(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { +func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { // https://eprint.iacr.org/2008/013.pdf @@ -103,14 +85,12 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { u := api.Mul(p1.X, p1.Y) v := api.Mul(p1.X, p1.X) w := api.Mul(p1.Y, p1.Y) - z := api.Mul(v, w) n1 := api.Mul(2, u) av := api.Mul(v, &curve.A) n2 := api.Sub(w, av) - d := api.Mul(z, curve.D) - d1 := api.Add(1, d) - d2 := api.Sub(1, d) + d1 := api.Add(w, av) + d2 := api.Sub(2, d1) p.X = api.DivUnchecked(n1, d1) p.Y = api.DivUnchecked(n2, d2) @@ -118,55 +98,72 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { return p } -// ScalarMulNonFixedBase computes the scalar multiplication of a point on a twisted Edwards curve +// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve // p1: base point (as snark point) // curve: parameters of the Edwards curve // scal: scalar as a SNARK constraint // Standard left to right double and add -func (p *Point) ScalarMulNonFixedBase(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { +func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { // first unpack the scalar b := api.ToBinary(scalar) - res := Point{ - 0, - 1, + res := Point{} + tmp := Point{} + A := Point{} + B := Point{} + + A.Double(api, p1, curve) + B.Add(api, &A, p1, curve) + + n := len(b) - 1 + res.X = api.Lookup2(b[n], b[n-1], 0, A.X, p1.X, B.X) + res.Y = api.Lookup2(b[n], b[n-1], 1, A.Y, p1.Y, B.Y) + + for i := n - 2; i >= 1; i -= 2 { + res.Double(api, &res, curve). + Double(api, &res, curve) + tmp.X = api.Lookup2(b[i], b[i-1], 0, A.X, p1.X, B.X) + tmp.Y = api.Lookup2(b[i], b[i-1], 1, A.Y, p1.Y, B.Y) + res.Add(api, &res, &tmp, curve) } - for i := len(b) - 1; i >= 0; i-- { + if n%2 == 0 { res.Double(api, &res, curve) - tmp := Point{} - tmp.AddGeneric(api, &res, p1, curve) - res.X = api.Select(b[i], tmp.X, res.X) - res.Y = api.Select(b[i], tmp.Y, res.Y) + tmp.Add(api, &res, p1, curve) + res.X = api.Select(b[0], tmp.X, res.X) + res.Y = api.Select(b[0], tmp.Y, res.Y) } p.X = res.X p.Y = res.Y + return p } -// ScalarMulFixedBase computes the scalar multiplication of a point on a twisted Edwards curve -// x, y: coordinates of the base point -// curve: parameters of the Edwards curve -// scal: scalar as a SNARK constraint -// Standard left to right double and add -func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar frontend.Variable, curve EdCurve) *Point { +// DoubleBaseScalarMul computes s1*P1+s2*P2 +// where P1 and P2 are points on a twisted Edwards curve +// and s1, s2 scalars. +func (p *Point) DoubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve EdCurve) *Point { - // first unpack the scalar - b := api.ToBinary(scalar) + // first unpack the scalars + b1 := api.ToBinary(s1) + b2 := api.ToBinary(s2) - res := Point{ - 0, - 1, - } + res := Point{} + tmp := Point{} + sum := Point{} + sum.Add(api, p1, p2, curve) + + n := len(b1) + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, sum.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, sum.Y) - for i := len(b) - 1; i >= 0; i-- { + for i := n - 2; i >= 0; i-- { res.Double(api, &res, curve) - tmp := Point{} - tmp.AddFixedPoint(api, &res, x, y, curve) - res.X = api.Select(b[i], tmp.X, res.X) - res.Y = api.Select(b[i], tmp.Y, res.Y) + tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, sum.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, sum.Y) + res.Add(api, &res, &tmp, curve) } p.X = res.X @@ -174,10 +171,3 @@ func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar fr return p } - -// Neg computes the negative of a point in SNARK coordinates -func (p *Point) Neg(api frontend.API, p1 *Point) *Point { - p.X = api.Neg(p1.X) - p.Y = p1.Y - return p -} diff --git a/std/algebra/twistededwards/bandersnatch/point_test.go b/std/algebra/twistededwards/bandersnatch/point_test.go index ba883d8930..ff879c2fa2 100644 --- a/std/algebra/twistededwards/bandersnatch/point_test.go +++ b/std/algebra/twistededwards/bandersnatch/point_test.go @@ -55,8 +55,8 @@ func TestIsOnCurve(t *testing.T) { t.Fatal(err) } - witness.P.X = (params.BaseX) - witness.P.Y = (params.BaseY) + witness.P.X = (params.Base.X) + witness.P.Y = (params.Base.Y) assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381)) @@ -74,7 +74,10 @@ func (circuit *add) Define(api frontend.API) error { return err } - res := circuit.P.AddFixedPoint(api, &circuit.P, params.BaseX, params.BaseY, params) + p := Point{} + p.X = params.Base.X + p.Y = params.Base.Y + res := circuit.P.Add(api, &circuit.P, &p, params) api.AssertIsEqual(res.X, circuit.E.X) api.AssertIsEqual(res.Y, circuit.E.Y) @@ -94,8 +97,8 @@ func TestAddFixedPoint(t *testing.T) { t.Fatal(err) } var base, point, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) point.Set(&base) r := big.NewInt(5) point.ScalarMul(&point, r) @@ -112,6 +115,9 @@ func TestAddFixedPoint(t *testing.T) { } +//------------------------------------------------------------- +// addGeneric + type addGeneric struct { P1, P2, E Point } @@ -124,7 +130,7 @@ func (circuit *addGeneric) Define(api frontend.API) error { return err } - res := circuit.P1.AddGeneric(api, &circuit.P1, &circuit.P2, params) + res := circuit.P1.Add(api, &circuit.P1, &circuit.P2, params) api.AssertIsEqual(res.X, circuit.E.X) api.AssertIsEqual(res.Y, circuit.E.Y) @@ -143,8 +149,8 @@ func TestAddGeneric(t *testing.T) { t.Fatal(err) } var point1, point2, expected bandersnatch.PointAffine - point1.X.SetBigInt(¶ms.BaseX) - point1.Y.SetBigInt(¶ms.BaseY) + point1.X.SetBigInt(¶ms.Base.X) + point1.Y.SetBigInt(¶ms.Base.Y) point2.Set(&point1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -165,6 +171,8 @@ func TestAddGeneric(t *testing.T) { } +//------------------------------------------------------------- +// Double type double struct { P, E Point } @@ -197,8 +205,8 @@ func TestDouble(t *testing.T) { t.Fatal(err) } var base, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) // populate witness @@ -225,8 +233,10 @@ func (circuit *scalarMulFixed) Define(api frontend.API) error { return err } - var resFixed Point - resFixed.ScalarMulFixedBase(api, params.BaseX, params.BaseY, circuit.S, params) + var resFixed, p Point + p.X = params.Base.X + p.Y = params.Base.Y + resFixed.ScalarMul(api, &p, circuit.S, params) api.AssertIsEqual(resFixed.X, circuit.E.X) api.AssertIsEqual(resFixed.Y, circuit.E.Y) @@ -246,8 +256,8 @@ func TestScalarMulFixed(t *testing.T) { t.Fatal(err) } var base, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) @@ -274,7 +284,7 @@ func (circuit *scalarMulGeneric) Define(api frontend.API) error { return err } - resGeneric := circuit.P.ScalarMulNonFixedBase(api, &circuit.P, circuit.S, params) + resGeneric := circuit.P.ScalarMul(api, &circuit.P, circuit.S, params) api.AssertIsEqual(resGeneric.X, circuit.E.X) api.AssertIsEqual(resGeneric.Y, circuit.E.Y) @@ -294,8 +304,8 @@ func TestScalarMulGeneric(t *testing.T) { t.Fatal(err) } var base, point, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) s := big.NewInt(902) point.ScalarMul(&base, s) // random point r := big.NewInt(230928302) @@ -336,8 +346,8 @@ func TestNeg(t *testing.T) { t.Fatal(err) } var base, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Neg(&base) // generate witness @@ -351,25 +361,39 @@ func TestNeg(t *testing.T) { } -// benches +// Bench +func BenchmarkDouble(b *testing.B) { + var c double + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} -var ccsBench frontend.CompiledConstraintSystem +func BenchmarkAddGeneric(b *testing.B) { + var c addGeneric + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} -func BenchmarkScalarMulG1(b *testing.B) { - var c scalarMulGeneric - b.Run("groth16", func(b *testing.B) { - for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) - } +func BenchmarkAddFixedPoint(b *testing.B) { + var c add + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} - }) +func BenchmarkMustBeOnCurve(b *testing.B) { + var c mustBeOnCurve + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) b.Log("groth16", ccsBench.GetNbConstraints()) - b.Run("plonk", func(b *testing.B) { - for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BLS12_381, backend.PLONK, &c) - } +} - }) - b.Log("plonk", ccsBench.GetNbConstraints()) +func BenchmarkScalarMulGeneric(b *testing.B) { + var c scalarMulGeneric + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} +func BenchmarkScalarMulFixed(b *testing.B) { + var c scalarMulFixed + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) } From 5e875a40be19d780e707e0ca0fa6e91e5571bd9d Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 14 Feb 2022 10:27:10 -0600 Subject: [PATCH 27/37] build: updatd to latezst gnarkcrypto --- go.mod | 4 ++-- go.sum | 14 ++++---------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 535811db35..41cd1c1a37 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/consensys/gnark go 1.17 require ( - github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a - github.com/consensys/gnark-crypto v0.6.1-0.20220209103408-f71b1fc783da + github.com/consensys/bavard v0.1.9 + github.com/consensys/gnark-crypto v0.6.1-0.20220214162454-2cb4678775e8 github.com/fxamacker/cbor/v2 v2.2.0 github.com/leanovate/gopter v0.2.9 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 6ff55dc1f3..840401ee36 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,7 @@ -github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a h1:AEpwbXTjBGKoqxuQ6QAcBMEuK0+PtajQj0wJkhTnSd0= -github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969 h1:SPKwScbSTdl2p+QvJulHumGa2+5FO6RPh857TCPxda0= -github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= -github.com/consensys/gnark-crypto v0.6.1-0.20220203135532-a5667210247a h1:Jfr3vYmkw4xxWvNAnavhGiN0pVyhmpPer5sq1zFJFAk= -github.com/consensys/gnark-crypto v0.6.1-0.20220203135532-a5667210247a/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= -github.com/consensys/gnark-crypto v0.6.1-0.20220204095423-2fb0ec48a36f h1:55DRDYCFD64OIJh/Yz1Bch9Va14lwKgA/xk0n8JUIjE= -github.com/consensys/gnark-crypto v0.6.1-0.20220204095423-2fb0ec48a36f/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= -github.com/consensys/gnark-crypto v0.6.1-0.20220209103408-f71b1fc783da h1:dfeAHW2Yx/ceM+ft3UP/f5fufoU0LyxqF8655Cp88TI= -github.com/consensys/gnark-crypto v0.6.1-0.20220209103408-f71b1fc783da/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= +github.com/consensys/bavard v0.1.9 h1:t9wg3/7Ko73yE+eKcavgMYcPMO1hinadJGlbSCdXTiM= +github.com/consensys/bavard v0.1.9/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= +github.com/consensys/gnark-crypto v0.6.1-0.20220214162454-2cb4678775e8 h1:6WJWeTs2BMRKrRmGRtZ0h+uSCB395x7GgHdsLbFLndM= +github.com/consensys/gnark-crypto v0.6.1-0.20220214162454-2cb4678775e8/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From b4398bfb1c6ccbb4be31a30bbb7847e882704e03 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 14 Feb 2022 10:33:13 -0600 Subject: [PATCH 28/37] perf: R1CS solver may now run in parallel --- frontend/cs/r1cs/conversion.go | 83 +++++++- internal/backend/bls12-377/cs/r1cs.go | 197 ++++++++++++------ internal/backend/bls12-377/cs/r1cs_sparse.go | 7 +- internal/backend/bls12-377/cs/solution.go | 32 +-- internal/backend/bls12-381/cs/r1cs.go | 197 ++++++++++++------ internal/backend/bls12-381/cs/r1cs_sparse.go | 7 +- internal/backend/bls12-381/cs/solution.go | 32 +-- internal/backend/bls24-315/cs/r1cs.go | 197 ++++++++++++------ internal/backend/bls24-315/cs/r1cs_sparse.go | 7 +- internal/backend/bls24-315/cs/solution.go | 32 +-- internal/backend/bn254/cs/r1cs.go | 197 ++++++++++++------ internal/backend/bn254/cs/r1cs_sparse.go | 7 +- internal/backend/bn254/cs/solution.go | 32 +-- internal/backend/bw6-633/cs/r1cs.go | 197 ++++++++++++------ internal/backend/bw6-633/cs/r1cs_sparse.go | 7 +- internal/backend/bw6-633/cs/solution.go | 32 +-- internal/backend/bw6-761/cs/r1cs.go | 197 ++++++++++++------ internal/backend/bw6-761/cs/r1cs_sparse.go | 7 +- internal/backend/bw6-761/cs/solution.go | 32 +-- internal/backend/compiled/r1cs.go | 5 + .../template/representations/r1cs.go.tmpl | 181 +++++++++++----- .../representations/r1cs.sparse.go.tmpl | 7 +- .../template/representations/solution.go.tmpl | 32 +-- 23 files changed, 1180 insertions(+), 544 deletions(-) diff --git a/frontend/cs/r1cs/conversion.go b/frontend/cs/r1cs/conversion.go index 5b042efc9b..4113f12344 100644 --- a/frontend/cs/r1cs/conversion.go +++ b/frontend/cs/r1cs/conversion.go @@ -122,13 +122,15 @@ HINTLOOP: } } for i := 0; i < len(cs.DebugInfo); i++ { - for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { _, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack() res.DebugInfo[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) } } + // build levels + res.Levels = buildLevels(res) + switch cs.CurveID { case ecc.BLS12_377: return bls12377r1cs.NewR1CS(res, cs.Coeffs), nil @@ -147,6 +149,85 @@ HINTLOOP: } } +func processLE(ccs compiled.R1CS, l compiled.LinearExpression, mWireToNode, mLevels map[int]int, nodeLevels []int, nodeLevel, cID int) int { + nbInputs := ccs.NbPublicVariables + ccs.NbSecretVariables + + for _, t := range l { + wID := t.WireID() + if wID < nbInputs { + // it's a input, we ignore it + continue + } + + // if we know a which constraint solves this wire, then it's a dependency + n, ok := mWireToNode[wID] + if ok { + if n != cID { // can happen with hints... + // we add a dependency, check if we need to increment our current level + if nodeLevels[n] >= nodeLevel { + nodeLevel = nodeLevels[n] + 1 // we are at the next level at least since we depend on it + } + } + continue + } + + // check if it's a hint and mark all the output wires + if h, ok := ccs.MHints[wID]; ok { + + for _, in := range h.Inputs { + switch t := in.(type) { + case compiled.Variable: + nodeLevel = processLE(ccs, t.LinExp, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) + case compiled.LinearExpression: + nodeLevel = processLE(ccs, t, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) + case compiled.Term: + nodeLevel = processLE(ccs, compiled.LinearExpression{t}, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) + } + } + + for _, hwid := range h.Wires { + mWireToNode[hwid] = cID + } + continue + } + + // mark this wire solved by current node + mWireToNode[wID] = cID + } + + return nodeLevel +} + +func buildLevels(ccs compiled.R1CS) [][]int { + + mWireToNode := make(map[int]int, ccs.NbInternalVariables) // at which node we resolved which wire + nodeLevels := make([]int, len(ccs.Constraints)) // level of a node + mLevels := make(map[int]int) // level counts + + for cID, c := range ccs.Constraints { + + nodeLevel := 0 + + nodeLevel = processLE(ccs, c.L.LinExp, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) + nodeLevel = processLE(ccs, c.R.LinExp, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) + nodeLevel = processLE(ccs, c.O.LinExp, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) + nodeLevels[cID] = nodeLevel + mLevels[nodeLevel]++ + + } + + levels := make([][]int, len(mLevels)) + for i := 0; i < len(levels); i++ { + levels[i] = make([]int, 0, mLevels[i]) + } + + for n, l := range nodeLevels { + levels[l] = append(levels[l], n) + } + + return levels +} + func (cs *r1CS) SetSchema(s *schema.Schema) { cs.Schema = s } diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index 680941e51e..5e8408895f 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -19,11 +19,12 @@ package cs import ( "errors" "fmt" + "github.com/fxamacker/cbor/v2" "io" "math/big" + "runtime" "strings" - - "github.com/fxamacker/cbor/v2" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -32,6 +33,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" + "math" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -70,11 +72,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return make([]fr.Element, nbWires), err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - if len(witness) != int(cs.NbPublicVariables-1+cs.NbSecretVariables) { // - 1 for ONE_WIRE return solution.values, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(cs.NbPublicVariables-1+cs.NbSecretVariables), cs.NbPublicVariables-1, cs.NbSecretVariables) } @@ -93,45 +90,117 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + 1 + solution.nbSolved += uint64(len(witness) + 1) // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - // check if there is an inconsistant constraint - var check fr.Element - var solved bool - - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + if len(cs.Levels) != 0 { + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a pool + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() } - if solved { - // a[i] * b[i] == c[i], since we just computed it. - continue + // for each level, we push the tasks + for _, level := range cs.Levels { + + const minWorkPerCPU = 50.0 + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + if maxCPU <= 1.0 { + // we do it sequentially + for _, n := range level { + i := n + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + + close(chTasks) + close(chError) + return solution.values, err + } + } + continue + } + + nbTasks := runtime.NumCPU() + mm := int(math.Ceil(maxCPU)) + if nbTasks > mm { + nbTasks = mm + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + chTasks <- level[_start:_end] + } + + wg.Wait() + if len(chError) > 0 { + close(chTasks) + close(chError) + return solution.values, <-chError + } } + close(chTasks) + close(chError) + + } else { - // ensure a[i] * b[i] == c[i] - check.Mul(&a[i], &b[i]) - if !check.Equal(&c[i]) { - errMsg := fmt.Sprintf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + for i := 0; i < len(cs.Constraints); i++ { + // solve the constraint, this will compute the missing wire of the gate + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) + } + return solution.values, err } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) } } @@ -183,7 +252,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +289,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +321,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index e52a55459b..43706c174d 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -84,11 +84,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index 2cf9bc935e..5f03b43efb 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -37,9 +38,8 @@ import ( type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +49,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +67,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +147,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index 45565a4a47..3d13ec15a8 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -19,11 +19,12 @@ package cs import ( "errors" "fmt" + "github.com/fxamacker/cbor/v2" "io" "math/big" + "runtime" "strings" - - "github.com/fxamacker/cbor/v2" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -32,6 +33,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" + "math" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -70,11 +72,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return make([]fr.Element, nbWires), err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - if len(witness) != int(cs.NbPublicVariables-1+cs.NbSecretVariables) { // - 1 for ONE_WIRE return solution.values, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(cs.NbPublicVariables-1+cs.NbSecretVariables), cs.NbPublicVariables-1, cs.NbSecretVariables) } @@ -93,45 +90,117 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + 1 + solution.nbSolved += uint64(len(witness) + 1) // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - // check if there is an inconsistant constraint - var check fr.Element - var solved bool - - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + if len(cs.Levels) != 0 { + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a pool + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() } - if solved { - // a[i] * b[i] == c[i], since we just computed it. - continue + // for each level, we push the tasks + for _, level := range cs.Levels { + + const minWorkPerCPU = 50.0 + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + if maxCPU <= 1.0 { + // we do it sequentially + for _, n := range level { + i := n + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + + close(chTasks) + close(chError) + return solution.values, err + } + } + continue + } + + nbTasks := runtime.NumCPU() + mm := int(math.Ceil(maxCPU)) + if nbTasks > mm { + nbTasks = mm + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + chTasks <- level[_start:_end] + } + + wg.Wait() + if len(chError) > 0 { + close(chTasks) + close(chError) + return solution.values, <-chError + } } + close(chTasks) + close(chError) + + } else { - // ensure a[i] * b[i] == c[i] - check.Mul(&a[i], &b[i]) - if !check.Equal(&c[i]) { - errMsg := fmt.Sprintf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + for i := 0; i < len(cs.Constraints); i++ { + // solve the constraint, this will compute the missing wire of the gate + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) + } + return solution.values, err } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) } } @@ -183,7 +252,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +289,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +321,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index 6e47e344fb..4899ef2661 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -84,11 +84,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index 7126962be9..92a6db35d9 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -37,9 +38,8 @@ import ( type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +49,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +67,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +147,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 4edfbe850f..211e201fb9 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -19,11 +19,12 @@ package cs import ( "errors" "fmt" + "github.com/fxamacker/cbor/v2" "io" "math/big" + "runtime" "strings" - - "github.com/fxamacker/cbor/v2" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -32,6 +33,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" + "math" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" @@ -70,11 +72,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return make([]fr.Element, nbWires), err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - if len(witness) != int(cs.NbPublicVariables-1+cs.NbSecretVariables) { // - 1 for ONE_WIRE return solution.values, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(cs.NbPublicVariables-1+cs.NbSecretVariables), cs.NbPublicVariables-1, cs.NbSecretVariables) } @@ -93,45 +90,117 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + 1 + solution.nbSolved += uint64(len(witness) + 1) // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - // check if there is an inconsistant constraint - var check fr.Element - var solved bool - - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + if len(cs.Levels) != 0 { + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a pool + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() } - if solved { - // a[i] * b[i] == c[i], since we just computed it. - continue + // for each level, we push the tasks + for _, level := range cs.Levels { + + const minWorkPerCPU = 50.0 + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + if maxCPU <= 1.0 { + // we do it sequentially + for _, n := range level { + i := n + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + + close(chTasks) + close(chError) + return solution.values, err + } + } + continue + } + + nbTasks := runtime.NumCPU() + mm := int(math.Ceil(maxCPU)) + if nbTasks > mm { + nbTasks = mm + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + chTasks <- level[_start:_end] + } + + wg.Wait() + if len(chError) > 0 { + close(chTasks) + close(chError) + return solution.values, <-chError + } } + close(chTasks) + close(chError) + + } else { - // ensure a[i] * b[i] == c[i] - check.Mul(&a[i], &b[i]) - if !check.Equal(&c[i]) { - errMsg := fmt.Sprintf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + for i := 0; i < len(cs.Constraints); i++ { + // solve the constraint, this will compute the missing wire of the gate + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) + } + return solution.values, err } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) } } @@ -183,7 +252,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +289,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +321,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index 667d5d230e..4585876f4f 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -84,11 +84,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index 272f31af54..6c50c44e2c 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -37,9 +38,8 @@ import ( type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +49,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +67,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +147,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 6e8084b7cf..02a9db2758 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -19,11 +19,12 @@ package cs import ( "errors" "fmt" + "github.com/fxamacker/cbor/v2" "io" "math/big" + "runtime" "strings" - - "github.com/fxamacker/cbor/v2" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -32,6 +33,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" + "math" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -70,11 +72,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return make([]fr.Element, nbWires), err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - if len(witness) != int(cs.NbPublicVariables-1+cs.NbSecretVariables) { // - 1 for ONE_WIRE return solution.values, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(cs.NbPublicVariables-1+cs.NbSecretVariables), cs.NbPublicVariables-1, cs.NbSecretVariables) } @@ -93,45 +90,117 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + 1 + solution.nbSolved += uint64(len(witness) + 1) // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - // check if there is an inconsistant constraint - var check fr.Element - var solved bool - - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + if len(cs.Levels) != 0 { + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a pool + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() } - if solved { - // a[i] * b[i] == c[i], since we just computed it. - continue + // for each level, we push the tasks + for _, level := range cs.Levels { + + const minWorkPerCPU = 50.0 + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + if maxCPU <= 1.0 { + // we do it sequentially + for _, n := range level { + i := n + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + + close(chTasks) + close(chError) + return solution.values, err + } + } + continue + } + + nbTasks := runtime.NumCPU() + mm := int(math.Ceil(maxCPU)) + if nbTasks > mm { + nbTasks = mm + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + chTasks <- level[_start:_end] + } + + wg.Wait() + if len(chError) > 0 { + close(chTasks) + close(chError) + return solution.values, <-chError + } } + close(chTasks) + close(chError) + + } else { - // ensure a[i] * b[i] == c[i] - check.Mul(&a[i], &b[i]) - if !check.Equal(&c[i]) { - errMsg := fmt.Sprintf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + for i := 0; i < len(cs.Constraints); i++ { + // solve the constraint, this will compute the missing wire of the gate + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) + } + return solution.values, err } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) } } @@ -183,7 +252,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +289,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +321,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index b92f7d2273..67b0a551e0 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -84,11 +84,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index 323c472df4..9cc432f793 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -37,9 +38,8 @@ import ( type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +49,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +67,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +147,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bw6-633/cs/r1cs.go b/internal/backend/bw6-633/cs/r1cs.go index 46a34b44d9..c70d014d62 100644 --- a/internal/backend/bw6-633/cs/r1cs.go +++ b/internal/backend/bw6-633/cs/r1cs.go @@ -19,11 +19,12 @@ package cs import ( "errors" "fmt" + "github.com/fxamacker/cbor/v2" "io" "math/big" + "runtime" "strings" - - "github.com/fxamacker/cbor/v2" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -32,6 +33,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" + "math" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" @@ -70,11 +72,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return make([]fr.Element, nbWires), err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - if len(witness) != int(cs.NbPublicVariables-1+cs.NbSecretVariables) { // - 1 for ONE_WIRE return solution.values, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(cs.NbPublicVariables-1+cs.NbSecretVariables), cs.NbPublicVariables-1, cs.NbSecretVariables) } @@ -93,45 +90,117 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + 1 + solution.nbSolved += uint64(len(witness) + 1) // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - // check if there is an inconsistant constraint - var check fr.Element - var solved bool - - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + if len(cs.Levels) != 0 { + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a pool + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() } - if solved { - // a[i] * b[i] == c[i], since we just computed it. - continue + // for each level, we push the tasks + for _, level := range cs.Levels { + + const minWorkPerCPU = 50.0 + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + if maxCPU <= 1.0 { + // we do it sequentially + for _, n := range level { + i := n + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + + close(chTasks) + close(chError) + return solution.values, err + } + } + continue + } + + nbTasks := runtime.NumCPU() + mm := int(math.Ceil(maxCPU)) + if nbTasks > mm { + nbTasks = mm + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + chTasks <- level[_start:_end] + } + + wg.Wait() + if len(chError) > 0 { + close(chTasks) + close(chError) + return solution.values, <-chError + } } + close(chTasks) + close(chError) + + } else { - // ensure a[i] * b[i] == c[i] - check.Mul(&a[i], &b[i]) - if !check.Equal(&c[i]) { - errMsg := fmt.Sprintf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + for i := 0; i < len(cs.Constraints); i++ { + // solve the constraint, this will compute the missing wire of the gate + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) + } + return solution.values, err } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) } } @@ -183,7 +252,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +289,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +321,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bw6-633/cs/r1cs_sparse.go b/internal/backend/bw6-633/cs/r1cs_sparse.go index e6154b4227..ae06af7fb6 100644 --- a/internal/backend/bw6-633/cs/r1cs_sparse.go +++ b/internal/backend/bw6-633/cs/r1cs_sparse.go @@ -84,11 +84,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go index 3f7daa0122..d1175967db 100644 --- a/internal/backend/bw6-633/cs/solution.go +++ b/internal/backend/bw6-633/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -37,9 +38,8 @@ import ( type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +49,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +67,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +147,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 5c8b357d3b..776f8b9285 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -19,11 +19,12 @@ package cs import ( "errors" "fmt" + "github.com/fxamacker/cbor/v2" "io" "math/big" + "runtime" "strings" - - "github.com/fxamacker/cbor/v2" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -32,6 +33,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" + "math" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" @@ -70,11 +72,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return make([]fr.Element, nbWires), err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - if len(witness) != int(cs.NbPublicVariables-1+cs.NbSecretVariables) { // - 1 for ONE_WIRE return solution.values, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(cs.NbPublicVariables-1+cs.NbSecretVariables), cs.NbPublicVariables-1, cs.NbSecretVariables) } @@ -93,45 +90,117 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + 1 + solution.nbSolved += uint64(len(witness) + 1) // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - // check if there is an inconsistant constraint - var check fr.Element - var solved bool - - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + if len(cs.Levels) != 0 { + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a pool + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() } - if solved { - // a[i] * b[i] == c[i], since we just computed it. - continue + // for each level, we push the tasks + for _, level := range cs.Levels { + + const minWorkPerCPU = 50.0 + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + if maxCPU <= 1.0 { + // we do it sequentially + for _, n := range level { + i := n + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + + close(chTasks) + close(chError) + return solution.values, err + } + } + continue + } + + nbTasks := runtime.NumCPU() + mm := int(math.Ceil(maxCPU)) + if nbTasks > mm { + nbTasks = mm + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + chTasks <- level[_start:_end] + } + + wg.Wait() + if len(chError) > 0 { + close(chTasks) + close(chError) + return solution.values, <-chError + } } + close(chTasks) + close(chError) + + } else { - // ensure a[i] * b[i] == c[i] - check.Mul(&a[i], &b[i]) - if !check.Equal(&c[i]) { - errMsg := fmt.Sprintf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + for i := 0; i < len(cs.Constraints); i++ { + // solve the constraint, this will compute the missing wire of the gate + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) + } + return solution.values, err } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) } } @@ -183,7 +252,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +289,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +321,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index fbb30462da..789db9e68f 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -84,11 +84,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +92,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index 560d95f7d5..2e1d7c360d 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -37,9 +38,8 @@ import ( type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +49,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +67,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +147,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/compiled/r1cs.go b/internal/backend/compiled/r1cs.go index 8831f6da01..fc51a453bd 100644 --- a/internal/backend/compiled/r1cs.go +++ b/internal/backend/compiled/r1cs.go @@ -18,6 +18,11 @@ package compiled type R1CS struct { CS Constraints []R1C + + // each level contains independent constraints and can be parallelized + // it is guaranteed that all dependncies for constraints in a level l are solved + // in previous levels + Levels [][]int } // GetNbConstraints returns the number of constraints diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index 6619d8e6b3..c035aed0c8 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -3,8 +3,9 @@ import ( "fmt" "io" "math/big" + "runtime" "strings" - + "sync" "github.com/fxamacker/cbor/v2" "github.com/consensys/gnark/internal/backend/ioutils" @@ -13,6 +14,7 @@ import ( "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" + "math" "github.com/consensys/gnark-crypto/ecc" {{ template "import_fr" . }} @@ -52,10 +54,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return make([]fr.Element, nbWires), err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() if len(witness) != int(cs.NbPublicVariables-1+cs.NbSecretVariables) { // - 1 for ONE_WIRE return solution.values, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(cs.NbPublicVariables-1+cs.NbSecretVariables), cs.NbPublicVariables-1, cs.NbSecretVariables) @@ -75,15 +73,104 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + 1 + solution.nbSolved += uint64(len(witness) + 1) // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - // check if there is an inconsistant constraint - var check fr.Element - var solved bool + if len(cs.Levels) != 0 { + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a pool + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + chError <- err + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // for each level, we push the tasks + for _, level := range cs.Levels { + + const minWorkPerCPU = 50.0 + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + if maxCPU <= 1.0 { + // we do it sequentially + for _, n := range level { + i := n + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + + close(chTasks) + close(chError) + return solution.values, err + } + } + continue + } + + nbTasks := runtime.NumCPU() + mm := int(math.Ceil(maxCPU)) + if nbTasks > mm { + nbTasks = mm + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + chTasks <- level[_start:_end] + } + + + wg.Wait() + if len(chError) > 0 { + close(chTasks) + close(chError) + return solution.values, <-chError + } + } + close(chTasks) + close(chError) + + } else { // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire @@ -92,30 +179,15 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // if a[i] * b[i] != c[i]; it means the constraint is not satisfied for i := 0; i < len(cs.Constraints); i++ { // solve the constraint, this will compute the missing wire of the gate - solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { + if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) } return solution.values, err } - - if solved { - // a[i] * b[i] == c[i], since we just computed it. - continue - } - - // ensure a[i] * b[i] == c[i] - check.Mul(&a[i], &b[i]) - if !check.Equal(&c[i]) { - errMsg := fmt.Sprintf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } } +} // sanity check; ensure all wires are marked as "instantiated" if !solution.isValid() { @@ -167,7 +239,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a,b,c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a,b,c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -204,28 +276,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -234,36 +309,42 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) + - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index c572bfa1d4..6c5cffce2f 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -71,11 +71,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -84,7 +79,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index 43aaca6ee7..e8b043932a 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -3,6 +3,7 @@ import ( "errors" "fmt" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/internal/backend/compiled" @@ -18,9 +19,8 @@ import ( type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -30,7 +30,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -49,11 +48,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -128,15 +128,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i :=0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i :=0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs:nbInputs+nbOutputs] q := fr.Modulus() From e3effd9f860d13559a5d8a0ecee974392aa3b66b Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 14 Feb 2022 11:08:39 -0600 Subject: [PATCH 29/37] style: code cleaning --- frontend/cs/r1cs/conversion.go | 113 ++++++---- internal/backend/bls12-377/cs/r1cs.go | 191 ++++++++-------- internal/backend/bls12-381/cs/r1cs.go | 191 ++++++++-------- internal/backend/bls24-315/cs/r1cs.go | 191 ++++++++-------- internal/backend/bn254/cs/r1cs.go | 191 ++++++++-------- internal/backend/bw6-633/cs/r1cs.go | 191 ++++++++-------- internal/backend/bw6-761/cs/r1cs.go | 191 ++++++++-------- .../template/representations/r1cs.go.tmpl | 210 +++++++++--------- 8 files changed, 769 insertions(+), 700 deletions(-) diff --git a/frontend/cs/r1cs/conversion.go b/frontend/cs/r1cs/conversion.go index 4113f12344..c71d520e9d 100644 --- a/frontend/cs/r1cs/conversion.go +++ b/frontend/cs/r1cs/conversion.go @@ -149,85 +149,102 @@ HINTLOOP: } } -func processLE(ccs compiled.R1CS, l compiled.LinearExpression, mWireToNode, mLevels map[int]int, nodeLevels []int, nodeLevel, cID int) int { - nbInputs := ccs.NbPublicVariables + ccs.NbSecretVariables +func (cs *r1CS) SetSchema(s *schema.Schema) { + cs.Schema = s +} + +func buildLevels(ccs compiled.R1CS) [][]int { + + b := levelBuilder{ + mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire + nodeLevels: make([]int, len(ccs.Constraints)), // level of a node + mLevels: make(map[int]int), // level counts + ccs: ccs, + nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, + } + + // for each constraint, we're going to find its direct dependencies + // that is, wires (solved by previous constraints) on which it depends + // each of these dependencies is tagged with a level + // current constraint will be tagged with max(level) + 1 + for cID, c := range ccs.Constraints { + + b.nodeLevel = 0 + + b.processLE(c.L.LinExp, cID) + b.processLE(c.R.LinExp, cID) + b.processLE(c.O.LinExp, cID) + b.nodeLevels[cID] = b.nodeLevel + b.mLevels[b.nodeLevel]++ + + } + + levels := make([][]int, len(b.mLevels)) + for i := 0; i < len(levels); i++ { + // allocate memory + levels[i] = make([]int, 0, b.mLevels[i]) + } + + for n, l := range b.nodeLevels { + levels[l] = append(levels[l], n) + } + + return levels +} + +type levelBuilder struct { + ccs compiled.R1CS + nbInputs int + + mWireToNode map[int]int // at which node we resolved which wire + nodeLevels []int // level per node + mLevels map[int]int // number of constraint per level + + nodeLevel int // current level +} + +func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { for _, t := range l { wID := t.WireID() - if wID < nbInputs { + if wID < b.nbInputs { // it's a input, we ignore it continue } // if we know a which constraint solves this wire, then it's a dependency - n, ok := mWireToNode[wID] + n, ok := b.mWireToNode[wID] if ok { if n != cID { // can happen with hints... // we add a dependency, check if we need to increment our current level - if nodeLevels[n] >= nodeLevel { - nodeLevel = nodeLevels[n] + 1 // we are at the next level at least since we depend on it + if b.nodeLevels[n] >= b.nodeLevel { + b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it } } continue } // check if it's a hint and mark all the output wires - if h, ok := ccs.MHints[wID]; ok { + if h, ok := b.ccs.MHints[wID]; ok { for _, in := range h.Inputs { switch t := in.(type) { case compiled.Variable: - nodeLevel = processLE(ccs, t.LinExp, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) + b.processLE(t.LinExp, cID) case compiled.LinearExpression: - nodeLevel = processLE(ccs, t, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) + b.processLE(t, cID) case compiled.Term: - nodeLevel = processLE(ccs, compiled.LinearExpression{t}, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) + b.processLE(compiled.LinearExpression{t}, cID) } } for _, hwid := range h.Wires { - mWireToNode[hwid] = cID + b.mWireToNode[hwid] = cID } continue } // mark this wire solved by current node - mWireToNode[wID] = cID - } - - return nodeLevel -} - -func buildLevels(ccs compiled.R1CS) [][]int { - - mWireToNode := make(map[int]int, ccs.NbInternalVariables) // at which node we resolved which wire - nodeLevels := make([]int, len(ccs.Constraints)) // level of a node - mLevels := make(map[int]int) // level counts - - for cID, c := range ccs.Constraints { - - nodeLevel := 0 - - nodeLevel = processLE(ccs, c.L.LinExp, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) - nodeLevel = processLE(ccs, c.R.LinExp, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) - nodeLevel = processLE(ccs, c.O.LinExp, mWireToNode, mLevels, nodeLevels, nodeLevel, cID) - nodeLevels[cID] = nodeLevel - mLevels[nodeLevel]++ - + b.mWireToNode[wID] = cID } - - levels := make([][]int, len(mLevels)) - for i := 0; i < len(levels); i++ { - levels[i] = make([]int, 0, mLevels[i]) - } - - for n, l := range nodeLevels { - levels[l] = append(levels[l], n) - } - - return levels -} - -func (cs *r1CS) SetSchema(s *schema.Schema) { - cs.Schema = s } diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index 5e8408895f..905976e552 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -83,7 +83,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( solution.solved[0] = true // ONE_WIRE solution.values[0].SetOne() - copy(solution.values[1:], witness) // TODO factorize + copy(solution.values[1:], witness) for i := 0; i < len(witness); i++ { solution.solved[i+1] = true } @@ -96,120 +96,127 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - if len(cs.Levels) != 0 { - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan error, runtime.NumCPU()) - - // start a pool - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) - } - chError <- err - wg.Done() - return - } - } - wg.Done() - } - }() - } + if err := cs.parallelSolve(a, b, c, &solution); err != nil { + return solution.values, err + } - // for each level, we push the tasks - for _, level := range cs.Levels { + // sanity check; ensure all wires are marked as "instantiated" + if !solution.isValid() { + panic("solver didn't instantiate all wires") + } - const minWorkPerCPU = 50.0 + return solution.values, nil +} - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - if maxCPU <= 1.0 { - // we do it sequentially - for _, n := range level { - i := n - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { +func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) err = fmt.Errorf("%w: %s", err, debugInfoStr) } - - close(chTasks) - close(chError) - return solution.values, err + chError <- err + wg.Done() + return } } - continue + wg.Done() } + }() + } - nbTasks := runtime.NumCPU() - mm := int(math.Ceil(maxCPU)) - if nbTasks > mm { - nbTasks = mm - } - nbIterationsPerCpus := len(level) / nbTasks + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() - // more CPUs than tasks: a CPU will work on exactly one iteration - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } + // for each level, we push the tasks + for _, level := range cs.Levels { - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + return err } - chTasks <- level[_start:_end] } + continue + } - wg.Wait() - if len(chError) > 0 { - close(chTasks) - close(chError) - return solution.values, <-chError - } + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks } - close(chTasks) - close(chError) + nbIterationsPerCpus := len(level) / nbTasks - } else { + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index 3d13ec15a8..d3791a3969 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -83,7 +83,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( solution.solved[0] = true // ONE_WIRE solution.values[0].SetOne() - copy(solution.values[1:], witness) // TODO factorize + copy(solution.values[1:], witness) for i := 0; i < len(witness); i++ { solution.solved[i+1] = true } @@ -96,120 +96,127 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - if len(cs.Levels) != 0 { - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan error, runtime.NumCPU()) - - // start a pool - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) - } - chError <- err - wg.Done() - return - } - } - wg.Done() - } - }() - } + if err := cs.parallelSolve(a, b, c, &solution); err != nil { + return solution.values, err + } - // for each level, we push the tasks - for _, level := range cs.Levels { + // sanity check; ensure all wires are marked as "instantiated" + if !solution.isValid() { + panic("solver didn't instantiate all wires") + } - const minWorkPerCPU = 50.0 + return solution.values, nil +} - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - if maxCPU <= 1.0 { - // we do it sequentially - for _, n := range level { - i := n - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { +func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) err = fmt.Errorf("%w: %s", err, debugInfoStr) } - - close(chTasks) - close(chError) - return solution.values, err + chError <- err + wg.Done() + return } } - continue + wg.Done() } + }() + } - nbTasks := runtime.NumCPU() - mm := int(math.Ceil(maxCPU)) - if nbTasks > mm { - nbTasks = mm - } - nbIterationsPerCpus := len(level) / nbTasks + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() - // more CPUs than tasks: a CPU will work on exactly one iteration - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } + // for each level, we push the tasks + for _, level := range cs.Levels { - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + return err } - chTasks <- level[_start:_end] } + continue + } - wg.Wait() - if len(chError) > 0 { - close(chTasks) - close(chError) - return solution.values, <-chError - } + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks } - close(chTasks) - close(chError) + nbIterationsPerCpus := len(level) / nbTasks - } else { + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 211e201fb9..e7d0f0946f 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -83,7 +83,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( solution.solved[0] = true // ONE_WIRE solution.values[0].SetOne() - copy(solution.values[1:], witness) // TODO factorize + copy(solution.values[1:], witness) for i := 0; i < len(witness); i++ { solution.solved[i+1] = true } @@ -96,120 +96,127 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - if len(cs.Levels) != 0 { - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan error, runtime.NumCPU()) - - // start a pool - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) - } - chError <- err - wg.Done() - return - } - } - wg.Done() - } - }() - } + if err := cs.parallelSolve(a, b, c, &solution); err != nil { + return solution.values, err + } - // for each level, we push the tasks - for _, level := range cs.Levels { + // sanity check; ensure all wires are marked as "instantiated" + if !solution.isValid() { + panic("solver didn't instantiate all wires") + } - const minWorkPerCPU = 50.0 + return solution.values, nil +} - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - if maxCPU <= 1.0 { - // we do it sequentially - for _, n := range level { - i := n - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { +func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) err = fmt.Errorf("%w: %s", err, debugInfoStr) } - - close(chTasks) - close(chError) - return solution.values, err + chError <- err + wg.Done() + return } } - continue + wg.Done() } + }() + } - nbTasks := runtime.NumCPU() - mm := int(math.Ceil(maxCPU)) - if nbTasks > mm { - nbTasks = mm - } - nbIterationsPerCpus := len(level) / nbTasks + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() - // more CPUs than tasks: a CPU will work on exactly one iteration - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } + // for each level, we push the tasks + for _, level := range cs.Levels { - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + return err } - chTasks <- level[_start:_end] } + continue + } - wg.Wait() - if len(chError) > 0 { - close(chTasks) - close(chError) - return solution.values, <-chError - } + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks } - close(chTasks) - close(chError) + nbIterationsPerCpus := len(level) / nbTasks - } else { + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 02a9db2758..77b5a7edac 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -83,7 +83,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( solution.solved[0] = true // ONE_WIRE solution.values[0].SetOne() - copy(solution.values[1:], witness) // TODO factorize + copy(solution.values[1:], witness) for i := 0; i < len(witness); i++ { solution.solved[i+1] = true } @@ -96,120 +96,127 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - if len(cs.Levels) != 0 { - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan error, runtime.NumCPU()) - - // start a pool - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) - } - chError <- err - wg.Done() - return - } - } - wg.Done() - } - }() - } + if err := cs.parallelSolve(a, b, c, &solution); err != nil { + return solution.values, err + } - // for each level, we push the tasks - for _, level := range cs.Levels { + // sanity check; ensure all wires are marked as "instantiated" + if !solution.isValid() { + panic("solver didn't instantiate all wires") + } - const minWorkPerCPU = 50.0 + return solution.values, nil +} - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - if maxCPU <= 1.0 { - // we do it sequentially - for _, n := range level { - i := n - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { +func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) err = fmt.Errorf("%w: %s", err, debugInfoStr) } - - close(chTasks) - close(chError) - return solution.values, err + chError <- err + wg.Done() + return } } - continue + wg.Done() } + }() + } - nbTasks := runtime.NumCPU() - mm := int(math.Ceil(maxCPU)) - if nbTasks > mm { - nbTasks = mm - } - nbIterationsPerCpus := len(level) / nbTasks + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() - // more CPUs than tasks: a CPU will work on exactly one iteration - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } + // for each level, we push the tasks + for _, level := range cs.Levels { - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + return err } - chTasks <- level[_start:_end] } + continue + } - wg.Wait() - if len(chError) > 0 { - close(chTasks) - close(chError) - return solution.values, <-chError - } + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks } - close(chTasks) - close(chError) + nbIterationsPerCpus := len(level) / nbTasks - } else { + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise diff --git a/internal/backend/bw6-633/cs/r1cs.go b/internal/backend/bw6-633/cs/r1cs.go index c70d014d62..35f9d358d3 100644 --- a/internal/backend/bw6-633/cs/r1cs.go +++ b/internal/backend/bw6-633/cs/r1cs.go @@ -83,7 +83,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( solution.solved[0] = true // ONE_WIRE solution.values[0].SetOne() - copy(solution.values[1:], witness) // TODO factorize + copy(solution.values[1:], witness) for i := 0; i < len(witness); i++ { solution.solved[i+1] = true } @@ -96,120 +96,127 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - if len(cs.Levels) != 0 { - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan error, runtime.NumCPU()) - - // start a pool - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) - } - chError <- err - wg.Done() - return - } - } - wg.Done() - } - }() - } + if err := cs.parallelSolve(a, b, c, &solution); err != nil { + return solution.values, err + } - // for each level, we push the tasks - for _, level := range cs.Levels { + // sanity check; ensure all wires are marked as "instantiated" + if !solution.isValid() { + panic("solver didn't instantiate all wires") + } - const minWorkPerCPU = 50.0 + return solution.values, nil +} - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - if maxCPU <= 1.0 { - // we do it sequentially - for _, n := range level { - i := n - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { +func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) err = fmt.Errorf("%w: %s", err, debugInfoStr) } - - close(chTasks) - close(chError) - return solution.values, err + chError <- err + wg.Done() + return } } - continue + wg.Done() } + }() + } - nbTasks := runtime.NumCPU() - mm := int(math.Ceil(maxCPU)) - if nbTasks > mm { - nbTasks = mm - } - nbIterationsPerCpus := len(level) / nbTasks + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() - // more CPUs than tasks: a CPU will work on exactly one iteration - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } + // for each level, we push the tasks + for _, level := range cs.Levels { - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + return err } - chTasks <- level[_start:_end] } + continue + } - wg.Wait() - if len(chError) > 0 { - close(chTasks) - close(chError) - return solution.values, <-chError - } + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks } - close(chTasks) - close(chError) + nbIterationsPerCpus := len(level) / nbTasks - } else { + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 776f8b9285..4c4023cee5 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -83,7 +83,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( solution.solved[0] = true // ONE_WIRE solution.values[0].SetOne() - copy(solution.values[1:], witness) // TODO factorize + copy(solution.values[1:], witness) for i := 0; i < len(witness); i++ { solution.solved[i+1] = true } @@ -96,120 +96,127 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - if len(cs.Levels) != 0 { - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan error, runtime.NumCPU()) - - // start a pool - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) - } - chError <- err - wg.Done() - return - } - } - wg.Done() - } - }() - } + if err := cs.parallelSolve(a, b, c, &solution); err != nil { + return solution.values, err + } - // for each level, we push the tasks - for _, level := range cs.Levels { + // sanity check; ensure all wires are marked as "instantiated" + if !solution.isValid() { + panic("solver didn't instantiate all wires") + } - const minWorkPerCPU = 50.0 + return solution.values, nil +} - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - if maxCPU <= 1.0 { - // we do it sequentially - for _, n := range level { - i := n - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { +func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) err = fmt.Errorf("%w: %s", err, debugInfoStr) } - - close(chTasks) - close(chError) - return solution.values, err + chError <- err + wg.Done() + return } } - continue + wg.Done() } + }() + } - nbTasks := runtime.NumCPU() - mm := int(math.Ceil(maxCPU)) - if nbTasks > mm { - nbTasks = mm - } - nbIterationsPerCpus := len(level) / nbTasks + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() - // more CPUs than tasks: a CPU will work on exactly one iteration - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } + // for each level, we push the tasks + for _, level := range cs.Levels { - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + return err } - chTasks <- level[_start:_end] } + continue + } - wg.Wait() - if len(chError) > 0 { - close(chTasks) - close(chError) - return solution.values, <-chError - } + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks } - close(chTasks) - close(chError) + nbIterationsPerCpus := len(level) / nbTasks - } else { + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) - } - return solution.values, err + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index c035aed0c8..fb1d6da10b 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -41,6 +41,7 @@ func NewR1CS(cs compiled.R1CS, coefficients []big.Int) *R1CS { return &r } + // Solve sets all the wires and returns the a, b, c vectors. // the cs system should have been compiled before. The entries in a, b, c are in Montgomery form. // a, b, c vectors: ab-c = hz @@ -66,7 +67,7 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( solution.solved[0] = true // ONE_WIRE solution.values[0].SetOne() - copy(solution.values[1:], witness) // TODO factorize + copy(solution.values[1:], witness) for i := 0; i < len(witness); i++ { solution.solved[i+1] = true } @@ -79,122 +80,131 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - if len(cs.Levels) != 0 { - - var wg sync.WaitGroup - chTasks := make(chan []int, runtime.NumCPU()) - chError := make(chan error, runtime.NumCPU()) - - // start a pool - for i := 0; i < runtime.NumCPU(); i++ { - go func() { - for t := range chTasks { - for _, i := range t { - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) - } - chError <- err - wg.Done() - return - } - } - wg.Done() - } - }() - } + if err := cs.parallelSolve(a, b, c, &solution); err != nil { + return solution.values, err + } + + // sanity check; ensure all wires are marked as "instantiated" + if !solution.isValid() { + panic("solver didn't instantiate all wires") + } + + return solution.values, nil +} + + - // for each level, we push the tasks - for _, level := range cs.Levels { - - const minWorkPerCPU = 50.0 - - // max CPU to use - maxCPU := float64(len(level)) / minWorkPerCPU - if maxCPU <= 1.0 { - // we do it sequentially - for _, n := range level { - i := n - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { +func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + // for each constraint + // we are guaranteed that each R1C contains at most one unsolved wire + // first we solve the unsolved wire (if any) + // then we check that the constraint is valid + // if a[i] * b[i] != c[i]; it means the constraint is not satisfied + + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) err = fmt.Errorf("%w: %s", err, debugInfoStr) } - - close(chTasks) - close(chError) - return solution.values, err + chError <- err + wg.Done() + return } } - continue + wg.Done() } + }() + } - nbTasks := runtime.NumCPU() - mm := int(math.Ceil(maxCPU)) - if nbTasks > mm { - nbTasks = mm - } - nbIterationsPerCpus := len(level) / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = len(level) - } - - - extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if dID, ok := cs.MDebug[int(i)]; ok { + debugInfoStr := solution.logValue(cs.DebugInfo[dID]) + err = fmt.Errorf("%w: %s", err, debugInfoStr) + } + return err } - chTasks <- level[_start:_end] - } - - - wg.Wait() - if len(chError) > 0 { - close(chTasks) - close(chError) - return solution.values, <-chError } + continue } - close(chTasks) - close(chError) - - } else { - // for each constraint - // we are guaranteed that each R1C contains at most one unsolved wire - // first we solve the unsolved wire (if any) - // then we check that the constraint is valid - // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - if err := cs.solveConstraint(cs.Constraints[i], &solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } - return solution.values, err + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } -} + + // wait for the level to be done + wg.Wait() - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise From 1be4ff4340a815693fdface83ec0ec5839459811 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 14 Feb 2022 11:35:50 -0600 Subject: [PATCH 30/37] perf: sparse R1CS solver is parallel --- frontend/cs/plonk/conversion.go | 103 ++++++++++++++ internal/backend/bls12-377/cs/r1cs.go | 8 +- internal/backend/bls12-377/cs/r1cs_sparse.go | 131 +++++++++++++++-- internal/backend/bls12-381/cs/r1cs.go | 8 +- internal/backend/bls12-381/cs/r1cs_sparse.go | 131 +++++++++++++++-- internal/backend/bls24-315/cs/r1cs.go | 8 +- internal/backend/bls24-315/cs/r1cs_sparse.go | 131 +++++++++++++++-- internal/backend/bn254/cs/r1cs.go | 8 +- internal/backend/bn254/cs/r1cs_sparse.go | 131 +++++++++++++++-- internal/backend/bw6-633/cs/r1cs.go | 8 +- internal/backend/bw6-633/cs/r1cs_sparse.go | 131 +++++++++++++++-- internal/backend/bw6-761/cs/r1cs.go | 8 +- internal/backend/bw6-761/cs/r1cs_sparse.go | 131 +++++++++++++++-- internal/backend/compiled/cs.go | 5 + internal/backend/compiled/r1cs.go | 5 - .../template/representations/r1cs.go.tmpl | 8 +- .../representations/r1cs.sparse.go.tmpl | 134 ++++++++++++++++-- 17 files changed, 971 insertions(+), 118 deletions(-) diff --git a/frontend/cs/plonk/conversion.go b/frontend/cs/plonk/conversion.go index b82fd4b7a0..1f2e8d564e 100644 --- a/frontend/cs/plonk/conversion.go +++ b/frontend/cs/plonk/conversion.go @@ -116,6 +116,9 @@ HINTLOOP: } res.MHints = shiftedMap + // build levels + res.Levels = buildLevels(res) + switch cs.CurveID { case ecc.BLS12_377: return bls12377r1cs.NewSparseR1CS(res, cs.Coeffs), nil @@ -138,3 +141,103 @@ HINTLOOP: func (cs *sparseR1CS) SetSchema(s *schema.Schema) { cs.Schema = s } + +func buildLevels(ccs compiled.SparseR1CS) [][]int { + + b := levelBuilder{ + mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire + nodeLevels: make([]int, len(ccs.Constraints)), // level of a node + mLevels: make(map[int]int), // level counts + ccs: ccs, + nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, + } + + // for each constraint, we're going to find its direct dependencies + // that is, wires (solved by previous constraints) on which it depends + // each of these dependencies is tagged with a level + // current constraint will be tagged with max(level) + 1 + for cID, c := range ccs.Constraints { + + b.nodeLevel = 0 + + b.processTerm(c.L, cID) + b.processTerm(c.R, cID) + b.processTerm(c.O, cID) + + b.nodeLevels[cID] = b.nodeLevel + b.mLevels[b.nodeLevel]++ + + } + + levels := make([][]int, len(b.mLevels)) + for i := 0; i < len(levels); i++ { + // allocate memory + levels[i] = make([]int, 0, b.mLevels[i]) + } + + for n, l := range b.nodeLevels { + levels[l] = append(levels[l], n) + } + + return levels +} + +type levelBuilder struct { + ccs compiled.SparseR1CS + nbInputs int + + mWireToNode map[int]int // at which node we resolved which wire + nodeLevels []int // level per node + mLevels map[int]int // number of constraint per level + + nodeLevel int // current level +} + +func (b *levelBuilder) processTerm(t compiled.Term, cID int) { + wID := t.WireID() + if wID < b.nbInputs { + // it's a input, we ignore it + return + } + + // if we know a which constraint solves this wire, then it's a dependency + n, ok := b.mWireToNode[wID] + if ok { + if n != cID { // can happen with hints... + // we add a dependency, check if we need to increment our current level + if b.nodeLevels[n] >= b.nodeLevel { + b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it + } + } + return + } + + // check if it's a hint and mark all the output wires + if h, ok := b.ccs.MHints[wID]; ok { + + for _, in := range h.Inputs { + switch t := in.(type) { + case compiled.Variable: + for _, tt := range t.LinExp { + b.processTerm(tt, cID) + } + case compiled.LinearExpression: + for _, tt := range t { + b.processTerm(tt, cID) + } + case compiled.Term: + b.processTerm(t, cID) + } + } + + for _, hwid := range h.Wires { + b.mWireToNode[hwid] = cID + } + + return + } + + // mark this wire solved by current node + b.mWireToNode[wID] = cID + +} diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index 905976e552..87ce34d00f 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -137,9 +137,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - chError <- err + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() return } @@ -167,9 +167,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[int(i)]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - return err + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } } continue diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index 43706c174d..8c5d2c087b 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -103,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -126,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index d3791a3969..9440d858ff 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -137,9 +137,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - chError <- err + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() return } @@ -167,9 +167,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[int(i)]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - return err + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } } continue diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index 4899ef2661..106e6eb0eb 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -103,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -126,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index e7d0f0946f..185c6ce2b3 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -137,9 +137,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - chError <- err + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() return } @@ -167,9 +167,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[int(i)]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - return err + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } } continue diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index 4585876f4f..9ecd27fb26 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -103,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -126,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index 77b5a7edac..d58a5ca5ec 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -137,9 +137,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - chError <- err + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() return } @@ -167,9 +167,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[int(i)]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - return err + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } } continue diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index 67b0a551e0..8eab1f901f 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -103,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -126,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bw6-633/cs/r1cs.go b/internal/backend/bw6-633/cs/r1cs.go index 35f9d358d3..9c34184586 100644 --- a/internal/backend/bw6-633/cs/r1cs.go +++ b/internal/backend/bw6-633/cs/r1cs.go @@ -137,9 +137,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - chError <- err + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() return } @@ -167,9 +167,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[int(i)]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - return err + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } } continue diff --git a/internal/backend/bw6-633/cs/r1cs_sparse.go b/internal/backend/bw6-633/cs/r1cs_sparse.go index ae06af7fb6..7cc2ccfad0 100644 --- a/internal/backend/bw6-633/cs/r1cs_sparse.go +++ b/internal/backend/bw6-633/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -103,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -126,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index 4c4023cee5..b3bf91ca56 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -137,9 +137,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - chError <- err + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() return } @@ -167,9 +167,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[int(i)]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - return err + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } } continue diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index 789db9e68f..8cc16bcf4c 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -103,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -126,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/compiled/cs.go b/internal/backend/compiled/cs.go index d74b2a9f3d..b91dbba9f4 100644 --- a/internal/backend/compiled/cs.go +++ b/internal/backend/compiled/cs.go @@ -38,6 +38,11 @@ type CS struct { MHints map[int]*Hint Schema *schema.Schema + + // each level contains independent constraints and can be parallelized + // it is guaranteed that all dependncies for constraints in a level l are solved + // in previous levels + Levels [][]int } // Hint represents a solver hint diff --git a/internal/backend/compiled/r1cs.go b/internal/backend/compiled/r1cs.go index fc51a453bd..8831f6da01 100644 --- a/internal/backend/compiled/r1cs.go +++ b/internal/backend/compiled/r1cs.go @@ -18,11 +18,6 @@ package compiled type R1CS struct { CS Constraints []R1C - - // each level contains independent constraints and can be parallelized - // it is guaranteed that all dependncies for constraints in a level l are solved - // in previous levels - Levels [][]int } // GetNbConstraints returns the number of constraints diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index fb1d6da10b..2f70d99896 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -124,9 +124,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[i]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - chError <- err + chError <- fmt.Errorf("constraint #%d is not satisfied: %w",i, err) wg.Done() return } @@ -154,9 +154,9 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { if dID, ok := cs.MDebug[int(i)]; ok { debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w: %s", err, debugInfoStr) + err = fmt.Errorf("%w - %s", err, debugInfoStr) } - return err + return fmt.Errorf("constraint #%d is not satisfied: %w",i, err) } } continue diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index 6c5cffce2f..dbc9915d2c 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -6,6 +6,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "strings" "os" + "sync" + "runtime" + "math" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark/internal/backend/compiled" @@ -90,19 +93,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -115,6 +107,122 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 From fd3ce5814f2781e4386ba0dacccb7470e45f15fb Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Mon, 14 Feb 2022 20:08:00 +0100 Subject: [PATCH 31/37] fix: fixed wrong bigInt op in plonk api --- frontend/cs/plonk/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/cs/plonk/api.go b/frontend/cs/plonk/api.go index 30e0d4c289..117dc7744a 100644 --- a/frontend/cs/plonk/api.go +++ b/frontend/cs/plonk/api.go @@ -120,7 +120,7 @@ func (system *sparseR1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variab q := system.CurveID.Info().Fr.Modulus() return r.ModInverse(&r, q). Mul(&l, &r). - Mod(&l, q) + Mod(&r, q) } if system.IsConstant(i2) { c := utils.FromInterface(i2) From 42b4e4277bde5a961adac077a9159ccdf3f7af42 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 14 Feb 2022 13:48:44 -0600 Subject: [PATCH 32/37] fix: fixed trace and println tests --- debug_test.go | 6 ++--- internal/backend/bls12-377/cs/r1cs.go | 24 ++++++++++++------- internal/backend/bls12-377/cs/solution.go | 2 ++ internal/backend/bls12-381/cs/r1cs.go | 24 ++++++++++++------- internal/backend/bls12-381/cs/solution.go | 2 ++ internal/backend/bls24-315/cs/r1cs.go | 24 ++++++++++++------- internal/backend/bls24-315/cs/solution.go | 2 ++ internal/backend/bn254/cs/r1cs.go | 24 ++++++++++++------- internal/backend/bn254/cs/solution.go | 2 ++ internal/backend/bw6-633/cs/r1cs.go | 24 ++++++++++++------- internal/backend/bw6-633/cs/solution.go | 2 ++ internal/backend/bw6-761/cs/r1cs.go | 24 ++++++++++++------- internal/backend/bw6-761/cs/solution.go | 2 ++ .../template/representations/r1cs.go.tmpl | 24 ++++++++++++------- .../template/representations/solution.go.tmpl | 2 ++ 15 files changed, 122 insertions(+), 66 deletions(-) diff --git a/debug_test.go b/debug_test.go index 3569e3f493..51910507d2 100644 --- a/debug_test.go +++ b/debug_test.go @@ -47,16 +47,16 @@ func TestPrintln(t *testing.T) { expected.WriteString("debug_test.go:27 26 42\n") expected.WriteString("debug_test.go:29 bits 1\n") expected.WriteString("debug_test.go:30 circuit {A: 2, B: 11}\n") - expected.WriteString("debug_test.go:34 m \n") + expected.WriteString("debug_test.go:34 m .*\n") { trace, _ := getGroth16Trace(&circuit, &witness) - assert.Equal(expected.String(), trace) + assert.Regexp(expected.String(), trace) } { trace, _ := getPlonkTrace(&circuit, &witness) - assert.Equal(expected.String(), trace) + assert.Regexp(expected.String(), trace) } } diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index 87ce34d00f..1313d715e4 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -135,9 +135,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { for _, i := range t { // for each constraint in the task, solve it. if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() @@ -165,9 +168,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // we do it sequentially for _, i := range level { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } @@ -314,7 +320,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // or if we solved the unsolved wires with hint functions var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } return nil } @@ -335,7 +341,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // we didn't actually ensure that a * b == c var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 2: @@ -346,7 +352,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. } else { var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 3: diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index 5f03b43efb..6911e72ab7 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -33,6 +33,8 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { diff --git a/internal/backend/bls12-381/cs/r1cs.go b/internal/backend/bls12-381/cs/r1cs.go index 9440d858ff..0a32652ad9 100644 --- a/internal/backend/bls12-381/cs/r1cs.go +++ b/internal/backend/bls12-381/cs/r1cs.go @@ -135,9 +135,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { for _, i := range t { // for each constraint in the task, solve it. if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() @@ -165,9 +168,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // we do it sequentially for _, i := range level { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } @@ -314,7 +320,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // or if we solved the unsolved wires with hint functions var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } return nil } @@ -335,7 +341,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // we didn't actually ensure that a * b == c var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 2: @@ -346,7 +352,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. } else { var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 3: diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index 92a6db35d9..9d630c8153 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -33,6 +33,8 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { diff --git a/internal/backend/bls24-315/cs/r1cs.go b/internal/backend/bls24-315/cs/r1cs.go index 185c6ce2b3..964acf1f50 100644 --- a/internal/backend/bls24-315/cs/r1cs.go +++ b/internal/backend/bls24-315/cs/r1cs.go @@ -135,9 +135,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { for _, i := range t { // for each constraint in the task, solve it. if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() @@ -165,9 +168,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // we do it sequentially for _, i := range level { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } @@ -314,7 +320,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // or if we solved the unsolved wires with hint functions var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } return nil } @@ -335,7 +341,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // we didn't actually ensure that a * b == c var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 2: @@ -346,7 +352,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. } else { var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 3: diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index 6c50c44e2c..e215ac343a 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -33,6 +33,8 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { diff --git a/internal/backend/bn254/cs/r1cs.go b/internal/backend/bn254/cs/r1cs.go index d58a5ca5ec..c44696bae1 100644 --- a/internal/backend/bn254/cs/r1cs.go +++ b/internal/backend/bn254/cs/r1cs.go @@ -135,9 +135,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { for _, i := range t { // for each constraint in the task, solve it. if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() @@ -165,9 +168,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // we do it sequentially for _, i := range level { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } @@ -314,7 +320,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // or if we solved the unsolved wires with hint functions var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } return nil } @@ -335,7 +341,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // we didn't actually ensure that a * b == c var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 2: @@ -346,7 +352,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. } else { var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 3: diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index 9cc432f793..46cb2eb6af 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -33,6 +33,8 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { diff --git a/internal/backend/bw6-633/cs/r1cs.go b/internal/backend/bw6-633/cs/r1cs.go index 9c34184586..7101b1a5b5 100644 --- a/internal/backend/bw6-633/cs/r1cs.go +++ b/internal/backend/bw6-633/cs/r1cs.go @@ -135,9 +135,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { for _, i := range t { // for each constraint in the task, solve it. if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() @@ -165,9 +168,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // we do it sequentially for _, i := range level { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } @@ -314,7 +320,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // or if we solved the unsolved wires with hint functions var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } return nil } @@ -335,7 +341,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // we didn't actually ensure that a * b == c var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 2: @@ -346,7 +352,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. } else { var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 3: diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go index d1175967db..4b611e7f07 100644 --- a/internal/backend/bw6-633/cs/solution.go +++ b/internal/backend/bw6-633/cs/solution.go @@ -33,6 +33,8 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { diff --git a/internal/backend/bw6-761/cs/r1cs.go b/internal/backend/bw6-761/cs/r1cs.go index b3bf91ca56..097d53581c 100644 --- a/internal/backend/bw6-761/cs/r1cs.go +++ b/internal/backend/bw6-761/cs/r1cs.go @@ -135,9 +135,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { for _, i := range t { // for each constraint in the task, solve it. if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) wg.Done() @@ -165,9 +168,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // we do it sequentially for _, i := range level { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) } @@ -314,7 +320,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // or if we solved the unsolved wires with hint functions var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } return nil } @@ -335,7 +341,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. // we didn't actually ensure that a * b == c var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 2: @@ -346,7 +352,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr. } else { var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 3: diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index 2e1d7c360d..fb1e1a19bd 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -33,6 +33,8 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { diff --git a/internal/generator/backend/template/representations/r1cs.go.tmpl b/internal/generator/backend/template/representations/r1cs.go.tmpl index 2f70d99896..ee3b99f846 100644 --- a/internal/generator/backend/template/representations/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.go.tmpl @@ -122,9 +122,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { for _, i := range t { // for each constraint in the task, solve it. if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } chError <- fmt.Errorf("constraint #%d is not satisfied: %w",i, err) wg.Done() @@ -152,9 +155,12 @@ func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { // we do it sequentially for _, i := range level { if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { - if dID, ok := cs.MDebug[int(i)]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - err = fmt.Errorf("%w - %s", err, debugInfoStr) + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } } return fmt.Errorf("constraint #%d is not satisfied: %w",i, err) } @@ -304,7 +310,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a,b,c *fr.El // or if we solved the unsolved wires with hint functions var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } return nil } @@ -326,7 +332,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a,b,c *fr.El // we didn't actually ensure that a * b == c var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 2: @@ -337,7 +343,7 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a,b,c *fr.El } else { var check fr.Element if !check.Mul(a, b).Equal(c) { - return fmt.Errorf("%s ⋅ %s != %s", a.String(), b.String(), c.String()) + return errUnsatisfiedConstraint } } case 3: diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index e8b043932a..dba02f7cb6 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -14,6 +14,8 @@ import ( {{ template "import_curve" . }} ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { From 665d42681ae034d22eb8d343705621449677f0de Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Mon, 14 Feb 2022 23:07:17 +0100 Subject: [PATCH 33/37] style(eddsa, tEd): no benchmarks --- std/algebra/twistededwards/point_test.go | 25 ------------------------ std/signature/eddsa/eddsa_test.go | 7 ------- 2 files changed, 32 deletions(-) diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index 12aebf8de2..232db9c6c2 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -819,28 +819,3 @@ func TestNeg(t *testing.T) { assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254)) } - -// Bench -func BenchmarkDouble(b *testing.B) { - var c double - ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} - -func BenchmarkAddGeneric(b *testing.B) { - var c addGeneric - ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} - -func BenchmarkAddFixedPoint(b *testing.B) { - var c add - ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} - -func BenchmarkMustBeOnCurve(b *testing.B) { - var c mustBeOnCurve - ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} diff --git a/std/signature/eddsa/eddsa_test.go b/std/signature/eddsa/eddsa_test.go index ecacd505e4..282d158aa5 100644 --- a/std/signature/eddsa/eddsa_test.go +++ b/std/signature/eddsa/eddsa_test.go @@ -253,10 +253,3 @@ func TestEddsa(t *testing.T) { } } - -// Bench -func BenchmarkEdDSA(b *testing.B) { - var c eddsaCircuit - ccsBench, _ := frontend.Compile(ecc.BN254, backend.GROTH16, &c) - b.Log("groth16", ccsBench.GetNbConstraints()) -} From c4c7bf5b284fee2b97ec87dee8e607926ea965bd Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Mon, 14 Feb 2022 23:09:38 +0100 Subject: [PATCH 34/37] style(eddsa, tEd): no benchmarks --- std/algebra/twistededwards/point_test.go | 1 - std/signature/eddsa/eddsa_test.go | 1 - 2 files changed, 2 deletions(-) diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index 232db9c6c2..d9387e94fd 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -28,7 +28,6 @@ import ( tbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" tbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/test" ) diff --git a/std/signature/eddsa/eddsa_test.go b/std/signature/eddsa/eddsa_test.go index 282d158aa5..84636c0a4b 100644 --- a/std/signature/eddsa/eddsa_test.go +++ b/std/signature/eddsa/eddsa_test.go @@ -36,7 +36,6 @@ import ( eddsabw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards/eddsa" "github.com/consensys/gnark-crypto/hash" "github.com/consensys/gnark-crypto/signature" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/twistededwards" "github.com/consensys/gnark/test" From 1c14ce891ac0a5e0369cc36a900a30101baeb6e1 Mon Sep 17 00:00:00 2001 From: Thomas Piellard Date: Tue, 15 Feb 2022 16:16:38 +0100 Subject: [PATCH 35/37] feat: plonk adapted to kzg modifications --- go.mod | 2 +- go.sum | 2 ++ internal/backend/bls12-377/plonk/prove.go | 10 +++------- internal/backend/bls12-377/plonk/verify.go | 7 +++++++ internal/backend/bls12-381/plonk/prove.go | 10 +++------- internal/backend/bls12-381/plonk/verify.go | 7 +++++++ internal/backend/bls24-315/plonk/prove.go | 10 +++------- internal/backend/bls24-315/plonk/verify.go | 7 +++++++ internal/backend/bn254/plonk/prove.go | 10 +++------- internal/backend/bn254/plonk/verify.go | 7 +++++++ internal/backend/bw6-633/plonk/prove.go | 10 +++------- internal/backend/bw6-633/plonk/verify.go | 7 +++++++ internal/backend/bw6-761/plonk/prove.go | 10 +++------- internal/backend/bw6-761/plonk/verify.go | 7 +++++++ .../template/zkpschemes/plonk/plonk.prove.go.tmpl | 9 +++------ .../template/zkpschemes/plonk/plonk.verify.go.tmpl | 7 +++++++ 16 files changed, 73 insertions(+), 49 deletions(-) diff --git a/go.mod b/go.mod index 41cd1c1a37..720e643b18 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/consensys/bavard v0.1.9 - github.com/consensys/gnark-crypto v0.6.1-0.20220214162454-2cb4678775e8 + github.com/consensys/gnark-crypto v0.6.1-0.20220215134556-f8ab1746cc1f github.com/fxamacker/cbor/v2 v2.2.0 github.com/leanovate/gopter v0.2.9 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 840401ee36..12356e5a17 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/consensys/bavard v0.1.9 h1:t9wg3/7Ko73yE+eKcavgMYcPMO1hinadJGlbSCdXTi github.com/consensys/bavard v0.1.9/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/consensys/gnark-crypto v0.6.1-0.20220214162454-2cb4678775e8 h1:6WJWeTs2BMRKrRmGRtZ0h+uSCB395x7GgHdsLbFLndM= github.com/consensys/gnark-crypto v0.6.1-0.20220214162454-2cb4678775e8/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70= +github.com/consensys/gnark-crypto v0.6.1-0.20220215134556-f8ab1746cc1f h1:CTqL+BaOO2yVv/TBlfDINOUrDamX0/3ke968mj004UE= +github.com/consensys/gnark-crypto v0.6.1-0.20220215134556-f8ab1746cc1f/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go index b8a821b6f4..74f54f4897 100644 --- a/internal/backend/bls12-377/plonk/prove.go +++ b/internal/backend/bls12-377/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" @@ -270,8 +268,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, - &zetaShifted, - &pk.Domain[1], + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -340,7 +337,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, linearizedPolynomialCanonical, blindedLCanonical, @@ -358,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { diff --git a/internal/backend/bls12-377/plonk/verify.go b/internal/backend/bls12-377/plonk/verify.go index dee48705ee..e827aef72b 100644 --- a/internal/backend/bls12-377/plonk/verify.go +++ b/internal/backend/bls12-377/plonk/verify.go @@ -205,6 +205,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne vk.S[1], }, &proof.BatchedProof, + zeta, hFunc, ) if err != nil { @@ -212,6 +213,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne } // Batch verify + var shiftedZeta fr.Element + shiftedZeta.Mul(&zeta, &vk.Generator) return kzg.BatchVerifyMultiPoints([]kzg.Digest{ foldedDigest, proof.Z, @@ -220,6 +223,10 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne foldedProof, proof.ZShiftedOpening, }, + []fr.Element{ + zeta, + shiftedZeta, + }, vk.KZGSRS, ) } diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go index 7b434fd24e..8d5a625122 100644 --- a/internal/backend/bls12-381/plonk/prove.go +++ b/internal/backend/bls12-381/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" @@ -270,8 +268,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, - &zetaShifted, - &pk.Domain[1], + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -340,7 +337,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, linearizedPolynomialCanonical, blindedLCanonical, @@ -358,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { diff --git a/internal/backend/bls12-381/plonk/verify.go b/internal/backend/bls12-381/plonk/verify.go index bf514edae9..620bc5e097 100644 --- a/internal/backend/bls12-381/plonk/verify.go +++ b/internal/backend/bls12-381/plonk/verify.go @@ -205,6 +205,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne vk.S[1], }, &proof.BatchedProof, + zeta, hFunc, ) if err != nil { @@ -212,6 +213,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne } // Batch verify + var shiftedZeta fr.Element + shiftedZeta.Mul(&zeta, &vk.Generator) return kzg.BatchVerifyMultiPoints([]kzg.Digest{ foldedDigest, proof.Z, @@ -220,6 +223,10 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne foldedProof, proof.ZShiftedOpening, }, + []fr.Element{ + zeta, + shiftedZeta, + }, vk.KZGSRS, ) } diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go index ee92a6436c..904a785599 100644 --- a/internal/backend/bls24-315/plonk/prove.go +++ b/internal/backend/bls24-315/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" @@ -270,8 +268,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, - &zetaShifted, - &pk.Domain[1], + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -340,7 +337,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, linearizedPolynomialCanonical, blindedLCanonical, @@ -358,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { diff --git a/internal/backend/bls24-315/plonk/verify.go b/internal/backend/bls24-315/plonk/verify.go index 1cdcb5ea06..ae08037e77 100644 --- a/internal/backend/bls24-315/plonk/verify.go +++ b/internal/backend/bls24-315/plonk/verify.go @@ -205,6 +205,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne vk.S[1], }, &proof.BatchedProof, + zeta, hFunc, ) if err != nil { @@ -212,6 +213,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne } // Batch verify + var shiftedZeta fr.Element + shiftedZeta.Mul(&zeta, &vk.Generator) return kzg.BatchVerifyMultiPoints([]kzg.Digest{ foldedDigest, proof.Z, @@ -220,6 +223,10 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne foldedProof, proof.ZShiftedOpening, }, + []fr.Element{ + zeta, + shiftedZeta, + }, vk.KZGSRS, ) } diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index e672397a91..126277d13e 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" @@ -270,8 +268,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, - &zetaShifted, - &pk.Domain[1], + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -340,7 +337,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, linearizedPolynomialCanonical, blindedLCanonical, @@ -358,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { diff --git a/internal/backend/bn254/plonk/verify.go b/internal/backend/bn254/plonk/verify.go index c4ca3c638a..564431c6c3 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/internal/backend/bn254/plonk/verify.go @@ -205,6 +205,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) vk.S[1], }, &proof.BatchedProof, + zeta, hFunc, ) if err != nil { @@ -212,6 +213,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) } // Batch verify + var shiftedZeta fr.Element + shiftedZeta.Mul(&zeta, &vk.Generator) return kzg.BatchVerifyMultiPoints([]kzg.Digest{ foldedDigest, proof.Z, @@ -220,6 +223,10 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) foldedProof, proof.ZShiftedOpening, }, + []fr.Element{ + zeta, + shiftedZeta, + }, vk.KZGSRS, ) } diff --git a/internal/backend/bw6-633/plonk/prove.go b/internal/backend/bw6-633/plonk/prove.go index 6210f2c40b..90ac35a7aa 100644 --- a/internal/backend/bw6-633/plonk/prove.go +++ b/internal/backend/bw6-633/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" @@ -270,8 +268,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, - &zetaShifted, - &pk.Domain[1], + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -340,7 +337,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, linearizedPolynomialCanonical, blindedLCanonical, @@ -358,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { diff --git a/internal/backend/bw6-633/plonk/verify.go b/internal/backend/bw6-633/plonk/verify.go index c15c433778..1a7651695b 100644 --- a/internal/backend/bw6-633/plonk/verify.go +++ b/internal/backend/bw6-633/plonk/verify.go @@ -205,6 +205,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness vk.S[1], }, &proof.BatchedProof, + zeta, hFunc, ) if err != nil { @@ -212,6 +213,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness } // Batch verify + var shiftedZeta fr.Element + shiftedZeta.Mul(&zeta, &vk.Generator) return kzg.BatchVerifyMultiPoints([]kzg.Digest{ foldedDigest, proof.Z, @@ -220,6 +223,10 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness foldedProof, proof.ZShiftedOpening, }, + []fr.Element{ + zeta, + shiftedZeta, + }, vk.KZGSRS, ) } diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go index e04a6c6a02..4ce2467212 100644 --- a/internal/backend/bw6-761/plonk/prove.go +++ b/internal/backend/bw6-761/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" @@ -270,8 +268,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, - &zetaShifted, - &pk.Domain[1], + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -340,7 +337,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, linearizedPolynomialCanonical, blindedLCanonical, @@ -358,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { diff --git a/internal/backend/bw6-761/plonk/verify.go b/internal/backend/bw6-761/plonk/verify.go index 139ca0e173..266c8859bc 100644 --- a/internal/backend/bw6-761/plonk/verify.go +++ b/internal/backend/bw6-761/plonk/verify.go @@ -205,6 +205,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness vk.S[1], }, &proof.BatchedProof, + zeta, hFunc, ) if err != nil { @@ -212,6 +213,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness } // Batch verify + var shiftedZeta fr.Element + shiftedZeta.Mul(&zeta, &vk.Generator) return kzg.BatchVerifyMultiPoints([]kzg.Digest{ foldedDigest, proof.Z, @@ -220,6 +223,10 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness foldedProof, proof.ZShiftedOpening, }, + []fr.Element{ + zeta, + shiftedZeta, + }, vk.KZGSRS, ) } diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index a5a31b8645..d56a8c8b24 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -7,7 +7,6 @@ import ( {{ template "import_fr" . }} {{ template "import_curve" . }} - {{ template "import_polynomial" . }} {{ template "import_kzg" . }} {{ template "import_fft" . }} {{ template "import_witness" . }} @@ -247,8 +246,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( blindedZCanonical, - &zetaShifted, - &pk.Domain[1], + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -317,7 +315,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, linearizedPolynomialCanonical, blindedLCanonical, @@ -335,9 +333,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.Domain[1], pk.Vk.KZGSRS, ) if err != nil { diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl index f73acdb89c..09880d83ee 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl @@ -184,6 +184,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} vk.S[1], }, &proof.BatchedProof, + zeta, hFunc, ) if err != nil { @@ -191,6 +192,8 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} } // Batch verify + var shiftedZeta fr.Element + shiftedZeta.Mul(&zeta, &vk.Generator) return kzg.BatchVerifyMultiPoints([]kzg.Digest{ foldedDigest, proof.Z, @@ -199,6 +202,10 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} foldedProof, proof.ZShiftedOpening, }, + []fr.Element{ + zeta, + shiftedZeta, + }, vk.KZGSRS, ) } From 1e600a1d8add9c3b683643efc3b4a0f82b4d6b0b Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 15 Feb 2022 11:29:39 -0600 Subject: [PATCH 36/37] build: update to gnark-crpto v0.6.1 --- go.mod | 2 +- go.sum | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 720e643b18..deb987101e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/consensys/bavard v0.1.9 - github.com/consensys/gnark-crypto v0.6.1-0.20220215134556-f8ab1746cc1f + github.com/consensys/gnark-crypto v0.6.1 github.com/fxamacker/cbor/v2 v2.2.0 github.com/leanovate/gopter v0.2.9 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 12356e5a17..dca7119dbb 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,7 @@ github.com/consensys/bavard v0.1.9 h1:t9wg3/7Ko73yE+eKcavgMYcPMO1hinadJGlbSCdXTiM= github.com/consensys/bavard v0.1.9/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/gnark-crypto v0.6.1-0.20220214162454-2cb4678775e8 h1:6WJWeTs2BMRKrRmGRtZ0h+uSCB395x7GgHdsLbFLndM= -github.com/consensys/gnark-crypto v0.6.1-0.20220214162454-2cb4678775e8/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70= -github.com/consensys/gnark-crypto v0.6.1-0.20220215134556-f8ab1746cc1f h1:CTqL+BaOO2yVv/TBlfDINOUrDamX0/3ke968mj004UE= -github.com/consensys/gnark-crypto v0.6.1-0.20220215134556-f8ab1746cc1f/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70= +github.com/consensys/gnark-crypto v0.6.1 h1:MuWaJyWzSw8wQUOfiZOlRwYjfweIj8dM/u2NN6m0O04= +github.com/consensys/gnark-crypto v0.6.1/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From 91d463f1980485707e594b36676aaad004f4dd61 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Tue, 15 Feb 2022 12:41:12 -0600 Subject: [PATCH 37/37] docs: updated changelog for v0.6.4 --- CHANGELOG.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79984aa311..64d29e56cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,41 @@ + + +## [v0.6.4] - 2022-02-15 + +### Build + +- update to gnark-crpto v0.6.1 + +### Feat + +- Constraint system solvers (Groth16 and PlonK) now run in parallel + +### Fix + +- `api.DivUnchecked` with PlonK between 2 constants was incorrect + +### Perf + +- **EdDSA:** `std/algebra/twistededwards` takes ~2K less constraints (Groth16). Bandersnatch benefits from same improvments. + + +### Pull Requests + +- Merge pull request [#259](https://github.com/consensys/gnark/issues/259) from ConsenSys/perf-parallel-solver +- Merge pull request [#261](https://github.com/consensys/gnark/issues/261) from ConsenSys/feat/kzg_updated +- Merge pull request [#257](https://github.com/consensys/gnark/issues/257) from ConsenSys/perf/EdDSA +- Merge pull request [#253](https://github.com/consensys/gnark/issues/253) from ConsenSys/feat/fft_cosets + ## [v0.6.3] - 2022-02-13 ### Feat + - MiMC changes: api doesn't take a "seed" parameter. MiMC impl matches Ethereum one. ### Fix + - fixes [#255](https://github.com/consensys/gnark/issues/255) variable visibility inheritance regression - counter was set with PLONK backend ID in R1CS - R1CS Solver was incorrectly calling a "MulByCoeff" instead of "DivByCoeff" (no impact, coeff was always 1 or -1) @@ -13,6 +43,7 @@ ### Pull Requests + - Merge pull request [#256](https://github.com/consensys/gnark/issues/256) from ConsenSys/fix-bug-compile-visibility - Merge pull request [#249](https://github.com/consensys/gnark/issues/249) from ConsenSys/perf-ccs-hint - Merge pull request [#248](https://github.com/consensys/gnark/issues/248) from ConsenSys/perf-ccs-solver @@ -24,9 +55,11 @@ ## [v0.6.1] - 2022-01-28 ### Build + - go version dependency bumped from 1.16 to 1.17 ### Feat + - added witness.MarshalJSON and witness.MarshalBinary - added `ccs.GetSchema()` - the schema of a circuit is required for witness json (de)serialization - added `ccs.GetConstraints()` - returns a list of human-readable constraints @@ -35,6 +68,7 @@ - addition of `Cmp` in the circuit API ### Refactor + - compiled.Visbility -> schema.Visibiility - witness.WriteSequence -> schema.WriteSequence - killed `ReadAndProve` and `ReadAndVerify` (plonk) @@ -42,11 +76,13 @@ - remove embbed struct tag for frontend.Variable fields ### Docs + - **backend:** unify documentation for options - **frontend:** unify docs for options - **test:** unify documentation for options ### Pull Requests + - Merge pull request [#244](https://github.com/consensys/gnark/issues/244) from ConsenSys/plonk-human-readable - Merge pull request [#237](https://github.com/consensys/gnark/issues/237) from ConsenSys/ccs-get-constraints - Merge pull request [#233](https://github.com/consensys/gnark/issues/233) from ConsenSys/feat/api_cmp