diff --git a/CHANGELOG.md b/CHANGELOG.md index 79984aa311..64d29e56cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,41 @@ + + +## [v0.6.4] - 2022-02-15 + +### Build + +- update to gnark-crpto v0.6.1 + +### Feat + +- Constraint system solvers (Groth16 and PlonK) now run in parallel + +### Fix + +- `api.DivUnchecked` with PlonK between 2 constants was incorrect + +### Perf + +- **EdDSA:** `std/algebra/twistededwards` takes ~2K less constraints (Groth16). Bandersnatch benefits from same improvments. + + +### Pull Requests + +- Merge pull request [#259](https://github.com/consensys/gnark/issues/259) from ConsenSys/perf-parallel-solver +- Merge pull request [#261](https://github.com/consensys/gnark/issues/261) from ConsenSys/feat/kzg_updated +- Merge pull request [#257](https://github.com/consensys/gnark/issues/257) from ConsenSys/perf/EdDSA +- Merge pull request [#253](https://github.com/consensys/gnark/issues/253) from ConsenSys/feat/fft_cosets + ## [v0.6.3] - 2022-02-13 ### Feat + - MiMC changes: api doesn't take a "seed" parameter. MiMC impl matches Ethereum one. ### Fix + - fixes [#255](https://github.com/consensys/gnark/issues/255) variable visibility inheritance regression - counter was set with PLONK backend ID in R1CS - R1CS Solver was incorrectly calling a "MulByCoeff" instead of "DivByCoeff" (no impact, coeff was always 1 or -1) @@ -13,6 +43,7 @@ ### Pull Requests + - Merge pull request [#256](https://github.com/consensys/gnark/issues/256) from ConsenSys/fix-bug-compile-visibility - Merge pull request [#249](https://github.com/consensys/gnark/issues/249) from ConsenSys/perf-ccs-hint - Merge pull request [#248](https://github.com/consensys/gnark/issues/248) from ConsenSys/perf-ccs-solver @@ -24,9 +55,11 @@ ## [v0.6.1] - 2022-01-28 ### Build + - go version dependency bumped from 1.16 to 1.17 ### Feat + - added witness.MarshalJSON and witness.MarshalBinary - added `ccs.GetSchema()` - the schema of a circuit is required for witness json (de)serialization - added `ccs.GetConstraints()` - returns a list of human-readable constraints @@ -35,6 +68,7 @@ - addition of `Cmp` in the circuit API ### Refactor + - compiled.Visbility -> schema.Visibiility - witness.WriteSequence -> schema.WriteSequence - killed `ReadAndProve` and `ReadAndVerify` (plonk) @@ -42,11 +76,13 @@ - remove embbed struct tag for frontend.Variable fields ### Docs + - **backend:** unify documentation for options - **frontend:** unify docs for options - **test:** unify documentation for options ### Pull Requests + - Merge pull request [#244](https://github.com/consensys/gnark/issues/244) from ConsenSys/plonk-human-readable - Merge pull request [#237](https://github.com/consensys/gnark/issues/237) from ConsenSys/ccs-get-constraints - Merge pull request [#233](https://github.com/consensys/gnark/issues/233) from ConsenSys/feat/api_cmp diff --git a/debug_test.go b/debug_test.go index 3569e3f493..51910507d2 100644 --- a/debug_test.go +++ b/debug_test.go @@ -47,16 +47,16 @@ func TestPrintln(t *testing.T) { expected.WriteString("debug_test.go:27 26 42\n") expected.WriteString("debug_test.go:29 bits 1\n") expected.WriteString("debug_test.go:30 circuit {A: 2, B: 11}\n") - expected.WriteString("debug_test.go:34 m \n") + expected.WriteString("debug_test.go:34 m .*\n") { trace, _ := getGroth16Trace(&circuit, &witness) - assert.Equal(expected.String(), trace) + assert.Regexp(expected.String(), trace) } { trace, _ := getPlonkTrace(&circuit, &witness) - assert.Equal(expected.String(), trace) + assert.Regexp(expected.String(), trace) } } diff --git a/examples/rollup/account.go b/examples/rollup/account.go index f381a968fd..ad6b92b48c 100644 --- a/examples/rollup/account.go +++ b/examples/rollup/account.go @@ -25,7 +25,7 @@ import ( var ( // SizeAccount byte size of a serialized account (5*32bytes) - // index || nonce || balance || pubkeyX || pubkeyY, each chunk is 32 bytes + // index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY, each chunk is 32 bytes SizeAccount = 160 ) @@ -48,7 +48,7 @@ func (ac *Account) Reset() { // Serialize serializes the account as a concatenation of 5 chunks of 256 bits // one chunk per field (pubKey has 2 chunks), except index and nonce that are concatenated in a single 256 bits chunk -// index || nonce || balance || pubkeyX || pubkeyY, each chunk is 256 bits +// index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY, each chunk is 256 bits func (ac *Account) Serialize() []byte { //var buffer bytes.Buffer diff --git a/examples/rollup/circuit.go b/examples/rollup/circuit.go index 2907f7fc19..83f554c2b9 100644 --- a/examples/rollup/circuit.go +++ b/examples/rollup/circuit.go @@ -159,7 +159,7 @@ func (circuit *Circuit) Define(api frontend.API) error { // verifySignatureTransfer ensures that the signature of the transfer is valid func verifyTransferSignature(api frontend.API, t TransferConstraints, hFunc mimc.MiMC) error { - // the signature is on h(nonce || amount || senderpubKey (x&y) || receiverPubkey(x&y)) + // the signature is on h(nonce ∥ amount ∥ senderpubKey (x&y) ∥ receiverPubkey(x&y)) hFunc.Write(t.Nonce, t.Amount, t.SenderPubKey.A.X, t.SenderPubKey.A.Y, t.ReceiverPubKey.A.X, t.ReceiverPubKey.A.Y) htransfer := hFunc.Sum() diff --git a/examples/rollup/operator.go b/examples/rollup/operator.go index a02dc0c07d..56c2304182 100644 --- a/examples/rollup/operator.go +++ b/examples/rollup/operator.go @@ -46,8 +46,8 @@ func NewQueue(batchSize int) Queue { // Operator represents a rollup operator type Operator struct { - State []byte // list of accounts: index || nonce || balance || pubkeyX || pubkeyY, each chunk is 256 bits - HashState []byte // Hashed version of the state, each chunk is 256bits: ... || H(index || nonce || balance || pubkeyX || pubkeyY)) || ... + State []byte // list of accounts: index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY, each chunk is 256 bits + HashState []byte // Hashed version of the state, each chunk is 256bits: ... ∥ H(index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY)) ∥ ... AccountMap map[string]uint64 // hashmap of all available accounts (the key is the account.pubkey.X), the value is the index of the account in the state nbAccounts int // number of accounts managed by this operator h hash.Hash // hash function used to build the Merkle Tree @@ -178,7 +178,7 @@ func (o *Operator) updateState(t Transfer, numTransfer int) error { o.witnesses.Transfers[numTransfer].Signature.S = t.signature.S[:] // verifying the signature. The msg is the hash (o.h) of the transfer - // nonce || amount || senderpubKey(x&y) || receiverPubkey(x&y) + // nonce ∥ amount ∥ senderpubKey(x&y) ∥ receiverPubkey(x&y) resSig, err := t.Verify(o.h) if err != nil { return err diff --git a/examples/rollup/transfer.go b/examples/rollup/transfer.go index e12776cbf8..7507b357d3 100644 --- a/examples/rollup/transfer.go +++ b/examples/rollup/transfer.go @@ -52,7 +52,7 @@ func (t *Transfer) Sign(priv eddsa.PrivateKey, h hash.Hash) (eddsa.Signature, er //var frNonce, msg fr.Element var frNonce fr.Element - // serializing transfer. The signature is on h(nonce || amount || senderpubKey (x&y) || receiverPubkey(x&y)) + // serializing transfer. The signature is on h(nonce ∥ amount ∥ senderpubKey (x&y) ∥ receiverPubkey(x&y)) // (each pubkey consist of 2 chunks of 256bits) frNonce.SetUint64(t.nonce) b := frNonce.Bytes() @@ -91,7 +91,7 @@ func (t *Transfer) Verify(h hash.Hash) (bool, error) { var frNonce fr.Element // serializing transfer. The msg to sign is - // nonce || amount || senderpubKey(x&y) || receiverPubkey(x&y) + // nonce ∥ amount ∥ senderpubKey(x&y) ∥ receiverPubkey(x&y) // (each pubkey consist of 2 chunks of 256bits) frNonce.SetUint64(t.nonce) b := frNonce.Bytes() diff --git a/frontend/api.go b/frontend/api.go index bc34ea6202..3e7fde87c0 100644 --- a/frontend/api.go +++ b/frontend/api.go @@ -101,7 +101,7 @@ type API interface { // AssertIsDifferent fails if i1 == i2 AssertIsDifferent(i1, i2 Variable) - // AssertIsBoolean fails if v != 0 || v != 1 + // AssertIsBoolean fails if v != 0 ∥ v != 1 AssertIsBoolean(i1 Variable) // AssertIsLessOrEqual fails if v > bound diff --git a/frontend/compile.go b/frontend/compile.go index 0fff033589..4565063d0f 100644 --- a/frontend/compile.go +++ b/frontend/compile.go @@ -42,8 +42,8 @@ type NewBuilder func(ecc.ID) (Builder, error) // from the declarative code // // 3. finally, it converts that to a ConstraintSystem. -// if zkpID == backend.GROTH16 --> R1CS -// if zkpID == backend.PLONK --> SparseR1CS +// if zkpID == backend.GROTH16 → R1CS +// if zkpID == backend.PLONK → SparseR1CS // // initialCapacity is an optional parameter that reserves memory in slices // it should be set to the estimated number of constraints in the circuit, if known. diff --git a/frontend/cs/plonk/api.go b/frontend/cs/plonk/api.go index 30e0d4c289..117dc7744a 100644 --- a/frontend/cs/plonk/api.go +++ b/frontend/cs/plonk/api.go @@ -120,7 +120,7 @@ func (system *sparseR1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variab q := system.CurveID.Info().Fr.Modulus() return r.ModInverse(&r, q). Mul(&l, &r). - Mod(&l, q) + Mod(&r, q) } if system.IsConstant(i2) { c := utils.FromInterface(i2) diff --git a/frontend/cs/plonk/assertions.go b/frontend/cs/plonk/assertions.go index 5f9cebdd33..efe58e8627 100644 --- a/frontend/cs/plonk/assertions.go +++ b/frontend/cs/plonk/assertions.go @@ -63,7 +63,7 @@ func (system *sparseR1CS) AssertIsDifferent(i1, i2 frontend.Variable) { system.Inverse(system.Sub(i1, i2)) } -// AssertIsBoolean fails if v != 0 || v != 1 +// AssertIsBoolean fails if v != 0 ∥ v != 1 func (system *sparseR1CS) AssertIsBoolean(i1 frontend.Variable) { if system.IsConstant(i1) { c := utils.FromInterface(i1) @@ -125,7 +125,7 @@ func (system *sparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term l := system.Sub(1, t, aBits[i]) // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 - // --> this is a boolean constraint + // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too system.markBoolean(aBits[i].(compiled.Term)) // this does not create a constraint @@ -172,7 +172,7 @@ func (system *sparseR1CS) mustBeLessOrEqCst(a compiled.Term, bound big.Int) { } p := make([]frontend.Variable, nbBits+1) - // p[i] == 1 --> a[j] == c[j] for all j >= i + // p[i] == 1 → a[j] == c[j] for all j ⩾ i p[nbBits] = 1 for i := nbBits - 1; i >= t; i-- { diff --git a/frontend/cs/plonk/conversion.go b/frontend/cs/plonk/conversion.go index b82fd4b7a0..1f2e8d564e 100644 --- a/frontend/cs/plonk/conversion.go +++ b/frontend/cs/plonk/conversion.go @@ -116,6 +116,9 @@ HINTLOOP: } res.MHints = shiftedMap + // build levels + res.Levels = buildLevels(res) + switch cs.CurveID { case ecc.BLS12_377: return bls12377r1cs.NewSparseR1CS(res, cs.Coeffs), nil @@ -138,3 +141,103 @@ HINTLOOP: func (cs *sparseR1CS) SetSchema(s *schema.Schema) { cs.Schema = s } + +func buildLevels(ccs compiled.SparseR1CS) [][]int { + + b := levelBuilder{ + mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire + nodeLevels: make([]int, len(ccs.Constraints)), // level of a node + mLevels: make(map[int]int), // level counts + ccs: ccs, + nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, + } + + // for each constraint, we're going to find its direct dependencies + // that is, wires (solved by previous constraints) on which it depends + // each of these dependencies is tagged with a level + // current constraint will be tagged with max(level) + 1 + for cID, c := range ccs.Constraints { + + b.nodeLevel = 0 + + b.processTerm(c.L, cID) + b.processTerm(c.R, cID) + b.processTerm(c.O, cID) + + b.nodeLevels[cID] = b.nodeLevel + b.mLevels[b.nodeLevel]++ + + } + + levels := make([][]int, len(b.mLevels)) + for i := 0; i < len(levels); i++ { + // allocate memory + levels[i] = make([]int, 0, b.mLevels[i]) + } + + for n, l := range b.nodeLevels { + levels[l] = append(levels[l], n) + } + + return levels +} + +type levelBuilder struct { + ccs compiled.SparseR1CS + nbInputs int + + mWireToNode map[int]int // at which node we resolved which wire + nodeLevels []int // level per node + mLevels map[int]int // number of constraint per level + + nodeLevel int // current level +} + +func (b *levelBuilder) processTerm(t compiled.Term, cID int) { + wID := t.WireID() + if wID < b.nbInputs { + // it's a input, we ignore it + return + } + + // if we know a which constraint solves this wire, then it's a dependency + n, ok := b.mWireToNode[wID] + if ok { + if n != cID { // can happen with hints... + // we add a dependency, check if we need to increment our current level + if b.nodeLevels[n] >= b.nodeLevel { + b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it + } + } + return + } + + // check if it's a hint and mark all the output wires + if h, ok := b.ccs.MHints[wID]; ok { + + for _, in := range h.Inputs { + switch t := in.(type) { + case compiled.Variable: + for _, tt := range t.LinExp { + b.processTerm(tt, cID) + } + case compiled.LinearExpression: + for _, tt := range t { + b.processTerm(tt, cID) + } + case compiled.Term: + b.processTerm(t, cID) + } + } + + for _, hwid := range h.Wires { + b.mWireToNode[hwid] = cID + } + + return + } + + // mark this wire solved by current node + b.mWireToNode[wID] = cID + +} diff --git a/frontend/cs/plonk/sparse_r1cs.go b/frontend/cs/plonk/sparse_r1cs.go index 959d28c55d..2e4beac548 100644 --- a/frontend/cs/plonk/sparse_r1cs.go +++ b/frontend/cs/plonk/sparse_r1cs.go @@ -130,7 +130,7 @@ func (system *sparseR1CS) NewSecretVariable(name string) frontend.Variable { // reduces redundancy in linear expression // It factorizes Variable that appears multiple times with != coeff Ids -// To ensure the determinism in the compile process, Variables are stored as public||secret||internal||unset +// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset // for each visibility, the Variables are sorted from lowest ID to highest ID func (system *sparseR1CS) reduce(l compiled.LinearExpression) compiled.LinearExpression { @@ -268,7 +268,7 @@ func (system *sparseR1CS) CheckVariables() error { sbb.WriteString(strconv.Itoa(cptHints)) sbb.WriteString(" unconstrained hints") sbb.WriteByte('\n') - // TODO we may add more debug info here --> idea, in NewHint, take the debug stack, and store in the hint map some + // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some // debugInfo to find where a hint was declared (and not constrained) } return errors.New(sbb.String()) diff --git a/frontend/cs/r1cs/assertions.go b/frontend/cs/r1cs/assertions.go index 22b3873a42..aba87440c0 100644 --- a/frontend/cs/r1cs/assertions.go +++ b/frontend/cs/r1cs/assertions.go @@ -41,7 +41,7 @@ func (system *r1CS) AssertIsDifferent(i1, i2 frontend.Variable) { system.Inverse(system.Sub(i1, i2)) } -// AssertIsBoolean adds an assertion in the constraint system (v == 0 || v == 1) +// AssertIsBoolean adds an assertion in the constraint system (v == 0 ∥ v == 1) func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { vars, _ := system.toVariables(i1) @@ -69,7 +69,7 @@ func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) { system.addConstraint(newR1C(v, _v, o), debug) } -// AssertIsLessOrEqual adds assertion in constraint system (v <= bound) +// AssertIsLessOrEqual adds assertion in constraint system (v ⩽ bound) // // bound can be a constant or a Variable // @@ -120,7 +120,7 @@ func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) { l = system.Sub(l, t, aBits[i]) // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 - // --> this is a boolean constraint + // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too system.markBoolean(aBits[i].(compiled.Variable)) // this does not create a constraint @@ -158,7 +158,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) { } p := make([]frontend.Variable, nbBits+1) - // p[i] == 1 --> a[j] == c[j] for all j >= i + // p[i] == 1 → a[j] == c[j] for all j ⩾ i p[nbBits] = system.constant(1) for i := nbBits - 1; i >= t; i-- { diff --git a/frontend/cs/r1cs/conversion.go b/frontend/cs/r1cs/conversion.go index 5b042efc9b..c71d520e9d 100644 --- a/frontend/cs/r1cs/conversion.go +++ b/frontend/cs/r1cs/conversion.go @@ -122,13 +122,15 @@ HINTLOOP: } } for i := 0; i < len(cs.DebugInfo); i++ { - for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ { _, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack() res.DebugInfo[i].ToResolve[j].SetWireID(shiftVID(vID, visibility)) } } + // build levels + res.Levels = buildLevels(res) + switch cs.CurveID { case ecc.BLS12_377: return bls12377r1cs.NewR1CS(res, cs.Coeffs), nil @@ -150,3 +152,99 @@ HINTLOOP: func (cs *r1CS) SetSchema(s *schema.Schema) { cs.Schema = s } + +func buildLevels(ccs compiled.R1CS) [][]int { + + b := levelBuilder{ + mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire + nodeLevels: make([]int, len(ccs.Constraints)), // level of a node + mLevels: make(map[int]int), // level counts + ccs: ccs, + nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables, + } + + // for each constraint, we're going to find its direct dependencies + // that is, wires (solved by previous constraints) on which it depends + // each of these dependencies is tagged with a level + // current constraint will be tagged with max(level) + 1 + for cID, c := range ccs.Constraints { + + b.nodeLevel = 0 + + b.processLE(c.L.LinExp, cID) + b.processLE(c.R.LinExp, cID) + b.processLE(c.O.LinExp, cID) + b.nodeLevels[cID] = b.nodeLevel + b.mLevels[b.nodeLevel]++ + + } + + levels := make([][]int, len(b.mLevels)) + for i := 0; i < len(levels); i++ { + // allocate memory + levels[i] = make([]int, 0, b.mLevels[i]) + } + + for n, l := range b.nodeLevels { + levels[l] = append(levels[l], n) + } + + return levels +} + +type levelBuilder struct { + ccs compiled.R1CS + nbInputs int + + mWireToNode map[int]int // at which node we resolved which wire + nodeLevels []int // level per node + mLevels map[int]int // number of constraint per level + + nodeLevel int // current level +} + +func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) { + + for _, t := range l { + wID := t.WireID() + if wID < b.nbInputs { + // it's a input, we ignore it + continue + } + + // if we know a which constraint solves this wire, then it's a dependency + n, ok := b.mWireToNode[wID] + if ok { + if n != cID { // can happen with hints... + // we add a dependency, check if we need to increment our current level + if b.nodeLevels[n] >= b.nodeLevel { + b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it + } + } + continue + } + + // check if it's a hint and mark all the output wires + if h, ok := b.ccs.MHints[wID]; ok { + + for _, in := range h.Inputs { + switch t := in.(type) { + case compiled.Variable: + b.processLE(t.LinExp, cID) + case compiled.LinearExpression: + b.processLE(t, cID) + case compiled.Term: + b.processLE(compiled.LinearExpression{t}, cID) + } + } + + for _, hwid := range h.Wires { + b.mWireToNode[hwid] = cID + } + continue + } + + // mark this wire solved by current node + b.mWireToNode[wID] = cID + } +} diff --git a/frontend/cs/r1cs/r1cs.go b/frontend/cs/r1cs/r1cs.go index def4e35901..eb22b890b8 100644 --- a/frontend/cs/r1cs/r1cs.go +++ b/frontend/cs/r1cs/r1cs.go @@ -146,7 +146,7 @@ func (system *r1CS) one() compiled.Variable { // reduces redundancy in linear expression // It factorizes Variable that appears multiple times with != coeff Ids -// To ensure the determinism in the compile process, Variables are stored as public||secret||internal||unset +// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset // for each visibility, the Variables are sorted from lowest ID to highest ID func (system *r1CS) reduce(l compiled.Variable) compiled.Variable { // ensure our linear expression is sorted, by visibility and by Variable ID @@ -311,7 +311,7 @@ func (system *r1CS) CheckVariables() error { sbb.WriteString(strconv.Itoa(cptHints)) sbb.WriteString(" unconstrained hints") sbb.WriteByte('\n') - // TODO we may add more debug info here --> idea, in NewHint, take the debug stack, and store in the hint map some + // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some // debugInfo to find where a hint was declared (and not constrained) } return errors.New(sbb.String()) diff --git a/go.mod b/go.mod index 8ae200dc61..deb987101e 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/consensys/gnark go 1.17 require ( - github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a - github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969 + github.com/consensys/bavard v0.1.9 + github.com/consensys/gnark-crypto v0.6.1 github.com/fxamacker/cbor/v2 v2.2.0 github.com/leanovate/gopter v0.2.9 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 6b59c42b75..dca7119dbb 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ -github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a h1:AEpwbXTjBGKoqxuQ6QAcBMEuK0+PtajQj0wJkhTnSd0= -github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969 h1:SPKwScbSTdl2p+QvJulHumGa2+5FO6RPh857TCPxda0= -github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU= +github.com/consensys/bavard v0.1.9 h1:t9wg3/7Ko73yE+eKcavgMYcPMO1hinadJGlbSCdXTiM= +github.com/consensys/bavard v0.1.9/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= +github.com/consensys/gnark-crypto v0.6.1 h1:MuWaJyWzSw8wQUOfiZOlRwYjfweIj8dM/u2NN6m0O04= +github.com/consensys/gnark-crypto v0.6.1/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go index 680941e51e..1313d715e4 100644 --- a/internal/backend/bls12-377/cs/r1cs.go +++ b/internal/backend/bls12-377/cs/r1cs.go @@ -19,11 +19,12 @@ package cs import ( "errors" "fmt" + "github.com/fxamacker/cbor/v2" "io" "math/big" + "runtime" "strings" - - "github.com/fxamacker/cbor/v2" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -32,6 +33,7 @@ import ( "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark-crypto/ecc" + "math" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -70,11 +72,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( return make([]fr.Element, nbWires), err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - if len(witness) != int(cs.NbPublicVariables-1+cs.NbSecretVariables) { // - 1 for ONE_WIRE return solution.values, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(cs.NbPublicVariables-1+cs.NbSecretVariables), cs.NbPublicVariables-1, cs.NbSecretVariables) } @@ -86,61 +83,146 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) ( solution.solved[0] = true // ONE_WIRE solution.values[0].SetOne() - copy(solution.values[1:], witness) // TODO factorize + copy(solution.values[1:], witness) for i := 0; i < len(witness); i++ { solution.solved[i+1] = true } // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + 1 + solution.nbSolved += uint64(len(witness) + 1) // now that we know all inputs are set, defer log printing once all solution.values are computed // (or sooner, if a constraint is not satisfied) defer solution.printLogs(opt.LoggerOut, cs.Logs) - // check if there is an inconsistant constraint - var check fr.Element - var solved bool + if err := cs.parallelSolve(a, b, c, &solution); err != nil { + return solution.values, err + } + + // sanity check; ensure all wires are marked as "instantiated" + if !solution.isValid() { + panic("solver didn't instantiate all wires") + } + + return solution.values, nil +} + +func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels // for each constraint // we are guaranteed that each R1C contains at most one unsolved wire // first we solve the unsolved wire (if any) // then we check that the constraint is valid // if a[i] * b[i] != c[i]; it means the constraint is not satisfied - for i := 0; i < len(cs.Constraints); i++ { - // solve the constraint, this will compute the missing wire of the gate - solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution) - if err != nil { - if dID, ok := cs.MDebug[i]; ok { - debugInfoStr := solution.logValue(cs.DebugInfo[dID]) - return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr) + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + } + wg.Done() } - return solution.values, err - } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() - if solved { - // a[i] * b[i] == c[i], since we just computed it. + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil { + if err == errUnsatisfiedConstraint { + if dID, ok := cs.MDebug[int(i)]; ok { + err = errors.New(solution.logValue(cs.DebugInfo[dID])) + } else { + err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) + } + } + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + } continue } - // ensure a[i] * b[i] == c[i] - check.Mul(&a[i], &b[i]) - if !check.Equal(&c[i]) { - errMsg := fmt.Sprintf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String()) - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise @@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go index e52a55459b..8c5d2c087b 100644 --- a/internal/backend/bls12-377/cs/r1cs_sparse.go +++ b/internal/backend/bls12-377/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) @@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go index 2cf9bc935e..6911e72ab7 100644 --- a/internal/backend/bls12-377/cs/solution.go +++ b/internal/backend/bls12-377/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -32,14 +33,15 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bls12-377/groth16/marshal_test.go b/internal/backend/bls12-377/groth16/marshal_test.go index 0d00cba75e..901d6f885f 100644 --- a/internal/backend/bls12-377/groth16/marshal_test.go +++ b/internal/backend/bls12-377/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bls12-377/groth16/prove.go b/internal/backend/bls12-377/groth16/prove.go index 16a1b7ac32..634c4b58e7 100644 --- a/internal/backend/bls12-377/groth16/prove.go +++ b/internal/backend/bls12-377/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bls12-377/groth16/setup.go b/internal/backend/bls12-377/groth16/setup.go index 8db0c6ac8c..21fe78713c 100644 --- a/internal/backend/bls12-377/groth16/setup.go +++ b/internal/backend/bls12-377/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bls12-377/plonk/marshal.go b/internal/backend/bls12-377/plonk/marshal.go index 411a2be9e7..2bb46eabd3 100644 --- a/internal/backend/bls12-377/plonk/marshal.go +++ b/internal/backend/bls12-377/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainNum.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainH.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { ([]fr.Element)(pk.Qo), ([]fr.Element)(pk.CQk), ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.LS1), - ([]fr.Element)(pk.LS2), - ([]fr.Element)(pk.LS3), - ([]fr.Element)(pk.CS1), - ([]fr.Element)(pk.CS2), - ([]fr.Element)(pk.CS3), + ([]fr.Element)(pk.S1Canonical), + ([]fr.Element)(pk.S2Canonical), + ([]fr.Element)(pk.S3Canonical), pk.Permutation, } @@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainNum.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainH.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ @@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { (*[]fr.Element)(&pk.Qo), (*[]fr.Element)(&pk.CQk), (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.LS1), - (*[]fr.Element)(&pk.LS2), - (*[]fr.Element)(&pk.LS3), - (*[]fr.Element)(&pk.CS1), - (*[]fr.Element)(&pk.CS2), - (*[]fr.Element)(&pk.CS3), + (*[]fr.Element)(&pk.S1Canonical), + (*[]fr.Element)(&pk.S2Canonical), + (*[]fr.Element)(&pk.S3Canonical), &pk.Permutation, } @@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.SizeInv, &vk.Generator, vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.SizeInv, &vk.Generator, &vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], diff --git a/internal/backend/bls12-377/plonk/marshal_test.go b/internal/backend/bls12-377/plonk/marshal_test.go index 04933096c5..f4bea8c379 100644 --- a/internal/backend/bls12-377/plonk/marshal_test.go +++ b/internal/backend/bls12-377/plonk/marshal_test.go @@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen @@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 @@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go index 7c8cb49ce7..74f54f4897 100644 --- a/internal/backend/bls12-377/plonk/prove.go +++ b/internal/backend/bls12-377/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" @@ -43,6 +41,7 @@ import ( ) type Proof struct { + // Commitments to the solution vectors LRO [3]kzg.Digest @@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := computeLRO(spr, pk, solution) + evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum) + blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + &pk.Domain[0]) if err != nil { return nil, err } // compute kzg commitments of bcl, bcr and bco - if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn return nil, err } + // Fiat Shamir this + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } + // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz polynomial.Polynomial + var blindedZCanonical []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { var err error - bz, err = computeBlindedZ(ll, lr, lo, pk, gamma) + blindedZCanonical, err = computeBlindedZCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + pk, beta, gamma) if err != nil { chZ <- err close(chZ) @@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn // this may add additional arithmetic operations, but with smaller tasks // we ensure that this commitment is well parallelized, without having a "unbalanced task" making // the rest of the code wait too long. - if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { + if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { chZ <- err close(chZ) return @@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn }() // evaluation of the blinded versions of l, r, o and bz - // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial + // on the coset of the big domain + var ( + evaluationBlindedLDomainBigBitReversed []fr.Element + evaluationBlindedRDomainBigBitReversed []fr.Element + evaluationBlindedODomainBigBitReversed []fr.Element + evaluationBlindedZDomainBigBitReversed []fr.Element + ) chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateHDomain(bcl, &pk.DomainH) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evalBR = evaluateHDomain(bcr, &pk.DomainH) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evalBO = evaluateHDomain(bco, &pk.DomainH) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() - var constraintsInd, constraintsOrdering polynomial.Polynomial + var constraintsInd, constraintsOrdering []fr.Element chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) - copy(qk, fullWitness[:spr.NbPublicVariables]) - copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) - fft.BitReverse(qk) - - // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) - // --> uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() @@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn if err := <-chConstraintOrdering; err != nil { return nil, err } + <-chConstraintInd + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, - &zetaShifted, - &pk.DomainH, + blindedZCanonical, + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.DomainH, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) totalDegree := rou + bo @@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, res := cp[:totalDegree+1] // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) + blindingPoly := make([]fr.Element, bo+1) for i := uint64(0); i < bo+1; i++ { if _, err := blindingPoly[i].SetRandom(); err != nil { return nil, err @@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, } return res, nil + } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.Domain[0].Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.Domain[1].Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.Domain[1].Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) - - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] return h1, h2, h3 @@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.Domain[0].Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bls12-377/plonk/setup.go b/internal/backend/bls12-377/plonk/setup.go index 774e1be88a..259f76e6f1 100644 --- a/internal/backend/bls12-377/plonk/setup.go +++ b/internal/backend/bls12-377/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" "github.com/consensys/gnark/internal/backend/bls12-377/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,21 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element - // Domains used for the FFTs - DomainNum, DomainH fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +71,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bls12-377/plonk/verify.go b/internal/backend/bls12-377/plonk/verify.go index 17ecfd94ee..e827aef72b 100644 --- a/internal/backend/bls12-377/plonk/verify.go +++ b/internal/backend/bls12-377/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise @@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go index 6e47e344fb..106e6eb0eb 100644 --- a/internal/backend/bls12-381/cs/r1cs_sparse.go +++ b/internal/backend/bls12-381/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) @@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go index 7126962be9..9d630c8153 100644 --- a/internal/backend/bls12-381/cs/solution.go +++ b/internal/backend/bls12-381/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -32,14 +33,15 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bls12-381/groth16/marshal_test.go b/internal/backend/bls12-381/groth16/marshal_test.go index 38c1e8b038..990e3ee6b1 100644 --- a/internal/backend/bls12-381/groth16/marshal_test.go +++ b/internal/backend/bls12-381/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bls12-381/groth16/prove.go b/internal/backend/bls12-381/groth16/prove.go index fe52aafb2d..a8e7767d4e 100644 --- a/internal/backend/bls12-381/groth16/prove.go +++ b/internal/backend/bls12-381/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bls12-381/groth16/setup.go b/internal/backend/bls12-381/groth16/setup.go index 1daf2c2f69..d481f8ad0d 100644 --- a/internal/backend/bls12-381/groth16/setup.go +++ b/internal/backend/bls12-381/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bls12-381/plonk/marshal.go b/internal/backend/bls12-381/plonk/marshal.go index 4e3945c054..d03d1be5fa 100644 --- a/internal/backend/bls12-381/plonk/marshal.go +++ b/internal/backend/bls12-381/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainNum.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainH.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { ([]fr.Element)(pk.Qo), ([]fr.Element)(pk.CQk), ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.LS1), - ([]fr.Element)(pk.LS2), - ([]fr.Element)(pk.LS3), - ([]fr.Element)(pk.CS1), - ([]fr.Element)(pk.CS2), - ([]fr.Element)(pk.CS3), + ([]fr.Element)(pk.S1Canonical), + ([]fr.Element)(pk.S2Canonical), + ([]fr.Element)(pk.S3Canonical), pk.Permutation, } @@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainNum.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainH.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ @@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { (*[]fr.Element)(&pk.Qo), (*[]fr.Element)(&pk.CQk), (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.LS1), - (*[]fr.Element)(&pk.LS2), - (*[]fr.Element)(&pk.LS3), - (*[]fr.Element)(&pk.CS1), - (*[]fr.Element)(&pk.CS2), - (*[]fr.Element)(&pk.CS3), + (*[]fr.Element)(&pk.S1Canonical), + (*[]fr.Element)(&pk.S2Canonical), + (*[]fr.Element)(&pk.S3Canonical), &pk.Permutation, } @@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.SizeInv, &vk.Generator, vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.SizeInv, &vk.Generator, &vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], diff --git a/internal/backend/bls12-381/plonk/marshal_test.go b/internal/backend/bls12-381/plonk/marshal_test.go index 9076e2280c..e30d108e8b 100644 --- a/internal/backend/bls12-381/plonk/marshal_test.go +++ b/internal/backend/bls12-381/plonk/marshal_test.go @@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen @@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 @@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go index 5f9dadb7bb..8d5a625122 100644 --- a/internal/backend/bls12-381/plonk/prove.go +++ b/internal/backend/bls12-381/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls12-381" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" @@ -43,6 +41,7 @@ import ( ) type Proof struct { + // Commitments to the solution vectors LRO [3]kzg.Digest @@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := computeLRO(spr, pk, solution) + evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum) + blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + &pk.Domain[0]) if err != nil { return nil, err } // compute kzg commitments of bcl, bcr and bco - if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn return nil, err } + // Fiat Shamir this + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } + // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz polynomial.Polynomial + var blindedZCanonical []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { var err error - bz, err = computeBlindedZ(ll, lr, lo, pk, gamma) + blindedZCanonical, err = computeBlindedZCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + pk, beta, gamma) if err != nil { chZ <- err close(chZ) @@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn // this may add additional arithmetic operations, but with smaller tasks // we ensure that this commitment is well parallelized, without having a "unbalanced task" making // the rest of the code wait too long. - if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { + if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { chZ <- err close(chZ) return @@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn }() // evaluation of the blinded versions of l, r, o and bz - // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial + // on the coset of the big domain + var ( + evaluationBlindedLDomainBigBitReversed []fr.Element + evaluationBlindedRDomainBigBitReversed []fr.Element + evaluationBlindedODomainBigBitReversed []fr.Element + evaluationBlindedZDomainBigBitReversed []fr.Element + ) chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateHDomain(bcl, &pk.DomainH) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evalBR = evaluateHDomain(bcr, &pk.DomainH) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evalBO = evaluateHDomain(bco, &pk.DomainH) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() - var constraintsInd, constraintsOrdering polynomial.Polynomial + var constraintsInd, constraintsOrdering []fr.Element chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) - copy(qk, fullWitness[:spr.NbPublicVariables]) - copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) - fft.BitReverse(qk) - - // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) - // --> uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() @@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn if err := <-chConstraintOrdering; err != nil { return nil, err } + <-chConstraintInd + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, - &zetaShifted, - &pk.DomainH, + blindedZCanonical, + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.DomainH, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) totalDegree := rou + bo @@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, res := cp[:totalDegree+1] // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) + blindingPoly := make([]fr.Element, bo+1) for i := uint64(0); i < bo+1; i++ { if _, err := blindingPoly[i].SetRandom(); err != nil { return nil, err @@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, } return res, nil + } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.Domain[0].Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.Domain[1].Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.Domain[1].Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) - - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] return h1, h2, h3 @@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.Domain[0].Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bls12-381/plonk/setup.go b/internal/backend/bls12-381/plonk/setup.go index 27dfd6868d..823cc25d8b 100644 --- a/internal/backend/bls12-381/plonk/setup.go +++ b/internal/backend/bls12-381/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" "github.com/consensys/gnark/internal/backend/bls12-381/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,21 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element - // Domains used for the FFTs - DomainNum, DomainH fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +71,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bls12-381/plonk/verify.go b/internal/backend/bls12-381/plonk/verify.go index 53283ee4a9..620bc5e097 100644 --- a/internal/backend/bls12-381/plonk/verify.go +++ b/internal/backend/bls12-381/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise @@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go index 667d5d230e..9ecd27fb26 100644 --- a/internal/backend/bls24-315/cs/r1cs_sparse.go +++ b/internal/backend/bls24-315/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) @@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go index 272f31af54..e215ac343a 100644 --- a/internal/backend/bls24-315/cs/solution.go +++ b/internal/backend/bls24-315/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -32,14 +33,15 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bls24-315/groth16/marshal_test.go b/internal/backend/bls24-315/groth16/marshal_test.go index aa2ceda799..06bf1ec9de 100644 --- a/internal/backend/bls24-315/groth16/marshal_test.go +++ b/internal/backend/bls24-315/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bls24-315/groth16/prove.go b/internal/backend/bls24-315/groth16/prove.go index 606f9e07f9..fbc67eb021 100644 --- a/internal/backend/bls24-315/groth16/prove.go +++ b/internal/backend/bls24-315/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bls24-315/groth16/setup.go b/internal/backend/bls24-315/groth16/setup.go index e34b8d752b..596e7786fc 100644 --- a/internal/backend/bls24-315/groth16/setup.go +++ b/internal/backend/bls24-315/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bls24-315/plonk/marshal.go b/internal/backend/bls24-315/plonk/marshal.go index 83d53ffa1f..05f24f4f36 100644 --- a/internal/backend/bls24-315/plonk/marshal.go +++ b/internal/backend/bls24-315/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainNum.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainH.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { ([]fr.Element)(pk.Qo), ([]fr.Element)(pk.CQk), ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.LS1), - ([]fr.Element)(pk.LS2), - ([]fr.Element)(pk.LS3), - ([]fr.Element)(pk.CS1), - ([]fr.Element)(pk.CS2), - ([]fr.Element)(pk.CS3), + ([]fr.Element)(pk.S1Canonical), + ([]fr.Element)(pk.S2Canonical), + ([]fr.Element)(pk.S3Canonical), pk.Permutation, } @@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainNum.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainH.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ @@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { (*[]fr.Element)(&pk.Qo), (*[]fr.Element)(&pk.CQk), (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.LS1), - (*[]fr.Element)(&pk.LS2), - (*[]fr.Element)(&pk.LS3), - (*[]fr.Element)(&pk.CS1), - (*[]fr.Element)(&pk.CS2), - (*[]fr.Element)(&pk.CS3), + (*[]fr.Element)(&pk.S1Canonical), + (*[]fr.Element)(&pk.S2Canonical), + (*[]fr.Element)(&pk.S3Canonical), &pk.Permutation, } @@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.SizeInv, &vk.Generator, vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.SizeInv, &vk.Generator, &vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], diff --git a/internal/backend/bls24-315/plonk/marshal_test.go b/internal/backend/bls24-315/plonk/marshal_test.go index c24928ccf8..99763b07e3 100644 --- a/internal/backend/bls24-315/plonk/marshal_test.go +++ b/internal/backend/bls24-315/plonk/marshal_test.go @@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen @@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 @@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go index 6f951143ba..904a785599 100644 --- a/internal/backend/bls24-315/plonk/prove.go +++ b/internal/backend/bls24-315/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bls24-315" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" @@ -43,6 +41,7 @@ import ( ) type Proof struct { + // Commitments to the solution vectors LRO [3]kzg.Digest @@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := computeLRO(spr, pk, solution) + evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum) + blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + &pk.Domain[0]) if err != nil { return nil, err } // compute kzg commitments of bcl, bcr and bco - if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn return nil, err } + // Fiat Shamir this + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } + // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz polynomial.Polynomial + var blindedZCanonical []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { var err error - bz, err = computeBlindedZ(ll, lr, lo, pk, gamma) + blindedZCanonical, err = computeBlindedZCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + pk, beta, gamma) if err != nil { chZ <- err close(chZ) @@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn // this may add additional arithmetic operations, but with smaller tasks // we ensure that this commitment is well parallelized, without having a "unbalanced task" making // the rest of the code wait too long. - if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { + if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { chZ <- err close(chZ) return @@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn }() // evaluation of the blinded versions of l, r, o and bz - // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial + // on the coset of the big domain + var ( + evaluationBlindedLDomainBigBitReversed []fr.Element + evaluationBlindedRDomainBigBitReversed []fr.Element + evaluationBlindedODomainBigBitReversed []fr.Element + evaluationBlindedZDomainBigBitReversed []fr.Element + ) chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateHDomain(bcl, &pk.DomainH) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evalBR = evaluateHDomain(bcr, &pk.DomainH) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evalBO = evaluateHDomain(bco, &pk.DomainH) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() - var constraintsInd, constraintsOrdering polynomial.Polynomial + var constraintsInd, constraintsOrdering []fr.Element chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) - copy(qk, fullWitness[:spr.NbPublicVariables]) - copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) - fft.BitReverse(qk) - - // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) - // --> uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() @@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn if err := <-chConstraintOrdering; err != nil { return nil, err } + <-chConstraintInd + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, - &zetaShifted, - &pk.DomainH, + blindedZCanonical, + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.DomainH, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) totalDegree := rou + bo @@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, res := cp[:totalDegree+1] // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) + blindingPoly := make([]fr.Element, bo+1) for i := uint64(0); i < bo+1; i++ { if _, err := blindingPoly[i].SetRandom(); err != nil { return nil, err @@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, } return res, nil + } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.Domain[0].Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.Domain[1].Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.Domain[1].Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) - - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] return h1, h2, h3 @@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.Domain[0].Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bls24-315/plonk/setup.go b/internal/backend/bls24-315/plonk/setup.go index 8327805f3a..c0fd45c8b2 100644 --- a/internal/backend/bls24-315/plonk/setup.go +++ b/internal/backend/bls24-315/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" "github.com/consensys/gnark/internal/backend/bls24-315/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,21 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element - // Domains used for the FFTs - DomainNum, DomainH fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +71,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bls24-315/plonk/verify.go b/internal/backend/bls24-315/plonk/verify.go index e8d9fa52b5..ae08037e77 100644 --- a/internal/backend/bls24-315/plonk/verify.go +++ b/internal/backend/bls24-315/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise @@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go index b92f7d2273..8eab1f901f 100644 --- a/internal/backend/bn254/cs/r1cs_sparse.go +++ b/internal/backend/bn254/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) @@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go index 323c472df4..46cb2eb6af 100644 --- a/internal/backend/bn254/cs/solution.go +++ b/internal/backend/bn254/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -32,14 +33,15 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bn254/groth16/marshal_test.go b/internal/backend/bn254/groth16/marshal_test.go index d8f8c38496..2db9e45415 100644 --- a/internal/backend/bn254/groth16/marshal_test.go +++ b/internal/backend/bn254/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bn254/groth16/prove.go b/internal/backend/bn254/groth16/prove.go index 87bb23c383..f4a929f426 100644 --- a/internal/backend/bn254/groth16/prove.go +++ b/internal/backend/bn254/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bn254/groth16/setup.go b/internal/backend/bn254/groth16/setup.go index 95cccddf80..334461b25b 100644 --- a/internal/backend/bn254/groth16/setup.go +++ b/internal/backend/bn254/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bn254/plonk/marshal.go b/internal/backend/bn254/plonk/marshal.go index 6012cb114b..e4c4d7059f 100644 --- a/internal/backend/bn254/plonk/marshal.go +++ b/internal/backend/bn254/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainNum.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainH.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { ([]fr.Element)(pk.Qo), ([]fr.Element)(pk.CQk), ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.LS1), - ([]fr.Element)(pk.LS2), - ([]fr.Element)(pk.LS3), - ([]fr.Element)(pk.CS1), - ([]fr.Element)(pk.CS2), - ([]fr.Element)(pk.CS3), + ([]fr.Element)(pk.S1Canonical), + ([]fr.Element)(pk.S2Canonical), + ([]fr.Element)(pk.S3Canonical), pk.Permutation, } @@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainNum.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainH.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ @@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { (*[]fr.Element)(&pk.Qo), (*[]fr.Element)(&pk.CQk), (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.LS1), - (*[]fr.Element)(&pk.LS2), - (*[]fr.Element)(&pk.LS3), - (*[]fr.Element)(&pk.CS1), - (*[]fr.Element)(&pk.CS2), - (*[]fr.Element)(&pk.CS3), + (*[]fr.Element)(&pk.S1Canonical), + (*[]fr.Element)(&pk.S2Canonical), + (*[]fr.Element)(&pk.S3Canonical), &pk.Permutation, } @@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.SizeInv, &vk.Generator, vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.SizeInv, &vk.Generator, &vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], diff --git a/internal/backend/bn254/plonk/marshal_test.go b/internal/backend/bn254/plonk/marshal_test.go index f17f8ca756..ceec7305b0 100644 --- a/internal/backend/bn254/plonk/marshal_test.go +++ b/internal/backend/bn254/plonk/marshal_test.go @@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen @@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 @@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go index 25b24887ed..126277d13e 100644 --- a/internal/backend/bn254/plonk/prove.go +++ b/internal/backend/bn254/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bn254" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" @@ -43,6 +41,7 @@ import ( ) type Proof struct { + // Commitments to the solution vectors LRO [3]kzg.Digest @@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := computeLRO(spr, pk, solution) + evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum) + blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + &pk.Domain[0]) if err != nil { return nil, err } // compute kzg commitments of bcl, bcr and bco - if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, return nil, err } + // Fiat Shamir this + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } + // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz polynomial.Polynomial + var blindedZCanonical []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { var err error - bz, err = computeBlindedZ(ll, lr, lo, pk, gamma) + blindedZCanonical, err = computeBlindedZCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + pk, beta, gamma) if err != nil { chZ <- err close(chZ) @@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // this may add additional arithmetic operations, but with smaller tasks // we ensure that this commitment is well parallelized, without having a "unbalanced task" making // the rest of the code wait too long. - if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { + if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { chZ <- err close(chZ) return @@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, }() // evaluation of the blinded versions of l, r, o and bz - // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial + // on the coset of the big domain + var ( + evaluationBlindedLDomainBigBitReversed []fr.Element + evaluationBlindedRDomainBigBitReversed []fr.Element + evaluationBlindedODomainBigBitReversed []fr.Element + evaluationBlindedZDomainBigBitReversed []fr.Element + ) chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateHDomain(bcl, &pk.DomainH) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evalBR = evaluateHDomain(bcr, &pk.DomainH) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evalBO = evaluateHDomain(bco, &pk.DomainH) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() - var constraintsInd, constraintsOrdering polynomial.Polynomial + var constraintsInd, constraintsOrdering []fr.Element chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) - copy(qk, fullWitness[:spr.NbPublicVariables]) - copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) - fft.BitReverse(qk) - - // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) - // --> uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() @@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, if err := <-chConstraintOrdering; err != nil { return nil, err } + <-chConstraintInd + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, - &zetaShifted, - &pk.DomainH, + blindedZCanonical, + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.DomainH, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness, } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) totalDegree := rou + bo @@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, res := cp[:totalDegree+1] // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) + blindingPoly := make([]fr.Element, bo+1) for i := uint64(0); i < bo+1; i++ { if _, err := blindingPoly[i].SetRandom(); err != nil { return nil, err @@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, } return res, nil + } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.Domain[0].Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.Domain[1].Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.Domain[1].Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) - - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] return h1, h2, h3 @@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.Domain[0].Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bn254/plonk/setup.go b/internal/backend/bn254/plonk/setup.go index b7b8d41869..421f035bd0 100644 --- a/internal/backend/bn254/plonk/setup.go +++ b/internal/backend/bn254/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" "github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" "github.com/consensys/gnark/internal/backend/bn254/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,21 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element - // Domains used for the FFTs - DomainNum, DomainH fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +71,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bn254/plonk/verify.go b/internal/backend/bn254/plonk/verify.go index 0c5261a51c..564431c6c3 100644 --- a/internal/backend/bn254/plonk/verify.go +++ b/internal/backend/bn254/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness) zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise @@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bw6-633/cs/r1cs_sparse.go b/internal/backend/bw6-633/cs/r1cs_sparse.go index e6154b4227..7cc2ccfad0 100644 --- a/internal/backend/bw6-633/cs/r1cs_sparse.go +++ b/internal/backend/bw6-633/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) @@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go index 3f7daa0122..4b611e7f07 100644 --- a/internal/backend/bw6-633/cs/solution.go +++ b/internal/backend/bw6-633/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -32,14 +33,15 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bw6-633/groth16/marshal_test.go b/internal/backend/bw6-633/groth16/marshal_test.go index 1fc47f35de..b2a3a63d4b 100644 --- a/internal/backend/bw6-633/groth16/marshal_test.go +++ b/internal/backend/bw6-633/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bw6-633/groth16/prove.go b/internal/backend/bw6-633/groth16/prove.go index 45c4688959..b72cf7b41f 100644 --- a/internal/backend/bw6-633/groth16/prove.go +++ b/internal/backend/bw6-633/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bw6-633/groth16/setup.go b/internal/backend/bw6-633/groth16/setup.go index 17caeecb6f..b26489abbc 100644 --- a/internal/backend/bw6-633/groth16/setup.go +++ b/internal/backend/bw6-633/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bw6-633/plonk/marshal.go b/internal/backend/bw6-633/plonk/marshal.go index 756fa46bf9..8f2ad728cc 100644 --- a/internal/backend/bw6-633/plonk/marshal.go +++ b/internal/backend/bw6-633/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainNum.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainH.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { ([]fr.Element)(pk.Qo), ([]fr.Element)(pk.CQk), ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.LS1), - ([]fr.Element)(pk.LS2), - ([]fr.Element)(pk.LS3), - ([]fr.Element)(pk.CS1), - ([]fr.Element)(pk.CS2), - ([]fr.Element)(pk.CS3), + ([]fr.Element)(pk.S1Canonical), + ([]fr.Element)(pk.S2Canonical), + ([]fr.Element)(pk.S3Canonical), pk.Permutation, } @@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainNum.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainH.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ @@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { (*[]fr.Element)(&pk.Qo), (*[]fr.Element)(&pk.CQk), (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.LS1), - (*[]fr.Element)(&pk.LS2), - (*[]fr.Element)(&pk.LS3), - (*[]fr.Element)(&pk.CS1), - (*[]fr.Element)(&pk.CS2), - (*[]fr.Element)(&pk.CS3), + (*[]fr.Element)(&pk.S1Canonical), + (*[]fr.Element)(&pk.S2Canonical), + (*[]fr.Element)(&pk.S3Canonical), &pk.Permutation, } @@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.SizeInv, &vk.Generator, vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.SizeInv, &vk.Generator, &vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], diff --git a/internal/backend/bw6-633/plonk/marshal_test.go b/internal/backend/bw6-633/plonk/marshal_test.go index ebf373a463..7ace49aef7 100644 --- a/internal/backend/bw6-633/plonk/marshal_test.go +++ b/internal/backend/bw6-633/plonk/marshal_test.go @@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen @@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 @@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen diff --git a/internal/backend/bw6-633/plonk/prove.go b/internal/backend/bw6-633/plonk/prove.go index e633795585..90ac35a7aa 100644 --- a/internal/backend/bw6-633/plonk/prove.go +++ b/internal/backend/bw6-633/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-633" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" @@ -43,6 +41,7 @@ import ( ) type Proof struct { + // Commitments to the solution vectors LRO [3]kzg.Digest @@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := computeLRO(spr, pk, solution) + evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum) + blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + &pk.Domain[0]) if err != nil { return nil, err } // compute kzg commitments of bcl, bcr and bco - if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes return nil, err } + // Fiat Shamir this + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } + // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz polynomial.Polynomial + var blindedZCanonical []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { var err error - bz, err = computeBlindedZ(ll, lr, lo, pk, gamma) + blindedZCanonical, err = computeBlindedZCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + pk, beta, gamma) if err != nil { chZ <- err close(chZ) @@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes // this may add additional arithmetic operations, but with smaller tasks // we ensure that this commitment is well parallelized, without having a "unbalanced task" making // the rest of the code wait too long. - if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { + if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { chZ <- err close(chZ) return @@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes }() // evaluation of the blinded versions of l, r, o and bz - // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial + // on the coset of the big domain + var ( + evaluationBlindedLDomainBigBitReversed []fr.Element + evaluationBlindedRDomainBigBitReversed []fr.Element + evaluationBlindedODomainBigBitReversed []fr.Element + evaluationBlindedZDomainBigBitReversed []fr.Element + ) chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateHDomain(bcl, &pk.DomainH) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evalBR = evaluateHDomain(bcr, &pk.DomainH) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evalBO = evaluateHDomain(bco, &pk.DomainH) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() - var constraintsInd, constraintsOrdering polynomial.Polynomial + var constraintsInd, constraintsOrdering []fr.Element chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) - copy(qk, fullWitness[:spr.NbPublicVariables]) - copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) - fft.BitReverse(qk) - - // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) - // --> uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() @@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes if err := <-chConstraintOrdering; err != nil { return nil, err } + <-chConstraintInd + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, - &zetaShifted, - &pk.DomainH, + blindedZCanonical, + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.DomainH, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) totalDegree := rou + bo @@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, res := cp[:totalDegree+1] // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) + blindingPoly := make([]fr.Element, bo+1) for i := uint64(0); i < bo+1; i++ { if _, err := blindingPoly[i].SetRandom(); err != nil { return nil, err @@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, } return res, nil + } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.Domain[0].Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.Domain[1].Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.Domain[1].Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) - - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] return h1, h2, h3 @@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.Domain[0].Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bw6-633/plonk/setup.go b/internal/backend/bw6-633/plonk/setup.go index 06cfe86c1b..f7f6fa581c 100644 --- a/internal/backend/bw6-633/plonk/setup.go +++ b/internal/backend/bw6-633/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" "github.com/consensys/gnark/internal/backend/bw6-633/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,21 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element - // Domains used for the FFTs - DomainNum, DomainH fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +71,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bw6-633/plonk/verify.go b/internal/backend/bw6-633/plonk/verify.go index 49e6684f2c..1a7651695b 100644 --- a/internal/backend/bw6-633/plonk/verify.go +++ b/internal/backend/bw6-633/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise @@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go index fbb30462da..8cc16bcf4c 100644 --- a/internal/backend/bw6-761/cs/r1cs_sparse.go +++ b/internal/backend/bw6-761/cs/r1cs_sparse.go @@ -21,9 +21,12 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/fxamacker/cbor/v2" "io" + "math" "math/big" "os" + "runtime" "strings" + "sync" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/witness" @@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f return solution.values, err } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) @@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 // else returns the wire position (L -> 0, R -> 1, O -> 2) diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go index 560d95f7d5..fb1e1a19bd 100644 --- a/internal/backend/bw6-761/cs/solution.go +++ b/internal/backend/bw6-761/cs/solution.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/frontend/schema" @@ -32,14 +33,15 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i := 0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i := 0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/backend/bw6-761/groth16/marshal_test.go b/internal/backend/bw6-761/groth16/marshal_test.go index 73054efb74..bccf077d51 100644 --- a/internal/backend/bw6-761/groth16/marshal_test.go +++ b/internal/backend/bw6-761/groth16/marshal_test.go @@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/backend/bw6-761/groth16/prove.go b/internal/backend/bw6-761/groth16/prove.go index 8848e71c05..663bc522b1 100644 --- a/internal/backend/bw6-761/groth16/prove.go +++ b/internal/backend/bw6-761/groth16/prove.go @@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/backend/bw6-761/groth16/setup.go b/internal/backend/bw6-761/groth16/setup.go index 137b4df3fe..0e100d3319 100644 --- a/internal/backend/bw6-761/groth16/setup.go +++ b/internal/backend/bw6-761/groth16/setup.go @@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/backend/bw6-761/plonk/marshal.go b/internal/backend/bw6-761/plonk/marshal.go index 6daaa3130f..fe81cc4e0d 100644 --- a/internal/backend/bw6-761/plonk/marshal.go +++ b/internal/backend/bw6-761/plonk/marshal.go @@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainNum.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { return } n += n2 - n2, err = pk.DomainH.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { return } n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality) - if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } @@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { ([]fr.Element)(pk.Qo), ([]fr.Element)(pk.CQk), ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.LS1), - ([]fr.Element)(pk.LS2), - ([]fr.Element)(pk.LS3), - ([]fr.Element)(pk.CS1), - ([]fr.Element)(pk.CS2), - ([]fr.Element)(pk.CS3), + ([]fr.Element)(pk.S1Canonical), + ([]fr.Element)(pk.S2Canonical), + ([]fr.Element)(pk.S3Canonical), pk.Permutation, } @@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainNum.ReadFrom(r) + n2, err := pk.Domain[0].ReadFrom(r) n += n2 if err != nil { return n, err } - n2, err = pk.DomainH.ReadFrom(r) + n2, err = pk.Domain[1].ReadFrom(r) n += n2 if err != nil { return n, err } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ @@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { (*[]fr.Element)(&pk.Qo), (*[]fr.Element)(&pk.CQk), (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.LS1), - (*[]fr.Element)(&pk.LS2), - (*[]fr.Element)(&pk.LS3), - (*[]fr.Element)(&pk.CS1), - (*[]fr.Element)(&pk.CS2), - (*[]fr.Element)(&pk.CS3), + (*[]fr.Element)(&pk.S1Canonical), + (*[]fr.Element)(&pk.S2Canonical), + (*[]fr.Element)(&pk.S3Canonical), &pk.Permutation, } @@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { &vk.SizeInv, &vk.Generator, vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { &vk.SizeInv, &vk.Generator, &vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], diff --git a/internal/backend/bw6-761/plonk/marshal_test.go b/internal/backend/bw6-761/plonk/marshal_test.go index 4a18f7651f..2b47f5e50a 100644 --- a/internal/backend/bw6-761/plonk/marshal_test.go +++ b/internal/backend/bw6-761/plonk/marshal_test.go @@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen @@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) { // random pk var pk ProvingKey pk.Vk = &vk - pk.DomainNum = *fft.NewDomain(42, 3, false) - pk.DomainH = *fft.NewDomain(4*42, 1, false) - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Domain[0] = *fft.NewDomain(42) + pk.Domain[1] = *fft.NewDomain(4 * 42) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < 12; i++ { pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) { pk.Qo[i].SetUint64(42) } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) pk.Permutation[0] = -12 pk.Permutation[len(pk.Permutation)-1] = 8888 @@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) { var vk VerifyingKey vk.Size = 42 vk.SizeInv = fr.One() - vk.Shifter[1].SetUint64(12) _, _, g1gen, _ := curve.Generators() vk.S[0] = g1gen diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go index 3f2981607b..4ce2467212 100644 --- a/internal/backend/bw6-761/plonk/prove.go +++ b/internal/backend/bw6-761/plonk/prove.go @@ -27,8 +27,6 @@ import ( curve "github.com/consensys/gnark-crypto/ecc/bw6-761" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" @@ -43,6 +41,7 @@ import ( ) type Proof struct { + // Commitments to the solution vectors LRO [3]kzg.Digest @@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} @@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := computeLRO(spr, pk, solution) + evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum) + blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + &pk.Domain[0]) if err != nil { return nil, err } // compute kzg commitments of bcl, bcr and bco - if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes return nil, err } + // Fiat Shamir this + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } + // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz polynomial.Polynomial + var blindedZCanonical []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { var err error - bz, err = computeBlindedZ(ll, lr, lo, pk, gamma) + blindedZCanonical, err = computeBlindedZCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + pk, beta, gamma) if err != nil { chZ <- err close(chZ) @@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes // this may add additional arithmetic operations, but with smaller tasks // we ensure that this commitment is well parallelized, without having a "unbalanced task" making // the rest of the code wait too long. - if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { + if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { chZ <- err close(chZ) return @@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes }() // evaluation of the blinded versions of l, r, o and bz - // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial + // on the coset of the big domain + var ( + evaluationBlindedLDomainBigBitReversed []fr.Element + evaluationBlindedRDomainBigBitReversed []fr.Element + evaluationBlindedODomainBigBitReversed []fr.Element + evaluationBlindedZDomainBigBitReversed []fr.Element + ) chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateHDomain(bcl, &pk.DomainH) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evalBR = evaluateHDomain(bcr, &pk.DomainH) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evalBO = evaluateHDomain(bco, &pk.DomainH) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() - var constraintsInd, constraintsOrdering polynomial.Polynomial + var constraintsInd, constraintsOrdering []fr.Element chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) - copy(qk, fullWitness[:spr.NbPublicVariables]) - copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) - fft.BitReverse(qk) - - // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) - // --> uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() @@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes if err := <-chConstraintOrdering; err != nil { return nil, err } + <-chConstraintInd + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, - &zetaShifted, - &pk.DomainH, + blindedZCanonical, + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.DomainH, pk.Vk.KZGSRS, ) if err != nil { @@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2) + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) chDone := make(chan error, 2) go func() { var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) chDone <- err @@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { return @@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) totalDegree := rou + bo @@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, res := cp[:totalDegree+1] // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) + blindingPoly := make([]fr.Element, bo+1) for i := uint64(0); i < bo+1; i++ { if _, err := blindingPoly[i].SetRandom(); err != nil { return nil, err @@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, } return res, nil + } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.Domain[0].Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.Domain[1].Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.Domain[1].Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) - - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] return h1, h2, h3 @@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.Domain[0].Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/backend/bw6-761/plonk/setup.go b/internal/backend/bw6-761/plonk/setup.go index 5a0741b31e..946d153c3f 100644 --- a/internal/backend/bw6-761/plonk/setup.go +++ b/internal/backend/bw6-761/plonk/setup.go @@ -21,7 +21,6 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" "github.com/consensys/gnark/internal/backend/bw6-761/cs" kzgg "github.com/consensys/gnark-crypto/kzg" @@ -40,18 +39,21 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element - // Domains used for the FFTs - DomainNum, DomainH fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -69,13 +71,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized diff --git a/internal/backend/bw6-761/plonk/verify.go b/internal/backend/bw6-761/plonk/verify.go index 9cba6efb9a..266c8859bc 100644 --- a/internal/backend/bw6-761/plonk/verify.go +++ b/internal/backend/bw6-761/plonk/verify.go @@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] } - } + + // wait for the level to be done + wg.Wait() - // sanity check; ensure all wires are marked as "instantiated" - if !solution.isValid() { - panic("solver didn't instantiate all wires") + if len(chError) > 0 { + return <-chError + } } - return solution.values, nil + return nil } // IsSolved returns nil if given witness solves the R1CS and error otherwise @@ -167,7 +255,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) { // returns false, nil if there was no wire to solve // returns true, nil if exactly one wire was solved. In that case, it is redundant to check that // the constraint is satisfied later. -func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a,b,c fr.Element, err error) { +func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a,b,c *fr.Element) error { // the index of the non zero entry shows if L, R or O has an uninstantiated wire // the content is the ID of the wire non instantiated @@ -204,28 +292,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool return nil } - if err = processLExp(r.L.LinExp, &a, 1); err != nil { - return + if err := processLExp(r.L.LinExp, a, 1); err != nil { + return err } - if err = processLExp(r.R.LinExp, &b, 2); err != nil { - return + if err := processLExp(r.R.LinExp, b, 2); err != nil { + return err } - if err = processLExp(r.O.LinExp, &c, 3); err != nil { - return + if err := processLExp(r.O.LinExp, c, 3); err != nil { + return err } if loc == 0 { // there is nothing to solve, may happen if we have an assertion // (ie a constraints that doesn't yield any output) // or if we solved the unsolved wires with hint functions - return + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } + return nil } // we compute the wire value and instantiate it - solved = true - vID := termToCompute.WireID() + wID := termToCompute.WireID() // solver result var wire fr.Element @@ -234,36 +325,42 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool switch loc { case 1: if !b.IsZero() { - wire.Div(&c, &b). - Sub(&wire, &a) - a.Add(&a, &wire) + wire.Div(c, b). + Sub(&wire, a) + a.Add(a, &wire) } else { // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 2: if !a.IsZero() { - wire.Div(&c, &a). - Sub(&wire, &b) - b.Add(&b, &wire) + wire.Div(c, a). + Sub(&wire, b) + b.Add(b, &wire) } else { - // we didn't actually ensure that a * b == c - solved = false + var check fr.Element + if !check.Mul(a, b).Equal(c) { + return errUnsatisfiedConstraint + } } case 3: - wire.Mul(&a, &b). - Sub(&wire, &c) + wire.Mul(a, b). + Sub(&wire, c) - c.Add(&c, &wire) + c.Add(c, &wire) } // wire is the term (coeff * value) // but in the solution we want to store the value only // note that in gnark frontend, coeff here is always 1 or -1 cs.divByCoeff(&wire, termToCompute) - solution.set(vID, wire) + solution.set(wID, wire) + - return + return nil } // GetConstraints return a list of constraint formatted as L⋅R == O diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl index c572bfa1d4..dbc9915d2c 100644 --- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl +++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl @@ -6,6 +6,9 @@ import ( "github.com/consensys/gnark-crypto/ecc" "strings" "os" + "sync" + "runtime" + "math" "github.com/consensys/gnark/internal/backend/ioutils" "github.com/consensys/gnark/internal/backend/compiled" @@ -71,11 +74,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } - defer func() { - // release memory - solution.tmpHintsIO = nil - }() - // solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs copy(solution.values, witness) for i := 0; i < len(witness); i++ { @@ -84,7 +82,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f // keep track of the number of wire instantiations we do, for a sanity check to ensure // we instantiated all wires - solution.nbSolved += len(witness) + solution.nbSolved += uint64(len(witness)) // defer log printing once all solution.values are computed defer solution.printLogs(opt.LoggerOut, cs.Logs) @@ -95,19 +93,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f coefficientsNegInv[i].Neg(&coefficientsNegInv[i]) } - - // loop through the constraints to solve the variables - for i := 0; i < len(cs.Constraints); i++ { - if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil { - return solution.values, fmt.Errorf("constraint %d: %w", i, err) - } - if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil { - errMsg := err.Error() - if dID, ok := cs.MDebug[i]; ok { - errMsg = solution.logValue(cs.DebugInfo[dID]) - } - return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) - } + if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil { + return solution.values, err } // sanity check; ensure all wires are marked as "instantiated" @@ -120,6 +107,122 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f } +func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error { + // minWorkPerCPU is the minimum target number of constraint a task should hold + // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed + // sequentially without sync. + const minWorkPerCPU = 50.0 + + // cs.Levels has a list of levels, where all constraints in a level l(n) are independent + // and may only have dependencies on previous levels + + var wg sync.WaitGroup + chTasks := make(chan []int, runtime.NumCPU()) + chError := make(chan error, runtime.NumCPU()) + + // start a worker pool + // each worker wait on chTasks + // a task is a slice of constraint indexes to be solved + for i := 0; i < runtime.NumCPU(); i++ { + go func() { + for t := range chTasks { + for _, i := range t { + // for each constraint in the task, solve it. + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + wg.Done() + return + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + wg.Done() + return + } + } + wg.Done() + } + }() + } + + // clean up pool go routines + defer func() { + close(chTasks) + close(chError) + }() + + // for each level, we push the tasks + for _, level := range cs.Levels { + + // max CPU to use + maxCPU := float64(len(level)) / minWorkPerCPU + + if maxCPU <= 1.0 { + // we do it sequentially + for _, i := range level { + if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil { + return fmt.Errorf("constraint #%d is not satisfied: %w", i, err) + } + if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil { + errMsg := err.Error() + if dID, ok := cs.MDebug[i]; ok { + errMsg = solution.logValue(cs.DebugInfo[dID]) + } + return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg) + } + } + continue + } + + // number of tasks for this level is set to num cpus + // but if we don't have enough work for all our CPUS, it can be lower. + nbTasks := runtime.NumCPU() + maxTasks := int(math.Ceil(maxCPU)) + if nbTasks > maxTasks { + nbTasks = maxTasks + } + nbIterationsPerCpus := len(level) / nbTasks + + // more CPUs than tasks: a CPU will work on exactly one iteration + // note: this depends on minWorkPerCPU constant + if nbIterationsPerCpus < 1 { + nbIterationsPerCpus = 1 + nbTasks = len(level) + } + + + extraTasks := len(level) - (nbTasks * nbIterationsPerCpus) + extraTasksOffset := 0 + + for i := 0; i < nbTasks; i++ { + wg.Add(1) + _start := i*nbIterationsPerCpus + extraTasksOffset + _end := _start + nbIterationsPerCpus + if extraTasks > 0 { + _end++ + extraTasks-- + extraTasksOffset++ + } + // since we're never pushing more than num CPU tasks + // we will never be blocked here + chTasks <- level[_start:_end] + } + + // wait for the level to be done + wg.Wait() + + if len(chError) > 0 { + return <-chError + } + } + + return nil +} + + // computeHints computes wires associated with a hint function, if any // if there is no remaining wire to solve, returns -1 diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl index 43aaca6ee7..dba02f7cb6 100644 --- a/internal/generator/backend/template/representations/solution.go.tmpl +++ b/internal/generator/backend/template/representations/solution.go.tmpl @@ -3,6 +3,7 @@ import ( "errors" "fmt" "math/big" + "sync/atomic" "github.com/consensys/gnark/backend/hint" "github.com/consensys/gnark/internal/backend/compiled" @@ -13,14 +14,15 @@ import ( {{ template "import_curve" . }} ) +var errUnsatisfiedConstraint = errors.New("unsatisfied") + // solution represents elements needed to compute // a solution to a R1CS or SparseR1CS type solution struct { values, coefficients []fr.Element solved []bool - nbSolved int + nbSolved uint64 mHintsFunctions map[hint.ID]hint.Function - tmpHintsIO []*big.Int } func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) { @@ -30,7 +32,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E coefficients: coefficients, solved: make([]bool, nbWires), mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)), - tmpHintsIO: make([]*big.Int, 0), } for _, h := range hintFunctions { @@ -49,11 +50,12 @@ func (s *solution) set(id int, value fr.Element) { } s.values[id] = value s.solved[id] = true - s.nbSolved++ + atomic.AddUint64(&s.nbSolved, 1) + // s.nbSolved++ } func (s *solution) isValid() bool { - return s.nbSolved == len(s.values) + return int(s.nbSolved) == len(s.values) } // computeTerm computes coef*variable @@ -128,15 +130,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error { // tmp IO big int memory nbInputs := len(h.Inputs) nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs)) - m := len(s.tmpHintsIO) - if m < (nbInputs + nbOutputs) { - s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) - for i := m; i < len(s.tmpHintsIO); i++ { - s.tmpHintsIO[i] = big.NewInt(0) - } + // m := len(s.tmpHintsIO) + // if m < (nbInputs + nbOutputs) { + // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...) + // for i := m; i < len(s.tmpHintsIO); i++ { + // s.tmpHintsIO[i] = big.NewInt(0) + // } + // } + inputs := make([]*big.Int, nbInputs) + outputs := make([]*big.Int, nbOutputs) + for i :=0; i < nbInputs; i++ { + inputs[i] = big.NewInt(0) + } + for i :=0; i < nbOutputs; i++ { + outputs[i] = big.NewInt(0) } - inputs := s.tmpHintsIO[:nbInputs] - outputs := s.tmpHintsIO[nbInputs:nbInputs+nbOutputs] q := fr.Modulus() diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl index 58f9cf6dd7..485fcf437a 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl @@ -259,18 +259,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) - domain.FFTInverse(a, fft.DIF, 0) - domain.FFTInverse(b, fft.DIF, 0) - domain.FFTInverse(c, fft.DIF, 0) + domain.FFTInverse(a, fft.DIF) + domain.FFTInverse(b, fft.DIF) + domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, 1) - domain.FFT(b, fft.DIT, 1) - domain.FFT(c, fft.DIT, 1) + domain.FFT(a, fft.DIT, true) + domain.FFT(b, fft.DIT, true) + domain.FFT(c, fft.DIT, true) - var minusTwoInv fr.Element - minusTwoInv.SetUint64(2) - minusTwoInv.Neg(&minusTwoInv). - Inverse(&minusTwoInv) + var den, one fr.Element + one.SetOne() + den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) // h = ifft_coset(ca o cb - cc) // reusing a to avoid unecessary memalloc @@ -278,12 +278,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). Sub(&a[i], &c[i]). - Mul(&a[i], &minusTwoInv) + Mul(&a[i], &den) } }) // ifft_coset - domain.FFTInverse(a, fft.DIF, 1) + domain.FFTInverse(a, fft.DIF, true) utils.Parallelize(len(a), func(start, end int) { for i := start; i < end; i++ { diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl index 867e2aee83..14933f3931 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl @@ -74,7 +74,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error { nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables // Setting group for fft - domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true) + domain := fft.NewDomain(uint64(len(r1cs.Constraints))) // samples toxic waste toxicWaste, err := sampleToxicWaste() @@ -394,7 +394,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error { nbConstraints := len(r1cs.Constraints) // Setting group for fft - domain := fft.NewDomain(uint64(nbConstraints), 1, true) + domain := fft.NewDomain(uint64(nbConstraints)) // count number of infinity points we would have had we a normal setup // in pk.G1.A, pk.G1.B, and pk.G2.B diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl index 67ed2b1e93..e5ad79d370 100644 --- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl @@ -166,7 +166,7 @@ func TestProvingKeySerialization(t *testing.T) { var pk, pkCompressed, pkRaw ProvingKey // create a random pk - domain := fft.NewDomain(8, 1, true) + domain := fft.NewDomain(8) pk.Domain = *domain nbWires := 6 diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl index ebf5f20ae1..5e03ff6fb2 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl @@ -5,11 +5,11 @@ import ( "errors" ) -// WriteTo writes binary encoding of Proof to w +// WriteTo writes binary encoding of Proof to w func (proof *Proof) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{} { + toEncode := []interface{}{ &proof.LRO[0], &proof.LRO[1], &proof.LRO[2], @@ -27,10 +27,10 @@ func (proof *Proof) WriteTo(w io.Writer) (int64, error) { n, err := proof.BatchedProof.WriteTo(w) if err != nil { - return n+enc.BytesWritten(), err + return n + enc.BytesWritten(), err } n2, err := proof.ZShiftedOpening.WriteTo(w) - + return n + n2 + enc.BytesWritten(), err } @@ -55,13 +55,11 @@ func (proof *Proof) ReadFrom(r io.Reader) (int64, error) { n, err := proof.BatchedProof.ReadFrom(r) if err != nil { - return n+dec.BytesRead(), err + return n + dec.BytesRead(), err } n2, err := proof.ZShiftedOpening.ReadFrom(r) - return n+n2+dec.BytesRead(), err -} - - + return n + n2 + dec.BytesRead(), err +} // WriteTo writes binary encoding of ProvingKey to w func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { @@ -72,40 +70,37 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { } // fft domains - n2, err := pk.DomainNum.WriteTo(w) + n2, err := pk.Domain[0].WriteTo(w) if err != nil { - return + return } - n+=n2 + n += n2 - n2, err = pk.DomainH.WriteTo(w) + n2, err = pk.Domain[1].WriteTo(w) if err != nil { - return + return } - n+=n2 + n += n2 - // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality) - if len(pk.Permutation) != (3*int(pk.DomainNum.Cardinality)) { + // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality) + if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) { return n, errors.New("invalid permutation size, expected 3*domain cardinality") } enc := curve.NewEncoder(w) - // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't - // encode the size (nor does it convert from Montgomery to Regular form) + // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't + // encode the size (nor does it convert from Montgomery to Regular form) // so we explicitly transmit []fr.Element - toEncode := []interface{} { - ([]fr.Element)(pk.Ql), + toEncode := []interface{}{ + ([]fr.Element)(pk.Ql), ([]fr.Element)(pk.Qr), ([]fr.Element)(pk.Qm), ([]fr.Element)(pk.Qo), ([]fr.Element)(pk.CQk), ([]fr.Element)(pk.LQk), - ([]fr.Element)(pk.LS1), - ([]fr.Element)(pk.LS2), - ([]fr.Element)(pk.LS3), - ([]fr.Element)(pk.CS1), - ([]fr.Element)(pk.CS2), - ([]fr.Element)(pk.CS3), + ([]fr.Element)(pk.S1Canonical), + ([]fr.Element)(pk.S2Canonical), + ([]fr.Element)(pk.S3Canonical), pk.Permutation, } @@ -118,7 +113,6 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) { return n + enc.BytesWritten(), nil } - // ReadFrom reads from binary representation in r into ProvingKey func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { pk.Vk = &VerifyingKey{} @@ -127,59 +121,53 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) { return n, err } - n2, err := pk.DomainNum.ReadFrom(r) - n+=n2 - if err != nil { - return n, err + n2, err := pk.Domain[0].ReadFrom(r) + n += n2 + if err != nil { + return n, err } - n2, err = pk.DomainH.ReadFrom(r) - n+=n2 - if err != nil { - return n, err + n2, err = pk.Domain[1].ReadFrom(r) + n += n2 + if err != nil { + return n, err } - pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality) + pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality) dec := curve.NewDecoder(r) toDecode := []interface{}{ - (*[]fr.Element)(&pk.Ql), + (*[]fr.Element)(&pk.Ql), (*[]fr.Element)(&pk.Qr), (*[]fr.Element)(&pk.Qm), (*[]fr.Element)(&pk.Qo), (*[]fr.Element)(&pk.CQk), (*[]fr.Element)(&pk.LQk), - (*[]fr.Element)(&pk.LS1), - (*[]fr.Element)(&pk.LS2), - (*[]fr.Element)(&pk.LS3), - (*[]fr.Element)(&pk.CS1), - (*[]fr.Element)(&pk.CS2), - (*[]fr.Element)(&pk.CS3), + (*[]fr.Element)(&pk.S1Canonical), + (*[]fr.Element)(&pk.S2Canonical), + (*[]fr.Element)(&pk.S3Canonical), &pk.Permutation, } for _, v := range toDecode { if err := dec.Decode(v); err != nil { - return n +dec.BytesRead(), err + return n + dec.BytesRead(), err } } - return n +dec.BytesRead(), nil + return n + dec.BytesRead(), nil } - // WriteTo writes binary encoding of VerifyingKey to w func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { enc := curve.NewEncoder(w) - toEncode := []interface{} { - vk.Size, + toEncode := []interface{}{ + vk.Size, &vk.SizeInv, &vk.Generator, vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -196,7 +184,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { } } - return enc.BytesWritten(), nil } @@ -204,12 +191,10 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) { func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) toDecode := []interface{}{ - &vk.Size, + &vk.Size, &vk.SizeInv, &vk.Generator, &vk.NbPublicVariables, - &vk.Shifter[0], - &vk.Shifter[1], &vk.S[0], &vk.S[1], &vk.S[2], @@ -227,4 +212,4 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) { } return dec.BytesRead(), nil -} \ No newline at end of file +} \ No newline at end of file diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl index a48210d151..d56a8c8b24 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl @@ -7,7 +7,6 @@ import ( {{ template "import_fr" . }} {{ template "import_curve" . }} - {{ template "import_polynomial" . }} {{ template "import_kzg" . }} {{ template "import_fft" . }} {{ template "import_witness" . }} @@ -18,7 +17,9 @@ import ( "github.com/consensys/gnark-crypto/fiat-shamir" ) + type Proof struct { + // Commitments to the solution vectors LRO [3]kzg.Digest @@ -42,14 +43,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } hFunc := sha256.New() // create a transcript manager to apply Fiat Shamir - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // result proof := &Proof{} // compute the constraint system solution var solution []fr.Element - var err error + var err error if solution, err = spr.Solve(fullWitness, opt); err != nil { if !opt.Force { return nil, err @@ -65,17 +66,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } } // query l, r, o in Lagrange basis, not blinded - ll, lr, lo := computeLRO(spr, pk, solution) + evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution) // save ll, lr, lo, and make a copy of them in canonical basis. // note that we allocate more capacity to reuse for blinded polynomials - bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum) + blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + &pk.Domain[0]) if err != nil { return nil, err } // compute kzg commitments of bcl, bcr and bco - if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -85,18 +90,28 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } return nil, err } + // Fiat Shamir this + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return nil, err + } + // compute Z, the permutation accumulator polynomial, in canonical basis // ll, lr, lo are NOT blinded - var bz polynomial.Polynomial + var blindedZCanonical []fr.Element chZ := make(chan error, 1) var alpha fr.Element go func() { - var err error - bz, err = computeBlindedZ(ll, lr, lo, pk, gamma) + var err error + blindedZCanonical, err = computeBlindedZCanonical( + evaluationLDomainSmall, + evaluationRDomainSmall, + evaluationODomainSmall, + pk, beta, gamma) if err != nil { - chZ <- err + chZ <- err close(chZ) - return + return } // commit to the blinded version of z @@ -104,7 +119,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } // this may add additional arithmetic operations, but with smaller tasks // we ensure that this commitment is well parallelized, without having a "unbalanced task" making // the rest of the code wait too long. - if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { + if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil { chZ <- err close(chZ) return @@ -117,40 +132,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } }() // evaluation of the blinded versions of l, r, o and bz - // on the odd cosets of (Z/8mZ)/(Z/mZ) - var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial + // on the coset of the big domain + var ( + evaluationBlindedLDomainBigBitReversed []fr.Element + evaluationBlindedRDomainBigBitReversed []fr.Element + evaluationBlindedODomainBigBitReversed []fr.Element + evaluationBlindedZDomainBigBitReversed []fr.Element + ) chEvalBL := make(chan struct{}, 1) chEvalBR := make(chan struct{}, 1) chEvalBO := make(chan struct{}, 1) go func() { - evalBL = evaluateHDomain(bcl, &pk.DomainH) + evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1]) close(chEvalBL) }() go func() { - evalBR = evaluateHDomain(bcr, &pk.DomainH) + evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1]) close(chEvalBR) }() go func() { - evalBO = evaluateHDomain(bco, &pk.DomainH) + evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1]) close(chEvalBO) }() - var constraintsInd, constraintsOrdering polynomial.Polynomial + var constraintsInd, constraintsOrdering []fr.Element chConstraintInd := make(chan struct{}, 1) go func() { // compute qk in canonical basis, completed with the public inputs - qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality) - copy(qk, fullWitness[:spr.NbPublicVariables]) - copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) - pk.DomainNum.FFTInverse(qk, fft.DIF, 0) - fft.BitReverse(qk) - - // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ) - // --> uses the blinded version of l, r, o + qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality) + copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables]) + copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:]) + pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF) + fft.BitReverse(qkCompletedCanonical) + + // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain + // → uses the blinded version of l, r, o <-chEvalBL <-chEvalBR <-chEvalBO - constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk) + constraintsInd = evaluateConstraintsDomainBigBitReversed( + pk, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + qkCompletedCanonical) close(chConstraintInd) }() @@ -160,13 +185,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } chConstraintOrdering <- err return } - evalBZ = evaluateHDomain(bz, &pk.DomainH) - // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ) + + evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1]) + // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain // evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o. <-chEvalBL <-chEvalBR <-chEvalBO - constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma) + constraintsOrdering = evaluateOrderingDomainBigBitReversed( + pk, + evaluationBlindedZDomainBigBitReversed, + evaluationBlindedLDomainBigBitReversed, + evaluationBlindedRDomainBigBitReversed, + evaluationBlindedODomainBigBitReversed, + beta, + gamma) chConstraintOrdering <- nil close(chConstraintOrdering) }() @@ -174,12 +207,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } if err := <-chConstraintOrdering; err != nil { return nil, err } + <-chConstraintInd + // compute h in canonical form - h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha) + h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha) // compute kzg commitments of h1, h2 and h3 - if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { + if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil { return nil, err } @@ -194,15 +229,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } var wgZetaEvals sync.WaitGroup wgZetaEvals.Add(3) go func() { - blzeta = bcl.Eval(&zeta) + blzeta = eval(blindedLCanonical, zeta) wgZetaEvals.Done() }() go func() { - brzeta = bcr.Eval(&zeta) + brzeta = eval(blindedRCanonical, zeta) wgZetaEvals.Done() }() go func() { - bozeta = bco.Eval(&zeta) + bozeta = eval(blindedOCanonical, zeta) wgZetaEvals.Done() }() @@ -210,9 +245,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } var zetaShifted fr.Element zetaShifted.Mul(&zeta, &pk.Vk.Generator) proof.ZShiftedOpening, err = kzg.Open( - bz, - &zetaShifted, - &pk.DomainH, + blindedZCanonical, + zetaShifted, pk.Vk.KZGSRS, ) if err != nil { @@ -223,53 +257,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } bzuzeta := proof.ZShiftedOpening.ClaimedValue var ( - linearizedPolynomial polynomial.Polynomial - linearizedPolynomialDigest curve.G1Affine - errLPoly error + linearizedPolynomialCanonical []fr.Element + linearizedPolynomialDigest curve.G1Affine + errLPoly error ) chLpoly := make(chan struct{}, 1) go func() { // compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k) wgZetaEvals.Wait() - linearizedPolynomial = computeLinearizedPolynomial( + linearizedPolynomialCanonical = computeLinearizedPolynomial( blzeta, brzeta, bozeta, alpha, + beta, gamma, zeta, bzuzeta, - bz, + blindedZCanonical, pk, ) // TODO this commitment is only necessary to derive the challenge, we should // be able to avoid doing it and get the challenge in another way - linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS) + linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS) close(chLpoly) }() - // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3) + // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3) var bZetaPowerm, bSize big.Int - bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) + bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1) var zetaPowerm fr.Element zetaPowerm.Exp(zeta, &bSize) zetaPowerm.ToBigIntRegular(&bZetaPowerm) foldedHDigest := proof.H[2] foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) - foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3) - foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) - foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1) + foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3) + foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1) - // foldedH = h1 + zeta*h2 + zeta**2*h3 + // foldedH = h1 + ζ*h2 + ζ²*h3 foldedH := h3 utils.Parallelize(len(foldedH), func(start, end int) { for i := start; i < end; i++ { - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3 - foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3 - foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1) - foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3 + foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2 + foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺² + foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1 } }) @@ -280,14 +315,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } // Batch open the first list of polynomials proof.BatchedProof, err = kzg.BatchOpenSinglePoint( - []polynomial.Polynomial{ + [][]fr.Element{ foldedH, - linearizedPolynomial, - bcl, - bcr, - bco, - pk.CS1, - pk.CS2, + linearizedPolynomialCanonical, + blindedLCanonical, + blindedRCanonical, + blindedOCanonical, + pk.S1Canonical, + pk.S2Canonical, }, []kzg.Digest{ foldedHDigest, @@ -298,9 +333,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } pk.Vk.S[0], pk.Vk.S[1], }, - &zeta, + zeta, hFunc, - &pk.DomainH, pk.Vk.KZGSRS, ) if err != nil { @@ -311,8 +345,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID } } +// eval evaluates c at p +func eval(c []fr.Element, p fr.Element) fr.Element { + var r fr.Element + for i := len(c) - 1; i >= 0; i-- { + r.Mul(&r, &p).Add(&r, &c[i]) + } + return r +} + // fills proof.LRO with kzg commits of bcl, bcr and bco -func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -338,7 +381,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS return err1 } -func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error { +func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error { n := runtime.NumCPU() / 2 var err0, err1, err2 error chCommit0 := make(chan struct{}, 1) @@ -364,44 +407,44 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err return err1 } -// computeBlindedLRO l, r, o in canonical basis with blinding -func computeBlindedLRO(ll,lr,lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) { - +// computeBlindedLROCanonical l, r, o in canonical basis with blinding +func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) { + // note that bcl, bcr and bco reuses cl, cr and co memory - cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality + 2) - cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality + 2) - co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality + 2) - + cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2) + chDone := make(chan error, 2) go func() { - var err error + var err error copy(cl, ll) - domain.FFTInverse(cl, fft.DIF, 0) + domain.FFTInverse(cl, fft.DIF) fft.BitReverse(cl) bcl, err = blindPoly(cl, domain.Cardinality, 1) - chDone <- err + chDone <- err }() go func() { var err error copy(cr, lr) - domain.FFTInverse(cr, fft.DIF, 0) + domain.FFTInverse(cr, fft.DIF) fft.BitReverse(cr) bcr, err = blindPoly(cr, domain.Cardinality, 1) - chDone <- err + chDone <- err }() copy(co, lo) - domain.FFTInverse(co, fft.DIF, 0) + domain.FFTInverse(co, fft.DIF) fft.BitReverse(co) if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil { - return + return } - err = <-chDone + err = <-chDone if err != nil { return } - err = <-chDone - return + err = <-chDone + return } @@ -412,9 +455,9 @@ func computeBlindedLRO(ll,lr,lo polynomial.Polynomial, domain *fft.Domain) (bcl, // * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1) // // WARNING: -// pre condition degree(cp) <= rou + bo -// pre condition cap(cp) >= int(totalDegree + 1) -func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) { +// pre condition degree(cp) ⩽ rou + bo +// pre condition cap(cp) ⩾ int(totalDegree + 1) +func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) { // degree of the blinded polynomial is max(rou+order, cp.Degree) totalDegree := rou + bo @@ -423,10 +466,10 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, res := cp[:totalDegree+1] // random polynomial - blindingPoly := make(polynomial.Polynomial, bo+1) + blindingPoly := make([]fr.Element, bo+1) for i := uint64(0); i < bo+1; i++ { if _, err := blindingPoly[i].SetRandom(); err != nil { - return nil, err + return nil, err } } @@ -436,16 +479,17 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, res[rou+i].Add(&res[rou+i], &blindingPoly[i]) } - return res, nil + return res, nil + } -// computeLRO extracts the solution l, r, o, and returns it in lagrange form. +// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form. // solution = [ public | secret | internal ] -func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) { +func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) { - s := int(pk.DomainNum.Cardinality) + s := int(pk.Domain[0].Cardinality) - var l, r, o polynomial.Polynomial + var l, r, o []fr.Element l = make([]fr.Element, s) r = make([]fr.Element, s) o = make([]fr.Element, s) @@ -478,47 +522,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly // // * Z of degree n (domainNum.Cardinality) // * Z(1)=1 -// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma) -// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1 - u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 } }) @@ -528,43 +568,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele Mul(&z[i], &gInv[i]) } - pk.DomainNum.FFTInverse(z, fft.DIF, 0) + pk.Domain[0].FFTInverse(z, fft.DIF) fft.BitReverse(z) - return blindPoly(z, pk.DomainNum.Cardinality, 2) + return blindPoly(z, pk.Domain[0].Cardinality, 2) } -// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on -// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on +// the big domain coset. // // * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets // * qk is the completed version of qk, in canonical version -func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { - var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial +func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element { + var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element var wg sync.WaitGroup wg.Add(4) go func() { - evalQl = evaluateHDomain(pk.Ql, &pk.DomainH) + evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1]) wg.Done() }() go func() { - evalQr = evaluateHDomain(pk.Qr, &pk.DomainH) + evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1]) wg.Done() }() go func() { - evalQm = evaluateHDomain(pk.Qm, &pk.DomainH) + evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1]) wg.Done() }() go func() { - evalQo = evaluateHDomain(pk.Qo, &pk.DomainH) + evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1]) wg.Done() }() - evalQk = evaluateHDomain(qk, &pk.DomainH) + evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1]) wg.Wait() - // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets - // of (Z/8mZ)/(Z/mZ) + + // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain utils.Parallelize(len(evalQk), func(start, end int) { var t0, t1 fr.Element for i := start; i < end; i++ { @@ -584,211 +624,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr. return evalQk } -// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ) -func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) { - - id = make([]fr.Element, pk.DomainH.Cardinality) - - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { - var acc fr.Element - acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start))) - for i := start; i < end; i++ { - id[i].Mul(&acc, &pk.DomainH.FinerGenerator) - acc.Mul(&acc, &pk.DomainH.Generator) - } - }) - - return id -} - -// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd -// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions. +// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd +// cosets of the big domain. // -// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets -// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets +// * z evaluation of the blinded permutation accumulator polynomial on odd cosets +// * l, r, o evaluation of the blinded solution vectors on odd cosets // * gamma randomization -func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial { +func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element { - // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ) - evalID := evalIDCosets(pk) + nbElmts := int(pk.Domain[1].Cardinality) - // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ) - var wg sync.WaitGroup - wg.Add(2) - var evalS1, evalS2, evalS3 polynomial.Polynomial - go func() { - evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH) - wg.Done() - }() - go func() { - evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH) - wg.Done() - }() - evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH) - wg.Wait() + // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ) + // on the big domain (coset). + res := make([]fr.Element, pk.Domain[1].Cardinality) - // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ) - res := evalS1 // re use allocated memory for evalS1 - s := uint64(len(evalZ)) - nn := uint64(64 - bits.TrailingZeros64(uint64(s))) + nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts))) // needed to shift evalZ - toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality + toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality) + + var cosetShift, cosetShiftSquare fr.Element + cosetShift.Set(&pk.Vk.CosetShift) + cosetShiftSquare.Square(&pk.Vk.CosetShift) + + utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) { + + var evaluationIDBigDomain fr.Element + evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))). + Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen) - utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) { var f [3]fr.Element var g [3]fr.Element - var eID fr.Element for i := start; i < end; i++ { - // here we want to left shift evalZ by domainH/domainNum - // however, evalZ is permuted - // we take the non permuted index - // compute the corresponding shift position - // permute it again - irev := bits.Reverse64(uint64(i)) >> nn - eID = evalID[irev] - - shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn - //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn + _i := bits.Reverse64(uint64(i)) >> nn + _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn - f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma - f[1].Mul(&eID, &pk.Vk.Shifter[0]) - f[2].Mul(&eID, &pk.Vk.Shifter[1]) - f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma - f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma + // in what follows gⁱ is understood as the generator of the chosen coset of domainBig + f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ + f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ + f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ - g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma - g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma - g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma + g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ + g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ + g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ - f[0].Mul(&f[0], &f[1]). - Mul(&f[0], &f[2]). - Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma) + f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) + g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - g[0].Mul(&g[0], &g[1]). - Mul(&g[0], &g[2]). - Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma) + res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ) - res[i].Sub(&g[0], &f[0]) + evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g } }) return res } -// evaluateHDomain evaluates poly (canonical form) of degree m> nn - // h[i].Mul(&h[i], &_u[irev%4]) - h[i].Mul(&h[i], &_u[irev%toShift]) + + _i := bits.Reverse64(i) >> nn + + t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain + h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t). + Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]). + Mul(&h[_i], &alpha). + Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]). + Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio]) } }) // put h in canonical form. h is of degree 3*(n+1)+2. // using fft.DIT put h revert bit reverse - pk.DomainH.FFTInverse(h, fft.DIT, 1) - // fmt.Println("h:") - // for i := 0; i < len(h); i++ { - // fmt.Printf("%s\n", h[i].String()) - // } - // fmt.Println("") + pk.Domain[1].FFTInverse(h, fft.DIT, true) // degree of hi is n+2 because of the blinding - h1 := h[:pk.DomainNum.Cardinality+2] - h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)] - h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)] + h1 := h[:pk.Domain[0].Cardinality+2] + h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)] + h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)] return h1, h2, h3 @@ -796,78 +779,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom // computeLinearizedPolynomial computes the linearized polynomial in canonical basis. // The purpose is to commit and open all in one ql, qr, qm, qo, qk. -// * a, b, c are the evaluation of l, r, o at zeta -// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z +// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta +// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z // * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk. -func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial { +// +// The Linearized polynomial is: +// +// α²*L₁(ζ)*Z(X) +// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)) +// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X) +func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element { // first part: individual constraints var rl fr.Element - rl.Mul(&r, &l) + rl.Mul(&rZeta, &lZeta) - // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) + // second part: + // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ) var s1, s2 fr.Element chS1 := make(chan struct{}, 1) go func() { - s1 = pk.CS1.Eval(&zeta) - s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma) + s1 = eval(pk.S1Canonical, zeta) // s1(ζ) + s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ) close(chS1) }() - t := pk.CS2.Eval(&zeta) - t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma) + tmp := eval(pk.S2Canonical, zeta) // s2(ζ) + tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ) <-chS1 - s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma) - Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta) - - s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma) - t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma) - t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma) - s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma) - - // third part L1(zeta)*alpha**2**Z - var lagrange, one, den, frNbElmt fr.Element + s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ) + + var uzeta, uuzeta fr.Element + uzeta.Mul(&zeta, &pk.Vk.CosetShift) + uuzeta.Mul(&uzeta, &pk.Vk.CosetShift) + + s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ) + tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ) + tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ) + s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + // third part L₁(ζ)*α²*Z + var lagrangeZeta, one, den, frNbElmt fr.Element one.SetOne() - nbElmt := int64(pk.DomainNum.Cardinality) - lagrange.Set(&zeta). - Exp(lagrange, big.NewInt(nbElmt)). - Sub(&lagrange, &one) + nbElmt := int64(pk.Domain[0].Cardinality) + lagrangeZeta.Set(&zeta). + Exp(lagrangeZeta, big.NewInt(nbElmt)). + Sub(&lagrangeZeta, &one) frNbElmt.SetUint64(uint64(nbElmt)) den.Sub(&zeta, &one). - Mul(&den, &frNbElmt). Inverse(&den) - lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1) - Mul(&lagrange, &alpha). - Mul(&lagrange, &alpha) // alpha**2*L_0 + lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1) + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &alpha). + Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ) - linPol := z.Clone() + linPol := make([]fr.Element, len(blindedZCanonical)) + copy(linPol, blindedZCanonical) utils.Parallelize(len(linPol), func(start, end int) { + var t0, t1 fr.Element + for i := start; i < end; i++ { - linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) - if i < len(pk.CS3) { - t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X) + + linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ) + + if i < len(pk.S3Canonical) { + + t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X) + linPol[i].Add(&linPol[i], &t0) } - linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) ) + linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)) if i < len(pk.Qm) { - t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm - t0.Mul(&pk.Ql[i], &l) + + t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X) + t0.Mul(&pk.Ql[i], &lZeta) t0.Add(&t0, &t1) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X) - t0.Mul(&pk.Qr[i], &r) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + t0.Mul(&pk.Qr[i], &rZeta) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X) - t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i]) - linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk + t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i]) + linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X) } - t0.Mul(&z[i], &lagrange) + t0.Mul(&blindedZCanonical[i], &lagrangeZeta) linPol[i].Add(&linPol[i], &t0) // finish the computation } }) diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl index 0c326d73a8..25458538e9 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl @@ -1,6 +1,5 @@ import ( "errors" - {{- template "import_polynomial" . }} {{- template "import_kzg" . }} {{- template "import_fr" . }} {{- template "import_fft" . }} @@ -22,18 +21,21 @@ type ProvingKey struct { Vk *VerifyingKey // qr,ql,qm,qo (in canonical basis). - Ql, Qr, Qm, Qo polynomial.Polynomial + Ql, Qr, Qm, Qo []fr.Element // LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs. // Storing LQk in Lagrange basis saves a fft... - CQk, LQk polynomial.Polynomial + CQk, LQk []fr.Element - // Domains used for the FFTs - DomainNum, DomainH fft.Domain + // Domains used for the FFTs. + // Domain[0] = small Domain + // Domain[1] = big Domain + Domain [2]fft.Domain + // Domain[0], Domain[1] fft.Domain - // s1, s2, s3 (L=Lagrange basis, C=canonical basis) - LS1, LS2, LS3 polynomial.Polynomial - CS1, CS2, CS3 polynomial.Polynomial + // Permutation polynomials + EvaluationPermutationBigDomainBitReversed []fr.Element + S1Canonical, S2Canonical, S3Canonical []fr.Element // position -> permuted position (position in [0,3*sizeSystem-1]) Permutation []int64 @@ -51,13 +53,12 @@ type VerifyingKey struct { Generator fr.Element NbPublicVariables uint64 - // shifters for extending the permutation set: from s=<1,z,..,z**n-1>, - // extended domain = s || shifter[0].s || shifter[1].s - Shifter [2]fr.Element - // Commitment scheme that is used for an instantiation of PLONK KZGSRS *kzg.SRS + // cosetShift generator of the coset on the small domain + CosetShift fr.Element + // S commitments to S1, S2, S3 S [3]kzg.Digest @@ -78,37 +79,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // fft domains sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints - pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false) + pk.Domain[0] = *fft.NewDomain(sizeSystem) + pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) // h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space, // the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases // except when n<6. if sizeSystem < 6 { - pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(8 * sizeSystem) } else { - pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false) + pk.Domain[1] = *fft.NewDomain(4 * sizeSystem) } - vk.Size = pk.DomainNum.Cardinality + vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) - vk.Generator.Set(&pk.DomainNum.Generator) + vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(spr.NbPublicVariables) - // shifters - vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator) - vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator) - if err := pk.InitKZG(srs); err != nil { return nil, nil, err } // public polynomials corresponding to constraints: [ placholders | constraints | assertions ] - pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality) - pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality) - pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality) - pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality) + pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality) + pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality) + pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality) + pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality) for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant pk.Ql[i].SetOne().Neg(&pk.Ql[i]) @@ -116,7 +114,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.Qm[i].SetZero() pk.Qo[i].SetZero() pk.CQk[i].SetZero() - pk.LQk[i].SetZero() // --> to be completed by the prover + pk.LQk[i].SetZero() // → to be completed by the prover } offset := spr.NbPublicVariables for i := 0; i < nbConstraints; i++ { // constraints @@ -130,11 +128,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K]) } - pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0) + pk.Domain[0].FFTInverse(pk.Ql, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qr, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qm, fft.DIF) + pk.Domain[0].FFTInverse(pk.Qo, fft.DIF) + pk.Domain[0].FFTInverse(pk.CQk, fft.DIF) fft.BitReverse(pk.Ql) fft.BitReverse(pk.Qr) fft.BitReverse(pk.Qm) @@ -145,7 +143,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) buildPermutation(spr, &pk) // set s1, s2, s3 - computeLDE(&pk) + ccomputePermutationPolynomials(&pk) // Commit to the polynomials to set up the verifying key var err error @@ -164,13 +162,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil { + if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil { + if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil { return nil, nil, err } - if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil { + if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil { return nil, nil, err } @@ -182,18 +180,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error) // // The permutation s is composed of cycles of maximum length such that // -// s. (l||r||o) = (l||r||o) +// s. (l∥r∥o) = (l∥r∥o) // -//, where l||r||o is the concatenation of the indices of l, r, o in +//, where l∥r∥o is the concatenation of the indices of l, r, o in // ql.l+qr.r+qm.l.r+qo.O+k = 0. // // The permutation is encoded as a slice s of size 3*size(l), where the -// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab +// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab // like this: for i in tab: tab[i] = tab[permutation[i]] func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables - sizeSolution := int(pk.DomainNum.Cardinality) + sizeSolution := int(pk.Domain[0].Cardinality) // init permutation pk.Permutation = make([]int64, 3*sizeSolution) @@ -238,60 +236,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) { } } -// computeLDE computes the LDE (Lagrange basis) of the permutations +// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations // s1, s2, s3. // -// ex: z gen of Z/mZ, u gen of Z/8mZ, then -// // 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 | // | // | Permutation // s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v // \---------------/ \--------------------/ \------------------------/ // s1 (LDE) s2 (LDE) s3 (LDE) -func computeLDE(pk *ProvingKey) { +func ccomputePermutationPolynomials(pk *ProvingKey) { - nbElmt := int(pk.DomainNum.Cardinality) + nbElmts := int(pk.Domain[0].Cardinality) - // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1] - sID := make([]fr.Element, 3*nbElmt) - sID[0].SetOne() - sID[nbElmt].Set(&pk.DomainNum.FinerGenerator) - sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator) - - for i := 1; i < nbElmt; i++ { - sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1 - sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1 - sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1 - } + // Lagrange form of ID + evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0]) // Lagrange form of S1, S2, S3 - pk.LS1 = make(polynomial.Polynomial, nbElmt) - pk.LS2 = make(polynomial.Polynomial, nbElmt) - pk.LS3 = make(polynomial.Polynomial, nbElmt) - for i := 0; i < nbElmt; i++ { - pk.LS1[i].Set(&sID[pk.Permutation[i]]) - pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]]) - pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]]) + pk.S1Canonical = make([]fr.Element, nbElmts) + pk.S2Canonical = make([]fr.Element, nbElmts) + pk.S3Canonical = make([]fr.Element, nbElmts) + for i := 0; i < nbElmts; i++ { + pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]]) + pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]]) + pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]]) } // Canonical form of S1, S2, S3 - pk.CS1 = make(polynomial.Polynomial, nbElmt) - pk.CS2 = make(polynomial.Polynomial, nbElmt) - pk.CS3 = make(polynomial.Polynomial, nbElmt) - copy(pk.CS1, pk.LS1) - copy(pk.CS2, pk.LS2) - copy(pk.CS3, pk.LS3) - pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0) - pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0) - fft.BitReverse(pk.CS1) - fft.BitReverse(pk.CS2) - fft.BitReverse(pk.CS3) + pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF) + pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF) + fft.BitReverse(pk.S1Canonical) + fft.BitReverse(pk.S2Canonical) + fft.BitReverse(pk.S3Canonical) + + // evaluation of permutation on the big domain + pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality) + copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical) + copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true) + pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true) + +} + +// getIDSmallDomain returns the Lagrange form of ID on the small domain +func getIDSmallDomain(domain *fft.Domain) []fr.Element { + + res := make([]fr.Element, 3*domain.Cardinality) + + res[0].SetOne() + res[domain.Cardinality].Set(&domain.FrMultiplicativeGen) + res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen) + for i := uint64(1); i < domain.Cardinality; i++ { + res[i].Mul(&res[i-1], &domain.Generator) + res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator) + res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator) + } + + return res } -// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS +// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS // // This should be used after deserializing a ProvingKey // as pk.Vk.KZG is NOT serialized @@ -324,4 +332,4 @@ func (vk *VerifyingKey) NbPublicWitness() int { // VerifyingKey returns pk.Vk func (pk *ProvingKey) VerifyingKey() interface{} { return pk.Vk -} \ No newline at end of file +} diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl index a38645edc3..09880d83ee 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl @@ -22,7 +22,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} hFunc := sha256.New() // transcript to derive the challenge - fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta") + fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta") // derive gamma from Comm(l), Comm(r), Comm(o) gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2]) @@ -30,6 +30,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} return err } + // derive beta from Comm(l), Comm(r), Comm(o) + beta, err := deriveRandomness(&fs, "beta") + if err != nil { + return err + } + // derive alpha from Comm(l), Comm(r), Comm(o), Com(Z) alpha, err := deriveRandomness(&fs, "alpha", &proof.Z) if err != nil { @@ -42,7 +48,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} return err } - // evaluation of Z=X**m-1 at zeta + // evaluation of Z=Xⁿ⁻¹ at ζ var zetaPowerM, zzeta fr.Element var bExpo big.Int one := fr.One() @@ -50,20 +56,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }} zetaPowerM.Exp(zeta, &bExpo) zzeta.Sub(&zetaPowerM, &one) - // ccompute PI = Sum_i if the circuits contains maps or slices +// this is actually a shallow copy → if the circuits contains maps or slices // only the reference is copied. func ShallowClone(circuit frontend.Circuit) frontend.Circuit { diff --git a/std/algebra/fields_bls12377/e12.go b/std/algebra/fields_bls12377/e12.go index d63f963f64..af387437e3 100644 --- a/std/algebra/fields_bls12377/e12.go +++ b/std/algebra/fields_bls12377/e12.go @@ -163,25 +163,25 @@ func (e *E12) CyclotomicSquareCompressed(api frontend.API, x E12, ext Extension) var t [7]E2 - // t0 = g1^2 + // t0 = g1² t[0].Square(api, x.C0.B1, ext) - // t1 = g5^2 + // t1 = g5² t[1].Square(api, x.C1.B2, ext) // t5 = g1 + g5 t[5].Add(api, x.C0.B1, x.C1.B2) - // t2 = (g1 + g5)^2 + // t2 = (g1 + g5)² t[2].Square(api, t[5], ext) - // t3 = g1^2 + g5^2 + // t3 = g1² + g5² t[3].Add(api, t[0], t[1]) // t5 = 2 * g1 * g5 t[5].Sub(api, t[2], t[3]) // t6 = g3 + g2 t[6].Add(api, x.C1.B0, x.C0.B2) - // t3 = (g3 + g2)^2 + // t3 = (g3 + g2)² t[3].Square(api, t[6], ext) - // t2 = g3^2 + // t2 = g3² t[2].Square(api, x.C1.B0, ext) // t6 = 2 * nr * g1 * g5 @@ -192,33 +192,33 @@ func (e *E12) CyclotomicSquareCompressed(api frontend.API, x E12, ext Extension) // z3 = 6 * nr * g1 * g5 + 2 * g3 e.C1.B0.Add(api, t[5], t[6]) - // t4 = nr * g5^2 + // t4 = nr * g5² t[4].MulByNonResidue(api, t[1], ext) - // t5 = nr * g5^2 + g1^2 + // t5 = nr * g5² + g1² t[5].Add(api, t[0], t[4]) - // t6 = nr * g5^2 + g1^2 - g2 + // t6 = nr * g5² + g1² - g2 t[6].Sub(api, t[5], x.C0.B2) - // t1 = g2^2 + // t1 = g2² t[1].Square(api, x.C0.B2, ext) - // t6 = 2 * nr * g5^2 + 2 * g1^2 - 2*g2 + // t6 = 2 * nr * g5² + 2 * g1² - 2*g2 t[6].Double(api, t[6]) - // z2 = 3 * nr * g5^2 + 3 * g1^2 - 2*g2 + // z2 = 3 * nr * g5² + 3 * g1² - 2*g2 e.C0.B2.Add(api, t[6], t[5]) - // t4 = nr * g2^2 + // t4 = nr * g2² t[4].MulByNonResidue(api, t[1], ext) - // t5 = g3^2 + nr * g2^2 + // t5 = g3² + nr * g2² t[5].Add(api, t[2], t[4]) - // t6 = g3^2 + nr * g2^2 - g1 + // t6 = g3² + nr * g2² - g1 t[6].Sub(api, t[5], x.C0.B1) - // t6 = 2 * g3^2 + 2 * nr * g2^2 - 2 * g1 + // t6 = 2 * g3² + 2 * nr * g2² - 2 * g1 t[6].Double(api, t[6]) - // z1 = 3 * g3^2 + 3 * nr * g2^2 - 2 * g1 + // z1 = 3 * g3² + 3 * nr * g2² - 2 * g1 e.C0.B1.Add(api, t[6], t[5]) - // t0 = g2^2 + g3^2 + // t0 = g2² + g3² t[0].Add(api, t[2], t[1]) // t5 = 2 * g3 * g2 t[5].Sub(api, t[3], t[0]) @@ -239,13 +239,13 @@ func (e *E12) Decompress(api frontend.API, x E12, ext Extension) *E12 { var one E2 one.SetOne(api) - // t0 = g1^2 + // t0 = g1² t[0].Square(api, x.C0.B1, ext) - // t1 = 3 * g1^2 - 2 * g2 + // t1 = 3 * g1² - 2 * g2 t[1].Sub(api, t[0], x.C0.B2). Double(api, t[1]). Add(api, t[1], t[0]) - // t0 = E * g5^2 + t1 + // t0 = E * g5² + t1 t[2].Square(api, x.C1.B2, ext) t[0].MulByNonResidue(api, t[2], ext). Add(api, t[0], t[1]) @@ -258,14 +258,14 @@ func (e *E12) Decompress(api frontend.API, x E12, ext Extension) *E12 { // t1 = g2 * g1 t[1].Mul(api, x.C0.B2, x.C0.B1, ext) - // t2 = 2 * g4^2 - 3 * g2 * g1 + // t2 = 2 * g4² - 3 * g2 * g1 t[2].Square(api, e.C1.B1, ext). Sub(api, t[2], t[1]). Double(api, t[2]). Sub(api, t[2], t[1]) // t1 = g3 * g5 t[1].Mul(api, x.C1.B0, x.C1.B2, ext) - // c_0 = E * (2 * g4^2 + g3 * g5 - 3 * g2 * g1) + 1 + // c₀ = E * (2 * g4² + g3 * g5 - 3 * g2 * g1) + 1 t[2].Add(api, t[2], t[1]) e.C0.B0.MulByNonResidue(api, t[2], ext). Add(api, e.C0.B0, one) @@ -296,9 +296,9 @@ func (e *E12) CyclotomicSquare(api frontend.API, x E12, ext Extension) *E12 { t[5].Square(api, x.C0.B1, ext) t[8].Add(api, x.C1.B2, x.C0.B1).Square(api, t[8], ext).Sub(api, t[8], t[4]).Sub(api, t[8], t[5]).MulByNonResidue(api, t[8], ext) // 2*x5*x1*u - t[0].MulByNonResidue(api, t[0], ext).Add(api, t[0], t[1]) // x4^2*u + x0^2 - t[2].MulByNonResidue(api, t[2], ext).Add(api, t[2], t[3]) // x2^2*u + x3^2 - t[4].MulByNonResidue(api, t[4], ext).Add(api, t[4], t[5]) // x5^2*u + x1^2 + t[0].MulByNonResidue(api, t[0], ext).Add(api, t[0], t[1]) // x4²*u + x0² + t[2].MulByNonResidue(api, t[2], ext).Add(api, t[2], t[3]) // x2²*u + x3² + t[4].MulByNonResidue(api, t[4], ext).Add(api, t[4], t[5]) // x5²*u + x1² e.C0.B0.Sub(api, t[0], x.C0.B0).Add(api, e.C0.B0, e.C0.B0).Add(api, e.C0.B0, t[0]) e.C0.B1.Sub(api, t[2], x.C0.B1).Add(api, e.C0.B1, e.C0.B1).Add(api, e.C0.B1, t[2]) diff --git a/std/algebra/fields_bls12377/e6.go b/std/algebra/fields_bls12377/e6.go index 610caec82e..caa171bbbb 100644 --- a/std/algebra/fields_bls12377/e6.go +++ b/std/algebra/fields_bls12377/e6.go @@ -109,7 +109,7 @@ func (e *E6) MulByFp2(api frontend.API, e1 E6, e2 E2, ext Extension) *E6 { return e } -// MulByNonResidue multiplies e by the imaginary elmt of Fp6 (noted a+bV+cV where V**3 in F^2) +// MulByNonResidue multiplies e by the imaginary elmt of Fp6 (noted a+bV+cV where V**3 in F²) func (e *E6) MulByNonResidue(api frontend.API, e1 E6, ext Extension) *E6 { res := E6{} res.B0.MulByNonResidue(api, e1.B2, ext) diff --git a/std/algebra/fields_bls24315/e12.go b/std/algebra/fields_bls24315/e12.go index a0d087eafe..08b8ea75a7 100644 --- a/std/algebra/fields_bls24315/e12.go +++ b/std/algebra/fields_bls24315/e12.go @@ -109,7 +109,7 @@ func (e *E12) MulByFp2(api frontend.API, e1 E12, e2 E4, ext Extension) *E12 { return e } -// MulByNonResidue multiplies e by the imaginary elmt of Fp12 (noted a+bV+cV where V**3 in F^2) +// MulByNonResidue multiplies e by the imaginary elmt of Fp12 (noted a+bV+cV where V**3 in F²) func (e *E12) MulByNonResidue(api frontend.API, e1 E12, ext Extension) *E12 { res := E12{} res.C0.MulByNonResidue(api, e1.C2, ext) diff --git a/std/algebra/fields_bls24315/e24.go b/std/algebra/fields_bls24315/e24.go index 6b12d58c6b..aa13c3ddde 100644 --- a/std/algebra/fields_bls24315/e24.go +++ b/std/algebra/fields_bls24315/e24.go @@ -163,25 +163,25 @@ func (e *E24) Square(api frontend.API, x E24, ext Extension) *E24 { func (e *E24) CyclotomicSquareCompressed(api frontend.API, x E24, ext Extension) *E24 { var t [7]E4 - // t0 = g1^2 + // t0 = g1² t[0].Square(api, x.D0.C1, ext) - // t1 = g5^2 + // t1 = g5² t[1].Square(api, x.D1.C2, ext) // t5 = g1 + g5 t[5].Add(api, x.D0.C1, x.D1.C2) - // t2 = (g1 + g5)^2 + // t2 = (g1 + g5)² t[2].Square(api, t[5], ext) - // t3 = g1^2 + g5^2 + // t3 = g1² + g5² t[3].Add(api, t[0], t[1]) // t5 = 2 * g1 * g5 t[5].Sub(api, t[2], t[3]) // t6 = g3 + g2 t[6].Add(api, x.D1.C0, x.D0.C2) - // t3 = (g3 + g2)^2 + // t3 = (g3 + g2)² t[3].Square(api, t[6], ext) - // t2 = g3^2 + // t2 = g3² t[2].Square(api, x.D1.C0, ext) // t6 = 2 * nr * g1 * g5 @@ -192,33 +192,33 @@ func (e *E24) CyclotomicSquareCompressed(api frontend.API, x E24, ext Extension) // z3 = 6 * nr * g1 * g5 + 2 * g3 e.D1.C0.Add(api, t[5], t[6]) - // t4 = nr * g5^2 + // t4 = nr * g5² t[4].MulByNonResidue(api, t[1], ext) - // t5 = nr * g5^2 + g1^2 + // t5 = nr * g5² + g1² t[5].Add(api, t[0], t[4]) - // t6 = nr * g5^2 + g1^2 - g2 + // t6 = nr * g5² + g1² - g2 t[6].Sub(api, t[5], x.D0.C2) - // t1 = g2^2 + // t1 = g2² t[1].Square(api, x.D0.C2, ext) - // t6 = 2 * nr * g5^2 + 2 * g1^2 - 2*g2 + // t6 = 2 * nr * g5² + 2 * g1² - 2*g2 t[6].Double(api, t[6]) - // z2 = 3 * nr * g5^2 + 3 * g1^2 - 2*g2 + // z2 = 3 * nr * g5² + 3 * g1² - 2*g2 e.D0.C2.Add(api, t[6], t[5]) - // t4 = nr * g2^2 + // t4 = nr * g2² t[4].MulByNonResidue(api, t[1], ext) - // t5 = g3^2 + nr * g2^2 + // t5 = g3² + nr * g2² t[5].Add(api, t[2], t[4]) - // t6 = g3^2 + nr * g2^2 - g1 + // t6 = g3² + nr * g2² - g1 t[6].Sub(api, t[5], x.D0.C1) - // t6 = 2 * g3^2 + 2 * nr * g2^2 - 2 * g1 + // t6 = 2 * g3² + 2 * nr * g2² - 2 * g1 t[6].Double(api, t[6]) - // z1 = 3 * g3^2 + 3 * nr * g2^2 - 2 * g1 + // z1 = 3 * g3² + 3 * nr * g2² - 2 * g1 e.D0.C1.Add(api, t[6], t[5]) - // t0 = g2^2 + g3^2 + // t0 = g2² + g3² t[0].Add(api, t[2], t[1]) // t5 = 2 * g3 * g2 t[5].Sub(api, t[3], t[0]) @@ -239,13 +239,13 @@ func (e *E24) Decompress(api frontend.API, x E24, ext Extension) *E24 { var one E4 one.SetOne(api) - // t0 = g1^2 + // t0 = g1² t[0].Square(api, x.D0.C1, ext) - // t1 = 3 * g1^2 - 2 * g2 + // t1 = 3 * g1² - 2 * g2 t[1].Sub(api, t[0], x.D0.C2). Double(api, t[1]). Add(api, t[1], t[0]) - // t0 = E * g5^2 + t1 + // t0 = E * g5² + t1 t[2].Square(api, x.D1.C2, ext) t[0].MulByNonResidue(api, t[2], ext). Add(api, t[0], t[1]) @@ -258,14 +258,14 @@ func (e *E24) Decompress(api frontend.API, x E24, ext Extension) *E24 { // t1 = g2 * g1 t[1].Mul(api, x.D0.C2, x.D0.C1, ext) - // t2 = 2 * g4^2 - 3 * g2 * g1 + // t2 = 2 * g4² - 3 * g2 * g1 t[2].Square(api, e.D1.C1, ext). Sub(api, t[2], t[1]). Double(api, t[2]). Sub(api, t[2], t[1]) // t1 = g3 * g5 t[1].Mul(api, x.D1.C0, x.D1.C2, ext) - // c_0 = E * (2 * g4^2 + g3 * g5 - 3 * g2 * g1) + 1 + // c₀ = E * (2 * g4² + g3 * g5 - 3 * g2 * g1) + 1 t[2].Add(api, t[2], t[1]) e.D0.C0.MulByNonResidue(api, t[2], ext). Add(api, e.D0.C0, one) @@ -474,7 +474,7 @@ func (e *E24) FinalExponentiation(api frontend.API, e1 E24, genT uint64, ext Ext // Daiki Hayashida and Kenichiro Hayasaka // and Tadanori Teruya // https://eprint.iacr.org/2020/875.pdf - // 3*Phi_24(api, p)/r = (api, u-1)^2 * (api, u+p) * (api, u^2+p^2) * (api, u^4+p^4-1) + 3 + // 3*Phi_24(api, p)/r = (api, u-1)² * (api, u+p) * (api, u²+p²) * (api, u⁴+p⁴-1) + 3 t[0].CyclotomicSquare(api, result, ext) t[1].Expt(api, result, genT, ext) t[2].Conjugate(api, result) diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/sw_bls12377/g1.go index b38735f94d..0e33212b96 100644 --- a/std/algebra/sw_bls12377/g1.go +++ b/std/algebra/sw_bls12377/g1.go @@ -137,7 +137,7 @@ func (p *G1Jac) DoubleAssign(api frontend.API) *G1Jac { S = api.Sub(S, XX) S = api.Sub(S, YYYY) S = api.Add(S, S) - M = api.Mul(XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0) + M = api.Mul(XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0) p.Z = api.Add(p.Z, p.Y) p.Z = api.Mul(p.Z, p.Z) p.Z = api.Sub(p.Z, YY) diff --git a/std/algebra/sw_bls12377/g2.go b/std/algebra/sw_bls12377/g2.go index 0aaa4ac03f..8014f9494c 100644 --- a/std/algebra/sw_bls12377/g2.go +++ b/std/algebra/sw_bls12377/g2.go @@ -183,7 +183,7 @@ func (p *G2Jac) Double(api frontend.API, p1 *G2Jac, ext fields_bls12377.Extensio S.Sub(api, S, XX) S.Sub(api, S, YYYY) S.Add(api, S, S) - M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0) + M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0) p.Z.Add(api, p.Z, p.Y) p.Z.Square(api, p.Z, ext) p.Z.Sub(api, p.Z, YY) diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/sw_bls24315/g1.go index 10bb09cdab..27560ee675 100644 --- a/std/algebra/sw_bls24315/g1.go +++ b/std/algebra/sw_bls24315/g1.go @@ -137,7 +137,7 @@ func (p *G1Jac) DoubleAssign(api frontend.API) *G1Jac { S = api.Sub(S, XX) S = api.Sub(S, YYYY) S = api.Add(S, S) - M = api.Mul(XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0) + M = api.Mul(XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0) p.Z = api.Add(p.Z, p.Y) p.Z = api.Mul(p.Z, p.Z) p.Z = api.Sub(p.Z, YY) diff --git a/std/algebra/sw_bls24315/g2.go b/std/algebra/sw_bls24315/g2.go index 1dd7154686..07e020e1b4 100644 --- a/std/algebra/sw_bls24315/g2.go +++ b/std/algebra/sw_bls24315/g2.go @@ -183,7 +183,7 @@ func (p *G2Jac) Double(api frontend.API, p1 *G2Jac, ext fields_bls24315.Extensio S.Sub(api, S, XX) S.Sub(api, S, YYYY) S.Add(api, S, S) - M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0) + M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0) p.Z.Add(api, p.Z, p.Y) p.Z.Square(api, p.Z, ext) p.Z.Sub(api, p.Z, YY) diff --git a/std/algebra/twistededwards/bandersnatch/curve.go b/std/algebra/twistededwards/bandersnatch/curve.go index 15caf30980..0cfb3c071e 100644 --- a/std/algebra/twistededwards/bandersnatch/curve.go +++ b/std/algebra/twistededwards/bandersnatch/curve.go @@ -25,10 +25,16 @@ import ( "github.com/consensys/gnark/internal/utils" ) +// Coordinates of a point on a twisted Edwards curve +type Coord struct { + X, Y big.Int +} + // EdCurve stores the info on the chosen edwards curve type EdCurve struct { - A, D, Cofactor, Order, BaseX, BaseY big.Int - ID ecc.ID + A, D, Cofactor, Order big.Int + Base Coord + ID ecc.ID } var constructors map[ecc.ID]func() EdCurve @@ -60,9 +66,11 @@ func newBandersnatch() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BLS12_381, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BLS12_381, } } diff --git a/std/algebra/twistededwards/bandersnatch/point.go b/std/algebra/twistededwards/bandersnatch/point.go index ddfb8ac52a..071f2762de 100644 --- a/std/algebra/twistededwards/bandersnatch/point.go +++ b/std/algebra/twistededwards/bandersnatch/point.go @@ -27,8 +27,15 @@ type Point struct { X, Y frontend.Variable } +// Neg computes the negative of a point in SNARK coordinates +func (p *Point) Neg(api frontend.API, p1 *Point) *Point { + p.X = api.Neg(p1.X) + p.Y = p1.Y + return p +} + // MustBeOnCurve checks if a point is on the reduced twisted Edwards curve -// a*x^2 + y^2 = 1 + d*x^2*y^2. +// a*x² + y² = 1 + d*x²*y². func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { one := big.NewInt(1) @@ -46,34 +53,9 @@ func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { } -// AddFixedPoint Adds two points, among which is one fixed point (the base), on a twisted edwards curve (eg jubjub) -// p1, base, ecurve are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) AddFixedPoint(api frontend.API, p1 *Point /*basex*/, x /*basey*/, y interface{}, curve EdCurve) *Point { - - // https://eprint.iacr.org/2008/013.pdf - - n11 := api.Mul(p1.X, y) - n12 := api.Mul(p1.Y, x) - n1 := api.Add(n11, n12) - - n21 := api.Mul(p1.Y, y) - n22 := api.Mul(p1.X, x) - an22 := api.Mul(n22, &curve.A) - n2 := api.Sub(n21, an22) - - d11 := api.Mul(curve.D, n11, n12) - d1 := api.Add(1, d11) - d2 := api.Sub(1, d11) - - p.X = api.DivUnchecked(n1, d1) - p.Y = api.DivUnchecked(n2, d2) - - return p -} - -// AddGeneric Adds two points on a twisted edwards curve (eg jubjub) +// Add Adds two points on a twisted edwards curve (eg jubjub) // p1, p2, c are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) AddGeneric(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { +func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { // https://eprint.iacr.org/2008/013.pdf @@ -103,14 +85,12 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { u := api.Mul(p1.X, p1.Y) v := api.Mul(p1.X, p1.X) w := api.Mul(p1.Y, p1.Y) - z := api.Mul(v, w) n1 := api.Mul(2, u) av := api.Mul(v, &curve.A) n2 := api.Sub(w, av) - d := api.Mul(z, curve.D) - d1 := api.Add(1, d) - d2 := api.Sub(1, d) + d1 := api.Add(w, av) + d2 := api.Sub(2, d1) p.X = api.DivUnchecked(n1, d1) p.Y = api.DivUnchecked(n2, d2) @@ -118,55 +98,72 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { return p } -// ScalarMulNonFixedBase computes the scalar multiplication of a point on a twisted Edwards curve +// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve // p1: base point (as snark point) // curve: parameters of the Edwards curve // scal: scalar as a SNARK constraint // Standard left to right double and add -func (p *Point) ScalarMulNonFixedBase(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { +func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { // first unpack the scalar b := api.ToBinary(scalar) - res := Point{ - 0, - 1, + res := Point{} + tmp := Point{} + A := Point{} + B := Point{} + + A.Double(api, p1, curve) + B.Add(api, &A, p1, curve) + + n := len(b) - 1 + res.X = api.Lookup2(b[n], b[n-1], 0, A.X, p1.X, B.X) + res.Y = api.Lookup2(b[n], b[n-1], 1, A.Y, p1.Y, B.Y) + + for i := n - 2; i >= 1; i -= 2 { + res.Double(api, &res, curve). + Double(api, &res, curve) + tmp.X = api.Lookup2(b[i], b[i-1], 0, A.X, p1.X, B.X) + tmp.Y = api.Lookup2(b[i], b[i-1], 1, A.Y, p1.Y, B.Y) + res.Add(api, &res, &tmp, curve) } - for i := len(b) - 1; i >= 0; i-- { + if n%2 == 0 { res.Double(api, &res, curve) - tmp := Point{} - tmp.AddGeneric(api, &res, p1, curve) - res.X = api.Select(b[i], tmp.X, res.X) - res.Y = api.Select(b[i], tmp.Y, res.Y) + tmp.Add(api, &res, p1, curve) + res.X = api.Select(b[0], tmp.X, res.X) + res.Y = api.Select(b[0], tmp.Y, res.Y) } p.X = res.X p.Y = res.Y + return p } -// ScalarMulFixedBase computes the scalar multiplication of a point on a twisted Edwards curve -// x, y: coordinates of the base point -// curve: parameters of the Edwards curve -// scal: scalar as a SNARK constraint -// Standard left to right double and add -func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar frontend.Variable, curve EdCurve) *Point { +// DoubleBaseScalarMul computes s1*P1+s2*P2 +// where P1 and P2 are points on a twisted Edwards curve +// and s1, s2 scalars. +func (p *Point) DoubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve EdCurve) *Point { - // first unpack the scalar - b := api.ToBinary(scalar) + // first unpack the scalars + b1 := api.ToBinary(s1) + b2 := api.ToBinary(s2) - res := Point{ - 0, - 1, - } + res := Point{} + tmp := Point{} + sum := Point{} + sum.Add(api, p1, p2, curve) + + n := len(b1) + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, sum.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, sum.Y) - for i := len(b) - 1; i >= 0; i-- { + for i := n - 2; i >= 0; i-- { res.Double(api, &res, curve) - tmp := Point{} - tmp.AddFixedPoint(api, &res, x, y, curve) - res.X = api.Select(b[i], tmp.X, res.X) - res.Y = api.Select(b[i], tmp.Y, res.Y) + tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, sum.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, sum.Y) + res.Add(api, &res, &tmp, curve) } p.X = res.X @@ -174,10 +171,3 @@ func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar fr return p } - -// Neg computes the negative of a point in SNARK coordinates -func (p *Point) Neg(api frontend.API, p1 *Point) *Point { - p.X = api.Neg(p1.X) - p.Y = p1.Y - return p -} diff --git a/std/algebra/twistededwards/bandersnatch/point_test.go b/std/algebra/twistededwards/bandersnatch/point_test.go index ba883d8930..ff879c2fa2 100644 --- a/std/algebra/twistededwards/bandersnatch/point_test.go +++ b/std/algebra/twistededwards/bandersnatch/point_test.go @@ -55,8 +55,8 @@ func TestIsOnCurve(t *testing.T) { t.Fatal(err) } - witness.P.X = (params.BaseX) - witness.P.Y = (params.BaseY) + witness.P.X = (params.Base.X) + witness.P.Y = (params.Base.Y) assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381)) @@ -74,7 +74,10 @@ func (circuit *add) Define(api frontend.API) error { return err } - res := circuit.P.AddFixedPoint(api, &circuit.P, params.BaseX, params.BaseY, params) + p := Point{} + p.X = params.Base.X + p.Y = params.Base.Y + res := circuit.P.Add(api, &circuit.P, &p, params) api.AssertIsEqual(res.X, circuit.E.X) api.AssertIsEqual(res.Y, circuit.E.Y) @@ -94,8 +97,8 @@ func TestAddFixedPoint(t *testing.T) { t.Fatal(err) } var base, point, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) point.Set(&base) r := big.NewInt(5) point.ScalarMul(&point, r) @@ -112,6 +115,9 @@ func TestAddFixedPoint(t *testing.T) { } +//------------------------------------------------------------- +// addGeneric + type addGeneric struct { P1, P2, E Point } @@ -124,7 +130,7 @@ func (circuit *addGeneric) Define(api frontend.API) error { return err } - res := circuit.P1.AddGeneric(api, &circuit.P1, &circuit.P2, params) + res := circuit.P1.Add(api, &circuit.P1, &circuit.P2, params) api.AssertIsEqual(res.X, circuit.E.X) api.AssertIsEqual(res.Y, circuit.E.Y) @@ -143,8 +149,8 @@ func TestAddGeneric(t *testing.T) { t.Fatal(err) } var point1, point2, expected bandersnatch.PointAffine - point1.X.SetBigInt(¶ms.BaseX) - point1.Y.SetBigInt(¶ms.BaseY) + point1.X.SetBigInt(¶ms.Base.X) + point1.Y.SetBigInt(¶ms.Base.Y) point2.Set(&point1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -165,6 +171,8 @@ func TestAddGeneric(t *testing.T) { } +//------------------------------------------------------------- +// Double type double struct { P, E Point } @@ -197,8 +205,8 @@ func TestDouble(t *testing.T) { t.Fatal(err) } var base, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) // populate witness @@ -225,8 +233,10 @@ func (circuit *scalarMulFixed) Define(api frontend.API) error { return err } - var resFixed Point - resFixed.ScalarMulFixedBase(api, params.BaseX, params.BaseY, circuit.S, params) + var resFixed, p Point + p.X = params.Base.X + p.Y = params.Base.Y + resFixed.ScalarMul(api, &p, circuit.S, params) api.AssertIsEqual(resFixed.X, circuit.E.X) api.AssertIsEqual(resFixed.Y, circuit.E.Y) @@ -246,8 +256,8 @@ func TestScalarMulFixed(t *testing.T) { t.Fatal(err) } var base, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) @@ -274,7 +284,7 @@ func (circuit *scalarMulGeneric) Define(api frontend.API) error { return err } - resGeneric := circuit.P.ScalarMulNonFixedBase(api, &circuit.P, circuit.S, params) + resGeneric := circuit.P.ScalarMul(api, &circuit.P, circuit.S, params) api.AssertIsEqual(resGeneric.X, circuit.E.X) api.AssertIsEqual(resGeneric.Y, circuit.E.Y) @@ -294,8 +304,8 @@ func TestScalarMulGeneric(t *testing.T) { t.Fatal(err) } var base, point, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) s := big.NewInt(902) point.ScalarMul(&base, s) // random point r := big.NewInt(230928302) @@ -336,8 +346,8 @@ func TestNeg(t *testing.T) { t.Fatal(err) } var base, expected bandersnatch.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Neg(&base) // generate witness @@ -351,25 +361,39 @@ func TestNeg(t *testing.T) { } -// benches +// Bench +func BenchmarkDouble(b *testing.B) { + var c double + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} -var ccsBench frontend.CompiledConstraintSystem +func BenchmarkAddGeneric(b *testing.B) { + var c addGeneric + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} -func BenchmarkScalarMulG1(b *testing.B) { - var c scalarMulGeneric - b.Run("groth16", func(b *testing.B) { - for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) - } +func BenchmarkAddFixedPoint(b *testing.B) { + var c add + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} - }) +func BenchmarkMustBeOnCurve(b *testing.B) { + var c mustBeOnCurve + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) b.Log("groth16", ccsBench.GetNbConstraints()) - b.Run("plonk", func(b *testing.B) { - for i := 0; i < b.N; i++ { - ccsBench, _ = frontend.Compile(ecc.BLS12_381, backend.PLONK, &c) - } +} - }) - b.Log("plonk", ccsBench.GetNbConstraints()) +func BenchmarkScalarMulGeneric(b *testing.B) { + var c scalarMulGeneric + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) +} +func BenchmarkScalarMulFixed(b *testing.B) { + var c scalarMulFixed + ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c) + b.Log("groth16", ccsBench.GetNbConstraints()) } diff --git a/std/algebra/twistededwards/curve.go b/std/algebra/twistededwards/curve.go index 798b01271f..432210c588 100644 --- a/std/algebra/twistededwards/curve.go +++ b/std/algebra/twistededwards/curve.go @@ -30,11 +30,17 @@ import ( "github.com/consensys/gnark/internal/utils" ) +// Coordinates of a point on a twisted Edwards curve +type Coord struct { + X, Y big.Int +} + // EdCurve stores the info on the chosen edwards curve // note that all curves implemented in gnark-crypto have A = -1 type EdCurve struct { - A, D, Cofactor, Order, BaseX, BaseY big.Int - ID ecc.ID + A, D, Cofactor, Order big.Int + Base Coord + ID ecc.ID } var constructors map[ecc.ID]func() EdCurve @@ -71,9 +77,11 @@ func newEdBN254() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BN254, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BN254, } } @@ -88,9 +96,11 @@ func newEdBLS381() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BLS12_381, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BLS12_381, } } @@ -105,9 +115,11 @@ func newEdBLS377() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BLS12_377, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BLS12_377, } } @@ -122,9 +134,11 @@ func newEdBW633() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BW6_633, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BW6_633, } } @@ -139,9 +153,11 @@ func newEdBW761() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BW6_761, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BW6_761, } } @@ -156,9 +172,11 @@ func newEdBLS315() EdCurve { D: utils.FromInterface(edcurve.D), Cofactor: utils.FromInterface(edcurve.Cofactor), Order: utils.FromInterface(edcurve.Order), - BaseX: utils.FromInterface(edcurve.Base.X), - BaseY: utils.FromInterface(edcurve.Base.Y), - ID: ecc.BLS24_315, + Base: Coord{ + X: utils.FromInterface(edcurve.Base.X), + Y: utils.FromInterface(edcurve.Base.Y), + }, + ID: ecc.BLS24_315, } } diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go index 7faf85faf4..72a712e517 100644 --- a/std/algebra/twistededwards/point.go +++ b/std/algebra/twistededwards/point.go @@ -27,8 +27,15 @@ type Point struct { X, Y frontend.Variable } +// Neg computes the negative of a point in SNARK coordinates +func (p *Point) Neg(api frontend.API, p1 *Point) *Point { + p.X = api.Neg(p1.X) + p.Y = p1.Y + return p +} + // MustBeOnCurve checks if a point is on the reduced twisted Edwards curve -// a*x^2 + y^2 = 1 + d*x^2*y^2. +// a*x² + y² = 1 + d*x²*y². func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { one := big.NewInt(1) @@ -46,34 +53,9 @@ func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) { } -// AddFixedPoint Adds two points, among which is one fixed point (the base), on a twisted edwards curve (eg jubjub) -// p1, base, ecurve are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) AddFixedPoint(api frontend.API, p1 *Point /*basex*/, x /*basey*/, y interface{}, curve EdCurve) *Point { - - // https://eprint.iacr.org/2008/013.pdf - - n11 := api.Mul(p1.X, y) - n12 := api.Mul(p1.Y, x) - n1 := api.Add(n11, n12) - - n21 := api.Mul(p1.Y, y) - n22 := api.Mul(p1.X, x) - an22 := api.Mul(n22, &curve.A) - n2 := api.Sub(n21, an22) - - d11 := api.Mul(curve.D, n11, n12) - d1 := api.Add(1, d11) - d2 := api.Sub(1, d11) - - p.X = api.DivUnchecked(n1, d1) - p.Y = api.DivUnchecked(n2, d2) - - return p -} - -// AddGeneric Adds two points on a twisted edwards curve (eg jubjub) +// Add Adds two points on a twisted edwards curve (eg jubjub) // p1, p2, c are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve -func (p *Point) AddGeneric(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { +func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point { // https://eprint.iacr.org/2008/013.pdf @@ -103,14 +85,12 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { u := api.Mul(p1.X, p1.Y) v := api.Mul(p1.X, p1.X) w := api.Mul(p1.Y, p1.Y) - z := api.Mul(v, w) n1 := api.Mul(2, u) av := api.Mul(v, &curve.A) n2 := api.Sub(w, av) - d := api.Mul(z, curve.D) - d1 := api.Add(1, d) - d2 := api.Sub(1, d) + d1 := api.Add(w, av) + d2 := api.Sub(2, d1) p.X = api.DivUnchecked(n1, d1) p.Y = api.DivUnchecked(n2, d2) @@ -118,55 +98,72 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point { return p } -// ScalarMulNonFixedBase computes the scalar multiplication of a point on a twisted Edwards curve +// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve // p1: base point (as snark point) // curve: parameters of the Edwards curve // scal: scalar as a SNARK constraint // Standard left to right double and add -func (p *Point) ScalarMulNonFixedBase(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { +func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point { // first unpack the scalar b := api.ToBinary(scalar) - res := Point{ - 0, - 1, + res := Point{} + tmp := Point{} + A := Point{} + B := Point{} + + A.Double(api, p1, curve) + B.Add(api, &A, p1, curve) + + n := len(b) - 1 + res.X = api.Lookup2(b[n], b[n-1], 0, A.X, p1.X, B.X) + res.Y = api.Lookup2(b[n], b[n-1], 1, A.Y, p1.Y, B.Y) + + for i := n - 2; i >= 1; i -= 2 { + res.Double(api, &res, curve). + Double(api, &res, curve) + tmp.X = api.Lookup2(b[i], b[i-1], 0, A.X, p1.X, B.X) + tmp.Y = api.Lookup2(b[i], b[i-1], 1, A.Y, p1.Y, B.Y) + res.Add(api, &res, &tmp, curve) } - for i := len(b) - 1; i >= 0; i-- { + if n%2 == 0 { res.Double(api, &res, curve) - tmp := Point{} - tmp.AddGeneric(api, &res, p1, curve) - res.X = api.Select(b[i], tmp.X, res.X) - res.Y = api.Select(b[i], tmp.Y, res.Y) + tmp.Add(api, &res, p1, curve) + res.X = api.Select(b[0], tmp.X, res.X) + res.Y = api.Select(b[0], tmp.Y, res.Y) } p.X = res.X p.Y = res.Y + return p } -// ScalarMulFixedBase computes the scalar multiplication of a point on a twisted Edwards curve -// x, y: coordinates of the base point -// curve: parameters of the Edwards curve -// scal: scalar as a SNARK constraint -// Standard left to right double and add -func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar frontend.Variable, curve EdCurve) *Point { +// DoubleBaseScalarMul computes s1*P1+s2*P2 +// where P1 and P2 are points on a twisted Edwards curve +// and s1, s2 scalars. +func (p *Point) DoubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve EdCurve) *Point { - // first unpack the scalar - b := api.ToBinary(scalar) + // first unpack the scalars + b1 := api.ToBinary(s1) + b2 := api.ToBinary(s2) - res := Point{ - 0, - 1, - } + res := Point{} + tmp := Point{} + sum := Point{} + sum.Add(api, p1, p2, curve) + + n := len(b1) + res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, sum.X) + res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, sum.Y) - for i := len(b) - 1; i >= 0; i-- { + for i := n - 2; i >= 0; i-- { res.Double(api, &res, curve) - tmp := Point{} - tmp.AddFixedPoint(api, &res, x, y, curve) - res.X = api.Select(b[i], tmp.X, res.X) - res.Y = api.Select(b[i], tmp.Y, res.Y) + tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, sum.X) + tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, sum.Y) + res.Add(api, &res, &tmp, curve) } p.X = res.X @@ -174,10 +171,3 @@ func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar fr return p } - -// Neg computes the negative of a point in SNARK coordinates -func (p *Point) Neg(api frontend.API, p1 *Point) *Point { - p.X = api.Neg(p1.X) - p.Y = p1.Y - return p -} diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go index b8f9d1959b..d9387e94fd 100644 --- a/std/algebra/twistededwards/point_test.go +++ b/std/algebra/twistededwards/point_test.go @@ -28,7 +28,6 @@ import ( tbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards" tbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/test" ) @@ -61,8 +60,8 @@ func TestIsOnCurve(t *testing.T) { t.Fatal(err) } - witness.P.X = (params.BaseX) - witness.P.Y = (params.BaseY) + witness.P.X = (params.Base.X) + witness.P.Y = (params.Base.Y) assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254)) @@ -80,7 +79,10 @@ func (circuit *add) Define(api frontend.API) error { return err } - res := circuit.P.AddFixedPoint(api, &circuit.P, params.BaseX, params.BaseY, params) + p := Point{} + p.X = params.Base.X + p.Y = params.Base.Y + res := circuit.P.Add(api, &circuit.P, &p, params) api.AssertIsEqual(res.X, circuit.E.X) api.AssertIsEqual(res.Y, circuit.E.Y) @@ -100,8 +102,8 @@ func TestAddFixedPoint(t *testing.T) { t.Fatal(err) } var base, point, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) point.Set(&base) r := big.NewInt(5) point.ScalarMul(&point, r) @@ -133,7 +135,7 @@ func (circuit *addGeneric) Define(api frontend.API) error { return err } - res := circuit.P1.AddGeneric(api, &circuit.P1, &circuit.P2, params) + res := circuit.P1.Add(api, &circuit.P1, &circuit.P2, params) api.AssertIsEqual(res.X, circuit.E.X) api.AssertIsEqual(res.Y, circuit.E.Y) @@ -157,8 +159,8 @@ func TestAddGeneric(t *testing.T) { switch id { case ecc.BN254: var op1, op2, expected tbn254.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -173,8 +175,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS12_381: var op1, op2, expected tbls12381.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -189,8 +191,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS12_377: var op1, op2, expected tbls12377.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -205,8 +207,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS24_315: var op1, op2, expected tbls24315.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -221,8 +223,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BW6_633: var op1, op2, expected tbw6633.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -237,8 +239,8 @@ func TestAddGeneric(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BW6_761: var op1, op2, expected tbw6761.PointAffine - op1.X.SetBigInt(¶ms.BaseX) - op1.Y.SetBigInt(¶ms.BaseY) + op1.X.SetBigInt(¶ms.Base.X) + op1.Y.SetBigInt(¶ms.Base.Y) op2.Set(&op1) r1 := big.NewInt(5) r2 := big.NewInt(12) @@ -299,8 +301,8 @@ func TestDouble(t *testing.T) { switch id { case ecc.BN254: var base, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -308,8 +310,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS12_381: var base, expected tbls12381.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -317,8 +319,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS12_377: var base, expected tbls12377.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -326,8 +328,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BLS24_315: var base, expected tbls24315.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -335,8 +337,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BW6_633: var base, expected tbw6633.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -344,8 +346,8 @@ func TestDouble(t *testing.T) { witness.E.Y = (expected.Y.String()) case ecc.BW6_761: var base, expected tbw6761.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Double(&base) witness.P.X = (base.X.String()) witness.P.Y = (base.Y.String()) @@ -375,8 +377,10 @@ func (circuit *scalarMulFixed) Define(api frontend.API) error { return err } - var resFixed Point - resFixed.ScalarMulFixedBase(api, params.BaseX, params.BaseY, circuit.S, params) + var resFixed, p Point + p.X = params.Base.X + p.Y = params.Base.Y + resFixed.ScalarMul(api, &p, circuit.S, params) api.AssertIsEqual(resFixed.X, circuit.E.X) api.AssertIsEqual(resFixed.Y, circuit.E.Y) @@ -391,8 +395,7 @@ func TestScalarMulFixed(t *testing.T) { var circuit, witness scalarMulFixed // generate witness data - //for _, id := range ecc.Implemented() { - for _, id := range []ecc.ID{ecc.BLS12_377} { + for _, id := range ecc.Implemented() { params, err := NewEdCurve(id) if err != nil { @@ -402,8 +405,8 @@ func TestScalarMulFixed(t *testing.T) { switch id { case ecc.BN254: var base, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -411,8 +414,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BLS12_381: var base, expected tbls12381.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -420,8 +423,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BLS12_377: var base, expected tbls12377.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -429,8 +432,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BLS24_315: var base, expected tbls24315.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -438,8 +441,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BW6_633: var base, expected tbw6633.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -447,8 +450,8 @@ func TestScalarMulFixed(t *testing.T) { witness.S = (r) case ecc.BW6_761: var base, expected tbw6761.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) r := big.NewInt(928323002) expected.ScalarMul(&base, r) witness.E.X = (expected.X.String()) @@ -457,7 +460,7 @@ func TestScalarMulFixed(t *testing.T) { } // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithBackends(backend.PLONK), test.WithCurves(id)) + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) } } @@ -475,7 +478,7 @@ func (circuit *scalarMulGeneric) Define(api frontend.API) error { return err } - resGeneric := circuit.P.ScalarMulNonFixedBase(api, &circuit.P, circuit.S, params) + resGeneric := circuit.P.ScalarMul(api, &circuit.P, circuit.S, params) api.AssertIsEqual(resGeneric.X, circuit.E.X) api.AssertIsEqual(resGeneric.Y, circuit.E.Y) @@ -490,28 +493,292 @@ func TestScalarMulGeneric(t *testing.T) { var circuit, witness scalarMulGeneric // generate witness data - params, err := NewEdCurve(ecc.BN254) + for _, id := range ecc.Implemented() { + + params, err := NewEdCurve(id) + if err != nil { + t.Fatal(err) + } + + switch id { + case ecc.BN254: + var base, point, expected tbn254.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BLS12_377: + var base, point, expected tbls12377.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BLS12_381: + var base, point, expected tbls12381.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BLS24_315: + var base, point, expected tbls24315.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BW6_761: + var base, point, expected tbw6761.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + case ecc.BW6_633: + var base, point, expected tbw6633.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s := big.NewInt(902) + point.ScalarMul(&base, s) // random point + r := big.NewInt(230928302) + expected.ScalarMul(&point, r) + + // populate witness + witness.P.X = (point.X.String()) + witness.P.Y = (point.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S = (r) + } + + // creates r1cs + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) + } +} + +// + +type doubleScalarMulGeneric struct { + P1, P2, E Point + S1, S2 frontend.Variable +} + +func (circuit *doubleScalarMulGeneric) Define(api frontend.API) error { + + // get edwards curve params + params, err := NewEdCurve(api.Curve()) if err != nil { - t.Fatal(err) + return err } - var base, point, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) - s := big.NewInt(902) - point.ScalarMul(&base, s) // random point - r := big.NewInt(230928302) - expected.ScalarMul(&point, r) - // populate witness - witness.P.X = (point.X.String()) - witness.P.Y = (point.Y.String()) - witness.E.X = (expected.X.String()) - witness.E.Y = (expected.Y.String()) - witness.S = (r) + resGeneric := circuit.P1.DoubleBaseScalarMul(api, &circuit.P1, &circuit.P2, circuit.S1, circuit.S2, params) - // creates r1cs - assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254)) + api.AssertIsEqual(resGeneric.X, circuit.E.X) + api.AssertIsEqual(resGeneric.Y, circuit.E.Y) + + return nil +} + +func TestDoubleScalarMulGeneric(t *testing.T) { + + assert := test.NewAssert(t) + + var circuit, witness doubleScalarMulGeneric + + // generate witness data + for _, id := range ecc.Implemented() { + + params, err := NewEdCurve(id) + if err != nil { + t.Fatal(err) + } + switch id { + case ecc.BN254: + var base, point1, point2, tmp, expected tbn254.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BLS12_377: + var base, point1, point2, tmp, expected tbls12377.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BLS12_381: + var base, point1, point2, tmp, expected tbls12381.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BLS24_315: + var base, point1, point2, tmp, expected tbls24315.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BW6_761: + var base, point1, point2, tmp, expected tbw6761.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + case ecc.BW6_633: + var base, point1, point2, tmp, expected tbw6633.PointAffine + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) + s1 := big.NewInt(902) + s2 := big.NewInt(891) + point1.ScalarMul(&base, s1) // random point + point2.ScalarMul(&base, s2) // random point + r1 := big.NewInt(230928303) + r2 := big.NewInt(2830309) + tmp.ScalarMul(&point1, r1) + expected.ScalarMul(&point2, r2). + Add(&expected, &tmp) + + // populate witness + witness.P1.X = (point1.X.String()) + witness.P1.Y = (point1.Y.String()) + witness.P2.X = (point2.X.String()) + witness.P2.Y = (point2.Y.String()) + witness.E.X = (expected.X.String()) + witness.E.Y = (expected.Y.String()) + witness.S1 = (r1) + witness.S2 = (r2) + } + + // creates r1cs + assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id)) + } } type neg struct { @@ -537,8 +804,8 @@ func TestNeg(t *testing.T) { t.Fatal(err) } var base, expected tbn254.PointAffine - base.X.SetBigInt(¶ms.BaseX) - base.Y.SetBigInt(¶ms.BaseY) + base.X.SetBigInt(¶ms.Base.X) + base.Y.SetBigInt(¶ms.Base.Y) expected.Neg(&base) // generate witness diff --git a/std/fiat-shamir/transcript.go b/std/fiat-shamir/transcript.go index 3491a548a5..97a35ca62a 100644 --- a/std/fiat-shamir/transcript.go +++ b/std/fiat-shamir/transcript.go @@ -91,8 +91,8 @@ func (t *Transcript) Bind(challengeID string, values []frontend.Variable) error // ComputeChallenge computes the challenge corresponding to the given name. // The resulting variable is: -// * H(name || previous_challenge || binded_values...) if the challenge is not the first one -// * H(name || binded_values... ) if it's is the first challenge +// * H(name ∥ previous_challenge ∥ binded_values...) if the challenge is not the first one +// * H(name ∥ binded_values... ) if it's is the first challenge func (t *Transcript) ComputeChallenge(challengeID string) (frontend.Variable, error) { challenge, ok := t.challenges[challengeID] diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go index c46b0273c3..c266b33a7e 100644 --- a/std/signature/eddsa/eddsa.go +++ b/std/signature/eddsa/eddsa.go @@ -58,37 +58,35 @@ func Verify(api frontend.API, sig Signature, msg frontend.Variable, pubKey Publi return err } hash.Write(data...) - //hramConstant := hash.Sum(data...) hramConstant := hash.Sum() - // lhs = [S]G - cofactor := pubKey.Curve.Cofactor.Uint64() - lhs := twistededwards.Point{} - lhs.ScalarMulFixedBase(api, pubKey.Curve.BaseX, pubKey.Curve.BaseY, sig.S, pubKey.Curve) - lhs.MustBeOnCurve(api, pubKey.Curve) + base := twistededwards.Point{} + base.X = pubKey.Curve.Base.X + base.Y = pubKey.Curve.Base.Y - // rhs = R+[H(R,A,M)]*A - rhs := twistededwards.Point{} - rhs.ScalarMulNonFixedBase(api, &pubKey.A, hramConstant, pubKey.Curve). - AddGeneric(api, &rhs, &sig.R, pubKey.Curve) - rhs.MustBeOnCurve(api, pubKey.Curve) + //[S]G-[H(R,A,M)]*A + cofactor := pubKey.Curve.Cofactor.Uint64() + Q := twistededwards.Point{} + _A := twistededwards.Point{} + _A.Neg(api, &pubKey.A) + Q.DoubleBaseScalarMul(api, &base, &_A, sig.S, hramConstant, pubKey.Curve) + Q.MustBeOnCurve(api, pubKey.Curve) - // lhs-rhs - rhs.Neg(api, &rhs).AddGeneric(api, &lhs, &rhs, pubKey.Curve) + //[S]G-[H(R,A,M)]*A-R + Q.Neg(api, &Q).Add(api, &Q, &sig.R, pubKey.Curve) - // [cofactor](lhs-rhs) + // [cofactor]*(lhs-rhs) switch cofactor { case 4: - rhs.Double(api, &rhs, pubKey.Curve). - Double(api, &rhs, pubKey.Curve) + Q.Double(api, &Q, pubKey.Curve). + Double(api, &Q, pubKey.Curve) case 8: - rhs.Double(api, &rhs, pubKey.Curve). - Double(api, &rhs, pubKey.Curve).Double(api, &rhs, pubKey.Curve) + Q.Double(api, &Q, pubKey.Curve). + Double(api, &Q, pubKey.Curve).Double(api, &Q, pubKey.Curve) } - //rhs.MustBeOnCurve(api, pubKey.Curve) - api.AssertIsEqual(rhs.X, 0) - api.AssertIsEqual(rhs.Y, 1) + api.AssertIsEqual(Q.X, 0) + api.AssertIsEqual(Q.Y, 1) return nil } diff --git a/test/kzg_srs.go b/test/kzg_srs.go index 7f96d5bb03..4e8225ce80 100644 --- a/test/kzg_srs.go +++ b/test/kzg_srs.go @@ -34,7 +34,7 @@ import ( const srsCachedSize = (1 << 14) + 3 // NewKZGSRS uses ccs nb variables and nb constraints to initialize a kzg srs -// for sizes < 2^15, returns a pre-computed cached SRS +// for sizes < 2¹⁵, returns a pre-computed cached SRS // // /!\ warning /!\: this method is here for convenience only: in production, a SRS generated through MPC should be used. func NewKZGSRS(ccs frontend.CompiledConstraintSystem) (kzg.SRS, error) {