diff --git a/CHANGELOG.md b/CHANGELOG.md
index 79984aa311..64d29e56cb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,11 +1,41 @@
+
+
+## [v0.6.4] - 2022-02-15
+
+### Build
+
+- update to gnark-crpto v0.6.1
+
+### Feat
+
+- Constraint system solvers (Groth16 and PlonK) now run in parallel
+
+### Fix
+
+- `api.DivUnchecked` with PlonK between 2 constants was incorrect
+
+### Perf
+
+- **EdDSA:** `std/algebra/twistededwards` takes ~2K less constraints (Groth16). Bandersnatch benefits from same improvments.
+
+
+### Pull Requests
+
+- Merge pull request [#259](https://github.com/consensys/gnark/issues/259) from ConsenSys/perf-parallel-solver
+- Merge pull request [#261](https://github.com/consensys/gnark/issues/261) from ConsenSys/feat/kzg_updated
+- Merge pull request [#257](https://github.com/consensys/gnark/issues/257) from ConsenSys/perf/EdDSA
+- Merge pull request [#253](https://github.com/consensys/gnark/issues/253) from ConsenSys/feat/fft_cosets
+
## [v0.6.3] - 2022-02-13
### Feat
+
- MiMC changes: api doesn't take a "seed" parameter. MiMC impl matches Ethereum one.
### Fix
+
- fixes [#255](https://github.com/consensys/gnark/issues/255) variable visibility inheritance regression
- counter was set with PLONK backend ID in R1CS
- R1CS Solver was incorrectly calling a "MulByCoeff" instead of "DivByCoeff" (no impact, coeff was always 1 or -1)
@@ -13,6 +43,7 @@
### Pull Requests
+
- Merge pull request [#256](https://github.com/consensys/gnark/issues/256) from ConsenSys/fix-bug-compile-visibility
- Merge pull request [#249](https://github.com/consensys/gnark/issues/249) from ConsenSys/perf-ccs-hint
- Merge pull request [#248](https://github.com/consensys/gnark/issues/248) from ConsenSys/perf-ccs-solver
@@ -24,9 +55,11 @@
## [v0.6.1] - 2022-01-28
### Build
+
- go version dependency bumped from 1.16 to 1.17
### Feat
+
- added witness.MarshalJSON and witness.MarshalBinary
- added `ccs.GetSchema()` - the schema of a circuit is required for witness json (de)serialization
- added `ccs.GetConstraints()` - returns a list of human-readable constraints
@@ -35,6 +68,7 @@
- addition of `Cmp` in the circuit API
### Refactor
+
- compiled.Visbility -> schema.Visibiility
- witness.WriteSequence -> schema.WriteSequence
- killed `ReadAndProve` and `ReadAndVerify` (plonk)
@@ -42,11 +76,13 @@
- remove embbed struct tag for frontend.Variable fields
### Docs
+
- **backend:** unify documentation for options
- **frontend:** unify docs for options
- **test:** unify documentation for options
### Pull Requests
+
- Merge pull request [#244](https://github.com/consensys/gnark/issues/244) from ConsenSys/plonk-human-readable
- Merge pull request [#237](https://github.com/consensys/gnark/issues/237) from ConsenSys/ccs-get-constraints
- Merge pull request [#233](https://github.com/consensys/gnark/issues/233) from ConsenSys/feat/api_cmp
diff --git a/debug_test.go b/debug_test.go
index 3569e3f493..51910507d2 100644
--- a/debug_test.go
+++ b/debug_test.go
@@ -47,16 +47,16 @@ func TestPrintln(t *testing.T) {
expected.WriteString("debug_test.go:27 26 42\n")
expected.WriteString("debug_test.go:29 bits 1\n")
expected.WriteString("debug_test.go:30 circuit {A: 2, B: 11}\n")
- expected.WriteString("debug_test.go:34 m \n")
+ expected.WriteString("debug_test.go:34 m .*\n")
{
trace, _ := getGroth16Trace(&circuit, &witness)
- assert.Equal(expected.String(), trace)
+ assert.Regexp(expected.String(), trace)
}
{
trace, _ := getPlonkTrace(&circuit, &witness)
- assert.Equal(expected.String(), trace)
+ assert.Regexp(expected.String(), trace)
}
}
diff --git a/examples/rollup/account.go b/examples/rollup/account.go
index f381a968fd..ad6b92b48c 100644
--- a/examples/rollup/account.go
+++ b/examples/rollup/account.go
@@ -25,7 +25,7 @@ import (
var (
// SizeAccount byte size of a serialized account (5*32bytes)
- // index || nonce || balance || pubkeyX || pubkeyY, each chunk is 32 bytes
+ // index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY, each chunk is 32 bytes
SizeAccount = 160
)
@@ -48,7 +48,7 @@ func (ac *Account) Reset() {
// Serialize serializes the account as a concatenation of 5 chunks of 256 bits
// one chunk per field (pubKey has 2 chunks), except index and nonce that are concatenated in a single 256 bits chunk
-// index || nonce || balance || pubkeyX || pubkeyY, each chunk is 256 bits
+// index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY, each chunk is 256 bits
func (ac *Account) Serialize() []byte {
//var buffer bytes.Buffer
diff --git a/examples/rollup/circuit.go b/examples/rollup/circuit.go
index 2907f7fc19..83f554c2b9 100644
--- a/examples/rollup/circuit.go
+++ b/examples/rollup/circuit.go
@@ -159,7 +159,7 @@ func (circuit *Circuit) Define(api frontend.API) error {
// verifySignatureTransfer ensures that the signature of the transfer is valid
func verifyTransferSignature(api frontend.API, t TransferConstraints, hFunc mimc.MiMC) error {
- // the signature is on h(nonce || amount || senderpubKey (x&y) || receiverPubkey(x&y))
+ // the signature is on h(nonce ∥ amount ∥ senderpubKey (x&y) ∥ receiverPubkey(x&y))
hFunc.Write(t.Nonce, t.Amount, t.SenderPubKey.A.X, t.SenderPubKey.A.Y, t.ReceiverPubKey.A.X, t.ReceiverPubKey.A.Y)
htransfer := hFunc.Sum()
diff --git a/examples/rollup/operator.go b/examples/rollup/operator.go
index a02dc0c07d..56c2304182 100644
--- a/examples/rollup/operator.go
+++ b/examples/rollup/operator.go
@@ -46,8 +46,8 @@ func NewQueue(batchSize int) Queue {
// Operator represents a rollup operator
type Operator struct {
- State []byte // list of accounts: index || nonce || balance || pubkeyX || pubkeyY, each chunk is 256 bits
- HashState []byte // Hashed version of the state, each chunk is 256bits: ... || H(index || nonce || balance || pubkeyX || pubkeyY)) || ...
+ State []byte // list of accounts: index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY, each chunk is 256 bits
+ HashState []byte // Hashed version of the state, each chunk is 256bits: ... ∥ H(index ∥ nonce ∥ balance ∥ pubkeyX ∥ pubkeyY)) ∥ ...
AccountMap map[string]uint64 // hashmap of all available accounts (the key is the account.pubkey.X), the value is the index of the account in the state
nbAccounts int // number of accounts managed by this operator
h hash.Hash // hash function used to build the Merkle Tree
@@ -178,7 +178,7 @@ func (o *Operator) updateState(t Transfer, numTransfer int) error {
o.witnesses.Transfers[numTransfer].Signature.S = t.signature.S[:]
// verifying the signature. The msg is the hash (o.h) of the transfer
- // nonce || amount || senderpubKey(x&y) || receiverPubkey(x&y)
+ // nonce ∥ amount ∥ senderpubKey(x&y) ∥ receiverPubkey(x&y)
resSig, err := t.Verify(o.h)
if err != nil {
return err
diff --git a/examples/rollup/transfer.go b/examples/rollup/transfer.go
index e12776cbf8..7507b357d3 100644
--- a/examples/rollup/transfer.go
+++ b/examples/rollup/transfer.go
@@ -52,7 +52,7 @@ func (t *Transfer) Sign(priv eddsa.PrivateKey, h hash.Hash) (eddsa.Signature, er
//var frNonce, msg fr.Element
var frNonce fr.Element
- // serializing transfer. The signature is on h(nonce || amount || senderpubKey (x&y) || receiverPubkey(x&y))
+ // serializing transfer. The signature is on h(nonce ∥ amount ∥ senderpubKey (x&y) ∥ receiverPubkey(x&y))
// (each pubkey consist of 2 chunks of 256bits)
frNonce.SetUint64(t.nonce)
b := frNonce.Bytes()
@@ -91,7 +91,7 @@ func (t *Transfer) Verify(h hash.Hash) (bool, error) {
var frNonce fr.Element
// serializing transfer. The msg to sign is
- // nonce || amount || senderpubKey(x&y) || receiverPubkey(x&y)
+ // nonce ∥ amount ∥ senderpubKey(x&y) ∥ receiverPubkey(x&y)
// (each pubkey consist of 2 chunks of 256bits)
frNonce.SetUint64(t.nonce)
b := frNonce.Bytes()
diff --git a/frontend/api.go b/frontend/api.go
index bc34ea6202..3e7fde87c0 100644
--- a/frontend/api.go
+++ b/frontend/api.go
@@ -101,7 +101,7 @@ type API interface {
// AssertIsDifferent fails if i1 == i2
AssertIsDifferent(i1, i2 Variable)
- // AssertIsBoolean fails if v != 0 || v != 1
+ // AssertIsBoolean fails if v != 0 ∥ v != 1
AssertIsBoolean(i1 Variable)
// AssertIsLessOrEqual fails if v > bound
diff --git a/frontend/compile.go b/frontend/compile.go
index 0fff033589..4565063d0f 100644
--- a/frontend/compile.go
+++ b/frontend/compile.go
@@ -42,8 +42,8 @@ type NewBuilder func(ecc.ID) (Builder, error)
// from the declarative code
//
// 3. finally, it converts that to a ConstraintSystem.
-// if zkpID == backend.GROTH16 --> R1CS
-// if zkpID == backend.PLONK --> SparseR1CS
+// if zkpID == backend.GROTH16 → R1CS
+// if zkpID == backend.PLONK → SparseR1CS
//
// initialCapacity is an optional parameter that reserves memory in slices
// it should be set to the estimated number of constraints in the circuit, if known.
diff --git a/frontend/cs/plonk/api.go b/frontend/cs/plonk/api.go
index 30e0d4c289..117dc7744a 100644
--- a/frontend/cs/plonk/api.go
+++ b/frontend/cs/plonk/api.go
@@ -120,7 +120,7 @@ func (system *sparseR1CS) DivUnchecked(i1, i2 frontend.Variable) frontend.Variab
q := system.CurveID.Info().Fr.Modulus()
return r.ModInverse(&r, q).
Mul(&l, &r).
- Mod(&l, q)
+ Mod(&r, q)
}
if system.IsConstant(i2) {
c := utils.FromInterface(i2)
diff --git a/frontend/cs/plonk/assertions.go b/frontend/cs/plonk/assertions.go
index 5f9cebdd33..efe58e8627 100644
--- a/frontend/cs/plonk/assertions.go
+++ b/frontend/cs/plonk/assertions.go
@@ -63,7 +63,7 @@ func (system *sparseR1CS) AssertIsDifferent(i1, i2 frontend.Variable) {
system.Inverse(system.Sub(i1, i2))
}
-// AssertIsBoolean fails if v != 0 || v != 1
+// AssertIsBoolean fails if v != 0 ∥ v != 1
func (system *sparseR1CS) AssertIsBoolean(i1 frontend.Variable) {
if system.IsConstant(i1) {
c := utils.FromInterface(i1)
@@ -125,7 +125,7 @@ func (system *sparseR1CS) mustBeLessOrEqVar(a compiled.Term, bound compiled.Term
l := system.Sub(1, t, aBits[i])
// note if bound[i] == 1, this constraint is (1 - ai) * ai == 0
- // --> this is a boolean constraint
+ // → this is a boolean constraint
// if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too
system.markBoolean(aBits[i].(compiled.Term)) // this does not create a constraint
@@ -172,7 +172,7 @@ func (system *sparseR1CS) mustBeLessOrEqCst(a compiled.Term, bound big.Int) {
}
p := make([]frontend.Variable, nbBits+1)
- // p[i] == 1 --> a[j] == c[j] for all j >= i
+ // p[i] == 1 → a[j] == c[j] for all j ⩾ i
p[nbBits] = 1
for i := nbBits - 1; i >= t; i-- {
diff --git a/frontend/cs/plonk/conversion.go b/frontend/cs/plonk/conversion.go
index b82fd4b7a0..1f2e8d564e 100644
--- a/frontend/cs/plonk/conversion.go
+++ b/frontend/cs/plonk/conversion.go
@@ -116,6 +116,9 @@ HINTLOOP:
}
res.MHints = shiftedMap
+ // build levels
+ res.Levels = buildLevels(res)
+
switch cs.CurveID {
case ecc.BLS12_377:
return bls12377r1cs.NewSparseR1CS(res, cs.Coeffs), nil
@@ -138,3 +141,103 @@ HINTLOOP:
func (cs *sparseR1CS) SetSchema(s *schema.Schema) {
cs.Schema = s
}
+
+func buildLevels(ccs compiled.SparseR1CS) [][]int {
+
+ b := levelBuilder{
+ mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire
+ nodeLevels: make([]int, len(ccs.Constraints)), // level of a node
+ mLevels: make(map[int]int), // level counts
+ ccs: ccs,
+ nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables,
+ }
+
+ // for each constraint, we're going to find its direct dependencies
+ // that is, wires (solved by previous constraints) on which it depends
+ // each of these dependencies is tagged with a level
+ // current constraint will be tagged with max(level) + 1
+ for cID, c := range ccs.Constraints {
+
+ b.nodeLevel = 0
+
+ b.processTerm(c.L, cID)
+ b.processTerm(c.R, cID)
+ b.processTerm(c.O, cID)
+
+ b.nodeLevels[cID] = b.nodeLevel
+ b.mLevels[b.nodeLevel]++
+
+ }
+
+ levels := make([][]int, len(b.mLevels))
+ for i := 0; i < len(levels); i++ {
+ // allocate memory
+ levels[i] = make([]int, 0, b.mLevels[i])
+ }
+
+ for n, l := range b.nodeLevels {
+ levels[l] = append(levels[l], n)
+ }
+
+ return levels
+}
+
+type levelBuilder struct {
+ ccs compiled.SparseR1CS
+ nbInputs int
+
+ mWireToNode map[int]int // at which node we resolved which wire
+ nodeLevels []int // level per node
+ mLevels map[int]int // number of constraint per level
+
+ nodeLevel int // current level
+}
+
+func (b *levelBuilder) processTerm(t compiled.Term, cID int) {
+ wID := t.WireID()
+ if wID < b.nbInputs {
+ // it's a input, we ignore it
+ return
+ }
+
+ // if we know a which constraint solves this wire, then it's a dependency
+ n, ok := b.mWireToNode[wID]
+ if ok {
+ if n != cID { // can happen with hints...
+ // we add a dependency, check if we need to increment our current level
+ if b.nodeLevels[n] >= b.nodeLevel {
+ b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it
+ }
+ }
+ return
+ }
+
+ // check if it's a hint and mark all the output wires
+ if h, ok := b.ccs.MHints[wID]; ok {
+
+ for _, in := range h.Inputs {
+ switch t := in.(type) {
+ case compiled.Variable:
+ for _, tt := range t.LinExp {
+ b.processTerm(tt, cID)
+ }
+ case compiled.LinearExpression:
+ for _, tt := range t {
+ b.processTerm(tt, cID)
+ }
+ case compiled.Term:
+ b.processTerm(t, cID)
+ }
+ }
+
+ for _, hwid := range h.Wires {
+ b.mWireToNode[hwid] = cID
+ }
+
+ return
+ }
+
+ // mark this wire solved by current node
+ b.mWireToNode[wID] = cID
+
+}
diff --git a/frontend/cs/plonk/sparse_r1cs.go b/frontend/cs/plonk/sparse_r1cs.go
index 959d28c55d..2e4beac548 100644
--- a/frontend/cs/plonk/sparse_r1cs.go
+++ b/frontend/cs/plonk/sparse_r1cs.go
@@ -130,7 +130,7 @@ func (system *sparseR1CS) NewSecretVariable(name string) frontend.Variable {
// reduces redundancy in linear expression
// It factorizes Variable that appears multiple times with != coeff Ids
-// To ensure the determinism in the compile process, Variables are stored as public||secret||internal||unset
+// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset
// for each visibility, the Variables are sorted from lowest ID to highest ID
func (system *sparseR1CS) reduce(l compiled.LinearExpression) compiled.LinearExpression {
@@ -268,7 +268,7 @@ func (system *sparseR1CS) CheckVariables() error {
sbb.WriteString(strconv.Itoa(cptHints))
sbb.WriteString(" unconstrained hints")
sbb.WriteByte('\n')
- // TODO we may add more debug info here --> idea, in NewHint, take the debug stack, and store in the hint map some
+ // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some
// debugInfo to find where a hint was declared (and not constrained)
}
return errors.New(sbb.String())
diff --git a/frontend/cs/r1cs/assertions.go b/frontend/cs/r1cs/assertions.go
index 22b3873a42..aba87440c0 100644
--- a/frontend/cs/r1cs/assertions.go
+++ b/frontend/cs/r1cs/assertions.go
@@ -41,7 +41,7 @@ func (system *r1CS) AssertIsDifferent(i1, i2 frontend.Variable) {
system.Inverse(system.Sub(i1, i2))
}
-// AssertIsBoolean adds an assertion in the constraint system (v == 0 || v == 1)
+// AssertIsBoolean adds an assertion in the constraint system (v == 0 ∥ v == 1)
func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) {
vars, _ := system.toVariables(i1)
@@ -69,7 +69,7 @@ func (system *r1CS) AssertIsBoolean(i1 frontend.Variable) {
system.addConstraint(newR1C(v, _v, o), debug)
}
-// AssertIsLessOrEqual adds assertion in constraint system (v <= bound)
+// AssertIsLessOrEqual adds assertion in constraint system (v ⩽ bound)
//
// bound can be a constant or a Variable
//
@@ -120,7 +120,7 @@ func (system *r1CS) mustBeLessOrEqVar(a, bound compiled.Variable) {
l = system.Sub(l, t, aBits[i])
// note if bound[i] == 1, this constraint is (1 - ai) * ai == 0
- // --> this is a boolean constraint
+ // → this is a boolean constraint
// if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too
system.markBoolean(aBits[i].(compiled.Variable)) // this does not create a constraint
@@ -158,7 +158,7 @@ func (system *r1CS) mustBeLessOrEqCst(a compiled.Variable, bound big.Int) {
}
p := make([]frontend.Variable, nbBits+1)
- // p[i] == 1 --> a[j] == c[j] for all j >= i
+ // p[i] == 1 → a[j] == c[j] for all j ⩾ i
p[nbBits] = system.constant(1)
for i := nbBits - 1; i >= t; i-- {
diff --git a/frontend/cs/r1cs/conversion.go b/frontend/cs/r1cs/conversion.go
index 5b042efc9b..c71d520e9d 100644
--- a/frontend/cs/r1cs/conversion.go
+++ b/frontend/cs/r1cs/conversion.go
@@ -122,13 +122,15 @@ HINTLOOP:
}
}
for i := 0; i < len(cs.DebugInfo); i++ {
-
for j := 0; j < len(res.DebugInfo[i].ToResolve); j++ {
_, vID, visibility := res.DebugInfo[i].ToResolve[j].Unpack()
res.DebugInfo[i].ToResolve[j].SetWireID(shiftVID(vID, visibility))
}
}
+ // build levels
+ res.Levels = buildLevels(res)
+
switch cs.CurveID {
case ecc.BLS12_377:
return bls12377r1cs.NewR1CS(res, cs.Coeffs), nil
@@ -150,3 +152,99 @@ HINTLOOP:
func (cs *r1CS) SetSchema(s *schema.Schema) {
cs.Schema = s
}
+
+func buildLevels(ccs compiled.R1CS) [][]int {
+
+ b := levelBuilder{
+ mWireToNode: make(map[int]int, ccs.NbInternalVariables), // at which node we resolved which wire
+ nodeLevels: make([]int, len(ccs.Constraints)), // level of a node
+ mLevels: make(map[int]int), // level counts
+ ccs: ccs,
+ nbInputs: ccs.NbPublicVariables + ccs.NbSecretVariables,
+ }
+
+ // for each constraint, we're going to find its direct dependencies
+ // that is, wires (solved by previous constraints) on which it depends
+ // each of these dependencies is tagged with a level
+ // current constraint will be tagged with max(level) + 1
+ for cID, c := range ccs.Constraints {
+
+ b.nodeLevel = 0
+
+ b.processLE(c.L.LinExp, cID)
+ b.processLE(c.R.LinExp, cID)
+ b.processLE(c.O.LinExp, cID)
+ b.nodeLevels[cID] = b.nodeLevel
+ b.mLevels[b.nodeLevel]++
+
+ }
+
+ levels := make([][]int, len(b.mLevels))
+ for i := 0; i < len(levels); i++ {
+ // allocate memory
+ levels[i] = make([]int, 0, b.mLevels[i])
+ }
+
+ for n, l := range b.nodeLevels {
+ levels[l] = append(levels[l], n)
+ }
+
+ return levels
+}
+
+type levelBuilder struct {
+ ccs compiled.R1CS
+ nbInputs int
+
+ mWireToNode map[int]int // at which node we resolved which wire
+ nodeLevels []int // level per node
+ mLevels map[int]int // number of constraint per level
+
+ nodeLevel int // current level
+}
+
+func (b *levelBuilder) processLE(l compiled.LinearExpression, cID int) {
+
+ for _, t := range l {
+ wID := t.WireID()
+ if wID < b.nbInputs {
+ // it's a input, we ignore it
+ continue
+ }
+
+ // if we know a which constraint solves this wire, then it's a dependency
+ n, ok := b.mWireToNode[wID]
+ if ok {
+ if n != cID { // can happen with hints...
+ // we add a dependency, check if we need to increment our current level
+ if b.nodeLevels[n] >= b.nodeLevel {
+ b.nodeLevel = b.nodeLevels[n] + 1 // we are at the next level at least since we depend on it
+ }
+ }
+ continue
+ }
+
+ // check if it's a hint and mark all the output wires
+ if h, ok := b.ccs.MHints[wID]; ok {
+
+ for _, in := range h.Inputs {
+ switch t := in.(type) {
+ case compiled.Variable:
+ b.processLE(t.LinExp, cID)
+ case compiled.LinearExpression:
+ b.processLE(t, cID)
+ case compiled.Term:
+ b.processLE(compiled.LinearExpression{t}, cID)
+ }
+ }
+
+ for _, hwid := range h.Wires {
+ b.mWireToNode[hwid] = cID
+ }
+ continue
+ }
+
+ // mark this wire solved by current node
+ b.mWireToNode[wID] = cID
+ }
+}
diff --git a/frontend/cs/r1cs/r1cs.go b/frontend/cs/r1cs/r1cs.go
index def4e35901..eb22b890b8 100644
--- a/frontend/cs/r1cs/r1cs.go
+++ b/frontend/cs/r1cs/r1cs.go
@@ -146,7 +146,7 @@ func (system *r1CS) one() compiled.Variable {
// reduces redundancy in linear expression
// It factorizes Variable that appears multiple times with != coeff Ids
-// To ensure the determinism in the compile process, Variables are stored as public||secret||internal||unset
+// To ensure the determinism in the compile process, Variables are stored as public∥secret∥internal∥unset
// for each visibility, the Variables are sorted from lowest ID to highest ID
func (system *r1CS) reduce(l compiled.Variable) compiled.Variable {
// ensure our linear expression is sorted, by visibility and by Variable ID
@@ -311,7 +311,7 @@ func (system *r1CS) CheckVariables() error {
sbb.WriteString(strconv.Itoa(cptHints))
sbb.WriteString(" unconstrained hints")
sbb.WriteByte('\n')
- // TODO we may add more debug info here --> idea, in NewHint, take the debug stack, and store in the hint map some
+ // TODO we may add more debug info here → idea, in NewHint, take the debug stack, and store in the hint map some
// debugInfo to find where a hint was declared (and not constrained)
}
return errors.New(sbb.String())
diff --git a/go.mod b/go.mod
index 8ae200dc61..deb987101e 100644
--- a/go.mod
+++ b/go.mod
@@ -3,8 +3,8 @@ module github.com/consensys/gnark
go 1.17
require (
- github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a
- github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969
+ github.com/consensys/bavard v0.1.9
+ github.com/consensys/gnark-crypto v0.6.1
github.com/fxamacker/cbor/v2 v2.2.0
github.com/leanovate/gopter v0.2.9
github.com/stretchr/testify v1.7.0
diff --git a/go.sum b/go.sum
index 6b59c42b75..dca7119dbb 100644
--- a/go.sum
+++ b/go.sum
@@ -1,7 +1,7 @@
-github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a h1:AEpwbXTjBGKoqxuQ6QAcBMEuK0+PtajQj0wJkhTnSd0=
-github.com/consensys/bavard v0.1.8-0.20210915155054-088da2f7f54a/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI=
-github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969 h1:SPKwScbSTdl2p+QvJulHumGa2+5FO6RPh857TCPxda0=
-github.com/consensys/gnark-crypto v0.6.1-0.20220203133229-a70fdc7da969/go.mod h1:PicAZJP763+7N9LZFfj+MquTXq98pwjD6l8Ry8WdHSU=
+github.com/consensys/bavard v0.1.9 h1:t9wg3/7Ko73yE+eKcavgMYcPMO1hinadJGlbSCdXTiM=
+github.com/consensys/bavard v0.1.9/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI=
+github.com/consensys/gnark-crypto v0.6.1 h1:MuWaJyWzSw8wQUOfiZOlRwYjfweIj8dM/u2NN6m0O04=
+github.com/consensys/gnark-crypto v0.6.1/go.mod h1:s41Bl3YIpNgu/zdvlSzf/xZkyV8MUmoBY96RmuB8x70=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
diff --git a/internal/backend/bls12-377/cs/r1cs.go b/internal/backend/bls12-377/cs/r1cs.go
index 680941e51e..1313d715e4 100644
--- a/internal/backend/bls12-377/cs/r1cs.go
+++ b/internal/backend/bls12-377/cs/r1cs.go
@@ -19,11 +19,12 @@ package cs
import (
"errors"
"fmt"
+ "github.com/fxamacker/cbor/v2"
"io"
"math/big"
+ "runtime"
"strings"
-
- "github.com/fxamacker/cbor/v2"
+ "sync"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/witness"
@@ -32,6 +33,7 @@ import (
"github.com/consensys/gnark/internal/backend/ioutils"
"github.com/consensys/gnark-crypto/ecc"
+ "math"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
@@ -70,11 +72,6 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) (
return make([]fr.Element, nbWires), err
}
- defer func() {
- // release memory
- solution.tmpHintsIO = nil
- }()
-
if len(witness) != int(cs.NbPublicVariables-1+cs.NbSecretVariables) { // - 1 for ONE_WIRE
return solution.values, fmt.Errorf("invalid witness size, got %d, expected %d = %d (public - ONE_WIRE) + %d (secret)", len(witness), int(cs.NbPublicVariables-1+cs.NbSecretVariables), cs.NbPublicVariables-1, cs.NbSecretVariables)
}
@@ -86,61 +83,146 @@ func (cs *R1CS) Solve(witness, a, b, c []fr.Element, opt backend.ProverConfig) (
solution.solved[0] = true // ONE_WIRE
solution.values[0].SetOne()
- copy(solution.values[1:], witness) // TODO factorize
+ copy(solution.values[1:], witness)
for i := 0; i < len(witness); i++ {
solution.solved[i+1] = true
}
// keep track of the number of wire instantiations we do, for a sanity check to ensure
// we instantiated all wires
- solution.nbSolved += len(witness) + 1
+ solution.nbSolved += uint64(len(witness) + 1)
// now that we know all inputs are set, defer log printing once all solution.values are computed
// (or sooner, if a constraint is not satisfied)
defer solution.printLogs(opt.LoggerOut, cs.Logs)
- // check if there is an inconsistant constraint
- var check fr.Element
- var solved bool
+ if err := cs.parallelSolve(a, b, c, &solution); err != nil {
+ return solution.values, err
+ }
+
+ // sanity check; ensure all wires are marked as "instantiated"
+ if !solution.isValid() {
+ panic("solver didn't instantiate all wires")
+ }
+
+ return solution.values, nil
+}
+
+func (cs *R1CS) parallelSolve(a, b, c []fr.Element, solution *solution) error {
+ // minWorkPerCPU is the minimum target number of constraint a task should hold
+ // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed
+ // sequentially without sync.
+ const minWorkPerCPU = 50.0
+ // cs.Levels has a list of levels, where all constraints in a level l(n) are independent
+ // and may only have dependencies on previous levels
// for each constraint
// we are guaranteed that each R1C contains at most one unsolved wire
// first we solve the unsolved wire (if any)
// then we check that the constraint is valid
// if a[i] * b[i] != c[i]; it means the constraint is not satisfied
- for i := 0; i < len(cs.Constraints); i++ {
- // solve the constraint, this will compute the missing wire of the gate
- solved, a[i], b[i], c[i], err = cs.solveConstraint(cs.Constraints[i], &solution)
- if err != nil {
- if dID, ok := cs.MDebug[i]; ok {
- debugInfoStr := solution.logValue(cs.DebugInfo[dID])
- return solution.values, fmt.Errorf("%w: %s", err, debugInfoStr)
+
+ var wg sync.WaitGroup
+ chTasks := make(chan []int, runtime.NumCPU())
+ chError := make(chan error, runtime.NumCPU())
+
+ // start a worker pool
+ // each worker wait on chTasks
+ // a task is a slice of constraint indexes to be solved
+ for i := 0; i < runtime.NumCPU(); i++ {
+ go func() {
+ for t := range chTasks {
+ for _, i := range t {
+ // for each constraint in the task, solve it.
+ if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil {
+ if err == errUnsatisfiedConstraint {
+ if dID, ok := cs.MDebug[int(i)]; ok {
+ err = errors.New(solution.logValue(cs.DebugInfo[dID]))
+ } else {
+ err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String())
+ }
+ }
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ wg.Done()
+ return
+ }
+ }
+ wg.Done()
}
- return solution.values, err
- }
+ }()
+ }
+
+ // clean up pool go routines
+ defer func() {
+ close(chTasks)
+ close(chError)
+ }()
- if solved {
- // a[i] * b[i] == c[i], since we just computed it.
+ // for each level, we push the tasks
+ for _, level := range cs.Levels {
+
+ // max CPU to use
+ maxCPU := float64(len(level)) / minWorkPerCPU
+
+ if maxCPU <= 1.0 {
+ // we do it sequentially
+ for _, i := range level {
+ if err := cs.solveConstraint(cs.Constraints[i], solution, &a[i], &b[i], &c[i]); err != nil {
+ if err == errUnsatisfiedConstraint {
+ if dID, ok := cs.MDebug[int(i)]; ok {
+ err = errors.New(solution.logValue(cs.DebugInfo[dID]))
+ } else {
+ err = fmt.Errorf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String())
+ }
+ }
+ return fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ }
+ }
continue
}
- // ensure a[i] * b[i] == c[i]
- check.Mul(&a[i], &b[i])
- if !check.Equal(&c[i]) {
- errMsg := fmt.Sprintf("%s ⋅ %s != %s", a[i].String(), b[i].String(), c[i].String())
- if dID, ok := cs.MDebug[i]; ok {
- errMsg = solution.logValue(cs.DebugInfo[dID])
+ // number of tasks for this level is set to num cpus
+ // but if we don't have enough work for all our CPUS, it can be lower.
+ nbTasks := runtime.NumCPU()
+ maxTasks := int(math.Ceil(maxCPU))
+ if nbTasks > maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
}
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
}
- }
- // sanity check; ensure all wires are marked as "instantiated"
- if !solution.isValid() {
- panic("solver didn't instantiate all wires")
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
}
- return solution.values, nil
+ return nil
}
// IsSolved returns nil if given witness solves the R1CS and error otherwise
@@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) {
// returns false, nil if there was no wire to solve
// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that
// the constraint is satisfied later.
-func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) {
+func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error {
// the index of the non zero entry shows if L, R or O has an uninstantiated wire
// the content is the ID of the wire non instantiated
@@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
return nil
}
- if err = processLExp(r.L.LinExp, &a, 1); err != nil {
- return
+ if err := processLExp(r.L.LinExp, a, 1); err != nil {
+ return err
}
- if err = processLExp(r.R.LinExp, &b, 2); err != nil {
- return
+ if err := processLExp(r.R.LinExp, b, 2); err != nil {
+ return err
}
- if err = processLExp(r.O.LinExp, &c, 3); err != nil {
- return
+ if err := processLExp(r.O.LinExp, c, 3); err != nil {
+ return err
}
if loc == 0 {
// there is nothing to solve, may happen if we have an assertion
// (ie a constraints that doesn't yield any output)
// or if we solved the unsolved wires with hint functions
- return
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
+ return nil
}
// we compute the wire value and instantiate it
- solved = true
- vID := termToCompute.WireID()
+ wID := termToCompute.WireID()
// solver result
var wire fr.Element
@@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
switch loc {
case 1:
if !b.IsZero() {
- wire.Div(&c, &b).
- Sub(&wire, &a)
- a.Add(&a, &wire)
+ wire.Div(c, b).
+ Sub(&wire, a)
+ a.Add(a, &wire)
} else {
// we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 2:
if !a.IsZero() {
- wire.Div(&c, &a).
- Sub(&wire, &b)
- b.Add(&b, &wire)
+ wire.Div(c, a).
+ Sub(&wire, b)
+ b.Add(b, &wire)
} else {
- // we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 3:
- wire.Mul(&a, &b).
- Sub(&wire, &c)
+ wire.Mul(a, b).
+ Sub(&wire, c)
- c.Add(&c, &wire)
+ c.Add(c, &wire)
}
// wire is the term (coeff * value)
// but in the solution we want to store the value only
// note that in gnark frontend, coeff here is always 1 or -1
cs.divByCoeff(&wire, termToCompute)
- solution.set(vID, wire)
+ solution.set(wID, wire)
- return
+ return nil
}
// GetConstraints return a list of constraint formatted as L⋅R == O
diff --git a/internal/backend/bls12-377/cs/r1cs_sparse.go b/internal/backend/bls12-377/cs/r1cs_sparse.go
index e52a55459b..8c5d2c087b 100644
--- a/internal/backend/bls12-377/cs/r1cs_sparse.go
+++ b/internal/backend/bls12-377/cs/r1cs_sparse.go
@@ -21,9 +21,12 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/fxamacker/cbor/v2"
"io"
+ "math"
"math/big"
"os"
+ "runtime"
"strings"
+ "sync"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/witness"
@@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
return solution.values, err
}
- defer func() {
- // release memory
- solution.tmpHintsIO = nil
- }()
-
// solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs
copy(solution.values, witness)
for i := 0; i < len(witness); i++ {
@@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
// keep track of the number of wire instantiations we do, for a sanity check to ensure
// we instantiated all wires
- solution.nbSolved += len(witness)
+ solution.nbSolved += uint64(len(witness))
// defer log printing once all solution.values are computed
defer solution.printLogs(opt.LoggerOut, cs.Logs)
@@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
coefficientsNegInv[i].Neg(&coefficientsNegInv[i])
}
- // loop through the constraints to solve the variables
- for i := 0; i < len(cs.Constraints); i++ {
- if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil {
- return solution.values, fmt.Errorf("constraint %d: %w", i, err)
- }
- if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil {
- errMsg := err.Error()
- if dID, ok := cs.MDebug[i]; ok {
- errMsg = solution.logValue(cs.DebugInfo[dID])
- }
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
- }
+ if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil {
+ return solution.values, err
}
// sanity check; ensure all wires are marked as "instantiated"
@@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
}
+func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error {
+ // minWorkPerCPU is the minimum target number of constraint a task should hold
+ // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed
+ // sequentially without sync.
+ const minWorkPerCPU = 50.0
+
+ // cs.Levels has a list of levels, where all constraints in a level l(n) are independent
+ // and may only have dependencies on previous levels
+
+ var wg sync.WaitGroup
+ chTasks := make(chan []int, runtime.NumCPU())
+ chError := make(chan error, runtime.NumCPU())
+
+ // start a worker pool
+ // each worker wait on chTasks
+ // a task is a slice of constraint indexes to be solved
+ for i := 0; i < runtime.NumCPU(); i++ {
+ go func() {
+ for t := range chTasks {
+ for _, i := range t {
+ // for each constraint in the task, solve it.
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ wg.Done()
+ return
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ wg.Done()
+ return
+ }
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ // clean up pool go routines
+ defer func() {
+ close(chTasks)
+ close(chError)
+ }()
+
+ // for each level, we push the tasks
+ for _, level := range cs.Levels {
+
+ // max CPU to use
+ maxCPU := float64(len(level)) / minWorkPerCPU
+
+ if maxCPU <= 1.0 {
+ // we do it sequentially
+ for _, i := range level {
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ return fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ }
+ }
+ continue
+ }
+
+ // number of tasks for this level is set to num cpus
+ // but if we don't have enough work for all our CPUS, it can be lower.
+ nbTasks := runtime.NumCPU()
+ maxTasks := int(math.Ceil(maxCPU))
+ if nbTasks > maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
+ }
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
+ }
+
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
+ }
+
+ return nil
+}
+
// computeHints computes wires associated with a hint function, if any
// if there is no remaining wire to solve, returns -1
// else returns the wire position (L -> 0, R -> 1, O -> 2)
diff --git a/internal/backend/bls12-377/cs/solution.go b/internal/backend/bls12-377/cs/solution.go
index 2cf9bc935e..6911e72ab7 100644
--- a/internal/backend/bls12-377/cs/solution.go
+++ b/internal/backend/bls12-377/cs/solution.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"math/big"
+ "sync/atomic"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend/schema"
@@ -32,14 +33,15 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bls12-377"
)
+var errUnsatisfiedConstraint = errors.New("unsatisfied")
+
// solution represents elements needed to compute
// a solution to a R1CS or SparseR1CS
type solution struct {
values, coefficients []fr.Element
solved []bool
- nbSolved int
+ nbSolved uint64
mHintsFunctions map[hint.ID]hint.Function
- tmpHintsIO []*big.Int
}
func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) {
@@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E
coefficients: coefficients,
solved: make([]bool, nbWires),
mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)),
- tmpHintsIO: make([]*big.Int, 0),
}
for _, h := range hintFunctions {
@@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) {
}
s.values[id] = value
s.solved[id] = true
- s.nbSolved++
+ atomic.AddUint64(&s.nbSolved, 1)
+ // s.nbSolved++
}
func (s *solution) isValid() bool {
- return s.nbSolved == len(s.values)
+ return int(s.nbSolved) == len(s.values)
}
// computeTerm computes coef*variable
@@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error {
// tmp IO big int memory
nbInputs := len(h.Inputs)
nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs))
- m := len(s.tmpHintsIO)
- if m < (nbInputs + nbOutputs) {
- s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...)
- for i := m; i < len(s.tmpHintsIO); i++ {
- s.tmpHintsIO[i] = big.NewInt(0)
- }
+ // m := len(s.tmpHintsIO)
+ // if m < (nbInputs + nbOutputs) {
+ // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...)
+ // for i := m; i < len(s.tmpHintsIO); i++ {
+ // s.tmpHintsIO[i] = big.NewInt(0)
+ // }
+ // }
+ inputs := make([]*big.Int, nbInputs)
+ outputs := make([]*big.Int, nbOutputs)
+ for i := 0; i < nbInputs; i++ {
+ inputs[i] = big.NewInt(0)
+ }
+ for i := 0; i < nbOutputs; i++ {
+ outputs[i] = big.NewInt(0)
}
- inputs := s.tmpHintsIO[:nbInputs]
- outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs]
q := fr.Modulus()
diff --git a/internal/backend/bls12-377/groth16/marshal_test.go b/internal/backend/bls12-377/groth16/marshal_test.go
index 0d00cba75e..901d6f885f 100644
--- a/internal/backend/bls12-377/groth16/marshal_test.go
+++ b/internal/backend/bls12-377/groth16/marshal_test.go
@@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) {
var pk, pkCompressed, pkRaw ProvingKey
// create a random pk
- domain := fft.NewDomain(8, 1, true)
+ domain := fft.NewDomain(8)
pk.Domain = *domain
nbWires := 6
diff --git a/internal/backend/bls12-377/groth16/prove.go b/internal/backend/bls12-377/groth16/prove.go
index 16a1b7ac32..634c4b58e7 100644
--- a/internal/backend/bls12-377/groth16/prove.go
+++ b/internal/backend/bls12-377/groth16/prove.go
@@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
c = append(c, padding...)
n = len(a)
- domain.FFTInverse(a, fft.DIF, 0)
- domain.FFTInverse(b, fft.DIF, 0)
- domain.FFTInverse(c, fft.DIF, 0)
+ domain.FFTInverse(a, fft.DIF)
+ domain.FFTInverse(b, fft.DIF)
+ domain.FFTInverse(c, fft.DIF)
- domain.FFT(a, fft.DIT, 1)
- domain.FFT(b, fft.DIT, 1)
- domain.FFT(c, fft.DIT, 1)
+ domain.FFT(a, fft.DIT, true)
+ domain.FFT(b, fft.DIT, true)
+ domain.FFT(c, fft.DIT, true)
- var minusTwoInv fr.Element
- minusTwoInv.SetUint64(2)
- minusTwoInv.Neg(&minusTwoInv).
- Inverse(&minusTwoInv)
+ var den, one fr.Element
+ one.SetOne()
+ den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
+ den.Sub(&den, &one).Inverse(&den)
// h = ifft_coset(ca o cb - cc)
// reusing a to avoid unecessary memalloc
@@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
for i := start; i < end; i++ {
a[i].Mul(&a[i], &b[i]).
Sub(&a[i], &c[i]).
- Mul(&a[i], &minusTwoInv)
+ Mul(&a[i], &den)
}
})
// ifft_coset
- domain.FFTInverse(a, fft.DIF, 1)
+ domain.FFTInverse(a, fft.DIF, true)
utils.Parallelize(len(a), func(start, end int) {
for i := start; i < end; i++ {
diff --git a/internal/backend/bls12-377/groth16/setup.go b/internal/backend/bls12-377/groth16/setup.go
index 8db0c6ac8c..21fe78713c 100644
--- a/internal/backend/bls12-377/groth16/setup.go
+++ b/internal/backend/bls12-377/groth16/setup.go
@@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error {
nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables
// Setting group for fft
- domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true)
+ domain := fft.NewDomain(uint64(len(r1cs.Constraints)))
// samples toxic waste
toxicWaste, err := sampleToxicWaste()
@@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error {
nbConstraints := len(r1cs.Constraints)
// Setting group for fft
- domain := fft.NewDomain(uint64(nbConstraints), 1, true)
+ domain := fft.NewDomain(uint64(nbConstraints))
// count number of infinity points we would have had we a normal setup
// in pk.G1.A, pk.G1.B, and pk.G2.B
diff --git a/internal/backend/bls12-377/plonk/marshal.go b/internal/backend/bls12-377/plonk/marshal.go
index 411a2be9e7..2bb46eabd3 100644
--- a/internal/backend/bls12-377/plonk/marshal.go
+++ b/internal/backend/bls12-377/plonk/marshal.go
@@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
}
// fft domains
- n2, err := pk.DomainNum.WriteTo(w)
+ n2, err := pk.Domain[0].WriteTo(w)
if err != nil {
return
}
n += n2
- n2, err = pk.DomainH.WriteTo(w)
+ n2, err = pk.Domain[1].WriteTo(w)
if err != nil {
return
}
n += n2
- // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality)
- if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) {
+ // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality)
+ if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) {
return n, errors.New("invalid permutation size, expected 3*domain cardinality")
}
@@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
([]fr.Element)(pk.Qo),
([]fr.Element)(pk.CQk),
([]fr.Element)(pk.LQk),
- ([]fr.Element)(pk.LS1),
- ([]fr.Element)(pk.LS2),
- ([]fr.Element)(pk.LS3),
- ([]fr.Element)(pk.CS1),
- ([]fr.Element)(pk.CS2),
- ([]fr.Element)(pk.CS3),
+ ([]fr.Element)(pk.S1Canonical),
+ ([]fr.Element)(pk.S2Canonical),
+ ([]fr.Element)(pk.S3Canonical),
pk.Permutation,
}
@@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
- n2, err := pk.DomainNum.ReadFrom(r)
+ n2, err := pk.Domain[0].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- n2, err = pk.DomainH.ReadFrom(r)
+ n2, err = pk.Domain[1].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
dec := curve.NewDecoder(r)
toDecode := []interface{}{
@@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
(*[]fr.Element)(&pk.Qo),
(*[]fr.Element)(&pk.CQk),
(*[]fr.Element)(&pk.LQk),
- (*[]fr.Element)(&pk.LS1),
- (*[]fr.Element)(&pk.LS2),
- (*[]fr.Element)(&pk.LS3),
- (*[]fr.Element)(&pk.CS1),
- (*[]fr.Element)(&pk.CS2),
- (*[]fr.Element)(&pk.CS3),
+ (*[]fr.Element)(&pk.S1Canonical),
+ (*[]fr.Element)(&pk.S2Canonical),
+ (*[]fr.Element)(&pk.S3Canonical),
&pk.Permutation,
}
@@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) {
&vk.SizeInv,
&vk.Generator,
vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
@@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) {
&vk.SizeInv,
&vk.Generator,
&vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
diff --git a/internal/backend/bls12-377/plonk/marshal_test.go b/internal/backend/bls12-377/plonk/marshal_test.go
index 04933096c5..f4bea8c379 100644
--- a/internal/backend/bls12-377/plonk/marshal_test.go
+++ b/internal/backend/bls12-377/plonk/marshal_test.go
@@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
@@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) {
// random pk
var pk ProvingKey
pk.Vk = &vk
- pk.DomainNum = *fft.NewDomain(42, 3, false)
- pk.DomainH = *fft.NewDomain(4*42, 1, false)
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Domain[0] = *fft.NewDomain(42)
+ pk.Domain[1] = *fft.NewDomain(4 * 42)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < 12; i++ {
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) {
pk.Qo[i].SetUint64(42)
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
pk.Permutation[0] = -12
pk.Permutation[len(pk.Permutation)-1] = 8888
@@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
diff --git a/internal/backend/bls12-377/plonk/prove.go b/internal/backend/bls12-377/plonk/prove.go
index 7c8cb49ce7..74f54f4897 100644
--- a/internal/backend/bls12-377/plonk/prove.go
+++ b/internal/backend/bls12-377/plonk/prove.go
@@ -27,8 +27,6 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bls12-377"
- "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial"
-
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft"
@@ -43,6 +41,7 @@ import (
)
type Proof struct {
+
// Commitments to the solution vectors
LRO [3]kzg.Digest
@@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
hFunc := sha256.New()
// create a transcript manager to apply Fiat Shamir
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// result
proof := &Proof{}
@@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
}
// query l, r, o in Lagrange basis, not blinded
- ll, lr, lo := computeLRO(spr, pk, solution)
+ evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution)
// save ll, lr, lo, and make a copy of them in canonical basis.
// note that we allocate more capacity to reuse for blinded polynomials
- bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum)
+ blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ &pk.Domain[0])
if err != nil {
return nil, err
}
// compute kzg commitments of bcl, bcr and bco
- if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
return nil, err
}
+ // Fiat Shamir this
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return nil, err
+ }
+
// compute Z, the permutation accumulator polynomial, in canonical basis
// ll, lr, lo are NOT blinded
- var bz polynomial.Polynomial
+ var blindedZCanonical []fr.Element
chZ := make(chan error, 1)
var alpha fr.Element
go func() {
var err error
- bz, err = computeBlindedZ(ll, lr, lo, pk, gamma)
+ blindedZCanonical, err = computeBlindedZCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ pk, beta, gamma)
if err != nil {
chZ <- err
close(chZ)
@@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
// this may add additional arithmetic operations, but with smaller tasks
// we ensure that this commitment is well parallelized, without having a "unbalanced task" making
// the rest of the code wait too long.
- if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
+ if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
chZ <- err
close(chZ)
return
@@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
}()
// evaluation of the blinded versions of l, r, o and bz
- // on the odd cosets of (Z/8mZ)/(Z/mZ)
- var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial
+ // on the coset of the big domain
+ var (
+ evaluationBlindedLDomainBigBitReversed []fr.Element
+ evaluationBlindedRDomainBigBitReversed []fr.Element
+ evaluationBlindedODomainBigBitReversed []fr.Element
+ evaluationBlindedZDomainBigBitReversed []fr.Element
+ )
chEvalBL := make(chan struct{}, 1)
chEvalBR := make(chan struct{}, 1)
chEvalBO := make(chan struct{}, 1)
go func() {
- evalBL = evaluateHDomain(bcl, &pk.DomainH)
+ evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1])
close(chEvalBL)
}()
go func() {
- evalBR = evaluateHDomain(bcr, &pk.DomainH)
+ evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1])
close(chEvalBR)
}()
go func() {
- evalBO = evaluateHDomain(bco, &pk.DomainH)
+ evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1])
close(chEvalBO)
}()
- var constraintsInd, constraintsOrdering polynomial.Polynomial
+ var constraintsInd, constraintsOrdering []fr.Element
chConstraintInd := make(chan struct{}, 1)
go func() {
// compute qk in canonical basis, completed with the public inputs
- qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality)
- copy(qk, fullWitness[:spr.NbPublicVariables])
- copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
- pk.DomainNum.FFTInverse(qk, fft.DIF, 0)
- fft.BitReverse(qk)
-
- // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ)
- // --> uses the blinded version of l, r, o
+ qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality)
+ copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables])
+ copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
+ pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF)
+ fft.BitReverse(qkCompletedCanonical)
+
+ // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain
+ // → uses the blinded version of l, r, o
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk)
+ constraintsInd = evaluateConstraintsDomainBigBitReversed(
+ pk,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ qkCompletedCanonical)
close(chConstraintInd)
}()
@@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
chConstraintOrdering <- err
return
}
- evalBZ = evaluateHDomain(bz, &pk.DomainH)
- // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ)
+
+ evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1])
+ // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain
// evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o.
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma)
+ constraintsOrdering = evaluateOrderingDomainBigBitReversed(
+ pk,
+ evaluationBlindedZDomainBigBitReversed,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ beta,
+ gamma)
chConstraintOrdering <- nil
close(chConstraintOrdering)
}()
@@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
if err := <-chConstraintOrdering; err != nil {
return nil, err
}
+
<-chConstraintInd
+
// compute h in canonical form
- h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha)
+ h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha)
// compute kzg commitments of h1, h2 and h3
- if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
var wgZetaEvals sync.WaitGroup
wgZetaEvals.Add(3)
go func() {
- blzeta = bcl.Eval(&zeta)
+ blzeta = eval(blindedLCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- brzeta = bcr.Eval(&zeta)
+ brzeta = eval(blindedRCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- bozeta = bco.Eval(&zeta)
+ bozeta = eval(blindedOCanonical, zeta)
wgZetaEvals.Done()
}()
@@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
var zetaShifted fr.Element
zetaShifted.Mul(&zeta, &pk.Vk.Generator)
proof.ZShiftedOpening, err = kzg.Open(
- bz,
- &zetaShifted,
- &pk.DomainH,
+ blindedZCanonical,
+ zetaShifted,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
bzuzeta := proof.ZShiftedOpening.ClaimedValue
var (
- linearizedPolynomial polynomial.Polynomial
- linearizedPolynomialDigest curve.G1Affine
- errLPoly error
+ linearizedPolynomialCanonical []fr.Element
+ linearizedPolynomialDigest curve.G1Affine
+ errLPoly error
)
chLpoly := make(chan struct{}, 1)
go func() {
// compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k)
wgZetaEvals.Wait()
- linearizedPolynomial = computeLinearizedPolynomial(
+ linearizedPolynomialCanonical = computeLinearizedPolynomial(
blzeta,
brzeta,
bozeta,
alpha,
+ beta,
gamma,
zeta,
bzuzeta,
- bz,
+ blindedZCanonical,
pk,
)
// TODO this commitment is only necessary to derive the challenge, we should
// be able to avoid doing it and get the challenge in another way
- linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS)
+ linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS)
close(chLpoly)
}()
- // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3)
+ // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3)
var bZetaPowerm, bSize big.Int
- bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
+ bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
var zetaPowerm fr.Element
zetaPowerm.Exp(zeta, &bSize)
zetaPowerm.ToBigIntRegular(&bZetaPowerm)
foldedHDigest := proof.H[2]
foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm)
- foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3)
- foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2)
- foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3)
+ foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1)
- // foldedH = h1 + zeta*h2 + zeta**2*h3
+ // foldedH = h1 + ζ*h2 + ζ²*h3
foldedH := h3
utils.Parallelize(len(foldedH), func(start, end int) {
for i := start; i < end; i++ {
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3
- foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1)
- foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3
+ foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺²
+ foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1
}
})
@@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
// Batch open the first list of polynomials
proof.BatchedProof, err = kzg.BatchOpenSinglePoint(
- []polynomial.Polynomial{
+ [][]fr.Element{
foldedH,
- linearizedPolynomial,
- bcl,
- bcr,
- bco,
- pk.CS1,
- pk.CS2,
+ linearizedPolynomialCanonical,
+ blindedLCanonical,
+ blindedRCanonical,
+ blindedOCanonical,
+ pk.S1Canonical,
+ pk.S2Canonical,
},
[]kzg.Digest{
foldedHDigest,
@@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
pk.Vk.S[0],
pk.Vk.S[1],
},
- &zeta,
+ zeta,
hFunc,
- &pk.DomainH,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_377witness.Witn
}
+// eval evaluates c at p
+func eval(c []fr.Element, p fr.Element) fr.Element {
+ var r fr.Element
+ for i := len(c) - 1; i >= 0; i-- {
+ r.Mul(&r, &p).Add(&r, &c[i])
+ }
+ return r
+}
+
// fills proof.LRO with kzg commits of bcl, bcr and bco
-func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS
return err1
}
-func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err
return err1
}
-// computeBlindedLRO l, r, o in canonical basis with blinding
-func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) {
+// computeBlindedLROCanonical l, r, o in canonical basis with blinding
+func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) {
// note that bcl, bcr and bco reuses cl, cr and co memory
- cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
+ cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
chDone := make(chan error, 2)
go func() {
var err error
copy(cl, ll)
- domain.FFTInverse(cl, fft.DIF, 0)
+ domain.FFTInverse(cl, fft.DIF)
fft.BitReverse(cl)
bcl, err = blindPoly(cl, domain.Cardinality, 1)
chDone <- err
@@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
go func() {
var err error
copy(cr, lr)
- domain.FFTInverse(cr, fft.DIF, 0)
+ domain.FFTInverse(cr, fft.DIF)
fft.BitReverse(cr)
bcr, err = blindPoly(cr, domain.Cardinality, 1)
chDone <- err
}()
copy(co, lo)
- domain.FFTInverse(co, fft.DIF, 0)
+ domain.FFTInverse(co, fft.DIF)
fft.BitReverse(co)
if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil {
return
@@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
// * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1)
//
// WARNING:
-// pre condition degree(cp) <= rou + bo
-// pre condition cap(cp) >= int(totalDegree + 1)
-func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) {
+// pre condition degree(cp) ⩽ rou + bo
+// pre condition cap(cp) ⩾ int(totalDegree + 1)
+func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) {
// degree of the blinded polynomial is max(rou+order, cp.Degree)
totalDegree := rou + bo
@@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
res := cp[:totalDegree+1]
// random polynomial
- blindingPoly := make(polynomial.Polynomial, bo+1)
+ blindingPoly := make([]fr.Element, bo+1)
for i := uint64(0); i < bo+1; i++ {
if _, err := blindingPoly[i].SetRandom(); err != nil {
return nil, err
@@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
}
return res, nil
+
}
-// computeLRO extracts the solution l, r, o, and returns it in lagrange form.
+// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form.
// solution = [ public | secret | internal ]
-func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) {
+func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) {
- s := int(pk.DomainNum.Cardinality)
+ s := int(pk.Domain[0].Cardinality)
- var l, r, o polynomial.Polynomial
+ var l, r, o []fr.Element
l = make([]fr.Element, s)
r = make([]fr.Element, s)
o = make([]fr.Element, s)
@@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly
//
// * Z of degree n (domainNum.Cardinality)
// * Z(1)=1
-// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma)
-// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1
- u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
}
})
@@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele
Mul(&z[i], &gInv[i])
}
- pk.DomainNum.FFTInverse(z, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(z, fft.DIF)
fft.BitReverse(z)
- return blindPoly(z, pk.DomainNum.Cardinality, 2)
+ return blindPoly(z, pk.Domain[0].Cardinality, 2)
}
-// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on
-// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on
+// the big domain coset.
//
// * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets
// * qk is the completed version of qk, in canonical version
-func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
- var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial
+func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
+ var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element
var wg sync.WaitGroup
wg.Add(4)
go func() {
- evalQl = evaluateHDomain(pk.Ql, &pk.DomainH)
+ evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQr = evaluateHDomain(pk.Qr, &pk.DomainH)
+ evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQm = evaluateHDomain(pk.Qm, &pk.DomainH)
+ evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQo = evaluateHDomain(pk.Qo, &pk.DomainH)
+ evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1])
wg.Done()
}()
- evalQk = evaluateHDomain(qk, &pk.DomainH)
+ evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1])
wg.Wait()
- // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets
- // of (Z/8mZ)/(Z/mZ)
+
+ // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain
utils.Parallelize(len(evalQk), func(start, end int) {
var t0, t1 fr.Element
for i := start; i < end; i++ {
@@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.
return evalQk
}
-// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ)
-func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) {
-
- id = make([]fr.Element, pk.DomainH.Cardinality)
-
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
- var acc fr.Element
- acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start)))
- for i := start; i < end; i++ {
- id[i].Mul(&acc, &pk.DomainH.FinerGenerator)
- acc.Mul(&acc, &pk.DomainH.Generator)
- }
- })
-
- return id
-}
-
-// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
-// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
+// cosets of the big domain.
//
-// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets
-// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets
+// * z evaluation of the blinded permutation accumulator polynomial on odd cosets
+// * l, r, o evaluation of the blinded solution vectors on odd cosets
// * gamma randomization
-func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial {
+func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element {
- // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ)
- evalID := evalIDCosets(pk)
+ nbElmts := int(pk.Domain[1].Cardinality)
- // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ)
- var wg sync.WaitGroup
- wg.Add(2)
- var evalS1, evalS2, evalS3 polynomial.Polynomial
- go func() {
- evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH)
- wg.Done()
- }()
- go func() {
- evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH)
- wg.Done()
- }()
- evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH)
- wg.Wait()
+ // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ)
+ // on the big domain (coset).
+ res := make([]fr.Element, pk.Domain[1].Cardinality)
- // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ)
- res := evalS1 // re use allocated memory for evalS1
- s := uint64(len(evalZ))
- nn := uint64(64 - bits.TrailingZeros64(uint64(s)))
+ nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts)))
// needed to shift evalZ
- toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality
+ toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality)
+
+ var cosetShift, cosetShiftSquare fr.Element
+ cosetShift.Set(&pk.Vk.CosetShift)
+ cosetShiftSquare.Square(&pk.Vk.CosetShift)
+
+ utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) {
+
+ var evaluationIDBigDomain fr.Element
+ evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))).
+ Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen)
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
var f [3]fr.Element
var g [3]fr.Element
- var eID fr.Element
for i := start; i < end; i++ {
- // here we want to left shift evalZ by domainH/domainNum
- // however, evalZ is permuted
- // we take the non permuted index
- // compute the corresponding shift position
- // permute it again
- irev := bits.Reverse64(uint64(i)) >> nn
- eID = evalID[irev]
+ _i := bits.Reverse64(uint64(i)) >> nn
+ _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn
- shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn
- //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn
+ // in what follows gⁱ is understood as the generator of the chosen coset of domainBig
+ f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ
+ f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ
+ f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ
- f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma
- f[1].Mul(&eID, &pk.Vk.Shifter[0])
- f[2].Mul(&eID, &pk.Vk.Shifter[1])
- f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma
- f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma
+ g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ
+ g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ
+ g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ
- g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma
- g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma
- g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma
+ f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
+ g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ)
- f[0].Mul(&f[0], &f[1]).
- Mul(&f[0], &f[2]).
- Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma)
+ res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
- g[0].Mul(&g[0], &g[1]).
- Mul(&g[0], &g[2]).
- Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma)
-
- res[i].Sub(&g[0], &f[0])
+ evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g
}
})
return res
}
-// evaluateHDomain evaluates poly (canonical form) of degree m> nn
- // h[i].Mul(&h[i], &_u[irev%4])
- h[i].Mul(&h[i], &_u[irev%toShift])
+
+ _i := bits.Reverse64(i) >> nn
+
+ t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain
+ h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t).
+ Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]).
+ Mul(&h[_i], &alpha).
+ Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]).
+ Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio])
}
})
// put h in canonical form. h is of degree 3*(n+1)+2.
// using fft.DIT put h revert bit reverse
- pk.DomainH.FFTInverse(h, fft.DIT, 1)
- // fmt.Println("h:")
- // for i := 0; i < len(h); i++ {
- // fmt.Printf("%s\n", h[i].String())
- // }
- // fmt.Println("")
+ pk.Domain[1].FFTInverse(h, fft.DIT, true)
// degree of hi is n+2 because of the blinding
- h1 := h[:pk.DomainNum.Cardinality+2]
- h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)]
- h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)]
+ h1 := h[:pk.Domain[0].Cardinality+2]
+ h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)]
+ h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)]
return h1, h2, h3
@@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom
// computeLinearizedPolynomial computes the linearized polynomial in canonical basis.
// The purpose is to commit and open all in one ql, qr, qm, qo, qk.
-// * a, b, c are the evaluation of l, r, o at zeta
-// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z
+// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta
+// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z
// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk.
-func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial {
+//
+// The Linearized polynomial is:
+//
+// α²*L₁(ζ)*Z(X)
+// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ))
+// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X)
+func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element {
// first part: individual constraints
var rl fr.Element
- rl.Mul(&r, &l)
+ rl.Mul(&rZeta, &lZeta)
- // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
+ // second part:
+ // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)
var s1, s2 fr.Element
chS1 := make(chan struct{}, 1)
go func() {
- s1 = pk.CS1.Eval(&zeta)
- s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma)
+ s1 = eval(pk.S1Canonical, zeta) // s1(ζ)
+ s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ)
close(chS1)
}()
- t := pk.CS2.Eval(&zeta)
- t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma)
+ tmp := eval(pk.S2Canonical, zeta) // s2(ζ)
+ tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ)
<-chS1
- s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma)
- Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)
-
- s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma)
- t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)
- t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
- s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
-
- // third part L1(zeta)*alpha**2**Z
- var lagrange, one, den, frNbElmt fr.Element
+ s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ)
+
+ var uzeta, uuzeta fr.Element
+ uzeta.Mul(&zeta, &pk.Vk.CosetShift)
+ uuzeta.Mul(&uzeta, &pk.Vk.CosetShift)
+
+ s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ)
+ tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)
+ tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+ s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ // third part L₁(ζ)*α²*Z
+ var lagrangeZeta, one, den, frNbElmt fr.Element
one.SetOne()
- nbElmt := int64(pk.DomainNum.Cardinality)
- lagrange.Set(&zeta).
- Exp(lagrange, big.NewInt(nbElmt)).
- Sub(&lagrange, &one)
+ nbElmt := int64(pk.Domain[0].Cardinality)
+ lagrangeZeta.Set(&zeta).
+ Exp(lagrangeZeta, big.NewInt(nbElmt)).
+ Sub(&lagrangeZeta, &one)
frNbElmt.SetUint64(uint64(nbElmt))
den.Sub(&zeta, &one).
- Mul(&den, &frNbElmt).
Inverse(&den)
- lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1)
- Mul(&lagrange, &alpha).
- Mul(&lagrange, &alpha) // alpha**2*L_0
+ lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1)
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ)
- linPol := z.Clone()
+ linPol := make([]fr.Element, len(blindedZCanonical))
+ copy(linPol, blindedZCanonical)
utils.Parallelize(len(linPol), func(start, end int) {
+
var t0, t1 fr.Element
+
for i := start; i < end; i++ {
- linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
- if i < len(pk.CS3) {
- t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X)
+
+ linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ if i < len(pk.S3Canonical) {
+
+ t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X)
+
linPol[i].Add(&linPol[i], &t0)
}
- linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) )
+ linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ))
if i < len(pk.Qm) {
- t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm
- t0.Mul(&pk.Ql[i], &l)
+
+ t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X)
+ t0.Mul(&pk.Ql[i], &lZeta)
t0.Add(&t0, &t1)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X)
- t0.Mul(&pk.Qr[i], &r)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr
+ t0.Mul(&pk.Qr[i], &rZeta)
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X)
- t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i])
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk
+ t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i])
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X)
}
- t0.Mul(&z[i], &lagrange)
+ t0.Mul(&blindedZCanonical[i], &lagrangeZeta)
linPol[i].Add(&linPol[i], &t0) // finish the computation
}
})
diff --git a/internal/backend/bls12-377/plonk/setup.go b/internal/backend/bls12-377/plonk/setup.go
index 774e1be88a..259f76e6f1 100644
--- a/internal/backend/bls12-377/plonk/setup.go
+++ b/internal/backend/bls12-377/plonk/setup.go
@@ -21,7 +21,6 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/kzg"
- "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial"
"github.com/consensys/gnark/internal/backend/bls12-377/cs"
kzgg "github.com/consensys/gnark-crypto/kzg"
@@ -40,18 +39,21 @@ type ProvingKey struct {
Vk *VerifyingKey
// qr,ql,qm,qo (in canonical basis).
- Ql, Qr, Qm, Qo polynomial.Polynomial
+ Ql, Qr, Qm, Qo []fr.Element
// LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs.
// Storing LQk in Lagrange basis saves a fft...
- CQk, LQk polynomial.Polynomial
+ CQk, LQk []fr.Element
- // Domains used for the FFTs
- DomainNum, DomainH fft.Domain
+ // Domains used for the FFTs.
+ // Domain[0] = small Domain
+ // Domain[1] = big Domain
+ Domain [2]fft.Domain
+ // Domain[0], Domain[1] fft.Domain
- // s1, s2, s3 (L=Lagrange basis, C=canonical basis)
- LS1, LS2, LS3 polynomial.Polynomial
- CS1, CS2, CS3 polynomial.Polynomial
+ // Permutation polynomials
+ EvaluationPermutationBigDomainBitReversed []fr.Element
+ S1Canonical, S2Canonical, S3Canonical []fr.Element
// position -> permuted position (position in [0,3*sizeSystem-1])
Permutation []int64
@@ -69,13 +71,12 @@ type VerifyingKey struct {
Generator fr.Element
NbPublicVariables uint64
- // shifters for extending the permutation set: from s=<1,z,..,z**n-1>,
- // extended domain = s || shifter[0].s || shifter[1].s
- Shifter [2]fr.Element
-
// Commitment scheme that is used for an instantiation of PLONK
KZGSRS *kzg.SRS
+ // cosetShift generator of the coset on the small domain
+ CosetShift fr.Element
+
// S commitments to S1, S2, S3
S [3]kzg.Digest
@@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
// fft domains
sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints
- pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false)
+ pk.Domain[0] = *fft.NewDomain(sizeSystem)
+ pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen)
// h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space,
// the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases
// except when n<6.
if sizeSystem < 6 {
- pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(8 * sizeSystem)
} else {
- pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(4 * sizeSystem)
}
- vk.Size = pk.DomainNum.Cardinality
+ vk.Size = pk.Domain[0].Cardinality
vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv)
- vk.Generator.Set(&pk.DomainNum.Generator)
+ vk.Generator.Set(&pk.Domain[0].Generator)
vk.NbPublicVariables = uint64(spr.NbPublicVariables)
- // shifters
- vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator)
- vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator)
-
if err := pk.InitKZG(srs); err != nil {
return nil, nil, err
}
// public polynomials corresponding to constraints: [ placholders | constraints | assertions ]
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.Qm[i].SetZero()
pk.Qo[i].SetZero()
pk.CQk[i].SetZero()
- pk.LQk[i].SetZero() // --> to be completed by the prover
+ pk.LQk[i].SetZero() // → to be completed by the prover
}
offset := spr.NbPublicVariables
for i := 0; i < nbConstraints; i++ { // constraints
@@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K])
}
- pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(pk.Ql, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qr, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qm, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qo, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.CQk, fft.DIF)
fft.BitReverse(pk.Ql)
fft.BitReverse(pk.Qr)
fft.BitReverse(pk.Qm)
@@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
buildPermutation(spr, &pk)
// set s1, s2, s3
- computeLDE(&pk)
+ ccomputePermutationPolynomials(&pk)
// Commit to the polynomials to set up the verifying key
var err error
@@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil {
+ if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil {
+ if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil {
+ if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
@@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
//
// The permutation s is composed of cycles of maximum length such that
//
-// s. (l||r||o) = (l||r||o)
+// s. (l∥r∥o) = (l∥r∥o)
//
-//, where l||r||o is the concatenation of the indices of l, r, o in
+//, where l∥r∥o is the concatenation of the indices of l, r, o in
// ql.l+qr.r+qm.l.r+qo.O+k = 0.
//
// The permutation is encoded as a slice s of size 3*size(l), where the
-// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab
+// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab
// like this: for i in tab: tab[i] = tab[permutation[i]]
func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables
- sizeSolution := int(pk.DomainNum.Cardinality)
+ sizeSolution := int(pk.Domain[0].Cardinality)
// init permutation
pk.Permutation = make([]int64, 3*sizeSolution)
@@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
}
}
-// computeLDE computes the LDE (Lagrange basis) of the permutations
+// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations
// s1, s2, s3.
//
-// ex: z gen of Z/mZ, u gen of Z/8mZ, then
-//
// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 |
// |
// | Permutation
// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v
// \---------------/ \--------------------/ \------------------------/
// s1 (LDE) s2 (LDE) s3 (LDE)
-func computeLDE(pk *ProvingKey) {
+func ccomputePermutationPolynomials(pk *ProvingKey) {
- nbElmt := int(pk.DomainNum.Cardinality)
+ nbElmts := int(pk.Domain[0].Cardinality)
- // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1]
- sID := make([]fr.Element, 3*nbElmt)
- sID[0].SetOne()
- sID[nbElmt].Set(&pk.DomainNum.FinerGenerator)
- sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator)
-
- for i := 1; i < nbElmt; i++ {
- sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1
- sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
- }
+ // Lagrange form of ID
+ evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0])
// Lagrange form of S1, S2, S3
- pk.LS1 = make(polynomial.Polynomial, nbElmt)
- pk.LS2 = make(polynomial.Polynomial, nbElmt)
- pk.LS3 = make(polynomial.Polynomial, nbElmt)
- for i := 0; i < nbElmt; i++ {
- pk.LS1[i].Set(&sID[pk.Permutation[i]])
- pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]])
- pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]])
+ pk.S1Canonical = make([]fr.Element, nbElmts)
+ pk.S2Canonical = make([]fr.Element, nbElmts)
+ pk.S3Canonical = make([]fr.Element, nbElmts)
+ for i := 0; i < nbElmts; i++ {
+ pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]])
+ pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]])
+ pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]])
}
// Canonical form of S1, S2, S3
- pk.CS1 = make(polynomial.Polynomial, nbElmt)
- pk.CS2 = make(polynomial.Polynomial, nbElmt)
- pk.CS3 = make(polynomial.Polynomial, nbElmt)
- copy(pk.CS1, pk.LS1)
- copy(pk.CS2, pk.LS2)
- copy(pk.CS3, pk.LS3)
- pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0)
- fft.BitReverse(pk.CS1)
- fft.BitReverse(pk.CS2)
- fft.BitReverse(pk.CS3)
+ pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF)
+ fft.BitReverse(pk.S1Canonical)
+ fft.BitReverse(pk.S2Canonical)
+ fft.BitReverse(pk.S3Canonical)
+
+ // evaluation of permutation on the big domain
+ pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality)
+ copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true)
+
+}
+
+// getIDSmallDomain returns the Lagrange form of ID on the small domain
+func getIDSmallDomain(domain *fft.Domain) []fr.Element {
+
+ res := make([]fr.Element, 3*domain.Cardinality)
+
+ res[0].SetOne()
+ res[domain.Cardinality].Set(&domain.FrMultiplicativeGen)
+ res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen)
+
+ for i := uint64(1); i < domain.Cardinality; i++ {
+ res[i].Mul(&res[i-1], &domain.Generator)
+ res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator)
+ res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator)
+ }
+ return res
}
-// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS
+// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS
//
// This should be used after deserializing a ProvingKey
// as pk.Vk.KZG is NOT serialized
diff --git a/internal/backend/bls12-377/plonk/verify.go b/internal/backend/bls12-377/plonk/verify.go
index 17ecfd94ee..e827aef72b 100644
--- a/internal/backend/bls12-377/plonk/verify.go
+++ b/internal/backend/bls12-377/plonk/verify.go
@@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne
hFunc := sha256.New()
// transcript to derive the challenge
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// derive gamma from Comm(l), Comm(r), Comm(o)
gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2])
@@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne
return err
}
+ // derive beta from Comm(l), Comm(r), Comm(o)
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return err
+ }
+
// derive alpha from Comm(l), Comm(r), Comm(o), Com(Z)
alpha, err := deriveRandomness(&fs, "alpha", &proof.Z)
if err != nil {
@@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne
return err
}
- // evaluation of Z=X**m-1 at zeta
+ // evaluation of Z=Xⁿ⁻¹ at ζ
var zetaPowerM, zzeta fr.Element
var bExpo big.Int
one := fr.One()
@@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_377witness.Witne
zetaPowerM.Exp(zeta, &bExpo)
zzeta.Sub(&zetaPowerM, &one)
- // ccompute PI = Sum_i maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
}
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
}
- }
- // sanity check; ensure all wires are marked as "instantiated"
- if !solution.isValid() {
- panic("solver didn't instantiate all wires")
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
}
- return solution.values, nil
+ return nil
}
// IsSolved returns nil if given witness solves the R1CS and error otherwise
@@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) {
// returns false, nil if there was no wire to solve
// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that
// the constraint is satisfied later.
-func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) {
+func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error {
// the index of the non zero entry shows if L, R or O has an uninstantiated wire
// the content is the ID of the wire non instantiated
@@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
return nil
}
- if err = processLExp(r.L.LinExp, &a, 1); err != nil {
- return
+ if err := processLExp(r.L.LinExp, a, 1); err != nil {
+ return err
}
- if err = processLExp(r.R.LinExp, &b, 2); err != nil {
- return
+ if err := processLExp(r.R.LinExp, b, 2); err != nil {
+ return err
}
- if err = processLExp(r.O.LinExp, &c, 3); err != nil {
- return
+ if err := processLExp(r.O.LinExp, c, 3); err != nil {
+ return err
}
if loc == 0 {
// there is nothing to solve, may happen if we have an assertion
// (ie a constraints that doesn't yield any output)
// or if we solved the unsolved wires with hint functions
- return
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
+ return nil
}
// we compute the wire value and instantiate it
- solved = true
- vID := termToCompute.WireID()
+ wID := termToCompute.WireID()
// solver result
var wire fr.Element
@@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
switch loc {
case 1:
if !b.IsZero() {
- wire.Div(&c, &b).
- Sub(&wire, &a)
- a.Add(&a, &wire)
+ wire.Div(c, b).
+ Sub(&wire, a)
+ a.Add(a, &wire)
} else {
// we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 2:
if !a.IsZero() {
- wire.Div(&c, &a).
- Sub(&wire, &b)
- b.Add(&b, &wire)
+ wire.Div(c, a).
+ Sub(&wire, b)
+ b.Add(b, &wire)
} else {
- // we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 3:
- wire.Mul(&a, &b).
- Sub(&wire, &c)
+ wire.Mul(a, b).
+ Sub(&wire, c)
- c.Add(&c, &wire)
+ c.Add(c, &wire)
}
// wire is the term (coeff * value)
// but in the solution we want to store the value only
// note that in gnark frontend, coeff here is always 1 or -1
cs.divByCoeff(&wire, termToCompute)
- solution.set(vID, wire)
+ solution.set(wID, wire)
- return
+ return nil
}
// GetConstraints return a list of constraint formatted as L⋅R == O
diff --git a/internal/backend/bls12-381/cs/r1cs_sparse.go b/internal/backend/bls12-381/cs/r1cs_sparse.go
index 6e47e344fb..106e6eb0eb 100644
--- a/internal/backend/bls12-381/cs/r1cs_sparse.go
+++ b/internal/backend/bls12-381/cs/r1cs_sparse.go
@@ -21,9 +21,12 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/fxamacker/cbor/v2"
"io"
+ "math"
"math/big"
"os"
+ "runtime"
"strings"
+ "sync"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/witness"
@@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
return solution.values, err
}
- defer func() {
- // release memory
- solution.tmpHintsIO = nil
- }()
-
// solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs
copy(solution.values, witness)
for i := 0; i < len(witness); i++ {
@@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
// keep track of the number of wire instantiations we do, for a sanity check to ensure
// we instantiated all wires
- solution.nbSolved += len(witness)
+ solution.nbSolved += uint64(len(witness))
// defer log printing once all solution.values are computed
defer solution.printLogs(opt.LoggerOut, cs.Logs)
@@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
coefficientsNegInv[i].Neg(&coefficientsNegInv[i])
}
- // loop through the constraints to solve the variables
- for i := 0; i < len(cs.Constraints); i++ {
- if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil {
- return solution.values, fmt.Errorf("constraint %d: %w", i, err)
- }
- if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil {
- errMsg := err.Error()
- if dID, ok := cs.MDebug[i]; ok {
- errMsg = solution.logValue(cs.DebugInfo[dID])
- }
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
- }
+ if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil {
+ return solution.values, err
}
// sanity check; ensure all wires are marked as "instantiated"
@@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
}
+func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error {
+ // minWorkPerCPU is the minimum target number of constraint a task should hold
+ // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed
+ // sequentially without sync.
+ const minWorkPerCPU = 50.0
+
+ // cs.Levels has a list of levels, where all constraints in a level l(n) are independent
+ // and may only have dependencies on previous levels
+
+ var wg sync.WaitGroup
+ chTasks := make(chan []int, runtime.NumCPU())
+ chError := make(chan error, runtime.NumCPU())
+
+ // start a worker pool
+ // each worker wait on chTasks
+ // a task is a slice of constraint indexes to be solved
+ for i := 0; i < runtime.NumCPU(); i++ {
+ go func() {
+ for t := range chTasks {
+ for _, i := range t {
+ // for each constraint in the task, solve it.
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ wg.Done()
+ return
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ wg.Done()
+ return
+ }
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ // clean up pool go routines
+ defer func() {
+ close(chTasks)
+ close(chError)
+ }()
+
+ // for each level, we push the tasks
+ for _, level := range cs.Levels {
+
+ // max CPU to use
+ maxCPU := float64(len(level)) / minWorkPerCPU
+
+ if maxCPU <= 1.0 {
+ // we do it sequentially
+ for _, i := range level {
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ return fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ }
+ }
+ continue
+ }
+
+ // number of tasks for this level is set to num cpus
+ // but if we don't have enough work for all our CPUS, it can be lower.
+ nbTasks := runtime.NumCPU()
+ maxTasks := int(math.Ceil(maxCPU))
+ if nbTasks > maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
+ }
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
+ }
+
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
+ }
+
+ return nil
+}
+
// computeHints computes wires associated with a hint function, if any
// if there is no remaining wire to solve, returns -1
// else returns the wire position (L -> 0, R -> 1, O -> 2)
diff --git a/internal/backend/bls12-381/cs/solution.go b/internal/backend/bls12-381/cs/solution.go
index 7126962be9..9d630c8153 100644
--- a/internal/backend/bls12-381/cs/solution.go
+++ b/internal/backend/bls12-381/cs/solution.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"math/big"
+ "sync/atomic"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend/schema"
@@ -32,14 +33,15 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bls12-381"
)
+var errUnsatisfiedConstraint = errors.New("unsatisfied")
+
// solution represents elements needed to compute
// a solution to a R1CS or SparseR1CS
type solution struct {
values, coefficients []fr.Element
solved []bool
- nbSolved int
+ nbSolved uint64
mHintsFunctions map[hint.ID]hint.Function
- tmpHintsIO []*big.Int
}
func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) {
@@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E
coefficients: coefficients,
solved: make([]bool, nbWires),
mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)),
- tmpHintsIO: make([]*big.Int, 0),
}
for _, h := range hintFunctions {
@@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) {
}
s.values[id] = value
s.solved[id] = true
- s.nbSolved++
+ atomic.AddUint64(&s.nbSolved, 1)
+ // s.nbSolved++
}
func (s *solution) isValid() bool {
- return s.nbSolved == len(s.values)
+ return int(s.nbSolved) == len(s.values)
}
// computeTerm computes coef*variable
@@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error {
// tmp IO big int memory
nbInputs := len(h.Inputs)
nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs))
- m := len(s.tmpHintsIO)
- if m < (nbInputs + nbOutputs) {
- s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...)
- for i := m; i < len(s.tmpHintsIO); i++ {
- s.tmpHintsIO[i] = big.NewInt(0)
- }
+ // m := len(s.tmpHintsIO)
+ // if m < (nbInputs + nbOutputs) {
+ // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...)
+ // for i := m; i < len(s.tmpHintsIO); i++ {
+ // s.tmpHintsIO[i] = big.NewInt(0)
+ // }
+ // }
+ inputs := make([]*big.Int, nbInputs)
+ outputs := make([]*big.Int, nbOutputs)
+ for i := 0; i < nbInputs; i++ {
+ inputs[i] = big.NewInt(0)
+ }
+ for i := 0; i < nbOutputs; i++ {
+ outputs[i] = big.NewInt(0)
}
- inputs := s.tmpHintsIO[:nbInputs]
- outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs]
q := fr.Modulus()
diff --git a/internal/backend/bls12-381/groth16/marshal_test.go b/internal/backend/bls12-381/groth16/marshal_test.go
index 38c1e8b038..990e3ee6b1 100644
--- a/internal/backend/bls12-381/groth16/marshal_test.go
+++ b/internal/backend/bls12-381/groth16/marshal_test.go
@@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) {
var pk, pkCompressed, pkRaw ProvingKey
// create a random pk
- domain := fft.NewDomain(8, 1, true)
+ domain := fft.NewDomain(8)
pk.Domain = *domain
nbWires := 6
diff --git a/internal/backend/bls12-381/groth16/prove.go b/internal/backend/bls12-381/groth16/prove.go
index fe52aafb2d..a8e7767d4e 100644
--- a/internal/backend/bls12-381/groth16/prove.go
+++ b/internal/backend/bls12-381/groth16/prove.go
@@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
c = append(c, padding...)
n = len(a)
- domain.FFTInverse(a, fft.DIF, 0)
- domain.FFTInverse(b, fft.DIF, 0)
- domain.FFTInverse(c, fft.DIF, 0)
+ domain.FFTInverse(a, fft.DIF)
+ domain.FFTInverse(b, fft.DIF)
+ domain.FFTInverse(c, fft.DIF)
- domain.FFT(a, fft.DIT, 1)
- domain.FFT(b, fft.DIT, 1)
- domain.FFT(c, fft.DIT, 1)
+ domain.FFT(a, fft.DIT, true)
+ domain.FFT(b, fft.DIT, true)
+ domain.FFT(c, fft.DIT, true)
- var minusTwoInv fr.Element
- minusTwoInv.SetUint64(2)
- minusTwoInv.Neg(&minusTwoInv).
- Inverse(&minusTwoInv)
+ var den, one fr.Element
+ one.SetOne()
+ den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
+ den.Sub(&den, &one).Inverse(&den)
// h = ifft_coset(ca o cb - cc)
// reusing a to avoid unecessary memalloc
@@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
for i := start; i < end; i++ {
a[i].Mul(&a[i], &b[i]).
Sub(&a[i], &c[i]).
- Mul(&a[i], &minusTwoInv)
+ Mul(&a[i], &den)
}
})
// ifft_coset
- domain.FFTInverse(a, fft.DIF, 1)
+ domain.FFTInverse(a, fft.DIF, true)
utils.Parallelize(len(a), func(start, end int) {
for i := start; i < end; i++ {
diff --git a/internal/backend/bls12-381/groth16/setup.go b/internal/backend/bls12-381/groth16/setup.go
index 1daf2c2f69..d481f8ad0d 100644
--- a/internal/backend/bls12-381/groth16/setup.go
+++ b/internal/backend/bls12-381/groth16/setup.go
@@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error {
nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables
// Setting group for fft
- domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true)
+ domain := fft.NewDomain(uint64(len(r1cs.Constraints)))
// samples toxic waste
toxicWaste, err := sampleToxicWaste()
@@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error {
nbConstraints := len(r1cs.Constraints)
// Setting group for fft
- domain := fft.NewDomain(uint64(nbConstraints), 1, true)
+ domain := fft.NewDomain(uint64(nbConstraints))
// count number of infinity points we would have had we a normal setup
// in pk.G1.A, pk.G1.B, and pk.G2.B
diff --git a/internal/backend/bls12-381/plonk/marshal.go b/internal/backend/bls12-381/plonk/marshal.go
index 4e3945c054..d03d1be5fa 100644
--- a/internal/backend/bls12-381/plonk/marshal.go
+++ b/internal/backend/bls12-381/plonk/marshal.go
@@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
}
// fft domains
- n2, err := pk.DomainNum.WriteTo(w)
+ n2, err := pk.Domain[0].WriteTo(w)
if err != nil {
return
}
n += n2
- n2, err = pk.DomainH.WriteTo(w)
+ n2, err = pk.Domain[1].WriteTo(w)
if err != nil {
return
}
n += n2
- // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality)
- if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) {
+ // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality)
+ if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) {
return n, errors.New("invalid permutation size, expected 3*domain cardinality")
}
@@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
([]fr.Element)(pk.Qo),
([]fr.Element)(pk.CQk),
([]fr.Element)(pk.LQk),
- ([]fr.Element)(pk.LS1),
- ([]fr.Element)(pk.LS2),
- ([]fr.Element)(pk.LS3),
- ([]fr.Element)(pk.CS1),
- ([]fr.Element)(pk.CS2),
- ([]fr.Element)(pk.CS3),
+ ([]fr.Element)(pk.S1Canonical),
+ ([]fr.Element)(pk.S2Canonical),
+ ([]fr.Element)(pk.S3Canonical),
pk.Permutation,
}
@@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
- n2, err := pk.DomainNum.ReadFrom(r)
+ n2, err := pk.Domain[0].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- n2, err = pk.DomainH.ReadFrom(r)
+ n2, err = pk.Domain[1].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
dec := curve.NewDecoder(r)
toDecode := []interface{}{
@@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
(*[]fr.Element)(&pk.Qo),
(*[]fr.Element)(&pk.CQk),
(*[]fr.Element)(&pk.LQk),
- (*[]fr.Element)(&pk.LS1),
- (*[]fr.Element)(&pk.LS2),
- (*[]fr.Element)(&pk.LS3),
- (*[]fr.Element)(&pk.CS1),
- (*[]fr.Element)(&pk.CS2),
- (*[]fr.Element)(&pk.CS3),
+ (*[]fr.Element)(&pk.S1Canonical),
+ (*[]fr.Element)(&pk.S2Canonical),
+ (*[]fr.Element)(&pk.S3Canonical),
&pk.Permutation,
}
@@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) {
&vk.SizeInv,
&vk.Generator,
vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
@@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) {
&vk.SizeInv,
&vk.Generator,
&vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
diff --git a/internal/backend/bls12-381/plonk/marshal_test.go b/internal/backend/bls12-381/plonk/marshal_test.go
index 9076e2280c..e30d108e8b 100644
--- a/internal/backend/bls12-381/plonk/marshal_test.go
+++ b/internal/backend/bls12-381/plonk/marshal_test.go
@@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
@@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) {
// random pk
var pk ProvingKey
pk.Vk = &vk
- pk.DomainNum = *fft.NewDomain(42, 3, false)
- pk.DomainH = *fft.NewDomain(4*42, 1, false)
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Domain[0] = *fft.NewDomain(42)
+ pk.Domain[1] = *fft.NewDomain(4 * 42)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < 12; i++ {
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) {
pk.Qo[i].SetUint64(42)
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
pk.Permutation[0] = -12
pk.Permutation[len(pk.Permutation)-1] = 8888
@@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
diff --git a/internal/backend/bls12-381/plonk/prove.go b/internal/backend/bls12-381/plonk/prove.go
index 5f9dadb7bb..8d5a625122 100644
--- a/internal/backend/bls12-381/plonk/prove.go
+++ b/internal/backend/bls12-381/plonk/prove.go
@@ -27,8 +27,6 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bls12-381"
- "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial"
-
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft"
@@ -43,6 +41,7 @@ import (
)
type Proof struct {
+
// Commitments to the solution vectors
LRO [3]kzg.Digest
@@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
hFunc := sha256.New()
// create a transcript manager to apply Fiat Shamir
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// result
proof := &Proof{}
@@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
}
// query l, r, o in Lagrange basis, not blinded
- ll, lr, lo := computeLRO(spr, pk, solution)
+ evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution)
// save ll, lr, lo, and make a copy of them in canonical basis.
// note that we allocate more capacity to reuse for blinded polynomials
- bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum)
+ blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ &pk.Domain[0])
if err != nil {
return nil, err
}
// compute kzg commitments of bcl, bcr and bco
- if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
return nil, err
}
+ // Fiat Shamir this
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return nil, err
+ }
+
// compute Z, the permutation accumulator polynomial, in canonical basis
// ll, lr, lo are NOT blinded
- var bz polynomial.Polynomial
+ var blindedZCanonical []fr.Element
chZ := make(chan error, 1)
var alpha fr.Element
go func() {
var err error
- bz, err = computeBlindedZ(ll, lr, lo, pk, gamma)
+ blindedZCanonical, err = computeBlindedZCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ pk, beta, gamma)
if err != nil {
chZ <- err
close(chZ)
@@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
// this may add additional arithmetic operations, but with smaller tasks
// we ensure that this commitment is well parallelized, without having a "unbalanced task" making
// the rest of the code wait too long.
- if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
+ if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
chZ <- err
close(chZ)
return
@@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
}()
// evaluation of the blinded versions of l, r, o and bz
- // on the odd cosets of (Z/8mZ)/(Z/mZ)
- var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial
+ // on the coset of the big domain
+ var (
+ evaluationBlindedLDomainBigBitReversed []fr.Element
+ evaluationBlindedRDomainBigBitReversed []fr.Element
+ evaluationBlindedODomainBigBitReversed []fr.Element
+ evaluationBlindedZDomainBigBitReversed []fr.Element
+ )
chEvalBL := make(chan struct{}, 1)
chEvalBR := make(chan struct{}, 1)
chEvalBO := make(chan struct{}, 1)
go func() {
- evalBL = evaluateHDomain(bcl, &pk.DomainH)
+ evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1])
close(chEvalBL)
}()
go func() {
- evalBR = evaluateHDomain(bcr, &pk.DomainH)
+ evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1])
close(chEvalBR)
}()
go func() {
- evalBO = evaluateHDomain(bco, &pk.DomainH)
+ evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1])
close(chEvalBO)
}()
- var constraintsInd, constraintsOrdering polynomial.Polynomial
+ var constraintsInd, constraintsOrdering []fr.Element
chConstraintInd := make(chan struct{}, 1)
go func() {
// compute qk in canonical basis, completed with the public inputs
- qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality)
- copy(qk, fullWitness[:spr.NbPublicVariables])
- copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
- pk.DomainNum.FFTInverse(qk, fft.DIF, 0)
- fft.BitReverse(qk)
-
- // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ)
- // --> uses the blinded version of l, r, o
+ qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality)
+ copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables])
+ copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
+ pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF)
+ fft.BitReverse(qkCompletedCanonical)
+
+ // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain
+ // → uses the blinded version of l, r, o
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk)
+ constraintsInd = evaluateConstraintsDomainBigBitReversed(
+ pk,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ qkCompletedCanonical)
close(chConstraintInd)
}()
@@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
chConstraintOrdering <- err
return
}
- evalBZ = evaluateHDomain(bz, &pk.DomainH)
- // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ)
+
+ evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1])
+ // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain
// evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o.
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma)
+ constraintsOrdering = evaluateOrderingDomainBigBitReversed(
+ pk,
+ evaluationBlindedZDomainBigBitReversed,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ beta,
+ gamma)
chConstraintOrdering <- nil
close(chConstraintOrdering)
}()
@@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
if err := <-chConstraintOrdering; err != nil {
return nil, err
}
+
<-chConstraintInd
+
// compute h in canonical form
- h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha)
+ h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha)
// compute kzg commitments of h1, h2 and h3
- if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
var wgZetaEvals sync.WaitGroup
wgZetaEvals.Add(3)
go func() {
- blzeta = bcl.Eval(&zeta)
+ blzeta = eval(blindedLCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- brzeta = bcr.Eval(&zeta)
+ brzeta = eval(blindedRCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- bozeta = bco.Eval(&zeta)
+ bozeta = eval(blindedOCanonical, zeta)
wgZetaEvals.Done()
}()
@@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
var zetaShifted fr.Element
zetaShifted.Mul(&zeta, &pk.Vk.Generator)
proof.ZShiftedOpening, err = kzg.Open(
- bz,
- &zetaShifted,
- &pk.DomainH,
+ blindedZCanonical,
+ zetaShifted,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
bzuzeta := proof.ZShiftedOpening.ClaimedValue
var (
- linearizedPolynomial polynomial.Polynomial
- linearizedPolynomialDigest curve.G1Affine
- errLPoly error
+ linearizedPolynomialCanonical []fr.Element
+ linearizedPolynomialDigest curve.G1Affine
+ errLPoly error
)
chLpoly := make(chan struct{}, 1)
go func() {
// compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k)
wgZetaEvals.Wait()
- linearizedPolynomial = computeLinearizedPolynomial(
+ linearizedPolynomialCanonical = computeLinearizedPolynomial(
blzeta,
brzeta,
bozeta,
alpha,
+ beta,
gamma,
zeta,
bzuzeta,
- bz,
+ blindedZCanonical,
pk,
)
// TODO this commitment is only necessary to derive the challenge, we should
// be able to avoid doing it and get the challenge in another way
- linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS)
+ linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS)
close(chLpoly)
}()
- // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3)
+ // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3)
var bZetaPowerm, bSize big.Int
- bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
+ bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
var zetaPowerm fr.Element
zetaPowerm.Exp(zeta, &bSize)
zetaPowerm.ToBigIntRegular(&bZetaPowerm)
foldedHDigest := proof.H[2]
foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm)
- foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3)
- foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2)
- foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3)
+ foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1)
- // foldedH = h1 + zeta*h2 + zeta**2*h3
+ // foldedH = h1 + ζ*h2 + ζ²*h3
foldedH := h3
utils.Parallelize(len(foldedH), func(start, end int) {
for i := start; i < end; i++ {
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3
- foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1)
- foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3
+ foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺²
+ foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1
}
})
@@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
// Batch open the first list of polynomials
proof.BatchedProof, err = kzg.BatchOpenSinglePoint(
- []polynomial.Polynomial{
+ [][]fr.Element{
foldedH,
- linearizedPolynomial,
- bcl,
- bcr,
- bco,
- pk.CS1,
- pk.CS2,
+ linearizedPolynomialCanonical,
+ blindedLCanonical,
+ blindedRCanonical,
+ blindedOCanonical,
+ pk.S1Canonical,
+ pk.S2Canonical,
},
[]kzg.Digest{
foldedHDigest,
@@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
pk.Vk.S[0],
pk.Vk.S[1],
},
- &zeta,
+ zeta,
hFunc,
- &pk.DomainH,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls12_381witness.Witn
}
+// eval evaluates c at p
+func eval(c []fr.Element, p fr.Element) fr.Element {
+ var r fr.Element
+ for i := len(c) - 1; i >= 0; i-- {
+ r.Mul(&r, &p).Add(&r, &c[i])
+ }
+ return r
+}
+
// fills proof.LRO with kzg commits of bcl, bcr and bco
-func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS
return err1
}
-func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err
return err1
}
-// computeBlindedLRO l, r, o in canonical basis with blinding
-func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) {
+// computeBlindedLROCanonical l, r, o in canonical basis with blinding
+func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) {
// note that bcl, bcr and bco reuses cl, cr and co memory
- cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
+ cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
chDone := make(chan error, 2)
go func() {
var err error
copy(cl, ll)
- domain.FFTInverse(cl, fft.DIF, 0)
+ domain.FFTInverse(cl, fft.DIF)
fft.BitReverse(cl)
bcl, err = blindPoly(cl, domain.Cardinality, 1)
chDone <- err
@@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
go func() {
var err error
copy(cr, lr)
- domain.FFTInverse(cr, fft.DIF, 0)
+ domain.FFTInverse(cr, fft.DIF)
fft.BitReverse(cr)
bcr, err = blindPoly(cr, domain.Cardinality, 1)
chDone <- err
}()
copy(co, lo)
- domain.FFTInverse(co, fft.DIF, 0)
+ domain.FFTInverse(co, fft.DIF)
fft.BitReverse(co)
if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil {
return
@@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
// * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1)
//
// WARNING:
-// pre condition degree(cp) <= rou + bo
-// pre condition cap(cp) >= int(totalDegree + 1)
-func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) {
+// pre condition degree(cp) ⩽ rou + bo
+// pre condition cap(cp) ⩾ int(totalDegree + 1)
+func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) {
// degree of the blinded polynomial is max(rou+order, cp.Degree)
totalDegree := rou + bo
@@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
res := cp[:totalDegree+1]
// random polynomial
- blindingPoly := make(polynomial.Polynomial, bo+1)
+ blindingPoly := make([]fr.Element, bo+1)
for i := uint64(0); i < bo+1; i++ {
if _, err := blindingPoly[i].SetRandom(); err != nil {
return nil, err
@@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
}
return res, nil
+
}
-// computeLRO extracts the solution l, r, o, and returns it in lagrange form.
+// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form.
// solution = [ public | secret | internal ]
-func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) {
+func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) {
- s := int(pk.DomainNum.Cardinality)
+ s := int(pk.Domain[0].Cardinality)
- var l, r, o polynomial.Polynomial
+ var l, r, o []fr.Element
l = make([]fr.Element, s)
r = make([]fr.Element, s)
o = make([]fr.Element, s)
@@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly
//
// * Z of degree n (domainNum.Cardinality)
// * Z(1)=1
-// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma)
-// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1
- u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
}
})
@@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele
Mul(&z[i], &gInv[i])
}
- pk.DomainNum.FFTInverse(z, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(z, fft.DIF)
fft.BitReverse(z)
- return blindPoly(z, pk.DomainNum.Cardinality, 2)
+ return blindPoly(z, pk.Domain[0].Cardinality, 2)
}
-// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on
-// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on
+// the big domain coset.
//
// * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets
// * qk is the completed version of qk, in canonical version
-func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
- var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial
+func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
+ var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element
var wg sync.WaitGroup
wg.Add(4)
go func() {
- evalQl = evaluateHDomain(pk.Ql, &pk.DomainH)
+ evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQr = evaluateHDomain(pk.Qr, &pk.DomainH)
+ evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQm = evaluateHDomain(pk.Qm, &pk.DomainH)
+ evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQo = evaluateHDomain(pk.Qo, &pk.DomainH)
+ evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1])
wg.Done()
}()
- evalQk = evaluateHDomain(qk, &pk.DomainH)
+ evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1])
wg.Wait()
- // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets
- // of (Z/8mZ)/(Z/mZ)
+
+ // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain
utils.Parallelize(len(evalQk), func(start, end int) {
var t0, t1 fr.Element
for i := start; i < end; i++ {
@@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.
return evalQk
}
-// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ)
-func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) {
-
- id = make([]fr.Element, pk.DomainH.Cardinality)
-
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
- var acc fr.Element
- acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start)))
- for i := start; i < end; i++ {
- id[i].Mul(&acc, &pk.DomainH.FinerGenerator)
- acc.Mul(&acc, &pk.DomainH.Generator)
- }
- })
-
- return id
-}
-
-// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
-// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
+// cosets of the big domain.
//
-// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets
-// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets
+// * z evaluation of the blinded permutation accumulator polynomial on odd cosets
+// * l, r, o evaluation of the blinded solution vectors on odd cosets
// * gamma randomization
-func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial {
+func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element {
- // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ)
- evalID := evalIDCosets(pk)
+ nbElmts := int(pk.Domain[1].Cardinality)
- // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ)
- var wg sync.WaitGroup
- wg.Add(2)
- var evalS1, evalS2, evalS3 polynomial.Polynomial
- go func() {
- evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH)
- wg.Done()
- }()
- go func() {
- evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH)
- wg.Done()
- }()
- evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH)
- wg.Wait()
+ // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ)
+ // on the big domain (coset).
+ res := make([]fr.Element, pk.Domain[1].Cardinality)
- // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ)
- res := evalS1 // re use allocated memory for evalS1
- s := uint64(len(evalZ))
- nn := uint64(64 - bits.TrailingZeros64(uint64(s)))
+ nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts)))
// needed to shift evalZ
- toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality
+ toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality)
+
+ var cosetShift, cosetShiftSquare fr.Element
+ cosetShift.Set(&pk.Vk.CosetShift)
+ cosetShiftSquare.Square(&pk.Vk.CosetShift)
+
+ utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) {
+
+ var evaluationIDBigDomain fr.Element
+ evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))).
+ Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen)
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
var f [3]fr.Element
var g [3]fr.Element
- var eID fr.Element
for i := start; i < end; i++ {
- // here we want to left shift evalZ by domainH/domainNum
- // however, evalZ is permuted
- // we take the non permuted index
- // compute the corresponding shift position
- // permute it again
- irev := bits.Reverse64(uint64(i)) >> nn
- eID = evalID[irev]
+ _i := bits.Reverse64(uint64(i)) >> nn
+ _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn
- shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn
- //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn
+ // in what follows gⁱ is understood as the generator of the chosen coset of domainBig
+ f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ
+ f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ
+ f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ
- f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma
- f[1].Mul(&eID, &pk.Vk.Shifter[0])
- f[2].Mul(&eID, &pk.Vk.Shifter[1])
- f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma
- f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma
+ g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ
+ g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ
+ g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ
- g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma
- g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma
- g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma
+ f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
+ g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ)
- f[0].Mul(&f[0], &f[1]).
- Mul(&f[0], &f[2]).
- Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma)
+ res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
- g[0].Mul(&g[0], &g[1]).
- Mul(&g[0], &g[2]).
- Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma)
-
- res[i].Sub(&g[0], &f[0])
+ evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g
}
})
return res
}
-// evaluateHDomain evaluates poly (canonical form) of degree m> nn
- // h[i].Mul(&h[i], &_u[irev%4])
- h[i].Mul(&h[i], &_u[irev%toShift])
+
+ _i := bits.Reverse64(i) >> nn
+
+ t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain
+ h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t).
+ Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]).
+ Mul(&h[_i], &alpha).
+ Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]).
+ Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio])
}
})
// put h in canonical form. h is of degree 3*(n+1)+2.
// using fft.DIT put h revert bit reverse
- pk.DomainH.FFTInverse(h, fft.DIT, 1)
- // fmt.Println("h:")
- // for i := 0; i < len(h); i++ {
- // fmt.Printf("%s\n", h[i].String())
- // }
- // fmt.Println("")
+ pk.Domain[1].FFTInverse(h, fft.DIT, true)
// degree of hi is n+2 because of the blinding
- h1 := h[:pk.DomainNum.Cardinality+2]
- h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)]
- h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)]
+ h1 := h[:pk.Domain[0].Cardinality+2]
+ h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)]
+ h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)]
return h1, h2, h3
@@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom
// computeLinearizedPolynomial computes the linearized polynomial in canonical basis.
// The purpose is to commit and open all in one ql, qr, qm, qo, qk.
-// * a, b, c are the evaluation of l, r, o at zeta
-// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z
+// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta
+// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z
// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk.
-func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial {
+//
+// The Linearized polynomial is:
+//
+// α²*L₁(ζ)*Z(X)
+// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ))
+// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X)
+func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element {
// first part: individual constraints
var rl fr.Element
- rl.Mul(&r, &l)
+ rl.Mul(&rZeta, &lZeta)
- // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
+ // second part:
+ // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)
var s1, s2 fr.Element
chS1 := make(chan struct{}, 1)
go func() {
- s1 = pk.CS1.Eval(&zeta)
- s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma)
+ s1 = eval(pk.S1Canonical, zeta) // s1(ζ)
+ s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ)
close(chS1)
}()
- t := pk.CS2.Eval(&zeta)
- t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma)
+ tmp := eval(pk.S2Canonical, zeta) // s2(ζ)
+ tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ)
<-chS1
- s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma)
- Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)
-
- s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma)
- t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)
- t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
- s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
-
- // third part L1(zeta)*alpha**2**Z
- var lagrange, one, den, frNbElmt fr.Element
+ s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ)
+
+ var uzeta, uuzeta fr.Element
+ uzeta.Mul(&zeta, &pk.Vk.CosetShift)
+ uuzeta.Mul(&uzeta, &pk.Vk.CosetShift)
+
+ s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ)
+ tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)
+ tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+ s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ // third part L₁(ζ)*α²*Z
+ var lagrangeZeta, one, den, frNbElmt fr.Element
one.SetOne()
- nbElmt := int64(pk.DomainNum.Cardinality)
- lagrange.Set(&zeta).
- Exp(lagrange, big.NewInt(nbElmt)).
- Sub(&lagrange, &one)
+ nbElmt := int64(pk.Domain[0].Cardinality)
+ lagrangeZeta.Set(&zeta).
+ Exp(lagrangeZeta, big.NewInt(nbElmt)).
+ Sub(&lagrangeZeta, &one)
frNbElmt.SetUint64(uint64(nbElmt))
den.Sub(&zeta, &one).
- Mul(&den, &frNbElmt).
Inverse(&den)
- lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1)
- Mul(&lagrange, &alpha).
- Mul(&lagrange, &alpha) // alpha**2*L_0
+ lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1)
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ)
- linPol := z.Clone()
+ linPol := make([]fr.Element, len(blindedZCanonical))
+ copy(linPol, blindedZCanonical)
utils.Parallelize(len(linPol), func(start, end int) {
+
var t0, t1 fr.Element
+
for i := start; i < end; i++ {
- linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
- if i < len(pk.CS3) {
- t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X)
+
+ linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ if i < len(pk.S3Canonical) {
+
+ t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X)
+
linPol[i].Add(&linPol[i], &t0)
}
- linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) )
+ linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ))
if i < len(pk.Qm) {
- t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm
- t0.Mul(&pk.Ql[i], &l)
+
+ t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X)
+ t0.Mul(&pk.Ql[i], &lZeta)
t0.Add(&t0, &t1)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X)
- t0.Mul(&pk.Qr[i], &r)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr
+ t0.Mul(&pk.Qr[i], &rZeta)
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X)
- t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i])
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk
+ t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i])
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X)
}
- t0.Mul(&z[i], &lagrange)
+ t0.Mul(&blindedZCanonical[i], &lagrangeZeta)
linPol[i].Add(&linPol[i], &t0) // finish the computation
}
})
diff --git a/internal/backend/bls12-381/plonk/setup.go b/internal/backend/bls12-381/plonk/setup.go
index 27dfd6868d..823cc25d8b 100644
--- a/internal/backend/bls12-381/plonk/setup.go
+++ b/internal/backend/bls12-381/plonk/setup.go
@@ -21,7 +21,6 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr/kzg"
- "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial"
"github.com/consensys/gnark/internal/backend/bls12-381/cs"
kzgg "github.com/consensys/gnark-crypto/kzg"
@@ -40,18 +39,21 @@ type ProvingKey struct {
Vk *VerifyingKey
// qr,ql,qm,qo (in canonical basis).
- Ql, Qr, Qm, Qo polynomial.Polynomial
+ Ql, Qr, Qm, Qo []fr.Element
// LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs.
// Storing LQk in Lagrange basis saves a fft...
- CQk, LQk polynomial.Polynomial
+ CQk, LQk []fr.Element
- // Domains used for the FFTs
- DomainNum, DomainH fft.Domain
+ // Domains used for the FFTs.
+ // Domain[0] = small Domain
+ // Domain[1] = big Domain
+ Domain [2]fft.Domain
+ // Domain[0], Domain[1] fft.Domain
- // s1, s2, s3 (L=Lagrange basis, C=canonical basis)
- LS1, LS2, LS3 polynomial.Polynomial
- CS1, CS2, CS3 polynomial.Polynomial
+ // Permutation polynomials
+ EvaluationPermutationBigDomainBitReversed []fr.Element
+ S1Canonical, S2Canonical, S3Canonical []fr.Element
// position -> permuted position (position in [0,3*sizeSystem-1])
Permutation []int64
@@ -69,13 +71,12 @@ type VerifyingKey struct {
Generator fr.Element
NbPublicVariables uint64
- // shifters for extending the permutation set: from s=<1,z,..,z**n-1>,
- // extended domain = s || shifter[0].s || shifter[1].s
- Shifter [2]fr.Element
-
// Commitment scheme that is used for an instantiation of PLONK
KZGSRS *kzg.SRS
+ // cosetShift generator of the coset on the small domain
+ CosetShift fr.Element
+
// S commitments to S1, S2, S3
S [3]kzg.Digest
@@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
// fft domains
sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints
- pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false)
+ pk.Domain[0] = *fft.NewDomain(sizeSystem)
+ pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen)
// h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space,
// the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases
// except when n<6.
if sizeSystem < 6 {
- pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(8 * sizeSystem)
} else {
- pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(4 * sizeSystem)
}
- vk.Size = pk.DomainNum.Cardinality
+ vk.Size = pk.Domain[0].Cardinality
vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv)
- vk.Generator.Set(&pk.DomainNum.Generator)
+ vk.Generator.Set(&pk.Domain[0].Generator)
vk.NbPublicVariables = uint64(spr.NbPublicVariables)
- // shifters
- vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator)
- vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator)
-
if err := pk.InitKZG(srs); err != nil {
return nil, nil, err
}
// public polynomials corresponding to constraints: [ placholders | constraints | assertions ]
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.Qm[i].SetZero()
pk.Qo[i].SetZero()
pk.CQk[i].SetZero()
- pk.LQk[i].SetZero() // --> to be completed by the prover
+ pk.LQk[i].SetZero() // → to be completed by the prover
}
offset := spr.NbPublicVariables
for i := 0; i < nbConstraints; i++ { // constraints
@@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K])
}
- pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(pk.Ql, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qr, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qm, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qo, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.CQk, fft.DIF)
fft.BitReverse(pk.Ql)
fft.BitReverse(pk.Qr)
fft.BitReverse(pk.Qm)
@@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
buildPermutation(spr, &pk)
// set s1, s2, s3
- computeLDE(&pk)
+ ccomputePermutationPolynomials(&pk)
// Commit to the polynomials to set up the verifying key
var err error
@@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil {
+ if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil {
+ if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil {
+ if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
@@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
//
// The permutation s is composed of cycles of maximum length such that
//
-// s. (l||r||o) = (l||r||o)
+// s. (l∥r∥o) = (l∥r∥o)
//
-//, where l||r||o is the concatenation of the indices of l, r, o in
+//, where l∥r∥o is the concatenation of the indices of l, r, o in
// ql.l+qr.r+qm.l.r+qo.O+k = 0.
//
// The permutation is encoded as a slice s of size 3*size(l), where the
-// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab
+// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab
// like this: for i in tab: tab[i] = tab[permutation[i]]
func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables
- sizeSolution := int(pk.DomainNum.Cardinality)
+ sizeSolution := int(pk.Domain[0].Cardinality)
// init permutation
pk.Permutation = make([]int64, 3*sizeSolution)
@@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
}
}
-// computeLDE computes the LDE (Lagrange basis) of the permutations
+// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations
// s1, s2, s3.
//
-// ex: z gen of Z/mZ, u gen of Z/8mZ, then
-//
// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 |
// |
// | Permutation
// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v
// \---------------/ \--------------------/ \------------------------/
// s1 (LDE) s2 (LDE) s3 (LDE)
-func computeLDE(pk *ProvingKey) {
+func ccomputePermutationPolynomials(pk *ProvingKey) {
- nbElmt := int(pk.DomainNum.Cardinality)
+ nbElmts := int(pk.Domain[0].Cardinality)
- // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1]
- sID := make([]fr.Element, 3*nbElmt)
- sID[0].SetOne()
- sID[nbElmt].Set(&pk.DomainNum.FinerGenerator)
- sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator)
-
- for i := 1; i < nbElmt; i++ {
- sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1
- sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
- }
+ // Lagrange form of ID
+ evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0])
// Lagrange form of S1, S2, S3
- pk.LS1 = make(polynomial.Polynomial, nbElmt)
- pk.LS2 = make(polynomial.Polynomial, nbElmt)
- pk.LS3 = make(polynomial.Polynomial, nbElmt)
- for i := 0; i < nbElmt; i++ {
- pk.LS1[i].Set(&sID[pk.Permutation[i]])
- pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]])
- pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]])
+ pk.S1Canonical = make([]fr.Element, nbElmts)
+ pk.S2Canonical = make([]fr.Element, nbElmts)
+ pk.S3Canonical = make([]fr.Element, nbElmts)
+ for i := 0; i < nbElmts; i++ {
+ pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]])
+ pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]])
+ pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]])
}
// Canonical form of S1, S2, S3
- pk.CS1 = make(polynomial.Polynomial, nbElmt)
- pk.CS2 = make(polynomial.Polynomial, nbElmt)
- pk.CS3 = make(polynomial.Polynomial, nbElmt)
- copy(pk.CS1, pk.LS1)
- copy(pk.CS2, pk.LS2)
- copy(pk.CS3, pk.LS3)
- pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0)
- fft.BitReverse(pk.CS1)
- fft.BitReverse(pk.CS2)
- fft.BitReverse(pk.CS3)
+ pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF)
+ fft.BitReverse(pk.S1Canonical)
+ fft.BitReverse(pk.S2Canonical)
+ fft.BitReverse(pk.S3Canonical)
+
+ // evaluation of permutation on the big domain
+ pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality)
+ copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true)
+
+}
+
+// getIDSmallDomain returns the Lagrange form of ID on the small domain
+func getIDSmallDomain(domain *fft.Domain) []fr.Element {
+
+ res := make([]fr.Element, 3*domain.Cardinality)
+
+ res[0].SetOne()
+ res[domain.Cardinality].Set(&domain.FrMultiplicativeGen)
+ res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen)
+
+ for i := uint64(1); i < domain.Cardinality; i++ {
+ res[i].Mul(&res[i-1], &domain.Generator)
+ res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator)
+ res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator)
+ }
+ return res
}
-// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS
+// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS
//
// This should be used after deserializing a ProvingKey
// as pk.Vk.KZG is NOT serialized
diff --git a/internal/backend/bls12-381/plonk/verify.go b/internal/backend/bls12-381/plonk/verify.go
index 53283ee4a9..620bc5e097 100644
--- a/internal/backend/bls12-381/plonk/verify.go
+++ b/internal/backend/bls12-381/plonk/verify.go
@@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne
hFunc := sha256.New()
// transcript to derive the challenge
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// derive gamma from Comm(l), Comm(r), Comm(o)
gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2])
@@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne
return err
}
+ // derive beta from Comm(l), Comm(r), Comm(o)
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return err
+ }
+
// derive alpha from Comm(l), Comm(r), Comm(o), Com(Z)
alpha, err := deriveRandomness(&fs, "alpha", &proof.Z)
if err != nil {
@@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne
return err
}
- // evaluation of Z=X**m-1 at zeta
+ // evaluation of Z=Xⁿ⁻¹ at ζ
var zetaPowerM, zzeta fr.Element
var bExpo big.Int
one := fr.One()
@@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls12_381witness.Witne
zetaPowerM.Exp(zeta, &bExpo)
zzeta.Sub(&zetaPowerM, &one)
- // ccompute PI = Sum_i maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
}
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
}
- }
- // sanity check; ensure all wires are marked as "instantiated"
- if !solution.isValid() {
- panic("solver didn't instantiate all wires")
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
}
- return solution.values, nil
+ return nil
}
// IsSolved returns nil if given witness solves the R1CS and error otherwise
@@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) {
// returns false, nil if there was no wire to solve
// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that
// the constraint is satisfied later.
-func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) {
+func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error {
// the index of the non zero entry shows if L, R or O has an uninstantiated wire
// the content is the ID of the wire non instantiated
@@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
return nil
}
- if err = processLExp(r.L.LinExp, &a, 1); err != nil {
- return
+ if err := processLExp(r.L.LinExp, a, 1); err != nil {
+ return err
}
- if err = processLExp(r.R.LinExp, &b, 2); err != nil {
- return
+ if err := processLExp(r.R.LinExp, b, 2); err != nil {
+ return err
}
- if err = processLExp(r.O.LinExp, &c, 3); err != nil {
- return
+ if err := processLExp(r.O.LinExp, c, 3); err != nil {
+ return err
}
if loc == 0 {
// there is nothing to solve, may happen if we have an assertion
// (ie a constraints that doesn't yield any output)
// or if we solved the unsolved wires with hint functions
- return
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
+ return nil
}
// we compute the wire value and instantiate it
- solved = true
- vID := termToCompute.WireID()
+ wID := termToCompute.WireID()
// solver result
var wire fr.Element
@@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
switch loc {
case 1:
if !b.IsZero() {
- wire.Div(&c, &b).
- Sub(&wire, &a)
- a.Add(&a, &wire)
+ wire.Div(c, b).
+ Sub(&wire, a)
+ a.Add(a, &wire)
} else {
// we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 2:
if !a.IsZero() {
- wire.Div(&c, &a).
- Sub(&wire, &b)
- b.Add(&b, &wire)
+ wire.Div(c, a).
+ Sub(&wire, b)
+ b.Add(b, &wire)
} else {
- // we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 3:
- wire.Mul(&a, &b).
- Sub(&wire, &c)
+ wire.Mul(a, b).
+ Sub(&wire, c)
- c.Add(&c, &wire)
+ c.Add(c, &wire)
}
// wire is the term (coeff * value)
// but in the solution we want to store the value only
// note that in gnark frontend, coeff here is always 1 or -1
cs.divByCoeff(&wire, termToCompute)
- solution.set(vID, wire)
+ solution.set(wID, wire)
- return
+ return nil
}
// GetConstraints return a list of constraint formatted as L⋅R == O
diff --git a/internal/backend/bls24-315/cs/r1cs_sparse.go b/internal/backend/bls24-315/cs/r1cs_sparse.go
index 667d5d230e..9ecd27fb26 100644
--- a/internal/backend/bls24-315/cs/r1cs_sparse.go
+++ b/internal/backend/bls24-315/cs/r1cs_sparse.go
@@ -21,9 +21,12 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/fxamacker/cbor/v2"
"io"
+ "math"
"math/big"
"os"
+ "runtime"
"strings"
+ "sync"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/witness"
@@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
return solution.values, err
}
- defer func() {
- // release memory
- solution.tmpHintsIO = nil
- }()
-
// solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs
copy(solution.values, witness)
for i := 0; i < len(witness); i++ {
@@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
// keep track of the number of wire instantiations we do, for a sanity check to ensure
// we instantiated all wires
- solution.nbSolved += len(witness)
+ solution.nbSolved += uint64(len(witness))
// defer log printing once all solution.values are computed
defer solution.printLogs(opt.LoggerOut, cs.Logs)
@@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
coefficientsNegInv[i].Neg(&coefficientsNegInv[i])
}
- // loop through the constraints to solve the variables
- for i := 0; i < len(cs.Constraints); i++ {
- if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil {
- return solution.values, fmt.Errorf("constraint %d: %w", i, err)
- }
- if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil {
- errMsg := err.Error()
- if dID, ok := cs.MDebug[i]; ok {
- errMsg = solution.logValue(cs.DebugInfo[dID])
- }
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
- }
+ if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil {
+ return solution.values, err
}
// sanity check; ensure all wires are marked as "instantiated"
@@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
}
+func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error {
+ // minWorkPerCPU is the minimum target number of constraint a task should hold
+ // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed
+ // sequentially without sync.
+ const minWorkPerCPU = 50.0
+
+ // cs.Levels has a list of levels, where all constraints in a level l(n) are independent
+ // and may only have dependencies on previous levels
+
+ var wg sync.WaitGroup
+ chTasks := make(chan []int, runtime.NumCPU())
+ chError := make(chan error, runtime.NumCPU())
+
+ // start a worker pool
+ // each worker wait on chTasks
+ // a task is a slice of constraint indexes to be solved
+ for i := 0; i < runtime.NumCPU(); i++ {
+ go func() {
+ for t := range chTasks {
+ for _, i := range t {
+ // for each constraint in the task, solve it.
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ wg.Done()
+ return
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ wg.Done()
+ return
+ }
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ // clean up pool go routines
+ defer func() {
+ close(chTasks)
+ close(chError)
+ }()
+
+ // for each level, we push the tasks
+ for _, level := range cs.Levels {
+
+ // max CPU to use
+ maxCPU := float64(len(level)) / minWorkPerCPU
+
+ if maxCPU <= 1.0 {
+ // we do it sequentially
+ for _, i := range level {
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ return fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ }
+ }
+ continue
+ }
+
+ // number of tasks for this level is set to num cpus
+ // but if we don't have enough work for all our CPUS, it can be lower.
+ nbTasks := runtime.NumCPU()
+ maxTasks := int(math.Ceil(maxCPU))
+ if nbTasks > maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
+ }
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
+ }
+
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
+ }
+
+ return nil
+}
+
// computeHints computes wires associated with a hint function, if any
// if there is no remaining wire to solve, returns -1
// else returns the wire position (L -> 0, R -> 1, O -> 2)
diff --git a/internal/backend/bls24-315/cs/solution.go b/internal/backend/bls24-315/cs/solution.go
index 272f31af54..e215ac343a 100644
--- a/internal/backend/bls24-315/cs/solution.go
+++ b/internal/backend/bls24-315/cs/solution.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"math/big"
+ "sync/atomic"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend/schema"
@@ -32,14 +33,15 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bls24-315"
)
+var errUnsatisfiedConstraint = errors.New("unsatisfied")
+
// solution represents elements needed to compute
// a solution to a R1CS or SparseR1CS
type solution struct {
values, coefficients []fr.Element
solved []bool
- nbSolved int
+ nbSolved uint64
mHintsFunctions map[hint.ID]hint.Function
- tmpHintsIO []*big.Int
}
func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) {
@@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E
coefficients: coefficients,
solved: make([]bool, nbWires),
mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)),
- tmpHintsIO: make([]*big.Int, 0),
}
for _, h := range hintFunctions {
@@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) {
}
s.values[id] = value
s.solved[id] = true
- s.nbSolved++
+ atomic.AddUint64(&s.nbSolved, 1)
+ // s.nbSolved++
}
func (s *solution) isValid() bool {
- return s.nbSolved == len(s.values)
+ return int(s.nbSolved) == len(s.values)
}
// computeTerm computes coef*variable
@@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error {
// tmp IO big int memory
nbInputs := len(h.Inputs)
nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs))
- m := len(s.tmpHintsIO)
- if m < (nbInputs + nbOutputs) {
- s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...)
- for i := m; i < len(s.tmpHintsIO); i++ {
- s.tmpHintsIO[i] = big.NewInt(0)
- }
+ // m := len(s.tmpHintsIO)
+ // if m < (nbInputs + nbOutputs) {
+ // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...)
+ // for i := m; i < len(s.tmpHintsIO); i++ {
+ // s.tmpHintsIO[i] = big.NewInt(0)
+ // }
+ // }
+ inputs := make([]*big.Int, nbInputs)
+ outputs := make([]*big.Int, nbOutputs)
+ for i := 0; i < nbInputs; i++ {
+ inputs[i] = big.NewInt(0)
+ }
+ for i := 0; i < nbOutputs; i++ {
+ outputs[i] = big.NewInt(0)
}
- inputs := s.tmpHintsIO[:nbInputs]
- outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs]
q := fr.Modulus()
diff --git a/internal/backend/bls24-315/groth16/marshal_test.go b/internal/backend/bls24-315/groth16/marshal_test.go
index aa2ceda799..06bf1ec9de 100644
--- a/internal/backend/bls24-315/groth16/marshal_test.go
+++ b/internal/backend/bls24-315/groth16/marshal_test.go
@@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) {
var pk, pkCompressed, pkRaw ProvingKey
// create a random pk
- domain := fft.NewDomain(8, 1, true)
+ domain := fft.NewDomain(8)
pk.Domain = *domain
nbWires := 6
diff --git a/internal/backend/bls24-315/groth16/prove.go b/internal/backend/bls24-315/groth16/prove.go
index 606f9e07f9..fbc67eb021 100644
--- a/internal/backend/bls24-315/groth16/prove.go
+++ b/internal/backend/bls24-315/groth16/prove.go
@@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
c = append(c, padding...)
n = len(a)
- domain.FFTInverse(a, fft.DIF, 0)
- domain.FFTInverse(b, fft.DIF, 0)
- domain.FFTInverse(c, fft.DIF, 0)
+ domain.FFTInverse(a, fft.DIF)
+ domain.FFTInverse(b, fft.DIF)
+ domain.FFTInverse(c, fft.DIF)
- domain.FFT(a, fft.DIT, 1)
- domain.FFT(b, fft.DIT, 1)
- domain.FFT(c, fft.DIT, 1)
+ domain.FFT(a, fft.DIT, true)
+ domain.FFT(b, fft.DIT, true)
+ domain.FFT(c, fft.DIT, true)
- var minusTwoInv fr.Element
- minusTwoInv.SetUint64(2)
- minusTwoInv.Neg(&minusTwoInv).
- Inverse(&minusTwoInv)
+ var den, one fr.Element
+ one.SetOne()
+ den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
+ den.Sub(&den, &one).Inverse(&den)
// h = ifft_coset(ca o cb - cc)
// reusing a to avoid unecessary memalloc
@@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
for i := start; i < end; i++ {
a[i].Mul(&a[i], &b[i]).
Sub(&a[i], &c[i]).
- Mul(&a[i], &minusTwoInv)
+ Mul(&a[i], &den)
}
})
// ifft_coset
- domain.FFTInverse(a, fft.DIF, 1)
+ domain.FFTInverse(a, fft.DIF, true)
utils.Parallelize(len(a), func(start, end int) {
for i := start; i < end; i++ {
diff --git a/internal/backend/bls24-315/groth16/setup.go b/internal/backend/bls24-315/groth16/setup.go
index e34b8d752b..596e7786fc 100644
--- a/internal/backend/bls24-315/groth16/setup.go
+++ b/internal/backend/bls24-315/groth16/setup.go
@@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error {
nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables
// Setting group for fft
- domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true)
+ domain := fft.NewDomain(uint64(len(r1cs.Constraints)))
// samples toxic waste
toxicWaste, err := sampleToxicWaste()
@@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error {
nbConstraints := len(r1cs.Constraints)
// Setting group for fft
- domain := fft.NewDomain(uint64(nbConstraints), 1, true)
+ domain := fft.NewDomain(uint64(nbConstraints))
// count number of infinity points we would have had we a normal setup
// in pk.G1.A, pk.G1.B, and pk.G2.B
diff --git a/internal/backend/bls24-315/plonk/marshal.go b/internal/backend/bls24-315/plonk/marshal.go
index 83d53ffa1f..05f24f4f36 100644
--- a/internal/backend/bls24-315/plonk/marshal.go
+++ b/internal/backend/bls24-315/plonk/marshal.go
@@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
}
// fft domains
- n2, err := pk.DomainNum.WriteTo(w)
+ n2, err := pk.Domain[0].WriteTo(w)
if err != nil {
return
}
n += n2
- n2, err = pk.DomainH.WriteTo(w)
+ n2, err = pk.Domain[1].WriteTo(w)
if err != nil {
return
}
n += n2
- // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality)
- if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) {
+ // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality)
+ if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) {
return n, errors.New("invalid permutation size, expected 3*domain cardinality")
}
@@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
([]fr.Element)(pk.Qo),
([]fr.Element)(pk.CQk),
([]fr.Element)(pk.LQk),
- ([]fr.Element)(pk.LS1),
- ([]fr.Element)(pk.LS2),
- ([]fr.Element)(pk.LS3),
- ([]fr.Element)(pk.CS1),
- ([]fr.Element)(pk.CS2),
- ([]fr.Element)(pk.CS3),
+ ([]fr.Element)(pk.S1Canonical),
+ ([]fr.Element)(pk.S2Canonical),
+ ([]fr.Element)(pk.S3Canonical),
pk.Permutation,
}
@@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
- n2, err := pk.DomainNum.ReadFrom(r)
+ n2, err := pk.Domain[0].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- n2, err = pk.DomainH.ReadFrom(r)
+ n2, err = pk.Domain[1].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
dec := curve.NewDecoder(r)
toDecode := []interface{}{
@@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
(*[]fr.Element)(&pk.Qo),
(*[]fr.Element)(&pk.CQk),
(*[]fr.Element)(&pk.LQk),
- (*[]fr.Element)(&pk.LS1),
- (*[]fr.Element)(&pk.LS2),
- (*[]fr.Element)(&pk.LS3),
- (*[]fr.Element)(&pk.CS1),
- (*[]fr.Element)(&pk.CS2),
- (*[]fr.Element)(&pk.CS3),
+ (*[]fr.Element)(&pk.S1Canonical),
+ (*[]fr.Element)(&pk.S2Canonical),
+ (*[]fr.Element)(&pk.S3Canonical),
&pk.Permutation,
}
@@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) {
&vk.SizeInv,
&vk.Generator,
vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
@@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) {
&vk.SizeInv,
&vk.Generator,
&vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
diff --git a/internal/backend/bls24-315/plonk/marshal_test.go b/internal/backend/bls24-315/plonk/marshal_test.go
index c24928ccf8..99763b07e3 100644
--- a/internal/backend/bls24-315/plonk/marshal_test.go
+++ b/internal/backend/bls24-315/plonk/marshal_test.go
@@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
@@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) {
// random pk
var pk ProvingKey
pk.Vk = &vk
- pk.DomainNum = *fft.NewDomain(42, 3, false)
- pk.DomainH = *fft.NewDomain(4*42, 1, false)
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Domain[0] = *fft.NewDomain(42)
+ pk.Domain[1] = *fft.NewDomain(4 * 42)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < 12; i++ {
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) {
pk.Qo[i].SetUint64(42)
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
pk.Permutation[0] = -12
pk.Permutation[len(pk.Permutation)-1] = 8888
@@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
diff --git a/internal/backend/bls24-315/plonk/prove.go b/internal/backend/bls24-315/plonk/prove.go
index 6f951143ba..904a785599 100644
--- a/internal/backend/bls24-315/plonk/prove.go
+++ b/internal/backend/bls24-315/plonk/prove.go
@@ -27,8 +27,6 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bls24-315"
- "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial"
-
"github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg"
"github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft"
@@ -43,6 +41,7 @@ import (
)
type Proof struct {
+
// Commitments to the solution vectors
LRO [3]kzg.Digest
@@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
hFunc := sha256.New()
// create a transcript manager to apply Fiat Shamir
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// result
proof := &Proof{}
@@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
}
// query l, r, o in Lagrange basis, not blinded
- ll, lr, lo := computeLRO(spr, pk, solution)
+ evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution)
// save ll, lr, lo, and make a copy of them in canonical basis.
// note that we allocate more capacity to reuse for blinded polynomials
- bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum)
+ blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ &pk.Domain[0])
if err != nil {
return nil, err
}
// compute kzg commitments of bcl, bcr and bco
- if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
return nil, err
}
+ // Fiat Shamir this
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return nil, err
+ }
+
// compute Z, the permutation accumulator polynomial, in canonical basis
// ll, lr, lo are NOT blinded
- var bz polynomial.Polynomial
+ var blindedZCanonical []fr.Element
chZ := make(chan error, 1)
var alpha fr.Element
go func() {
var err error
- bz, err = computeBlindedZ(ll, lr, lo, pk, gamma)
+ blindedZCanonical, err = computeBlindedZCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ pk, beta, gamma)
if err != nil {
chZ <- err
close(chZ)
@@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
// this may add additional arithmetic operations, but with smaller tasks
// we ensure that this commitment is well parallelized, without having a "unbalanced task" making
// the rest of the code wait too long.
- if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
+ if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
chZ <- err
close(chZ)
return
@@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
}()
// evaluation of the blinded versions of l, r, o and bz
- // on the odd cosets of (Z/8mZ)/(Z/mZ)
- var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial
+ // on the coset of the big domain
+ var (
+ evaluationBlindedLDomainBigBitReversed []fr.Element
+ evaluationBlindedRDomainBigBitReversed []fr.Element
+ evaluationBlindedODomainBigBitReversed []fr.Element
+ evaluationBlindedZDomainBigBitReversed []fr.Element
+ )
chEvalBL := make(chan struct{}, 1)
chEvalBR := make(chan struct{}, 1)
chEvalBO := make(chan struct{}, 1)
go func() {
- evalBL = evaluateHDomain(bcl, &pk.DomainH)
+ evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1])
close(chEvalBL)
}()
go func() {
- evalBR = evaluateHDomain(bcr, &pk.DomainH)
+ evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1])
close(chEvalBR)
}()
go func() {
- evalBO = evaluateHDomain(bco, &pk.DomainH)
+ evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1])
close(chEvalBO)
}()
- var constraintsInd, constraintsOrdering polynomial.Polynomial
+ var constraintsInd, constraintsOrdering []fr.Element
chConstraintInd := make(chan struct{}, 1)
go func() {
// compute qk in canonical basis, completed with the public inputs
- qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality)
- copy(qk, fullWitness[:spr.NbPublicVariables])
- copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
- pk.DomainNum.FFTInverse(qk, fft.DIF, 0)
- fft.BitReverse(qk)
-
- // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ)
- // --> uses the blinded version of l, r, o
+ qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality)
+ copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables])
+ copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
+ pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF)
+ fft.BitReverse(qkCompletedCanonical)
+
+ // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain
+ // → uses the blinded version of l, r, o
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk)
+ constraintsInd = evaluateConstraintsDomainBigBitReversed(
+ pk,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ qkCompletedCanonical)
close(chConstraintInd)
}()
@@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
chConstraintOrdering <- err
return
}
- evalBZ = evaluateHDomain(bz, &pk.DomainH)
- // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ)
+
+ evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1])
+ // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain
// evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o.
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma)
+ constraintsOrdering = evaluateOrderingDomainBigBitReversed(
+ pk,
+ evaluationBlindedZDomainBigBitReversed,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ beta,
+ gamma)
chConstraintOrdering <- nil
close(chConstraintOrdering)
}()
@@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
if err := <-chConstraintOrdering; err != nil {
return nil, err
}
+
<-chConstraintInd
+
// compute h in canonical form
- h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha)
+ h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha)
// compute kzg commitments of h1, h2 and h3
- if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
var wgZetaEvals sync.WaitGroup
wgZetaEvals.Add(3)
go func() {
- blzeta = bcl.Eval(&zeta)
+ blzeta = eval(blindedLCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- brzeta = bcr.Eval(&zeta)
+ brzeta = eval(blindedRCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- bozeta = bco.Eval(&zeta)
+ bozeta = eval(blindedOCanonical, zeta)
wgZetaEvals.Done()
}()
@@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
var zetaShifted fr.Element
zetaShifted.Mul(&zeta, &pk.Vk.Generator)
proof.ZShiftedOpening, err = kzg.Open(
- bz,
- &zetaShifted,
- &pk.DomainH,
+ blindedZCanonical,
+ zetaShifted,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
bzuzeta := proof.ZShiftedOpening.ClaimedValue
var (
- linearizedPolynomial polynomial.Polynomial
- linearizedPolynomialDigest curve.G1Affine
- errLPoly error
+ linearizedPolynomialCanonical []fr.Element
+ linearizedPolynomialDigest curve.G1Affine
+ errLPoly error
)
chLpoly := make(chan struct{}, 1)
go func() {
// compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k)
wgZetaEvals.Wait()
- linearizedPolynomial = computeLinearizedPolynomial(
+ linearizedPolynomialCanonical = computeLinearizedPolynomial(
blzeta,
brzeta,
bozeta,
alpha,
+ beta,
gamma,
zeta,
bzuzeta,
- bz,
+ blindedZCanonical,
pk,
)
// TODO this commitment is only necessary to derive the challenge, we should
// be able to avoid doing it and get the challenge in another way
- linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS)
+ linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS)
close(chLpoly)
}()
- // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3)
+ // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3)
var bZetaPowerm, bSize big.Int
- bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
+ bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
var zetaPowerm fr.Element
zetaPowerm.Exp(zeta, &bSize)
zetaPowerm.ToBigIntRegular(&bZetaPowerm)
foldedHDigest := proof.H[2]
foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm)
- foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3)
- foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2)
- foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3)
+ foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1)
- // foldedH = h1 + zeta*h2 + zeta**2*h3
+ // foldedH = h1 + ζ*h2 + ζ²*h3
foldedH := h3
utils.Parallelize(len(foldedH), func(start, end int) {
for i := start; i < end; i++ {
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3
- foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1)
- foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3
+ foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺²
+ foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1
}
})
@@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
// Batch open the first list of polynomials
proof.BatchedProof, err = kzg.BatchOpenSinglePoint(
- []polynomial.Polynomial{
+ [][]fr.Element{
foldedH,
- linearizedPolynomial,
- bcl,
- bcr,
- bco,
- pk.CS1,
- pk.CS2,
+ linearizedPolynomialCanonical,
+ blindedLCanonical,
+ blindedRCanonical,
+ blindedOCanonical,
+ pk.S1Canonical,
+ pk.S2Canonical,
},
[]kzg.Digest{
foldedHDigest,
@@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
pk.Vk.S[0],
pk.Vk.S[1],
},
- &zeta,
+ zeta,
hFunc,
- &pk.DomainH,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bls24_315witness.Witn
}
+// eval evaluates c at p
+func eval(c []fr.Element, p fr.Element) fr.Element {
+ var r fr.Element
+ for i := len(c) - 1; i >= 0; i-- {
+ r.Mul(&r, &p).Add(&r, &c[i])
+ }
+ return r
+}
+
// fills proof.LRO with kzg commits of bcl, bcr and bco
-func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS
return err1
}
-func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err
return err1
}
-// computeBlindedLRO l, r, o in canonical basis with blinding
-func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) {
+// computeBlindedLROCanonical l, r, o in canonical basis with blinding
+func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) {
// note that bcl, bcr and bco reuses cl, cr and co memory
- cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
+ cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
chDone := make(chan error, 2)
go func() {
var err error
copy(cl, ll)
- domain.FFTInverse(cl, fft.DIF, 0)
+ domain.FFTInverse(cl, fft.DIF)
fft.BitReverse(cl)
bcl, err = blindPoly(cl, domain.Cardinality, 1)
chDone <- err
@@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
go func() {
var err error
copy(cr, lr)
- domain.FFTInverse(cr, fft.DIF, 0)
+ domain.FFTInverse(cr, fft.DIF)
fft.BitReverse(cr)
bcr, err = blindPoly(cr, domain.Cardinality, 1)
chDone <- err
}()
copy(co, lo)
- domain.FFTInverse(co, fft.DIF, 0)
+ domain.FFTInverse(co, fft.DIF)
fft.BitReverse(co)
if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil {
return
@@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
// * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1)
//
// WARNING:
-// pre condition degree(cp) <= rou + bo
-// pre condition cap(cp) >= int(totalDegree + 1)
-func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) {
+// pre condition degree(cp) ⩽ rou + bo
+// pre condition cap(cp) ⩾ int(totalDegree + 1)
+func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) {
// degree of the blinded polynomial is max(rou+order, cp.Degree)
totalDegree := rou + bo
@@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
res := cp[:totalDegree+1]
// random polynomial
- blindingPoly := make(polynomial.Polynomial, bo+1)
+ blindingPoly := make([]fr.Element, bo+1)
for i := uint64(0); i < bo+1; i++ {
if _, err := blindingPoly[i].SetRandom(); err != nil {
return nil, err
@@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
}
return res, nil
+
}
-// computeLRO extracts the solution l, r, o, and returns it in lagrange form.
+// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form.
// solution = [ public | secret | internal ]
-func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) {
+func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) {
- s := int(pk.DomainNum.Cardinality)
+ s := int(pk.Domain[0].Cardinality)
- var l, r, o polynomial.Polynomial
+ var l, r, o []fr.Element
l = make([]fr.Element, s)
r = make([]fr.Element, s)
o = make([]fr.Element, s)
@@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly
//
// * Z of degree n (domainNum.Cardinality)
// * Z(1)=1
-// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma)
-// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1
- u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
}
})
@@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele
Mul(&z[i], &gInv[i])
}
- pk.DomainNum.FFTInverse(z, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(z, fft.DIF)
fft.BitReverse(z)
- return blindPoly(z, pk.DomainNum.Cardinality, 2)
+ return blindPoly(z, pk.Domain[0].Cardinality, 2)
}
-// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on
-// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on
+// the big domain coset.
//
// * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets
// * qk is the completed version of qk, in canonical version
-func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
- var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial
+func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
+ var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element
var wg sync.WaitGroup
wg.Add(4)
go func() {
- evalQl = evaluateHDomain(pk.Ql, &pk.DomainH)
+ evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQr = evaluateHDomain(pk.Qr, &pk.DomainH)
+ evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQm = evaluateHDomain(pk.Qm, &pk.DomainH)
+ evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQo = evaluateHDomain(pk.Qo, &pk.DomainH)
+ evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1])
wg.Done()
}()
- evalQk = evaluateHDomain(qk, &pk.DomainH)
+ evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1])
wg.Wait()
- // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets
- // of (Z/8mZ)/(Z/mZ)
+
+ // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain
utils.Parallelize(len(evalQk), func(start, end int) {
var t0, t1 fr.Element
for i := start; i < end; i++ {
@@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.
return evalQk
}
-// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ)
-func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) {
-
- id = make([]fr.Element, pk.DomainH.Cardinality)
-
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
- var acc fr.Element
- acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start)))
- for i := start; i < end; i++ {
- id[i].Mul(&acc, &pk.DomainH.FinerGenerator)
- acc.Mul(&acc, &pk.DomainH.Generator)
- }
- })
-
- return id
-}
-
-// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
-// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
+// cosets of the big domain.
//
-// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets
-// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets
+// * z evaluation of the blinded permutation accumulator polynomial on odd cosets
+// * l, r, o evaluation of the blinded solution vectors on odd cosets
// * gamma randomization
-func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial {
+func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element {
- // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ)
- evalID := evalIDCosets(pk)
+ nbElmts := int(pk.Domain[1].Cardinality)
- // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ)
- var wg sync.WaitGroup
- wg.Add(2)
- var evalS1, evalS2, evalS3 polynomial.Polynomial
- go func() {
- evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH)
- wg.Done()
- }()
- go func() {
- evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH)
- wg.Done()
- }()
- evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH)
- wg.Wait()
+ // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ)
+ // on the big domain (coset).
+ res := make([]fr.Element, pk.Domain[1].Cardinality)
- // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ)
- res := evalS1 // re use allocated memory for evalS1
- s := uint64(len(evalZ))
- nn := uint64(64 - bits.TrailingZeros64(uint64(s)))
+ nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts)))
// needed to shift evalZ
- toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality
+ toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality)
+
+ var cosetShift, cosetShiftSquare fr.Element
+ cosetShift.Set(&pk.Vk.CosetShift)
+ cosetShiftSquare.Square(&pk.Vk.CosetShift)
+
+ utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) {
+
+ var evaluationIDBigDomain fr.Element
+ evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))).
+ Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen)
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
var f [3]fr.Element
var g [3]fr.Element
- var eID fr.Element
for i := start; i < end; i++ {
- // here we want to left shift evalZ by domainH/domainNum
- // however, evalZ is permuted
- // we take the non permuted index
- // compute the corresponding shift position
- // permute it again
- irev := bits.Reverse64(uint64(i)) >> nn
- eID = evalID[irev]
+ _i := bits.Reverse64(uint64(i)) >> nn
+ _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn
- shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn
- //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn
+ // in what follows gⁱ is understood as the generator of the chosen coset of domainBig
+ f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ
+ f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ
+ f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ
- f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma
- f[1].Mul(&eID, &pk.Vk.Shifter[0])
- f[2].Mul(&eID, &pk.Vk.Shifter[1])
- f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma
- f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma
+ g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ
+ g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ
+ g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ
- g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma
- g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma
- g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma
+ f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
+ g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ)
- f[0].Mul(&f[0], &f[1]).
- Mul(&f[0], &f[2]).
- Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma)
+ res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
- g[0].Mul(&g[0], &g[1]).
- Mul(&g[0], &g[2]).
- Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma)
-
- res[i].Sub(&g[0], &f[0])
+ evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g
}
})
return res
}
-// evaluateHDomain evaluates poly (canonical form) of degree m> nn
- // h[i].Mul(&h[i], &_u[irev%4])
- h[i].Mul(&h[i], &_u[irev%toShift])
+
+ _i := bits.Reverse64(i) >> nn
+
+ t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain
+ h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t).
+ Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]).
+ Mul(&h[_i], &alpha).
+ Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]).
+ Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio])
}
})
// put h in canonical form. h is of degree 3*(n+1)+2.
// using fft.DIT put h revert bit reverse
- pk.DomainH.FFTInverse(h, fft.DIT, 1)
- // fmt.Println("h:")
- // for i := 0; i < len(h); i++ {
- // fmt.Printf("%s\n", h[i].String())
- // }
- // fmt.Println("")
+ pk.Domain[1].FFTInverse(h, fft.DIT, true)
// degree of hi is n+2 because of the blinding
- h1 := h[:pk.DomainNum.Cardinality+2]
- h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)]
- h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)]
+ h1 := h[:pk.Domain[0].Cardinality+2]
+ h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)]
+ h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)]
return h1, h2, h3
@@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom
// computeLinearizedPolynomial computes the linearized polynomial in canonical basis.
// The purpose is to commit and open all in one ql, qr, qm, qo, qk.
-// * a, b, c are the evaluation of l, r, o at zeta
-// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z
+// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta
+// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z
// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk.
-func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial {
+//
+// The Linearized polynomial is:
+//
+// α²*L₁(ζ)*Z(X)
+// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ))
+// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X)
+func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element {
// first part: individual constraints
var rl fr.Element
- rl.Mul(&r, &l)
+ rl.Mul(&rZeta, &lZeta)
- // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
+ // second part:
+ // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)
var s1, s2 fr.Element
chS1 := make(chan struct{}, 1)
go func() {
- s1 = pk.CS1.Eval(&zeta)
- s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma)
+ s1 = eval(pk.S1Canonical, zeta) // s1(ζ)
+ s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ)
close(chS1)
}()
- t := pk.CS2.Eval(&zeta)
- t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma)
+ tmp := eval(pk.S2Canonical, zeta) // s2(ζ)
+ tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ)
<-chS1
- s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma)
- Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)
-
- s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma)
- t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)
- t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
- s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
-
- // third part L1(zeta)*alpha**2**Z
- var lagrange, one, den, frNbElmt fr.Element
+ s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ)
+
+ var uzeta, uuzeta fr.Element
+ uzeta.Mul(&zeta, &pk.Vk.CosetShift)
+ uuzeta.Mul(&uzeta, &pk.Vk.CosetShift)
+
+ s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ)
+ tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)
+ tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+ s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ // third part L₁(ζ)*α²*Z
+ var lagrangeZeta, one, den, frNbElmt fr.Element
one.SetOne()
- nbElmt := int64(pk.DomainNum.Cardinality)
- lagrange.Set(&zeta).
- Exp(lagrange, big.NewInt(nbElmt)).
- Sub(&lagrange, &one)
+ nbElmt := int64(pk.Domain[0].Cardinality)
+ lagrangeZeta.Set(&zeta).
+ Exp(lagrangeZeta, big.NewInt(nbElmt)).
+ Sub(&lagrangeZeta, &one)
frNbElmt.SetUint64(uint64(nbElmt))
den.Sub(&zeta, &one).
- Mul(&den, &frNbElmt).
Inverse(&den)
- lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1)
- Mul(&lagrange, &alpha).
- Mul(&lagrange, &alpha) // alpha**2*L_0
+ lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1)
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ)
- linPol := z.Clone()
+ linPol := make([]fr.Element, len(blindedZCanonical))
+ copy(linPol, blindedZCanonical)
utils.Parallelize(len(linPol), func(start, end int) {
+
var t0, t1 fr.Element
+
for i := start; i < end; i++ {
- linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
- if i < len(pk.CS3) {
- t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X)
+
+ linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ if i < len(pk.S3Canonical) {
+
+ t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X)
+
linPol[i].Add(&linPol[i], &t0)
}
- linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) )
+ linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ))
if i < len(pk.Qm) {
- t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm
- t0.Mul(&pk.Ql[i], &l)
+
+ t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X)
+ t0.Mul(&pk.Ql[i], &lZeta)
t0.Add(&t0, &t1)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X)
- t0.Mul(&pk.Qr[i], &r)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr
+ t0.Mul(&pk.Qr[i], &rZeta)
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X)
- t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i])
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk
+ t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i])
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X)
}
- t0.Mul(&z[i], &lagrange)
+ t0.Mul(&blindedZCanonical[i], &lagrangeZeta)
linPol[i].Add(&linPol[i], &t0) // finish the computation
}
})
diff --git a/internal/backend/bls24-315/plonk/setup.go b/internal/backend/bls24-315/plonk/setup.go
index 8327805f3a..c0fd45c8b2 100644
--- a/internal/backend/bls24-315/plonk/setup.go
+++ b/internal/backend/bls24-315/plonk/setup.go
@@ -21,7 +21,6 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls24-315/fr"
"github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft"
"github.com/consensys/gnark-crypto/ecc/bls24-315/fr/kzg"
- "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial"
"github.com/consensys/gnark/internal/backend/bls24-315/cs"
kzgg "github.com/consensys/gnark-crypto/kzg"
@@ -40,18 +39,21 @@ type ProvingKey struct {
Vk *VerifyingKey
// qr,ql,qm,qo (in canonical basis).
- Ql, Qr, Qm, Qo polynomial.Polynomial
+ Ql, Qr, Qm, Qo []fr.Element
// LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs.
// Storing LQk in Lagrange basis saves a fft...
- CQk, LQk polynomial.Polynomial
+ CQk, LQk []fr.Element
- // Domains used for the FFTs
- DomainNum, DomainH fft.Domain
+ // Domains used for the FFTs.
+ // Domain[0] = small Domain
+ // Domain[1] = big Domain
+ Domain [2]fft.Domain
+ // Domain[0], Domain[1] fft.Domain
- // s1, s2, s3 (L=Lagrange basis, C=canonical basis)
- LS1, LS2, LS3 polynomial.Polynomial
- CS1, CS2, CS3 polynomial.Polynomial
+ // Permutation polynomials
+ EvaluationPermutationBigDomainBitReversed []fr.Element
+ S1Canonical, S2Canonical, S3Canonical []fr.Element
// position -> permuted position (position in [0,3*sizeSystem-1])
Permutation []int64
@@ -69,13 +71,12 @@ type VerifyingKey struct {
Generator fr.Element
NbPublicVariables uint64
- // shifters for extending the permutation set: from s=<1,z,..,z**n-1>,
- // extended domain = s || shifter[0].s || shifter[1].s
- Shifter [2]fr.Element
-
// Commitment scheme that is used for an instantiation of PLONK
KZGSRS *kzg.SRS
+ // cosetShift generator of the coset on the small domain
+ CosetShift fr.Element
+
// S commitments to S1, S2, S3
S [3]kzg.Digest
@@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
// fft domains
sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints
- pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false)
+ pk.Domain[0] = *fft.NewDomain(sizeSystem)
+ pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen)
// h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space,
// the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases
// except when n<6.
if sizeSystem < 6 {
- pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(8 * sizeSystem)
} else {
- pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(4 * sizeSystem)
}
- vk.Size = pk.DomainNum.Cardinality
+ vk.Size = pk.Domain[0].Cardinality
vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv)
- vk.Generator.Set(&pk.DomainNum.Generator)
+ vk.Generator.Set(&pk.Domain[0].Generator)
vk.NbPublicVariables = uint64(spr.NbPublicVariables)
- // shifters
- vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator)
- vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator)
-
if err := pk.InitKZG(srs); err != nil {
return nil, nil, err
}
// public polynomials corresponding to constraints: [ placholders | constraints | assertions ]
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.Qm[i].SetZero()
pk.Qo[i].SetZero()
pk.CQk[i].SetZero()
- pk.LQk[i].SetZero() // --> to be completed by the prover
+ pk.LQk[i].SetZero() // → to be completed by the prover
}
offset := spr.NbPublicVariables
for i := 0; i < nbConstraints; i++ { // constraints
@@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K])
}
- pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(pk.Ql, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qr, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qm, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qo, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.CQk, fft.DIF)
fft.BitReverse(pk.Ql)
fft.BitReverse(pk.Qr)
fft.BitReverse(pk.Qm)
@@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
buildPermutation(spr, &pk)
// set s1, s2, s3
- computeLDE(&pk)
+ ccomputePermutationPolynomials(&pk)
// Commit to the polynomials to set up the verifying key
var err error
@@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil {
+ if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil {
+ if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil {
+ if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
@@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
//
// The permutation s is composed of cycles of maximum length such that
//
-// s. (l||r||o) = (l||r||o)
+// s. (l∥r∥o) = (l∥r∥o)
//
-//, where l||r||o is the concatenation of the indices of l, r, o in
+//, where l∥r∥o is the concatenation of the indices of l, r, o in
// ql.l+qr.r+qm.l.r+qo.O+k = 0.
//
// The permutation is encoded as a slice s of size 3*size(l), where the
-// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab
+// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab
// like this: for i in tab: tab[i] = tab[permutation[i]]
func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables
- sizeSolution := int(pk.DomainNum.Cardinality)
+ sizeSolution := int(pk.Domain[0].Cardinality)
// init permutation
pk.Permutation = make([]int64, 3*sizeSolution)
@@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
}
}
-// computeLDE computes the LDE (Lagrange basis) of the permutations
+// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations
// s1, s2, s3.
//
-// ex: z gen of Z/mZ, u gen of Z/8mZ, then
-//
// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 |
// |
// | Permutation
// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v
// \---------------/ \--------------------/ \------------------------/
// s1 (LDE) s2 (LDE) s3 (LDE)
-func computeLDE(pk *ProvingKey) {
+func ccomputePermutationPolynomials(pk *ProvingKey) {
- nbElmt := int(pk.DomainNum.Cardinality)
+ nbElmts := int(pk.Domain[0].Cardinality)
- // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1]
- sID := make([]fr.Element, 3*nbElmt)
- sID[0].SetOne()
- sID[nbElmt].Set(&pk.DomainNum.FinerGenerator)
- sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator)
-
- for i := 1; i < nbElmt; i++ {
- sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1
- sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
- }
+ // Lagrange form of ID
+ evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0])
// Lagrange form of S1, S2, S3
- pk.LS1 = make(polynomial.Polynomial, nbElmt)
- pk.LS2 = make(polynomial.Polynomial, nbElmt)
- pk.LS3 = make(polynomial.Polynomial, nbElmt)
- for i := 0; i < nbElmt; i++ {
- pk.LS1[i].Set(&sID[pk.Permutation[i]])
- pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]])
- pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]])
+ pk.S1Canonical = make([]fr.Element, nbElmts)
+ pk.S2Canonical = make([]fr.Element, nbElmts)
+ pk.S3Canonical = make([]fr.Element, nbElmts)
+ for i := 0; i < nbElmts; i++ {
+ pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]])
+ pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]])
+ pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]])
}
// Canonical form of S1, S2, S3
- pk.CS1 = make(polynomial.Polynomial, nbElmt)
- pk.CS2 = make(polynomial.Polynomial, nbElmt)
- pk.CS3 = make(polynomial.Polynomial, nbElmt)
- copy(pk.CS1, pk.LS1)
- copy(pk.CS2, pk.LS2)
- copy(pk.CS3, pk.LS3)
- pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0)
- fft.BitReverse(pk.CS1)
- fft.BitReverse(pk.CS2)
- fft.BitReverse(pk.CS3)
+ pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF)
+ fft.BitReverse(pk.S1Canonical)
+ fft.BitReverse(pk.S2Canonical)
+ fft.BitReverse(pk.S3Canonical)
+
+ // evaluation of permutation on the big domain
+ pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality)
+ copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true)
+
+}
+
+// getIDSmallDomain returns the Lagrange form of ID on the small domain
+func getIDSmallDomain(domain *fft.Domain) []fr.Element {
+
+ res := make([]fr.Element, 3*domain.Cardinality)
+
+ res[0].SetOne()
+ res[domain.Cardinality].Set(&domain.FrMultiplicativeGen)
+ res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen)
+
+ for i := uint64(1); i < domain.Cardinality; i++ {
+ res[i].Mul(&res[i-1], &domain.Generator)
+ res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator)
+ res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator)
+ }
+ return res
}
-// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS
+// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS
//
// This should be used after deserializing a ProvingKey
// as pk.Vk.KZG is NOT serialized
diff --git a/internal/backend/bls24-315/plonk/verify.go b/internal/backend/bls24-315/plonk/verify.go
index e8d9fa52b5..ae08037e77 100644
--- a/internal/backend/bls24-315/plonk/verify.go
+++ b/internal/backend/bls24-315/plonk/verify.go
@@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne
hFunc := sha256.New()
// transcript to derive the challenge
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// derive gamma from Comm(l), Comm(r), Comm(o)
gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2])
@@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne
return err
}
+ // derive beta from Comm(l), Comm(r), Comm(o)
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return err
+ }
+
// derive alpha from Comm(l), Comm(r), Comm(o), Com(Z)
alpha, err := deriveRandomness(&fs, "alpha", &proof.Z)
if err != nil {
@@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne
return err
}
- // evaluation of Z=X**m-1 at zeta
+ // evaluation of Z=Xⁿ⁻¹ at ζ
var zetaPowerM, zzeta fr.Element
var bExpo big.Int
one := fr.One()
@@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bls24_315witness.Witne
zetaPowerM.Exp(zeta, &bExpo)
zzeta.Sub(&zetaPowerM, &one)
- // ccompute PI = Sum_i maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
}
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
}
- }
- // sanity check; ensure all wires are marked as "instantiated"
- if !solution.isValid() {
- panic("solver didn't instantiate all wires")
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
}
- return solution.values, nil
+ return nil
}
// IsSolved returns nil if given witness solves the R1CS and error otherwise
@@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) {
// returns false, nil if there was no wire to solve
// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that
// the constraint is satisfied later.
-func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) {
+func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error {
// the index of the non zero entry shows if L, R or O has an uninstantiated wire
// the content is the ID of the wire non instantiated
@@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
return nil
}
- if err = processLExp(r.L.LinExp, &a, 1); err != nil {
- return
+ if err := processLExp(r.L.LinExp, a, 1); err != nil {
+ return err
}
- if err = processLExp(r.R.LinExp, &b, 2); err != nil {
- return
+ if err := processLExp(r.R.LinExp, b, 2); err != nil {
+ return err
}
- if err = processLExp(r.O.LinExp, &c, 3); err != nil {
- return
+ if err := processLExp(r.O.LinExp, c, 3); err != nil {
+ return err
}
if loc == 0 {
// there is nothing to solve, may happen if we have an assertion
// (ie a constraints that doesn't yield any output)
// or if we solved the unsolved wires with hint functions
- return
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
+ return nil
}
// we compute the wire value and instantiate it
- solved = true
- vID := termToCompute.WireID()
+ wID := termToCompute.WireID()
// solver result
var wire fr.Element
@@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
switch loc {
case 1:
if !b.IsZero() {
- wire.Div(&c, &b).
- Sub(&wire, &a)
- a.Add(&a, &wire)
+ wire.Div(c, b).
+ Sub(&wire, a)
+ a.Add(a, &wire)
} else {
// we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 2:
if !a.IsZero() {
- wire.Div(&c, &a).
- Sub(&wire, &b)
- b.Add(&b, &wire)
+ wire.Div(c, a).
+ Sub(&wire, b)
+ b.Add(b, &wire)
} else {
- // we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 3:
- wire.Mul(&a, &b).
- Sub(&wire, &c)
+ wire.Mul(a, b).
+ Sub(&wire, c)
- c.Add(&c, &wire)
+ c.Add(c, &wire)
}
// wire is the term (coeff * value)
// but in the solution we want to store the value only
// note that in gnark frontend, coeff here is always 1 or -1
cs.divByCoeff(&wire, termToCompute)
- solution.set(vID, wire)
+ solution.set(wID, wire)
- return
+ return nil
}
// GetConstraints return a list of constraint formatted as L⋅R == O
diff --git a/internal/backend/bn254/cs/r1cs_sparse.go b/internal/backend/bn254/cs/r1cs_sparse.go
index b92f7d2273..8eab1f901f 100644
--- a/internal/backend/bn254/cs/r1cs_sparse.go
+++ b/internal/backend/bn254/cs/r1cs_sparse.go
@@ -21,9 +21,12 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/fxamacker/cbor/v2"
"io"
+ "math"
"math/big"
"os"
+ "runtime"
"strings"
+ "sync"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/witness"
@@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
return solution.values, err
}
- defer func() {
- // release memory
- solution.tmpHintsIO = nil
- }()
-
// solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs
copy(solution.values, witness)
for i := 0; i < len(witness); i++ {
@@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
// keep track of the number of wire instantiations we do, for a sanity check to ensure
// we instantiated all wires
- solution.nbSolved += len(witness)
+ solution.nbSolved += uint64(len(witness))
// defer log printing once all solution.values are computed
defer solution.printLogs(opt.LoggerOut, cs.Logs)
@@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
coefficientsNegInv[i].Neg(&coefficientsNegInv[i])
}
- // loop through the constraints to solve the variables
- for i := 0; i < len(cs.Constraints); i++ {
- if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil {
- return solution.values, fmt.Errorf("constraint %d: %w", i, err)
- }
- if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil {
- errMsg := err.Error()
- if dID, ok := cs.MDebug[i]; ok {
- errMsg = solution.logValue(cs.DebugInfo[dID])
- }
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
- }
+ if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil {
+ return solution.values, err
}
// sanity check; ensure all wires are marked as "instantiated"
@@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
}
+func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error {
+ // minWorkPerCPU is the minimum target number of constraint a task should hold
+ // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed
+ // sequentially without sync.
+ const minWorkPerCPU = 50.0
+
+ // cs.Levels has a list of levels, where all constraints in a level l(n) are independent
+ // and may only have dependencies on previous levels
+
+ var wg sync.WaitGroup
+ chTasks := make(chan []int, runtime.NumCPU())
+ chError := make(chan error, runtime.NumCPU())
+
+ // start a worker pool
+ // each worker wait on chTasks
+ // a task is a slice of constraint indexes to be solved
+ for i := 0; i < runtime.NumCPU(); i++ {
+ go func() {
+ for t := range chTasks {
+ for _, i := range t {
+ // for each constraint in the task, solve it.
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ wg.Done()
+ return
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ wg.Done()
+ return
+ }
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ // clean up pool go routines
+ defer func() {
+ close(chTasks)
+ close(chError)
+ }()
+
+ // for each level, we push the tasks
+ for _, level := range cs.Levels {
+
+ // max CPU to use
+ maxCPU := float64(len(level)) / minWorkPerCPU
+
+ if maxCPU <= 1.0 {
+ // we do it sequentially
+ for _, i := range level {
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ return fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ }
+ }
+ continue
+ }
+
+ // number of tasks for this level is set to num cpus
+ // but if we don't have enough work for all our CPUS, it can be lower.
+ nbTasks := runtime.NumCPU()
+ maxTasks := int(math.Ceil(maxCPU))
+ if nbTasks > maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
+ }
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
+ }
+
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
+ }
+
+ return nil
+}
+
// computeHints computes wires associated with a hint function, if any
// if there is no remaining wire to solve, returns -1
// else returns the wire position (L -> 0, R -> 1, O -> 2)
diff --git a/internal/backend/bn254/cs/solution.go b/internal/backend/bn254/cs/solution.go
index 323c472df4..46cb2eb6af 100644
--- a/internal/backend/bn254/cs/solution.go
+++ b/internal/backend/bn254/cs/solution.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"math/big"
+ "sync/atomic"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend/schema"
@@ -32,14 +33,15 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bn254"
)
+var errUnsatisfiedConstraint = errors.New("unsatisfied")
+
// solution represents elements needed to compute
// a solution to a R1CS or SparseR1CS
type solution struct {
values, coefficients []fr.Element
solved []bool
- nbSolved int
+ nbSolved uint64
mHintsFunctions map[hint.ID]hint.Function
- tmpHintsIO []*big.Int
}
func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) {
@@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E
coefficients: coefficients,
solved: make([]bool, nbWires),
mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)),
- tmpHintsIO: make([]*big.Int, 0),
}
for _, h := range hintFunctions {
@@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) {
}
s.values[id] = value
s.solved[id] = true
- s.nbSolved++
+ atomic.AddUint64(&s.nbSolved, 1)
+ // s.nbSolved++
}
func (s *solution) isValid() bool {
- return s.nbSolved == len(s.values)
+ return int(s.nbSolved) == len(s.values)
}
// computeTerm computes coef*variable
@@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error {
// tmp IO big int memory
nbInputs := len(h.Inputs)
nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs))
- m := len(s.tmpHintsIO)
- if m < (nbInputs + nbOutputs) {
- s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...)
- for i := m; i < len(s.tmpHintsIO); i++ {
- s.tmpHintsIO[i] = big.NewInt(0)
- }
+ // m := len(s.tmpHintsIO)
+ // if m < (nbInputs + nbOutputs) {
+ // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...)
+ // for i := m; i < len(s.tmpHintsIO); i++ {
+ // s.tmpHintsIO[i] = big.NewInt(0)
+ // }
+ // }
+ inputs := make([]*big.Int, nbInputs)
+ outputs := make([]*big.Int, nbOutputs)
+ for i := 0; i < nbInputs; i++ {
+ inputs[i] = big.NewInt(0)
+ }
+ for i := 0; i < nbOutputs; i++ {
+ outputs[i] = big.NewInt(0)
}
- inputs := s.tmpHintsIO[:nbInputs]
- outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs]
q := fr.Modulus()
diff --git a/internal/backend/bn254/groth16/marshal_test.go b/internal/backend/bn254/groth16/marshal_test.go
index d8f8c38496..2db9e45415 100644
--- a/internal/backend/bn254/groth16/marshal_test.go
+++ b/internal/backend/bn254/groth16/marshal_test.go
@@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) {
var pk, pkCompressed, pkRaw ProvingKey
// create a random pk
- domain := fft.NewDomain(8, 1, true)
+ domain := fft.NewDomain(8)
pk.Domain = *domain
nbWires := 6
diff --git a/internal/backend/bn254/groth16/prove.go b/internal/backend/bn254/groth16/prove.go
index 87bb23c383..f4a929f426 100644
--- a/internal/backend/bn254/groth16/prove.go
+++ b/internal/backend/bn254/groth16/prove.go
@@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
c = append(c, padding...)
n = len(a)
- domain.FFTInverse(a, fft.DIF, 0)
- domain.FFTInverse(b, fft.DIF, 0)
- domain.FFTInverse(c, fft.DIF, 0)
+ domain.FFTInverse(a, fft.DIF)
+ domain.FFTInverse(b, fft.DIF)
+ domain.FFTInverse(c, fft.DIF)
- domain.FFT(a, fft.DIT, 1)
- domain.FFT(b, fft.DIT, 1)
- domain.FFT(c, fft.DIT, 1)
+ domain.FFT(a, fft.DIT, true)
+ domain.FFT(b, fft.DIT, true)
+ domain.FFT(c, fft.DIT, true)
- var minusTwoInv fr.Element
- minusTwoInv.SetUint64(2)
- minusTwoInv.Neg(&minusTwoInv).
- Inverse(&minusTwoInv)
+ var den, one fr.Element
+ one.SetOne()
+ den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
+ den.Sub(&den, &one).Inverse(&den)
// h = ifft_coset(ca o cb - cc)
// reusing a to avoid unecessary memalloc
@@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
for i := start; i < end; i++ {
a[i].Mul(&a[i], &b[i]).
Sub(&a[i], &c[i]).
- Mul(&a[i], &minusTwoInv)
+ Mul(&a[i], &den)
}
})
// ifft_coset
- domain.FFTInverse(a, fft.DIF, 1)
+ domain.FFTInverse(a, fft.DIF, true)
utils.Parallelize(len(a), func(start, end int) {
for i := start; i < end; i++ {
diff --git a/internal/backend/bn254/groth16/setup.go b/internal/backend/bn254/groth16/setup.go
index 95cccddf80..334461b25b 100644
--- a/internal/backend/bn254/groth16/setup.go
+++ b/internal/backend/bn254/groth16/setup.go
@@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error {
nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables
// Setting group for fft
- domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true)
+ domain := fft.NewDomain(uint64(len(r1cs.Constraints)))
// samples toxic waste
toxicWaste, err := sampleToxicWaste()
@@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error {
nbConstraints := len(r1cs.Constraints)
// Setting group for fft
- domain := fft.NewDomain(uint64(nbConstraints), 1, true)
+ domain := fft.NewDomain(uint64(nbConstraints))
// count number of infinity points we would have had we a normal setup
// in pk.G1.A, pk.G1.B, and pk.G2.B
diff --git a/internal/backend/bn254/plonk/marshal.go b/internal/backend/bn254/plonk/marshal.go
index 6012cb114b..e4c4d7059f 100644
--- a/internal/backend/bn254/plonk/marshal.go
+++ b/internal/backend/bn254/plonk/marshal.go
@@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
}
// fft domains
- n2, err := pk.DomainNum.WriteTo(w)
+ n2, err := pk.Domain[0].WriteTo(w)
if err != nil {
return
}
n += n2
- n2, err = pk.DomainH.WriteTo(w)
+ n2, err = pk.Domain[1].WriteTo(w)
if err != nil {
return
}
n += n2
- // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality)
- if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) {
+ // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality)
+ if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) {
return n, errors.New("invalid permutation size, expected 3*domain cardinality")
}
@@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
([]fr.Element)(pk.Qo),
([]fr.Element)(pk.CQk),
([]fr.Element)(pk.LQk),
- ([]fr.Element)(pk.LS1),
- ([]fr.Element)(pk.LS2),
- ([]fr.Element)(pk.LS3),
- ([]fr.Element)(pk.CS1),
- ([]fr.Element)(pk.CS2),
- ([]fr.Element)(pk.CS3),
+ ([]fr.Element)(pk.S1Canonical),
+ ([]fr.Element)(pk.S2Canonical),
+ ([]fr.Element)(pk.S3Canonical),
pk.Permutation,
}
@@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
- n2, err := pk.DomainNum.ReadFrom(r)
+ n2, err := pk.Domain[0].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- n2, err = pk.DomainH.ReadFrom(r)
+ n2, err = pk.Domain[1].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
dec := curve.NewDecoder(r)
toDecode := []interface{}{
@@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
(*[]fr.Element)(&pk.Qo),
(*[]fr.Element)(&pk.CQk),
(*[]fr.Element)(&pk.LQk),
- (*[]fr.Element)(&pk.LS1),
- (*[]fr.Element)(&pk.LS2),
- (*[]fr.Element)(&pk.LS3),
- (*[]fr.Element)(&pk.CS1),
- (*[]fr.Element)(&pk.CS2),
- (*[]fr.Element)(&pk.CS3),
+ (*[]fr.Element)(&pk.S1Canonical),
+ (*[]fr.Element)(&pk.S2Canonical),
+ (*[]fr.Element)(&pk.S3Canonical),
&pk.Permutation,
}
@@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) {
&vk.SizeInv,
&vk.Generator,
vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
@@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) {
&vk.SizeInv,
&vk.Generator,
&vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
diff --git a/internal/backend/bn254/plonk/marshal_test.go b/internal/backend/bn254/plonk/marshal_test.go
index f17f8ca756..ceec7305b0 100644
--- a/internal/backend/bn254/plonk/marshal_test.go
+++ b/internal/backend/bn254/plonk/marshal_test.go
@@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
@@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) {
// random pk
var pk ProvingKey
pk.Vk = &vk
- pk.DomainNum = *fft.NewDomain(42, 3, false)
- pk.DomainH = *fft.NewDomain(4*42, 1, false)
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Domain[0] = *fft.NewDomain(42)
+ pk.Domain[1] = *fft.NewDomain(4 * 42)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < 12; i++ {
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) {
pk.Qo[i].SetUint64(42)
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
pk.Permutation[0] = -12
pk.Permutation[len(pk.Permutation)-1] = 8888
@@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
diff --git a/internal/backend/bn254/plonk/prove.go b/internal/backend/bn254/plonk/prove.go
index 25b24887ed..126277d13e 100644
--- a/internal/backend/bn254/plonk/prove.go
+++ b/internal/backend/bn254/plonk/prove.go
@@ -27,8 +27,6 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bn254"
- "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial"
-
"github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg"
"github.com/consensys/gnark-crypto/ecc/bn254/fr/fft"
@@ -43,6 +41,7 @@ import (
)
type Proof struct {
+
// Commitments to the solution vectors
LRO [3]kzg.Digest
@@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
hFunc := sha256.New()
// create a transcript manager to apply Fiat Shamir
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// result
proof := &Proof{}
@@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
}
// query l, r, o in Lagrange basis, not blinded
- ll, lr, lo := computeLRO(spr, pk, solution)
+ evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution)
// save ll, lr, lo, and make a copy of them in canonical basis.
// note that we allocate more capacity to reuse for blinded polynomials
- bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum)
+ blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ &pk.Domain[0])
if err != nil {
return nil, err
}
// compute kzg commitments of bcl, bcr and bco
- if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
return nil, err
}
+ // Fiat Shamir this
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return nil, err
+ }
+
// compute Z, the permutation accumulator polynomial, in canonical basis
// ll, lr, lo are NOT blinded
- var bz polynomial.Polynomial
+ var blindedZCanonical []fr.Element
chZ := make(chan error, 1)
var alpha fr.Element
go func() {
var err error
- bz, err = computeBlindedZ(ll, lr, lo, pk, gamma)
+ blindedZCanonical, err = computeBlindedZCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ pk, beta, gamma)
if err != nil {
chZ <- err
close(chZ)
@@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
// this may add additional arithmetic operations, but with smaller tasks
// we ensure that this commitment is well parallelized, without having a "unbalanced task" making
// the rest of the code wait too long.
- if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
+ if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
chZ <- err
close(chZ)
return
@@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
}()
// evaluation of the blinded versions of l, r, o and bz
- // on the odd cosets of (Z/8mZ)/(Z/mZ)
- var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial
+ // on the coset of the big domain
+ var (
+ evaluationBlindedLDomainBigBitReversed []fr.Element
+ evaluationBlindedRDomainBigBitReversed []fr.Element
+ evaluationBlindedODomainBigBitReversed []fr.Element
+ evaluationBlindedZDomainBigBitReversed []fr.Element
+ )
chEvalBL := make(chan struct{}, 1)
chEvalBR := make(chan struct{}, 1)
chEvalBO := make(chan struct{}, 1)
go func() {
- evalBL = evaluateHDomain(bcl, &pk.DomainH)
+ evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1])
close(chEvalBL)
}()
go func() {
- evalBR = evaluateHDomain(bcr, &pk.DomainH)
+ evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1])
close(chEvalBR)
}()
go func() {
- evalBO = evaluateHDomain(bco, &pk.DomainH)
+ evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1])
close(chEvalBO)
}()
- var constraintsInd, constraintsOrdering polynomial.Polynomial
+ var constraintsInd, constraintsOrdering []fr.Element
chConstraintInd := make(chan struct{}, 1)
go func() {
// compute qk in canonical basis, completed with the public inputs
- qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality)
- copy(qk, fullWitness[:spr.NbPublicVariables])
- copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
- pk.DomainNum.FFTInverse(qk, fft.DIF, 0)
- fft.BitReverse(qk)
-
- // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ)
- // --> uses the blinded version of l, r, o
+ qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality)
+ copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables])
+ copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
+ pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF)
+ fft.BitReverse(qkCompletedCanonical)
+
+ // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain
+ // → uses the blinded version of l, r, o
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk)
+ constraintsInd = evaluateConstraintsDomainBigBitReversed(
+ pk,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ qkCompletedCanonical)
close(chConstraintInd)
}()
@@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
chConstraintOrdering <- err
return
}
- evalBZ = evaluateHDomain(bz, &pk.DomainH)
- // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ)
+
+ evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1])
+ // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain
// evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o.
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma)
+ constraintsOrdering = evaluateOrderingDomainBigBitReversed(
+ pk,
+ evaluationBlindedZDomainBigBitReversed,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ beta,
+ gamma)
chConstraintOrdering <- nil
close(chConstraintOrdering)
}()
@@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
if err := <-chConstraintOrdering; err != nil {
return nil, err
}
+
<-chConstraintInd
+
// compute h in canonical form
- h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha)
+ h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha)
// compute kzg commitments of h1, h2 and h3
- if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
var wgZetaEvals sync.WaitGroup
wgZetaEvals.Add(3)
go func() {
- blzeta = bcl.Eval(&zeta)
+ blzeta = eval(blindedLCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- brzeta = bcr.Eval(&zeta)
+ brzeta = eval(blindedRCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- bozeta = bco.Eval(&zeta)
+ bozeta = eval(blindedOCanonical, zeta)
wgZetaEvals.Done()
}()
@@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
var zetaShifted fr.Element
zetaShifted.Mul(&zeta, &pk.Vk.Generator)
proof.ZShiftedOpening, err = kzg.Open(
- bz,
- &zetaShifted,
- &pk.DomainH,
+ blindedZCanonical,
+ zetaShifted,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
bzuzeta := proof.ZShiftedOpening.ClaimedValue
var (
- linearizedPolynomial polynomial.Polynomial
- linearizedPolynomialDigest curve.G1Affine
- errLPoly error
+ linearizedPolynomialCanonical []fr.Element
+ linearizedPolynomialDigest curve.G1Affine
+ errLPoly error
)
chLpoly := make(chan struct{}, 1)
go func() {
// compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k)
wgZetaEvals.Wait()
- linearizedPolynomial = computeLinearizedPolynomial(
+ linearizedPolynomialCanonical = computeLinearizedPolynomial(
blzeta,
brzeta,
bozeta,
alpha,
+ beta,
gamma,
zeta,
bzuzeta,
- bz,
+ blindedZCanonical,
pk,
)
// TODO this commitment is only necessary to derive the challenge, we should
// be able to avoid doing it and get the challenge in another way
- linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS)
+ linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS)
close(chLpoly)
}()
- // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3)
+ // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3)
var bZetaPowerm, bSize big.Int
- bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
+ bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
var zetaPowerm fr.Element
zetaPowerm.Exp(zeta, &bSize)
zetaPowerm.ToBigIntRegular(&bZetaPowerm)
foldedHDigest := proof.H[2]
foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm)
- foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3)
- foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2)
- foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3)
+ foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1)
- // foldedH = h1 + zeta*h2 + zeta**2*h3
+ // foldedH = h1 + ζ*h2 + ζ²*h3
foldedH := h3
utils.Parallelize(len(foldedH), func(start, end int) {
for i := start; i < end; i++ {
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3
- foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1)
- foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3
+ foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺²
+ foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1
}
})
@@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
// Batch open the first list of polynomials
proof.BatchedProof, err = kzg.BatchOpenSinglePoint(
- []polynomial.Polynomial{
+ [][]fr.Element{
foldedH,
- linearizedPolynomial,
- bcl,
- bcr,
- bco,
- pk.CS1,
- pk.CS2,
+ linearizedPolynomialCanonical,
+ blindedLCanonical,
+ blindedRCanonical,
+ blindedOCanonical,
+ pk.S1Canonical,
+ pk.S2Canonical,
},
[]kzg.Digest{
foldedHDigest,
@@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
pk.Vk.S[0],
pk.Vk.S[1],
},
- &zeta,
+ zeta,
hFunc,
- &pk.DomainH,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bn254witness.Witness,
}
+// eval evaluates c at p
+func eval(c []fr.Element, p fr.Element) fr.Element {
+ var r fr.Element
+ for i := len(c) - 1; i >= 0; i-- {
+ r.Mul(&r, &p).Add(&r, &c[i])
+ }
+ return r
+}
+
// fills proof.LRO with kzg commits of bcl, bcr and bco
-func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS
return err1
}
-func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err
return err1
}
-// computeBlindedLRO l, r, o in canonical basis with blinding
-func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) {
+// computeBlindedLROCanonical l, r, o in canonical basis with blinding
+func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) {
// note that bcl, bcr and bco reuses cl, cr and co memory
- cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
+ cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
chDone := make(chan error, 2)
go func() {
var err error
copy(cl, ll)
- domain.FFTInverse(cl, fft.DIF, 0)
+ domain.FFTInverse(cl, fft.DIF)
fft.BitReverse(cl)
bcl, err = blindPoly(cl, domain.Cardinality, 1)
chDone <- err
@@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
go func() {
var err error
copy(cr, lr)
- domain.FFTInverse(cr, fft.DIF, 0)
+ domain.FFTInverse(cr, fft.DIF)
fft.BitReverse(cr)
bcr, err = blindPoly(cr, domain.Cardinality, 1)
chDone <- err
}()
copy(co, lo)
- domain.FFTInverse(co, fft.DIF, 0)
+ domain.FFTInverse(co, fft.DIF)
fft.BitReverse(co)
if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil {
return
@@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
// * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1)
//
// WARNING:
-// pre condition degree(cp) <= rou + bo
-// pre condition cap(cp) >= int(totalDegree + 1)
-func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) {
+// pre condition degree(cp) ⩽ rou + bo
+// pre condition cap(cp) ⩾ int(totalDegree + 1)
+func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) {
// degree of the blinded polynomial is max(rou+order, cp.Degree)
totalDegree := rou + bo
@@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
res := cp[:totalDegree+1]
// random polynomial
- blindingPoly := make(polynomial.Polynomial, bo+1)
+ blindingPoly := make([]fr.Element, bo+1)
for i := uint64(0); i < bo+1; i++ {
if _, err := blindingPoly[i].SetRandom(); err != nil {
return nil, err
@@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
}
return res, nil
+
}
-// computeLRO extracts the solution l, r, o, and returns it in lagrange form.
+// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form.
// solution = [ public | secret | internal ]
-func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) {
+func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) {
- s := int(pk.DomainNum.Cardinality)
+ s := int(pk.Domain[0].Cardinality)
- var l, r, o polynomial.Polynomial
+ var l, r, o []fr.Element
l = make([]fr.Element, s)
r = make([]fr.Element, s)
o = make([]fr.Element, s)
@@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly
//
// * Z of degree n (domainNum.Cardinality)
// * Z(1)=1
-// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma)
-// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1
- u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
}
})
@@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele
Mul(&z[i], &gInv[i])
}
- pk.DomainNum.FFTInverse(z, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(z, fft.DIF)
fft.BitReverse(z)
- return blindPoly(z, pk.DomainNum.Cardinality, 2)
+ return blindPoly(z, pk.Domain[0].Cardinality, 2)
}
-// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on
-// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on
+// the big domain coset.
//
// * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets
// * qk is the completed version of qk, in canonical version
-func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
- var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial
+func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
+ var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element
var wg sync.WaitGroup
wg.Add(4)
go func() {
- evalQl = evaluateHDomain(pk.Ql, &pk.DomainH)
+ evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQr = evaluateHDomain(pk.Qr, &pk.DomainH)
+ evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQm = evaluateHDomain(pk.Qm, &pk.DomainH)
+ evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQo = evaluateHDomain(pk.Qo, &pk.DomainH)
+ evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1])
wg.Done()
}()
- evalQk = evaluateHDomain(qk, &pk.DomainH)
+ evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1])
wg.Wait()
- // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets
- // of (Z/8mZ)/(Z/mZ)
+
+ // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain
utils.Parallelize(len(evalQk), func(start, end int) {
var t0, t1 fr.Element
for i := start; i < end; i++ {
@@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.
return evalQk
}
-// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ)
-func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) {
-
- id = make([]fr.Element, pk.DomainH.Cardinality)
-
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
- var acc fr.Element
- acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start)))
- for i := start; i < end; i++ {
- id[i].Mul(&acc, &pk.DomainH.FinerGenerator)
- acc.Mul(&acc, &pk.DomainH.Generator)
- }
- })
-
- return id
-}
-
-// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
-// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
+// cosets of the big domain.
//
-// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets
-// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets
+// * z evaluation of the blinded permutation accumulator polynomial on odd cosets
+// * l, r, o evaluation of the blinded solution vectors on odd cosets
// * gamma randomization
-func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial {
+func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element {
- // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ)
- evalID := evalIDCosets(pk)
+ nbElmts := int(pk.Domain[1].Cardinality)
- // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ)
- var wg sync.WaitGroup
- wg.Add(2)
- var evalS1, evalS2, evalS3 polynomial.Polynomial
- go func() {
- evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH)
- wg.Done()
- }()
- go func() {
- evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH)
- wg.Done()
- }()
- evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH)
- wg.Wait()
+ // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ)
+ // on the big domain (coset).
+ res := make([]fr.Element, pk.Domain[1].Cardinality)
- // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ)
- res := evalS1 // re use allocated memory for evalS1
- s := uint64(len(evalZ))
- nn := uint64(64 - bits.TrailingZeros64(uint64(s)))
+ nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts)))
// needed to shift evalZ
- toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality
+ toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality)
+
+ var cosetShift, cosetShiftSquare fr.Element
+ cosetShift.Set(&pk.Vk.CosetShift)
+ cosetShiftSquare.Square(&pk.Vk.CosetShift)
+
+ utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) {
+
+ var evaluationIDBigDomain fr.Element
+ evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))).
+ Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen)
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
var f [3]fr.Element
var g [3]fr.Element
- var eID fr.Element
for i := start; i < end; i++ {
- // here we want to left shift evalZ by domainH/domainNum
- // however, evalZ is permuted
- // we take the non permuted index
- // compute the corresponding shift position
- // permute it again
- irev := bits.Reverse64(uint64(i)) >> nn
- eID = evalID[irev]
+ _i := bits.Reverse64(uint64(i)) >> nn
+ _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn
- shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn
- //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn
+ // in what follows gⁱ is understood as the generator of the chosen coset of domainBig
+ f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ
+ f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ
+ f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ
- f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma
- f[1].Mul(&eID, &pk.Vk.Shifter[0])
- f[2].Mul(&eID, &pk.Vk.Shifter[1])
- f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma
- f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma
+ g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ
+ g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ
+ g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ
- g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma
- g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma
- g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma
+ f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
+ g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ)
- f[0].Mul(&f[0], &f[1]).
- Mul(&f[0], &f[2]).
- Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma)
+ res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
- g[0].Mul(&g[0], &g[1]).
- Mul(&g[0], &g[2]).
- Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma)
-
- res[i].Sub(&g[0], &f[0])
+ evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g
}
})
return res
}
-// evaluateHDomain evaluates poly (canonical form) of degree m> nn
- // h[i].Mul(&h[i], &_u[irev%4])
- h[i].Mul(&h[i], &_u[irev%toShift])
+
+ _i := bits.Reverse64(i) >> nn
+
+ t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain
+ h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t).
+ Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]).
+ Mul(&h[_i], &alpha).
+ Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]).
+ Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio])
}
})
// put h in canonical form. h is of degree 3*(n+1)+2.
// using fft.DIT put h revert bit reverse
- pk.DomainH.FFTInverse(h, fft.DIT, 1)
- // fmt.Println("h:")
- // for i := 0; i < len(h); i++ {
- // fmt.Printf("%s\n", h[i].String())
- // }
- // fmt.Println("")
+ pk.Domain[1].FFTInverse(h, fft.DIT, true)
// degree of hi is n+2 because of the blinding
- h1 := h[:pk.DomainNum.Cardinality+2]
- h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)]
- h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)]
+ h1 := h[:pk.Domain[0].Cardinality+2]
+ h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)]
+ h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)]
return h1, h2, h3
@@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom
// computeLinearizedPolynomial computes the linearized polynomial in canonical basis.
// The purpose is to commit and open all in one ql, qr, qm, qo, qk.
-// * a, b, c are the evaluation of l, r, o at zeta
-// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z
+// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta
+// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z
// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk.
-func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial {
+//
+// The Linearized polynomial is:
+//
+// α²*L₁(ζ)*Z(X)
+// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ))
+// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X)
+func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element {
// first part: individual constraints
var rl fr.Element
- rl.Mul(&r, &l)
+ rl.Mul(&rZeta, &lZeta)
- // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
+ // second part:
+ // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)
var s1, s2 fr.Element
chS1 := make(chan struct{}, 1)
go func() {
- s1 = pk.CS1.Eval(&zeta)
- s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma)
+ s1 = eval(pk.S1Canonical, zeta) // s1(ζ)
+ s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ)
close(chS1)
}()
- t := pk.CS2.Eval(&zeta)
- t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma)
+ tmp := eval(pk.S2Canonical, zeta) // s2(ζ)
+ tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ)
<-chS1
- s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma)
- Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)
-
- s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma)
- t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)
- t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
- s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
-
- // third part L1(zeta)*alpha**2**Z
- var lagrange, one, den, frNbElmt fr.Element
+ s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ)
+
+ var uzeta, uuzeta fr.Element
+ uzeta.Mul(&zeta, &pk.Vk.CosetShift)
+ uuzeta.Mul(&uzeta, &pk.Vk.CosetShift)
+
+ s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ)
+ tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)
+ tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+ s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ // third part L₁(ζ)*α²*Z
+ var lagrangeZeta, one, den, frNbElmt fr.Element
one.SetOne()
- nbElmt := int64(pk.DomainNum.Cardinality)
- lagrange.Set(&zeta).
- Exp(lagrange, big.NewInt(nbElmt)).
- Sub(&lagrange, &one)
+ nbElmt := int64(pk.Domain[0].Cardinality)
+ lagrangeZeta.Set(&zeta).
+ Exp(lagrangeZeta, big.NewInt(nbElmt)).
+ Sub(&lagrangeZeta, &one)
frNbElmt.SetUint64(uint64(nbElmt))
den.Sub(&zeta, &one).
- Mul(&den, &frNbElmt).
Inverse(&den)
- lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1)
- Mul(&lagrange, &alpha).
- Mul(&lagrange, &alpha) // alpha**2*L_0
+ lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1)
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ)
- linPol := z.Clone()
+ linPol := make([]fr.Element, len(blindedZCanonical))
+ copy(linPol, blindedZCanonical)
utils.Parallelize(len(linPol), func(start, end int) {
+
var t0, t1 fr.Element
+
for i := start; i < end; i++ {
- linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
- if i < len(pk.CS3) {
- t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X)
+
+ linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ if i < len(pk.S3Canonical) {
+
+ t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X)
+
linPol[i].Add(&linPol[i], &t0)
}
- linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) )
+ linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ))
if i < len(pk.Qm) {
- t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm
- t0.Mul(&pk.Ql[i], &l)
+
+ t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X)
+ t0.Mul(&pk.Ql[i], &lZeta)
t0.Add(&t0, &t1)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X)
- t0.Mul(&pk.Qr[i], &r)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr
+ t0.Mul(&pk.Qr[i], &rZeta)
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X)
- t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i])
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk
+ t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i])
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X)
}
- t0.Mul(&z[i], &lagrange)
+ t0.Mul(&blindedZCanonical[i], &lagrangeZeta)
linPol[i].Add(&linPol[i], &t0) // finish the computation
}
})
diff --git a/internal/backend/bn254/plonk/setup.go b/internal/backend/bn254/plonk/setup.go
index b7b8d41869..421f035bd0 100644
--- a/internal/backend/bn254/plonk/setup.go
+++ b/internal/backend/bn254/plonk/setup.go
@@ -21,7 +21,6 @@ import (
"github.com/consensys/gnark-crypto/ecc/bn254/fr"
"github.com/consensys/gnark-crypto/ecc/bn254/fr/fft"
"github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg"
- "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial"
"github.com/consensys/gnark/internal/backend/bn254/cs"
kzgg "github.com/consensys/gnark-crypto/kzg"
@@ -40,18 +39,21 @@ type ProvingKey struct {
Vk *VerifyingKey
// qr,ql,qm,qo (in canonical basis).
- Ql, Qr, Qm, Qo polynomial.Polynomial
+ Ql, Qr, Qm, Qo []fr.Element
// LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs.
// Storing LQk in Lagrange basis saves a fft...
- CQk, LQk polynomial.Polynomial
+ CQk, LQk []fr.Element
- // Domains used for the FFTs
- DomainNum, DomainH fft.Domain
+ // Domains used for the FFTs.
+ // Domain[0] = small Domain
+ // Domain[1] = big Domain
+ Domain [2]fft.Domain
+ // Domain[0], Domain[1] fft.Domain
- // s1, s2, s3 (L=Lagrange basis, C=canonical basis)
- LS1, LS2, LS3 polynomial.Polynomial
- CS1, CS2, CS3 polynomial.Polynomial
+ // Permutation polynomials
+ EvaluationPermutationBigDomainBitReversed []fr.Element
+ S1Canonical, S2Canonical, S3Canonical []fr.Element
// position -> permuted position (position in [0,3*sizeSystem-1])
Permutation []int64
@@ -69,13 +71,12 @@ type VerifyingKey struct {
Generator fr.Element
NbPublicVariables uint64
- // shifters for extending the permutation set: from s=<1,z,..,z**n-1>,
- // extended domain = s || shifter[0].s || shifter[1].s
- Shifter [2]fr.Element
-
// Commitment scheme that is used for an instantiation of PLONK
KZGSRS *kzg.SRS
+ // cosetShift generator of the coset on the small domain
+ CosetShift fr.Element
+
// S commitments to S1, S2, S3
S [3]kzg.Digest
@@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
// fft domains
sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints
- pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false)
+ pk.Domain[0] = *fft.NewDomain(sizeSystem)
+ pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen)
// h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space,
// the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases
// except when n<6.
if sizeSystem < 6 {
- pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(8 * sizeSystem)
} else {
- pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(4 * sizeSystem)
}
- vk.Size = pk.DomainNum.Cardinality
+ vk.Size = pk.Domain[0].Cardinality
vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv)
- vk.Generator.Set(&pk.DomainNum.Generator)
+ vk.Generator.Set(&pk.Domain[0].Generator)
vk.NbPublicVariables = uint64(spr.NbPublicVariables)
- // shifters
- vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator)
- vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator)
-
if err := pk.InitKZG(srs); err != nil {
return nil, nil, err
}
// public polynomials corresponding to constraints: [ placholders | constraints | assertions ]
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.Qm[i].SetZero()
pk.Qo[i].SetZero()
pk.CQk[i].SetZero()
- pk.LQk[i].SetZero() // --> to be completed by the prover
+ pk.LQk[i].SetZero() // → to be completed by the prover
}
offset := spr.NbPublicVariables
for i := 0; i < nbConstraints; i++ { // constraints
@@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K])
}
- pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(pk.Ql, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qr, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qm, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qo, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.CQk, fft.DIF)
fft.BitReverse(pk.Ql)
fft.BitReverse(pk.Qr)
fft.BitReverse(pk.Qm)
@@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
buildPermutation(spr, &pk)
// set s1, s2, s3
- computeLDE(&pk)
+ ccomputePermutationPolynomials(&pk)
// Commit to the polynomials to set up the verifying key
var err error
@@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil {
+ if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil {
+ if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil {
+ if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
@@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
//
// The permutation s is composed of cycles of maximum length such that
//
-// s. (l||r||o) = (l||r||o)
+// s. (l∥r∥o) = (l∥r∥o)
//
-//, where l||r||o is the concatenation of the indices of l, r, o in
+//, where l∥r∥o is the concatenation of the indices of l, r, o in
// ql.l+qr.r+qm.l.r+qo.O+k = 0.
//
// The permutation is encoded as a slice s of size 3*size(l), where the
-// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab
+// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab
// like this: for i in tab: tab[i] = tab[permutation[i]]
func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables
- sizeSolution := int(pk.DomainNum.Cardinality)
+ sizeSolution := int(pk.Domain[0].Cardinality)
// init permutation
pk.Permutation = make([]int64, 3*sizeSolution)
@@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
}
}
-// computeLDE computes the LDE (Lagrange basis) of the permutations
+// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations
// s1, s2, s3.
//
-// ex: z gen of Z/mZ, u gen of Z/8mZ, then
-//
// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 |
// |
// | Permutation
// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v
// \---------------/ \--------------------/ \------------------------/
// s1 (LDE) s2 (LDE) s3 (LDE)
-func computeLDE(pk *ProvingKey) {
+func ccomputePermutationPolynomials(pk *ProvingKey) {
- nbElmt := int(pk.DomainNum.Cardinality)
+ nbElmts := int(pk.Domain[0].Cardinality)
- // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1]
- sID := make([]fr.Element, 3*nbElmt)
- sID[0].SetOne()
- sID[nbElmt].Set(&pk.DomainNum.FinerGenerator)
- sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator)
-
- for i := 1; i < nbElmt; i++ {
- sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1
- sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
- }
+ // Lagrange form of ID
+ evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0])
// Lagrange form of S1, S2, S3
- pk.LS1 = make(polynomial.Polynomial, nbElmt)
- pk.LS2 = make(polynomial.Polynomial, nbElmt)
- pk.LS3 = make(polynomial.Polynomial, nbElmt)
- for i := 0; i < nbElmt; i++ {
- pk.LS1[i].Set(&sID[pk.Permutation[i]])
- pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]])
- pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]])
+ pk.S1Canonical = make([]fr.Element, nbElmts)
+ pk.S2Canonical = make([]fr.Element, nbElmts)
+ pk.S3Canonical = make([]fr.Element, nbElmts)
+ for i := 0; i < nbElmts; i++ {
+ pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]])
+ pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]])
+ pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]])
}
// Canonical form of S1, S2, S3
- pk.CS1 = make(polynomial.Polynomial, nbElmt)
- pk.CS2 = make(polynomial.Polynomial, nbElmt)
- pk.CS3 = make(polynomial.Polynomial, nbElmt)
- copy(pk.CS1, pk.LS1)
- copy(pk.CS2, pk.LS2)
- copy(pk.CS3, pk.LS3)
- pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0)
- fft.BitReverse(pk.CS1)
- fft.BitReverse(pk.CS2)
- fft.BitReverse(pk.CS3)
+ pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF)
+ fft.BitReverse(pk.S1Canonical)
+ fft.BitReverse(pk.S2Canonical)
+ fft.BitReverse(pk.S3Canonical)
+
+ // evaluation of permutation on the big domain
+ pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality)
+ copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true)
+
+}
+
+// getIDSmallDomain returns the Lagrange form of ID on the small domain
+func getIDSmallDomain(domain *fft.Domain) []fr.Element {
+
+ res := make([]fr.Element, 3*domain.Cardinality)
+
+ res[0].SetOne()
+ res[domain.Cardinality].Set(&domain.FrMultiplicativeGen)
+ res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen)
+
+ for i := uint64(1); i < domain.Cardinality; i++ {
+ res[i].Mul(&res[i-1], &domain.Generator)
+ res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator)
+ res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator)
+ }
+ return res
}
-// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS
+// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS
//
// This should be used after deserializing a ProvingKey
// as pk.Vk.KZG is NOT serialized
diff --git a/internal/backend/bn254/plonk/verify.go b/internal/backend/bn254/plonk/verify.go
index 0c5261a51c..564431c6c3 100644
--- a/internal/backend/bn254/plonk/verify.go
+++ b/internal/backend/bn254/plonk/verify.go
@@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness)
hFunc := sha256.New()
// transcript to derive the challenge
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// derive gamma from Comm(l), Comm(r), Comm(o)
gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2])
@@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness)
return err
}
+ // derive beta from Comm(l), Comm(r), Comm(o)
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return err
+ }
+
// derive alpha from Comm(l), Comm(r), Comm(o), Com(Z)
alpha, err := deriveRandomness(&fs, "alpha", &proof.Z)
if err != nil {
@@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness)
return err
}
- // evaluation of Z=X**m-1 at zeta
+ // evaluation of Z=Xⁿ⁻¹ at ζ
var zetaPowerM, zzeta fr.Element
var bExpo big.Int
one := fr.One()
@@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bn254witness.Witness)
zetaPowerM.Exp(zeta, &bExpo)
zzeta.Sub(&zetaPowerM, &one)
- // ccompute PI = Sum_i maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
}
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
}
- }
- // sanity check; ensure all wires are marked as "instantiated"
- if !solution.isValid() {
- panic("solver didn't instantiate all wires")
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
}
- return solution.values, nil
+ return nil
}
// IsSolved returns nil if given witness solves the R1CS and error otherwise
@@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) {
// returns false, nil if there was no wire to solve
// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that
// the constraint is satisfied later.
-func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) {
+func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error {
// the index of the non zero entry shows if L, R or O has an uninstantiated wire
// the content is the ID of the wire non instantiated
@@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
return nil
}
- if err = processLExp(r.L.LinExp, &a, 1); err != nil {
- return
+ if err := processLExp(r.L.LinExp, a, 1); err != nil {
+ return err
}
- if err = processLExp(r.R.LinExp, &b, 2); err != nil {
- return
+ if err := processLExp(r.R.LinExp, b, 2); err != nil {
+ return err
}
- if err = processLExp(r.O.LinExp, &c, 3); err != nil {
- return
+ if err := processLExp(r.O.LinExp, c, 3); err != nil {
+ return err
}
if loc == 0 {
// there is nothing to solve, may happen if we have an assertion
// (ie a constraints that doesn't yield any output)
// or if we solved the unsolved wires with hint functions
- return
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
+ return nil
}
// we compute the wire value and instantiate it
- solved = true
- vID := termToCompute.WireID()
+ wID := termToCompute.WireID()
// solver result
var wire fr.Element
@@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
switch loc {
case 1:
if !b.IsZero() {
- wire.Div(&c, &b).
- Sub(&wire, &a)
- a.Add(&a, &wire)
+ wire.Div(c, b).
+ Sub(&wire, a)
+ a.Add(a, &wire)
} else {
// we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 2:
if !a.IsZero() {
- wire.Div(&c, &a).
- Sub(&wire, &b)
- b.Add(&b, &wire)
+ wire.Div(c, a).
+ Sub(&wire, b)
+ b.Add(b, &wire)
} else {
- // we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 3:
- wire.Mul(&a, &b).
- Sub(&wire, &c)
+ wire.Mul(a, b).
+ Sub(&wire, c)
- c.Add(&c, &wire)
+ c.Add(c, &wire)
}
// wire is the term (coeff * value)
// but in the solution we want to store the value only
// note that in gnark frontend, coeff here is always 1 or -1
cs.divByCoeff(&wire, termToCompute)
- solution.set(vID, wire)
+ solution.set(wID, wire)
- return
+ return nil
}
// GetConstraints return a list of constraint formatted as L⋅R == O
diff --git a/internal/backend/bw6-633/cs/r1cs_sparse.go b/internal/backend/bw6-633/cs/r1cs_sparse.go
index e6154b4227..7cc2ccfad0 100644
--- a/internal/backend/bw6-633/cs/r1cs_sparse.go
+++ b/internal/backend/bw6-633/cs/r1cs_sparse.go
@@ -21,9 +21,12 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/fxamacker/cbor/v2"
"io"
+ "math"
"math/big"
"os"
+ "runtime"
"strings"
+ "sync"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/witness"
@@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
return solution.values, err
}
- defer func() {
- // release memory
- solution.tmpHintsIO = nil
- }()
-
// solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs
copy(solution.values, witness)
for i := 0; i < len(witness); i++ {
@@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
// keep track of the number of wire instantiations we do, for a sanity check to ensure
// we instantiated all wires
- solution.nbSolved += len(witness)
+ solution.nbSolved += uint64(len(witness))
// defer log printing once all solution.values are computed
defer solution.printLogs(opt.LoggerOut, cs.Logs)
@@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
coefficientsNegInv[i].Neg(&coefficientsNegInv[i])
}
- // loop through the constraints to solve the variables
- for i := 0; i < len(cs.Constraints); i++ {
- if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil {
- return solution.values, fmt.Errorf("constraint %d: %w", i, err)
- }
- if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil {
- errMsg := err.Error()
- if dID, ok := cs.MDebug[i]; ok {
- errMsg = solution.logValue(cs.DebugInfo[dID])
- }
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
- }
+ if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil {
+ return solution.values, err
}
// sanity check; ensure all wires are marked as "instantiated"
@@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
}
+func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error {
+ // minWorkPerCPU is the minimum target number of constraint a task should hold
+ // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed
+ // sequentially without sync.
+ const minWorkPerCPU = 50.0
+
+ // cs.Levels has a list of levels, where all constraints in a level l(n) are independent
+ // and may only have dependencies on previous levels
+
+ var wg sync.WaitGroup
+ chTasks := make(chan []int, runtime.NumCPU())
+ chError := make(chan error, runtime.NumCPU())
+
+ // start a worker pool
+ // each worker wait on chTasks
+ // a task is a slice of constraint indexes to be solved
+ for i := 0; i < runtime.NumCPU(); i++ {
+ go func() {
+ for t := range chTasks {
+ for _, i := range t {
+ // for each constraint in the task, solve it.
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ wg.Done()
+ return
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ wg.Done()
+ return
+ }
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ // clean up pool go routines
+ defer func() {
+ close(chTasks)
+ close(chError)
+ }()
+
+ // for each level, we push the tasks
+ for _, level := range cs.Levels {
+
+ // max CPU to use
+ maxCPU := float64(len(level)) / minWorkPerCPU
+
+ if maxCPU <= 1.0 {
+ // we do it sequentially
+ for _, i := range level {
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ return fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ }
+ }
+ continue
+ }
+
+ // number of tasks for this level is set to num cpus
+ // but if we don't have enough work for all our CPUS, it can be lower.
+ nbTasks := runtime.NumCPU()
+ maxTasks := int(math.Ceil(maxCPU))
+ if nbTasks > maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
+ }
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
+ }
+
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
+ }
+
+ return nil
+}
+
// computeHints computes wires associated with a hint function, if any
// if there is no remaining wire to solve, returns -1
// else returns the wire position (L -> 0, R -> 1, O -> 2)
diff --git a/internal/backend/bw6-633/cs/solution.go b/internal/backend/bw6-633/cs/solution.go
index 3f7daa0122..4b611e7f07 100644
--- a/internal/backend/bw6-633/cs/solution.go
+++ b/internal/backend/bw6-633/cs/solution.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"math/big"
+ "sync/atomic"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend/schema"
@@ -32,14 +33,15 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bw6-633"
)
+var errUnsatisfiedConstraint = errors.New("unsatisfied")
+
// solution represents elements needed to compute
// a solution to a R1CS or SparseR1CS
type solution struct {
values, coefficients []fr.Element
solved []bool
- nbSolved int
+ nbSolved uint64
mHintsFunctions map[hint.ID]hint.Function
- tmpHintsIO []*big.Int
}
func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) {
@@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E
coefficients: coefficients,
solved: make([]bool, nbWires),
mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)),
- tmpHintsIO: make([]*big.Int, 0),
}
for _, h := range hintFunctions {
@@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) {
}
s.values[id] = value
s.solved[id] = true
- s.nbSolved++
+ atomic.AddUint64(&s.nbSolved, 1)
+ // s.nbSolved++
}
func (s *solution) isValid() bool {
- return s.nbSolved == len(s.values)
+ return int(s.nbSolved) == len(s.values)
}
// computeTerm computes coef*variable
@@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error {
// tmp IO big int memory
nbInputs := len(h.Inputs)
nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs))
- m := len(s.tmpHintsIO)
- if m < (nbInputs + nbOutputs) {
- s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...)
- for i := m; i < len(s.tmpHintsIO); i++ {
- s.tmpHintsIO[i] = big.NewInt(0)
- }
+ // m := len(s.tmpHintsIO)
+ // if m < (nbInputs + nbOutputs) {
+ // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...)
+ // for i := m; i < len(s.tmpHintsIO); i++ {
+ // s.tmpHintsIO[i] = big.NewInt(0)
+ // }
+ // }
+ inputs := make([]*big.Int, nbInputs)
+ outputs := make([]*big.Int, nbOutputs)
+ for i := 0; i < nbInputs; i++ {
+ inputs[i] = big.NewInt(0)
+ }
+ for i := 0; i < nbOutputs; i++ {
+ outputs[i] = big.NewInt(0)
}
- inputs := s.tmpHintsIO[:nbInputs]
- outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs]
q := fr.Modulus()
diff --git a/internal/backend/bw6-633/groth16/marshal_test.go b/internal/backend/bw6-633/groth16/marshal_test.go
index 1fc47f35de..b2a3a63d4b 100644
--- a/internal/backend/bw6-633/groth16/marshal_test.go
+++ b/internal/backend/bw6-633/groth16/marshal_test.go
@@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) {
var pk, pkCompressed, pkRaw ProvingKey
// create a random pk
- domain := fft.NewDomain(8, 1, true)
+ domain := fft.NewDomain(8)
pk.Domain = *domain
nbWires := 6
diff --git a/internal/backend/bw6-633/groth16/prove.go b/internal/backend/bw6-633/groth16/prove.go
index 45c4688959..b72cf7b41f 100644
--- a/internal/backend/bw6-633/groth16/prove.go
+++ b/internal/backend/bw6-633/groth16/prove.go
@@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
c = append(c, padding...)
n = len(a)
- domain.FFTInverse(a, fft.DIF, 0)
- domain.FFTInverse(b, fft.DIF, 0)
- domain.FFTInverse(c, fft.DIF, 0)
+ domain.FFTInverse(a, fft.DIF)
+ domain.FFTInverse(b, fft.DIF)
+ domain.FFTInverse(c, fft.DIF)
- domain.FFT(a, fft.DIT, 1)
- domain.FFT(b, fft.DIT, 1)
- domain.FFT(c, fft.DIT, 1)
+ domain.FFT(a, fft.DIT, true)
+ domain.FFT(b, fft.DIT, true)
+ domain.FFT(c, fft.DIT, true)
- var minusTwoInv fr.Element
- minusTwoInv.SetUint64(2)
- minusTwoInv.Neg(&minusTwoInv).
- Inverse(&minusTwoInv)
+ var den, one fr.Element
+ one.SetOne()
+ den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
+ den.Sub(&den, &one).Inverse(&den)
// h = ifft_coset(ca o cb - cc)
// reusing a to avoid unecessary memalloc
@@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
for i := start; i < end; i++ {
a[i].Mul(&a[i], &b[i]).
Sub(&a[i], &c[i]).
- Mul(&a[i], &minusTwoInv)
+ Mul(&a[i], &den)
}
})
// ifft_coset
- domain.FFTInverse(a, fft.DIF, 1)
+ domain.FFTInverse(a, fft.DIF, true)
utils.Parallelize(len(a), func(start, end int) {
for i := start; i < end; i++ {
diff --git a/internal/backend/bw6-633/groth16/setup.go b/internal/backend/bw6-633/groth16/setup.go
index 17caeecb6f..b26489abbc 100644
--- a/internal/backend/bw6-633/groth16/setup.go
+++ b/internal/backend/bw6-633/groth16/setup.go
@@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error {
nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables
// Setting group for fft
- domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true)
+ domain := fft.NewDomain(uint64(len(r1cs.Constraints)))
// samples toxic waste
toxicWaste, err := sampleToxicWaste()
@@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error {
nbConstraints := len(r1cs.Constraints)
// Setting group for fft
- domain := fft.NewDomain(uint64(nbConstraints), 1, true)
+ domain := fft.NewDomain(uint64(nbConstraints))
// count number of infinity points we would have had we a normal setup
// in pk.G1.A, pk.G1.B, and pk.G2.B
diff --git a/internal/backend/bw6-633/plonk/marshal.go b/internal/backend/bw6-633/plonk/marshal.go
index 756fa46bf9..8f2ad728cc 100644
--- a/internal/backend/bw6-633/plonk/marshal.go
+++ b/internal/backend/bw6-633/plonk/marshal.go
@@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
}
// fft domains
- n2, err := pk.DomainNum.WriteTo(w)
+ n2, err := pk.Domain[0].WriteTo(w)
if err != nil {
return
}
n += n2
- n2, err = pk.DomainH.WriteTo(w)
+ n2, err = pk.Domain[1].WriteTo(w)
if err != nil {
return
}
n += n2
- // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality)
- if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) {
+ // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality)
+ if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) {
return n, errors.New("invalid permutation size, expected 3*domain cardinality")
}
@@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
([]fr.Element)(pk.Qo),
([]fr.Element)(pk.CQk),
([]fr.Element)(pk.LQk),
- ([]fr.Element)(pk.LS1),
- ([]fr.Element)(pk.LS2),
- ([]fr.Element)(pk.LS3),
- ([]fr.Element)(pk.CS1),
- ([]fr.Element)(pk.CS2),
- ([]fr.Element)(pk.CS3),
+ ([]fr.Element)(pk.S1Canonical),
+ ([]fr.Element)(pk.S2Canonical),
+ ([]fr.Element)(pk.S3Canonical),
pk.Permutation,
}
@@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
- n2, err := pk.DomainNum.ReadFrom(r)
+ n2, err := pk.Domain[0].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- n2, err = pk.DomainH.ReadFrom(r)
+ n2, err = pk.Domain[1].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
dec := curve.NewDecoder(r)
toDecode := []interface{}{
@@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
(*[]fr.Element)(&pk.Qo),
(*[]fr.Element)(&pk.CQk),
(*[]fr.Element)(&pk.LQk),
- (*[]fr.Element)(&pk.LS1),
- (*[]fr.Element)(&pk.LS2),
- (*[]fr.Element)(&pk.LS3),
- (*[]fr.Element)(&pk.CS1),
- (*[]fr.Element)(&pk.CS2),
- (*[]fr.Element)(&pk.CS3),
+ (*[]fr.Element)(&pk.S1Canonical),
+ (*[]fr.Element)(&pk.S2Canonical),
+ (*[]fr.Element)(&pk.S3Canonical),
&pk.Permutation,
}
@@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) {
&vk.SizeInv,
&vk.Generator,
vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
@@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) {
&vk.SizeInv,
&vk.Generator,
&vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
diff --git a/internal/backend/bw6-633/plonk/marshal_test.go b/internal/backend/bw6-633/plonk/marshal_test.go
index ebf373a463..7ace49aef7 100644
--- a/internal/backend/bw6-633/plonk/marshal_test.go
+++ b/internal/backend/bw6-633/plonk/marshal_test.go
@@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
@@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) {
// random pk
var pk ProvingKey
pk.Vk = &vk
- pk.DomainNum = *fft.NewDomain(42, 3, false)
- pk.DomainH = *fft.NewDomain(4*42, 1, false)
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Domain[0] = *fft.NewDomain(42)
+ pk.Domain[1] = *fft.NewDomain(4 * 42)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < 12; i++ {
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) {
pk.Qo[i].SetUint64(42)
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
pk.Permutation[0] = -12
pk.Permutation[len(pk.Permutation)-1] = 8888
@@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
diff --git a/internal/backend/bw6-633/plonk/prove.go b/internal/backend/bw6-633/plonk/prove.go
index e633795585..90ac35a7aa 100644
--- a/internal/backend/bw6-633/plonk/prove.go
+++ b/internal/backend/bw6-633/plonk/prove.go
@@ -27,8 +27,6 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bw6-633"
- "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial"
-
"github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg"
"github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft"
@@ -43,6 +41,7 @@ import (
)
type Proof struct {
+
// Commitments to the solution vectors
LRO [3]kzg.Digest
@@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
hFunc := sha256.New()
// create a transcript manager to apply Fiat Shamir
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// result
proof := &Proof{}
@@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
}
// query l, r, o in Lagrange basis, not blinded
- ll, lr, lo := computeLRO(spr, pk, solution)
+ evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution)
// save ll, lr, lo, and make a copy of them in canonical basis.
// note that we allocate more capacity to reuse for blinded polynomials
- bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum)
+ blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ &pk.Domain[0])
if err != nil {
return nil, err
}
// compute kzg commitments of bcl, bcr and bco
- if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
return nil, err
}
+ // Fiat Shamir this
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return nil, err
+ }
+
// compute Z, the permutation accumulator polynomial, in canonical basis
// ll, lr, lo are NOT blinded
- var bz polynomial.Polynomial
+ var blindedZCanonical []fr.Element
chZ := make(chan error, 1)
var alpha fr.Element
go func() {
var err error
- bz, err = computeBlindedZ(ll, lr, lo, pk, gamma)
+ blindedZCanonical, err = computeBlindedZCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ pk, beta, gamma)
if err != nil {
chZ <- err
close(chZ)
@@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
// this may add additional arithmetic operations, but with smaller tasks
// we ensure that this commitment is well parallelized, without having a "unbalanced task" making
// the rest of the code wait too long.
- if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
+ if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
chZ <- err
close(chZ)
return
@@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
}()
// evaluation of the blinded versions of l, r, o and bz
- // on the odd cosets of (Z/8mZ)/(Z/mZ)
- var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial
+ // on the coset of the big domain
+ var (
+ evaluationBlindedLDomainBigBitReversed []fr.Element
+ evaluationBlindedRDomainBigBitReversed []fr.Element
+ evaluationBlindedODomainBigBitReversed []fr.Element
+ evaluationBlindedZDomainBigBitReversed []fr.Element
+ )
chEvalBL := make(chan struct{}, 1)
chEvalBR := make(chan struct{}, 1)
chEvalBO := make(chan struct{}, 1)
go func() {
- evalBL = evaluateHDomain(bcl, &pk.DomainH)
+ evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1])
close(chEvalBL)
}()
go func() {
- evalBR = evaluateHDomain(bcr, &pk.DomainH)
+ evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1])
close(chEvalBR)
}()
go func() {
- evalBO = evaluateHDomain(bco, &pk.DomainH)
+ evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1])
close(chEvalBO)
}()
- var constraintsInd, constraintsOrdering polynomial.Polynomial
+ var constraintsInd, constraintsOrdering []fr.Element
chConstraintInd := make(chan struct{}, 1)
go func() {
// compute qk in canonical basis, completed with the public inputs
- qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality)
- copy(qk, fullWitness[:spr.NbPublicVariables])
- copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
- pk.DomainNum.FFTInverse(qk, fft.DIF, 0)
- fft.BitReverse(qk)
-
- // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ)
- // --> uses the blinded version of l, r, o
+ qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality)
+ copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables])
+ copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
+ pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF)
+ fft.BitReverse(qkCompletedCanonical)
+
+ // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain
+ // → uses the blinded version of l, r, o
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk)
+ constraintsInd = evaluateConstraintsDomainBigBitReversed(
+ pk,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ qkCompletedCanonical)
close(chConstraintInd)
}()
@@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
chConstraintOrdering <- err
return
}
- evalBZ = evaluateHDomain(bz, &pk.DomainH)
- // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ)
+
+ evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1])
+ // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain
// evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o.
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma)
+ constraintsOrdering = evaluateOrderingDomainBigBitReversed(
+ pk,
+ evaluationBlindedZDomainBigBitReversed,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ beta,
+ gamma)
chConstraintOrdering <- nil
close(chConstraintOrdering)
}()
@@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
if err := <-chConstraintOrdering; err != nil {
return nil, err
}
+
<-chConstraintInd
+
// compute h in canonical form
- h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha)
+ h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha)
// compute kzg commitments of h1, h2 and h3
- if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
var wgZetaEvals sync.WaitGroup
wgZetaEvals.Add(3)
go func() {
- blzeta = bcl.Eval(&zeta)
+ blzeta = eval(blindedLCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- brzeta = bcr.Eval(&zeta)
+ brzeta = eval(blindedRCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- bozeta = bco.Eval(&zeta)
+ bozeta = eval(blindedOCanonical, zeta)
wgZetaEvals.Done()
}()
@@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
var zetaShifted fr.Element
zetaShifted.Mul(&zeta, &pk.Vk.Generator)
proof.ZShiftedOpening, err = kzg.Open(
- bz,
- &zetaShifted,
- &pk.DomainH,
+ blindedZCanonical,
+ zetaShifted,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
bzuzeta := proof.ZShiftedOpening.ClaimedValue
var (
- linearizedPolynomial polynomial.Polynomial
- linearizedPolynomialDigest curve.G1Affine
- errLPoly error
+ linearizedPolynomialCanonical []fr.Element
+ linearizedPolynomialDigest curve.G1Affine
+ errLPoly error
)
chLpoly := make(chan struct{}, 1)
go func() {
// compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k)
wgZetaEvals.Wait()
- linearizedPolynomial = computeLinearizedPolynomial(
+ linearizedPolynomialCanonical = computeLinearizedPolynomial(
blzeta,
brzeta,
bozeta,
alpha,
+ beta,
gamma,
zeta,
bzuzeta,
- bz,
+ blindedZCanonical,
pk,
)
// TODO this commitment is only necessary to derive the challenge, we should
// be able to avoid doing it and get the challenge in another way
- linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS)
+ linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS)
close(chLpoly)
}()
- // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3)
+ // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3)
var bZetaPowerm, bSize big.Int
- bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
+ bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
var zetaPowerm fr.Element
zetaPowerm.Exp(zeta, &bSize)
zetaPowerm.ToBigIntRegular(&bZetaPowerm)
foldedHDigest := proof.H[2]
foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm)
- foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3)
- foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2)
- foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3)
+ foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1)
- // foldedH = h1 + zeta*h2 + zeta**2*h3
+ // foldedH = h1 + ζ*h2 + ζ²*h3
foldedH := h3
utils.Parallelize(len(foldedH), func(start, end int) {
for i := start; i < end; i++ {
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3
- foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1)
- foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3
+ foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺²
+ foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1
}
})
@@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
// Batch open the first list of polynomials
proof.BatchedProof, err = kzg.BatchOpenSinglePoint(
- []polynomial.Polynomial{
+ [][]fr.Element{
foldedH,
- linearizedPolynomial,
- bcl,
- bcr,
- bco,
- pk.CS1,
- pk.CS2,
+ linearizedPolynomialCanonical,
+ blindedLCanonical,
+ blindedRCanonical,
+ blindedOCanonical,
+ pk.S1Canonical,
+ pk.S2Canonical,
},
[]kzg.Digest{
foldedHDigest,
@@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
pk.Vk.S[0],
pk.Vk.S[1],
},
- &zeta,
+ zeta,
hFunc,
- &pk.DomainH,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_633witness.Witnes
}
+// eval evaluates c at p
+func eval(c []fr.Element, p fr.Element) fr.Element {
+ var r fr.Element
+ for i := len(c) - 1; i >= 0; i-- {
+ r.Mul(&r, &p).Add(&r, &c[i])
+ }
+ return r
+}
+
// fills proof.LRO with kzg commits of bcl, bcr and bco
-func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS
return err1
}
-func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err
return err1
}
-// computeBlindedLRO l, r, o in canonical basis with blinding
-func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) {
+// computeBlindedLROCanonical l, r, o in canonical basis with blinding
+func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) {
// note that bcl, bcr and bco reuses cl, cr and co memory
- cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
+ cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
chDone := make(chan error, 2)
go func() {
var err error
copy(cl, ll)
- domain.FFTInverse(cl, fft.DIF, 0)
+ domain.FFTInverse(cl, fft.DIF)
fft.BitReverse(cl)
bcl, err = blindPoly(cl, domain.Cardinality, 1)
chDone <- err
@@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
go func() {
var err error
copy(cr, lr)
- domain.FFTInverse(cr, fft.DIF, 0)
+ domain.FFTInverse(cr, fft.DIF)
fft.BitReverse(cr)
bcr, err = blindPoly(cr, domain.Cardinality, 1)
chDone <- err
}()
copy(co, lo)
- domain.FFTInverse(co, fft.DIF, 0)
+ domain.FFTInverse(co, fft.DIF)
fft.BitReverse(co)
if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil {
return
@@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
// * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1)
//
// WARNING:
-// pre condition degree(cp) <= rou + bo
-// pre condition cap(cp) >= int(totalDegree + 1)
-func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) {
+// pre condition degree(cp) ⩽ rou + bo
+// pre condition cap(cp) ⩾ int(totalDegree + 1)
+func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) {
// degree of the blinded polynomial is max(rou+order, cp.Degree)
totalDegree := rou + bo
@@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
res := cp[:totalDegree+1]
// random polynomial
- blindingPoly := make(polynomial.Polynomial, bo+1)
+ blindingPoly := make([]fr.Element, bo+1)
for i := uint64(0); i < bo+1; i++ {
if _, err := blindingPoly[i].SetRandom(); err != nil {
return nil, err
@@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
}
return res, nil
+
}
-// computeLRO extracts the solution l, r, o, and returns it in lagrange form.
+// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form.
// solution = [ public | secret | internal ]
-func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) {
+func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) {
- s := int(pk.DomainNum.Cardinality)
+ s := int(pk.Domain[0].Cardinality)
- var l, r, o polynomial.Polynomial
+ var l, r, o []fr.Element
l = make([]fr.Element, s)
r = make([]fr.Element, s)
o = make([]fr.Element, s)
@@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly
//
// * Z of degree n (domainNum.Cardinality)
// * Z(1)=1
-// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma)
-// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1
- u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
}
})
@@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele
Mul(&z[i], &gInv[i])
}
- pk.DomainNum.FFTInverse(z, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(z, fft.DIF)
fft.BitReverse(z)
- return blindPoly(z, pk.DomainNum.Cardinality, 2)
+ return blindPoly(z, pk.Domain[0].Cardinality, 2)
}
-// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on
-// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on
+// the big domain coset.
//
// * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets
// * qk is the completed version of qk, in canonical version
-func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
- var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial
+func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
+ var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element
var wg sync.WaitGroup
wg.Add(4)
go func() {
- evalQl = evaluateHDomain(pk.Ql, &pk.DomainH)
+ evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQr = evaluateHDomain(pk.Qr, &pk.DomainH)
+ evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQm = evaluateHDomain(pk.Qm, &pk.DomainH)
+ evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQo = evaluateHDomain(pk.Qo, &pk.DomainH)
+ evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1])
wg.Done()
}()
- evalQk = evaluateHDomain(qk, &pk.DomainH)
+ evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1])
wg.Wait()
- // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets
- // of (Z/8mZ)/(Z/mZ)
+
+ // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain
utils.Parallelize(len(evalQk), func(start, end int) {
var t0, t1 fr.Element
for i := start; i < end; i++ {
@@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.
return evalQk
}
-// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ)
-func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) {
-
- id = make([]fr.Element, pk.DomainH.Cardinality)
-
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
- var acc fr.Element
- acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start)))
- for i := start; i < end; i++ {
- id[i].Mul(&acc, &pk.DomainH.FinerGenerator)
- acc.Mul(&acc, &pk.DomainH.Generator)
- }
- })
-
- return id
-}
-
-// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
-// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
+// cosets of the big domain.
//
-// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets
-// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets
+// * z evaluation of the blinded permutation accumulator polynomial on odd cosets
+// * l, r, o evaluation of the blinded solution vectors on odd cosets
// * gamma randomization
-func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial {
+func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element {
- // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ)
- evalID := evalIDCosets(pk)
+ nbElmts := int(pk.Domain[1].Cardinality)
- // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ)
- var wg sync.WaitGroup
- wg.Add(2)
- var evalS1, evalS2, evalS3 polynomial.Polynomial
- go func() {
- evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH)
- wg.Done()
- }()
- go func() {
- evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH)
- wg.Done()
- }()
- evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH)
- wg.Wait()
+ // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ)
+ // on the big domain (coset).
+ res := make([]fr.Element, pk.Domain[1].Cardinality)
- // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ)
- res := evalS1 // re use allocated memory for evalS1
- s := uint64(len(evalZ))
- nn := uint64(64 - bits.TrailingZeros64(uint64(s)))
+ nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts)))
// needed to shift evalZ
- toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality
+ toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality)
+
+ var cosetShift, cosetShiftSquare fr.Element
+ cosetShift.Set(&pk.Vk.CosetShift)
+ cosetShiftSquare.Square(&pk.Vk.CosetShift)
+
+ utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) {
+
+ var evaluationIDBigDomain fr.Element
+ evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))).
+ Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen)
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
var f [3]fr.Element
var g [3]fr.Element
- var eID fr.Element
for i := start; i < end; i++ {
- // here we want to left shift evalZ by domainH/domainNum
- // however, evalZ is permuted
- // we take the non permuted index
- // compute the corresponding shift position
- // permute it again
- irev := bits.Reverse64(uint64(i)) >> nn
- eID = evalID[irev]
+ _i := bits.Reverse64(uint64(i)) >> nn
+ _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn
- shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn
- //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn
+ // in what follows gⁱ is understood as the generator of the chosen coset of domainBig
+ f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ
+ f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ
+ f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ
- f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma
- f[1].Mul(&eID, &pk.Vk.Shifter[0])
- f[2].Mul(&eID, &pk.Vk.Shifter[1])
- f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma
- f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma
+ g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ
+ g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ
+ g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ
- g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma
- g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma
- g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma
+ f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
+ g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ)
- f[0].Mul(&f[0], &f[1]).
- Mul(&f[0], &f[2]).
- Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma)
+ res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
- g[0].Mul(&g[0], &g[1]).
- Mul(&g[0], &g[2]).
- Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma)
-
- res[i].Sub(&g[0], &f[0])
+ evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g
}
})
return res
}
-// evaluateHDomain evaluates poly (canonical form) of degree m> nn
- // h[i].Mul(&h[i], &_u[irev%4])
- h[i].Mul(&h[i], &_u[irev%toShift])
+
+ _i := bits.Reverse64(i) >> nn
+
+ t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain
+ h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t).
+ Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]).
+ Mul(&h[_i], &alpha).
+ Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]).
+ Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio])
}
})
// put h in canonical form. h is of degree 3*(n+1)+2.
// using fft.DIT put h revert bit reverse
- pk.DomainH.FFTInverse(h, fft.DIT, 1)
- // fmt.Println("h:")
- // for i := 0; i < len(h); i++ {
- // fmt.Printf("%s\n", h[i].String())
- // }
- // fmt.Println("")
+ pk.Domain[1].FFTInverse(h, fft.DIT, true)
// degree of hi is n+2 because of the blinding
- h1 := h[:pk.DomainNum.Cardinality+2]
- h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)]
- h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)]
+ h1 := h[:pk.Domain[0].Cardinality+2]
+ h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)]
+ h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)]
return h1, h2, h3
@@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom
// computeLinearizedPolynomial computes the linearized polynomial in canonical basis.
// The purpose is to commit and open all in one ql, qr, qm, qo, qk.
-// * a, b, c are the evaluation of l, r, o at zeta
-// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z
+// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta
+// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z
// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk.
-func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial {
+//
+// The Linearized polynomial is:
+//
+// α²*L₁(ζ)*Z(X)
+// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ))
+// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X)
+func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element {
// first part: individual constraints
var rl fr.Element
- rl.Mul(&r, &l)
+ rl.Mul(&rZeta, &lZeta)
- // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
+ // second part:
+ // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)
var s1, s2 fr.Element
chS1 := make(chan struct{}, 1)
go func() {
- s1 = pk.CS1.Eval(&zeta)
- s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma)
+ s1 = eval(pk.S1Canonical, zeta) // s1(ζ)
+ s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ)
close(chS1)
}()
- t := pk.CS2.Eval(&zeta)
- t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma)
+ tmp := eval(pk.S2Canonical, zeta) // s2(ζ)
+ tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ)
<-chS1
- s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma)
- Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)
-
- s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma)
- t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)
- t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
- s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
-
- // third part L1(zeta)*alpha**2**Z
- var lagrange, one, den, frNbElmt fr.Element
+ s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ)
+
+ var uzeta, uuzeta fr.Element
+ uzeta.Mul(&zeta, &pk.Vk.CosetShift)
+ uuzeta.Mul(&uzeta, &pk.Vk.CosetShift)
+
+ s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ)
+ tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)
+ tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+ s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ // third part L₁(ζ)*α²*Z
+ var lagrangeZeta, one, den, frNbElmt fr.Element
one.SetOne()
- nbElmt := int64(pk.DomainNum.Cardinality)
- lagrange.Set(&zeta).
- Exp(lagrange, big.NewInt(nbElmt)).
- Sub(&lagrange, &one)
+ nbElmt := int64(pk.Domain[0].Cardinality)
+ lagrangeZeta.Set(&zeta).
+ Exp(lagrangeZeta, big.NewInt(nbElmt)).
+ Sub(&lagrangeZeta, &one)
frNbElmt.SetUint64(uint64(nbElmt))
den.Sub(&zeta, &one).
- Mul(&den, &frNbElmt).
Inverse(&den)
- lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1)
- Mul(&lagrange, &alpha).
- Mul(&lagrange, &alpha) // alpha**2*L_0
+ lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1)
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ)
- linPol := z.Clone()
+ linPol := make([]fr.Element, len(blindedZCanonical))
+ copy(linPol, blindedZCanonical)
utils.Parallelize(len(linPol), func(start, end int) {
+
var t0, t1 fr.Element
+
for i := start; i < end; i++ {
- linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
- if i < len(pk.CS3) {
- t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X)
+
+ linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ if i < len(pk.S3Canonical) {
+
+ t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X)
+
linPol[i].Add(&linPol[i], &t0)
}
- linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) )
+ linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ))
if i < len(pk.Qm) {
- t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm
- t0.Mul(&pk.Ql[i], &l)
+
+ t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X)
+ t0.Mul(&pk.Ql[i], &lZeta)
t0.Add(&t0, &t1)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X)
- t0.Mul(&pk.Qr[i], &r)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr
+ t0.Mul(&pk.Qr[i], &rZeta)
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X)
- t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i])
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk
+ t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i])
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X)
}
- t0.Mul(&z[i], &lagrange)
+ t0.Mul(&blindedZCanonical[i], &lagrangeZeta)
linPol[i].Add(&linPol[i], &t0) // finish the computation
}
})
diff --git a/internal/backend/bw6-633/plonk/setup.go b/internal/backend/bw6-633/plonk/setup.go
index 06cfe86c1b..f7f6fa581c 100644
--- a/internal/backend/bw6-633/plonk/setup.go
+++ b/internal/backend/bw6-633/plonk/setup.go
@@ -21,7 +21,6 @@ import (
"github.com/consensys/gnark-crypto/ecc/bw6-633/fr"
"github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft"
"github.com/consensys/gnark-crypto/ecc/bw6-633/fr/kzg"
- "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial"
"github.com/consensys/gnark/internal/backend/bw6-633/cs"
kzgg "github.com/consensys/gnark-crypto/kzg"
@@ -40,18 +39,21 @@ type ProvingKey struct {
Vk *VerifyingKey
// qr,ql,qm,qo (in canonical basis).
- Ql, Qr, Qm, Qo polynomial.Polynomial
+ Ql, Qr, Qm, Qo []fr.Element
// LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs.
// Storing LQk in Lagrange basis saves a fft...
- CQk, LQk polynomial.Polynomial
+ CQk, LQk []fr.Element
- // Domains used for the FFTs
- DomainNum, DomainH fft.Domain
+ // Domains used for the FFTs.
+ // Domain[0] = small Domain
+ // Domain[1] = big Domain
+ Domain [2]fft.Domain
+ // Domain[0], Domain[1] fft.Domain
- // s1, s2, s3 (L=Lagrange basis, C=canonical basis)
- LS1, LS2, LS3 polynomial.Polynomial
- CS1, CS2, CS3 polynomial.Polynomial
+ // Permutation polynomials
+ EvaluationPermutationBigDomainBitReversed []fr.Element
+ S1Canonical, S2Canonical, S3Canonical []fr.Element
// position -> permuted position (position in [0,3*sizeSystem-1])
Permutation []int64
@@ -69,13 +71,12 @@ type VerifyingKey struct {
Generator fr.Element
NbPublicVariables uint64
- // shifters for extending the permutation set: from s=<1,z,..,z**n-1>,
- // extended domain = s || shifter[0].s || shifter[1].s
- Shifter [2]fr.Element
-
// Commitment scheme that is used for an instantiation of PLONK
KZGSRS *kzg.SRS
+ // cosetShift generator of the coset on the small domain
+ CosetShift fr.Element
+
// S commitments to S1, S2, S3
S [3]kzg.Digest
@@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
// fft domains
sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints
- pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false)
+ pk.Domain[0] = *fft.NewDomain(sizeSystem)
+ pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen)
// h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space,
// the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases
// except when n<6.
if sizeSystem < 6 {
- pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(8 * sizeSystem)
} else {
- pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(4 * sizeSystem)
}
- vk.Size = pk.DomainNum.Cardinality
+ vk.Size = pk.Domain[0].Cardinality
vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv)
- vk.Generator.Set(&pk.DomainNum.Generator)
+ vk.Generator.Set(&pk.Domain[0].Generator)
vk.NbPublicVariables = uint64(spr.NbPublicVariables)
- // shifters
- vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator)
- vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator)
-
if err := pk.InitKZG(srs); err != nil {
return nil, nil, err
}
// public polynomials corresponding to constraints: [ placholders | constraints | assertions ]
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.Qm[i].SetZero()
pk.Qo[i].SetZero()
pk.CQk[i].SetZero()
- pk.LQk[i].SetZero() // --> to be completed by the prover
+ pk.LQk[i].SetZero() // → to be completed by the prover
}
offset := spr.NbPublicVariables
for i := 0; i < nbConstraints; i++ { // constraints
@@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K])
}
- pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(pk.Ql, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qr, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qm, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qo, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.CQk, fft.DIF)
fft.BitReverse(pk.Ql)
fft.BitReverse(pk.Qr)
fft.BitReverse(pk.Qm)
@@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
buildPermutation(spr, &pk)
// set s1, s2, s3
- computeLDE(&pk)
+ ccomputePermutationPolynomials(&pk)
// Commit to the polynomials to set up the verifying key
var err error
@@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil {
+ if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil {
+ if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil {
+ if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
@@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
//
// The permutation s is composed of cycles of maximum length such that
//
-// s. (l||r||o) = (l||r||o)
+// s. (l∥r∥o) = (l∥r∥o)
//
-//, where l||r||o is the concatenation of the indices of l, r, o in
+//, where l∥r∥o is the concatenation of the indices of l, r, o in
// ql.l+qr.r+qm.l.r+qo.O+k = 0.
//
// The permutation is encoded as a slice s of size 3*size(l), where the
-// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab
+// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab
// like this: for i in tab: tab[i] = tab[permutation[i]]
func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables
- sizeSolution := int(pk.DomainNum.Cardinality)
+ sizeSolution := int(pk.Domain[0].Cardinality)
// init permutation
pk.Permutation = make([]int64, 3*sizeSolution)
@@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
}
}
-// computeLDE computes the LDE (Lagrange basis) of the permutations
+// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations
// s1, s2, s3.
//
-// ex: z gen of Z/mZ, u gen of Z/8mZ, then
-//
// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 |
// |
// | Permutation
// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v
// \---------------/ \--------------------/ \------------------------/
// s1 (LDE) s2 (LDE) s3 (LDE)
-func computeLDE(pk *ProvingKey) {
+func ccomputePermutationPolynomials(pk *ProvingKey) {
- nbElmt := int(pk.DomainNum.Cardinality)
+ nbElmts := int(pk.Domain[0].Cardinality)
- // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1]
- sID := make([]fr.Element, 3*nbElmt)
- sID[0].SetOne()
- sID[nbElmt].Set(&pk.DomainNum.FinerGenerator)
- sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator)
-
- for i := 1; i < nbElmt; i++ {
- sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1
- sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
- }
+ // Lagrange form of ID
+ evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0])
// Lagrange form of S1, S2, S3
- pk.LS1 = make(polynomial.Polynomial, nbElmt)
- pk.LS2 = make(polynomial.Polynomial, nbElmt)
- pk.LS3 = make(polynomial.Polynomial, nbElmt)
- for i := 0; i < nbElmt; i++ {
- pk.LS1[i].Set(&sID[pk.Permutation[i]])
- pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]])
- pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]])
+ pk.S1Canonical = make([]fr.Element, nbElmts)
+ pk.S2Canonical = make([]fr.Element, nbElmts)
+ pk.S3Canonical = make([]fr.Element, nbElmts)
+ for i := 0; i < nbElmts; i++ {
+ pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]])
+ pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]])
+ pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]])
}
// Canonical form of S1, S2, S3
- pk.CS1 = make(polynomial.Polynomial, nbElmt)
- pk.CS2 = make(polynomial.Polynomial, nbElmt)
- pk.CS3 = make(polynomial.Polynomial, nbElmt)
- copy(pk.CS1, pk.LS1)
- copy(pk.CS2, pk.LS2)
- copy(pk.CS3, pk.LS3)
- pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0)
- fft.BitReverse(pk.CS1)
- fft.BitReverse(pk.CS2)
- fft.BitReverse(pk.CS3)
+ pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF)
+ fft.BitReverse(pk.S1Canonical)
+ fft.BitReverse(pk.S2Canonical)
+ fft.BitReverse(pk.S3Canonical)
+
+ // evaluation of permutation on the big domain
+ pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality)
+ copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true)
+
+}
+
+// getIDSmallDomain returns the Lagrange form of ID on the small domain
+func getIDSmallDomain(domain *fft.Domain) []fr.Element {
+
+ res := make([]fr.Element, 3*domain.Cardinality)
+
+ res[0].SetOne()
+ res[domain.Cardinality].Set(&domain.FrMultiplicativeGen)
+ res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen)
+
+ for i := uint64(1); i < domain.Cardinality; i++ {
+ res[i].Mul(&res[i-1], &domain.Generator)
+ res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator)
+ res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator)
+ }
+ return res
}
-// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS
+// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS
//
// This should be used after deserializing a ProvingKey
// as pk.Vk.KZG is NOT serialized
diff --git a/internal/backend/bw6-633/plonk/verify.go b/internal/backend/bw6-633/plonk/verify.go
index 49e6684f2c..1a7651695b 100644
--- a/internal/backend/bw6-633/plonk/verify.go
+++ b/internal/backend/bw6-633/plonk/verify.go
@@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness
hFunc := sha256.New()
// transcript to derive the challenge
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// derive gamma from Comm(l), Comm(r), Comm(o)
gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2])
@@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness
return err
}
+ // derive beta from Comm(l), Comm(r), Comm(o)
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return err
+ }
+
// derive alpha from Comm(l), Comm(r), Comm(o), Com(Z)
alpha, err := deriveRandomness(&fs, "alpha", &proof.Z)
if err != nil {
@@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness
return err
}
- // evaluation of Z=X**m-1 at zeta
+ // evaluation of Z=Xⁿ⁻¹ at ζ
var zetaPowerM, zzeta fr.Element
var bExpo big.Int
one := fr.One()
@@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_633witness.Witness
zetaPowerM.Exp(zeta, &bExpo)
zzeta.Sub(&zetaPowerM, &one)
- // ccompute PI = Sum_i maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
}
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
}
- }
- // sanity check; ensure all wires are marked as "instantiated"
- if !solution.isValid() {
- panic("solver didn't instantiate all wires")
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
}
- return solution.values, nil
+ return nil
}
// IsSolved returns nil if given witness solves the R1CS and error otherwise
@@ -183,7 +265,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) {
// returns false, nil if there was no wire to solve
// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that
// the constraint is satisfied later.
-func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a, b, c fr.Element, err error) {
+func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a, b, c *fr.Element) error {
// the index of the non zero entry shows if L, R or O has an uninstantiated wire
// the content is the ID of the wire non instantiated
@@ -220,28 +302,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
return nil
}
- if err = processLExp(r.L.LinExp, &a, 1); err != nil {
- return
+ if err := processLExp(r.L.LinExp, a, 1); err != nil {
+ return err
}
- if err = processLExp(r.R.LinExp, &b, 2); err != nil {
- return
+ if err := processLExp(r.R.LinExp, b, 2); err != nil {
+ return err
}
- if err = processLExp(r.O.LinExp, &c, 3); err != nil {
- return
+ if err := processLExp(r.O.LinExp, c, 3); err != nil {
+ return err
}
if loc == 0 {
// there is nothing to solve, may happen if we have an assertion
// (ie a constraints that doesn't yield any output)
// or if we solved the unsolved wires with hint functions
- return
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
+ return nil
}
// we compute the wire value and instantiate it
- solved = true
- vID := termToCompute.WireID()
+ wID := termToCompute.WireID()
// solver result
var wire fr.Element
@@ -249,36 +334,41 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
switch loc {
case 1:
if !b.IsZero() {
- wire.Div(&c, &b).
- Sub(&wire, &a)
- a.Add(&a, &wire)
+ wire.Div(c, b).
+ Sub(&wire, a)
+ a.Add(a, &wire)
} else {
// we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 2:
if !a.IsZero() {
- wire.Div(&c, &a).
- Sub(&wire, &b)
- b.Add(&b, &wire)
+ wire.Div(c, a).
+ Sub(&wire, b)
+ b.Add(b, &wire)
} else {
- // we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 3:
- wire.Mul(&a, &b).
- Sub(&wire, &c)
+ wire.Mul(a, b).
+ Sub(&wire, c)
- c.Add(&c, &wire)
+ c.Add(c, &wire)
}
// wire is the term (coeff * value)
// but in the solution we want to store the value only
// note that in gnark frontend, coeff here is always 1 or -1
cs.divByCoeff(&wire, termToCompute)
- solution.set(vID, wire)
+ solution.set(wID, wire)
- return
+ return nil
}
// GetConstraints return a list of constraint formatted as L⋅R == O
diff --git a/internal/backend/bw6-761/cs/r1cs_sparse.go b/internal/backend/bw6-761/cs/r1cs_sparse.go
index fbb30462da..8cc16bcf4c 100644
--- a/internal/backend/bw6-761/cs/r1cs_sparse.go
+++ b/internal/backend/bw6-761/cs/r1cs_sparse.go
@@ -21,9 +21,12 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"github.com/fxamacker/cbor/v2"
"io"
+ "math"
"math/big"
"os"
+ "runtime"
"strings"
+ "sync"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/backend/witness"
@@ -84,11 +87,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
return solution.values, err
}
- defer func() {
- // release memory
- solution.tmpHintsIO = nil
- }()
-
// solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs
copy(solution.values, witness)
for i := 0; i < len(witness); i++ {
@@ -97,7 +95,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
// keep track of the number of wire instantiations we do, for a sanity check to ensure
// we instantiated all wires
- solution.nbSolved += len(witness)
+ solution.nbSolved += uint64(len(witness))
// defer log printing once all solution.values are computed
defer solution.printLogs(opt.LoggerOut, cs.Logs)
@@ -108,18 +106,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
coefficientsNegInv[i].Neg(&coefficientsNegInv[i])
}
- // loop through the constraints to solve the variables
- for i := 0; i < len(cs.Constraints); i++ {
- if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil {
- return solution.values, fmt.Errorf("constraint %d: %w", i, err)
- }
- if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil {
- errMsg := err.Error()
- if dID, ok := cs.MDebug[i]; ok {
- errMsg = solution.logValue(cs.DebugInfo[dID])
- }
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
- }
+ if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil {
+ return solution.values, err
}
// sanity check; ensure all wires are marked as "instantiated"
@@ -131,6 +119,120 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
}
+func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error {
+ // minWorkPerCPU is the minimum target number of constraint a task should hold
+ // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed
+ // sequentially without sync.
+ const minWorkPerCPU = 50.0
+
+ // cs.Levels has a list of levels, where all constraints in a level l(n) are independent
+ // and may only have dependencies on previous levels
+
+ var wg sync.WaitGroup
+ chTasks := make(chan []int, runtime.NumCPU())
+ chError := make(chan error, runtime.NumCPU())
+
+ // start a worker pool
+ // each worker wait on chTasks
+ // a task is a slice of constraint indexes to be solved
+ for i := 0; i < runtime.NumCPU(); i++ {
+ go func() {
+ for t := range chTasks {
+ for _, i := range t {
+ // for each constraint in the task, solve it.
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ wg.Done()
+ return
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ wg.Done()
+ return
+ }
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ // clean up pool go routines
+ defer func() {
+ close(chTasks)
+ close(chError)
+ }()
+
+ // for each level, we push the tasks
+ for _, level := range cs.Levels {
+
+ // max CPU to use
+ maxCPU := float64(len(level)) / minWorkPerCPU
+
+ if maxCPU <= 1.0 {
+ // we do it sequentially
+ for _, i := range level {
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ return fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ }
+ }
+ continue
+ }
+
+ // number of tasks for this level is set to num cpus
+ // but if we don't have enough work for all our CPUS, it can be lower.
+ nbTasks := runtime.NumCPU()
+ maxTasks := int(math.Ceil(maxCPU))
+ if nbTasks > maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
+ }
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
+ }
+
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
+ }
+
+ return nil
+}
+
// computeHints computes wires associated with a hint function, if any
// if there is no remaining wire to solve, returns -1
// else returns the wire position (L -> 0, R -> 1, O -> 2)
diff --git a/internal/backend/bw6-761/cs/solution.go b/internal/backend/bw6-761/cs/solution.go
index 560d95f7d5..fb1e1a19bd 100644
--- a/internal/backend/bw6-761/cs/solution.go
+++ b/internal/backend/bw6-761/cs/solution.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"math/big"
+ "sync/atomic"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend/schema"
@@ -32,14 +33,15 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bw6-761"
)
+var errUnsatisfiedConstraint = errors.New("unsatisfied")
+
// solution represents elements needed to compute
// a solution to a R1CS or SparseR1CS
type solution struct {
values, coefficients []fr.Element
solved []bool
- nbSolved int
+ nbSolved uint64
mHintsFunctions map[hint.ID]hint.Function
- tmpHintsIO []*big.Int
}
func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) {
@@ -49,7 +51,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E
coefficients: coefficients,
solved: make([]bool, nbWires),
mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)),
- tmpHintsIO: make([]*big.Int, 0),
}
for _, h := range hintFunctions {
@@ -68,11 +69,12 @@ func (s *solution) set(id int, value fr.Element) {
}
s.values[id] = value
s.solved[id] = true
- s.nbSolved++
+ atomic.AddUint64(&s.nbSolved, 1)
+ // s.nbSolved++
}
func (s *solution) isValid() bool {
- return s.nbSolved == len(s.values)
+ return int(s.nbSolved) == len(s.values)
}
// computeTerm computes coef*variable
@@ -147,15 +149,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error {
// tmp IO big int memory
nbInputs := len(h.Inputs)
nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs))
- m := len(s.tmpHintsIO)
- if m < (nbInputs + nbOutputs) {
- s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs+nbInputs)-m)...)
- for i := m; i < len(s.tmpHintsIO); i++ {
- s.tmpHintsIO[i] = big.NewInt(0)
- }
+ // m := len(s.tmpHintsIO)
+ // if m < (nbInputs + nbOutputs) {
+ // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...)
+ // for i := m; i < len(s.tmpHintsIO); i++ {
+ // s.tmpHintsIO[i] = big.NewInt(0)
+ // }
+ // }
+ inputs := make([]*big.Int, nbInputs)
+ outputs := make([]*big.Int, nbOutputs)
+ for i := 0; i < nbInputs; i++ {
+ inputs[i] = big.NewInt(0)
+ }
+ for i := 0; i < nbOutputs; i++ {
+ outputs[i] = big.NewInt(0)
}
- inputs := s.tmpHintsIO[:nbInputs]
- outputs := s.tmpHintsIO[nbInputs : nbInputs+nbOutputs]
q := fr.Modulus()
diff --git a/internal/backend/bw6-761/groth16/marshal_test.go b/internal/backend/bw6-761/groth16/marshal_test.go
index 73054efb74..bccf077d51 100644
--- a/internal/backend/bw6-761/groth16/marshal_test.go
+++ b/internal/backend/bw6-761/groth16/marshal_test.go
@@ -177,7 +177,7 @@ func TestProvingKeySerialization(t *testing.T) {
var pk, pkCompressed, pkRaw ProvingKey
// create a random pk
- domain := fft.NewDomain(8, 1, true)
+ domain := fft.NewDomain(8)
pk.Domain = *domain
nbWires := 6
diff --git a/internal/backend/bw6-761/groth16/prove.go b/internal/backend/bw6-761/groth16/prove.go
index 8848e71c05..663bc522b1 100644
--- a/internal/backend/bw6-761/groth16/prove.go
+++ b/internal/backend/bw6-761/groth16/prove.go
@@ -281,18 +281,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
c = append(c, padding...)
n = len(a)
- domain.FFTInverse(a, fft.DIF, 0)
- domain.FFTInverse(b, fft.DIF, 0)
- domain.FFTInverse(c, fft.DIF, 0)
+ domain.FFTInverse(a, fft.DIF)
+ domain.FFTInverse(b, fft.DIF)
+ domain.FFTInverse(c, fft.DIF)
- domain.FFT(a, fft.DIT, 1)
- domain.FFT(b, fft.DIT, 1)
- domain.FFT(c, fft.DIT, 1)
+ domain.FFT(a, fft.DIT, true)
+ domain.FFT(b, fft.DIT, true)
+ domain.FFT(c, fft.DIT, true)
- var minusTwoInv fr.Element
- minusTwoInv.SetUint64(2)
- minusTwoInv.Neg(&minusTwoInv).
- Inverse(&minusTwoInv)
+ var den, one fr.Element
+ one.SetOne()
+ den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
+ den.Sub(&den, &one).Inverse(&den)
// h = ifft_coset(ca o cb - cc)
// reusing a to avoid unecessary memalloc
@@ -300,12 +300,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
for i := start; i < end; i++ {
a[i].Mul(&a[i], &b[i]).
Sub(&a[i], &c[i]).
- Mul(&a[i], &minusTwoInv)
+ Mul(&a[i], &den)
}
})
// ifft_coset
- domain.FFTInverse(a, fft.DIF, 1)
+ domain.FFTInverse(a, fft.DIF, true)
utils.Parallelize(len(a), func(start, end int) {
for i := start; i < end; i++ {
diff --git a/internal/backend/bw6-761/groth16/setup.go b/internal/backend/bw6-761/groth16/setup.go
index 137b4df3fe..0e100d3319 100644
--- a/internal/backend/bw6-761/groth16/setup.go
+++ b/internal/backend/bw6-761/groth16/setup.go
@@ -95,7 +95,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error {
nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables
// Setting group for fft
- domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true)
+ domain := fft.NewDomain(uint64(len(r1cs.Constraints)))
// samples toxic waste
toxicWaste, err := sampleToxicWaste()
@@ -415,7 +415,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error {
nbConstraints := len(r1cs.Constraints)
// Setting group for fft
- domain := fft.NewDomain(uint64(nbConstraints), 1, true)
+ domain := fft.NewDomain(uint64(nbConstraints))
// count number of infinity points we would have had we a normal setup
// in pk.G1.A, pk.G1.B, and pk.G2.B
diff --git a/internal/backend/bw6-761/plonk/marshal.go b/internal/backend/bw6-761/plonk/marshal.go
index 6daaa3130f..fe81cc4e0d 100644
--- a/internal/backend/bw6-761/plonk/marshal.go
+++ b/internal/backend/bw6-761/plonk/marshal.go
@@ -89,20 +89,20 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
}
// fft domains
- n2, err := pk.DomainNum.WriteTo(w)
+ n2, err := pk.Domain[0].WriteTo(w)
if err != nil {
return
}
n += n2
- n2, err = pk.DomainH.WriteTo(w)
+ n2, err = pk.Domain[1].WriteTo(w)
if err != nil {
return
}
n += n2
- // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality)
- if len(pk.Permutation) != (3 * int(pk.DomainNum.Cardinality)) {
+ // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality)
+ if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) {
return n, errors.New("invalid permutation size, expected 3*domain cardinality")
}
@@ -117,12 +117,9 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
([]fr.Element)(pk.Qo),
([]fr.Element)(pk.CQk),
([]fr.Element)(pk.LQk),
- ([]fr.Element)(pk.LS1),
- ([]fr.Element)(pk.LS2),
- ([]fr.Element)(pk.LS3),
- ([]fr.Element)(pk.CS1),
- ([]fr.Element)(pk.CS2),
- ([]fr.Element)(pk.CS3),
+ ([]fr.Element)(pk.S1Canonical),
+ ([]fr.Element)(pk.S2Canonical),
+ ([]fr.Element)(pk.S3Canonical),
pk.Permutation,
}
@@ -143,19 +140,19 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
- n2, err := pk.DomainNum.ReadFrom(r)
+ n2, err := pk.Domain[0].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- n2, err = pk.DomainH.ReadFrom(r)
+ n2, err = pk.Domain[1].ReadFrom(r)
n += n2
if err != nil {
return n, err
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
dec := curve.NewDecoder(r)
toDecode := []interface{}{
@@ -165,12 +162,9 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
(*[]fr.Element)(&pk.Qo),
(*[]fr.Element)(&pk.CQk),
(*[]fr.Element)(&pk.LQk),
- (*[]fr.Element)(&pk.LS1),
- (*[]fr.Element)(&pk.LS2),
- (*[]fr.Element)(&pk.LS3),
- (*[]fr.Element)(&pk.CS1),
- (*[]fr.Element)(&pk.CS2),
- (*[]fr.Element)(&pk.CS3),
+ (*[]fr.Element)(&pk.S1Canonical),
+ (*[]fr.Element)(&pk.S2Canonical),
+ (*[]fr.Element)(&pk.S3Canonical),
&pk.Permutation,
}
@@ -193,8 +187,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) {
&vk.SizeInv,
&vk.Generator,
vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
@@ -222,8 +214,6 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) {
&vk.SizeInv,
&vk.Generator,
&vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
diff --git a/internal/backend/bw6-761/plonk/marshal_test.go b/internal/backend/bw6-761/plonk/marshal_test.go
index 4a18f7651f..2b47f5e50a 100644
--- a/internal/backend/bw6-761/plonk/marshal_test.go
+++ b/internal/backend/bw6-761/plonk/marshal_test.go
@@ -32,7 +32,6 @@ func TestProvingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
@@ -48,14 +47,14 @@ func TestProvingKeySerialization(t *testing.T) {
// random pk
var pk ProvingKey
pk.Vk = &vk
- pk.DomainNum = *fft.NewDomain(42, 3, false)
- pk.DomainH = *fft.NewDomain(4*42, 1, false)
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Domain[0] = *fft.NewDomain(42)
+ pk.Domain[1] = *fft.NewDomain(4 * 42)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < 12; i++ {
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -63,7 +62,7 @@ func TestProvingKeySerialization(t *testing.T) {
pk.Qo[i].SetUint64(42)
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
pk.Permutation[0] = -12
pk.Permutation[len(pk.Permutation)-1] = 8888
@@ -94,7 +93,6 @@ func TestVerifyingKeySerialization(t *testing.T) {
var vk VerifyingKey
vk.Size = 42
vk.SizeInv = fr.One()
- vk.Shifter[1].SetUint64(12)
_, _, g1gen, _ := curve.Generators()
vk.S[0] = g1gen
diff --git a/internal/backend/bw6-761/plonk/prove.go b/internal/backend/bw6-761/plonk/prove.go
index 3f2981607b..4ce2467212 100644
--- a/internal/backend/bw6-761/plonk/prove.go
+++ b/internal/backend/bw6-761/plonk/prove.go
@@ -27,8 +27,6 @@ import (
curve "github.com/consensys/gnark-crypto/ecc/bw6-761"
- "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial"
-
"github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg"
"github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft"
@@ -43,6 +41,7 @@ import (
)
type Proof struct {
+
// Commitments to the solution vectors
LRO [3]kzg.Digest
@@ -66,7 +65,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
hFunc := sha256.New()
// create a transcript manager to apply Fiat Shamir
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// result
proof := &Proof{}
@@ -89,17 +88,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
}
// query l, r, o in Lagrange basis, not blinded
- ll, lr, lo := computeLRO(spr, pk, solution)
+ evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution)
// save ll, lr, lo, and make a copy of them in canonical basis.
// note that we allocate more capacity to reuse for blinded polynomials
- bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum)
+ blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ &pk.Domain[0])
if err != nil {
return nil, err
}
// compute kzg commitments of bcl, bcr and bco
- if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -109,14 +112,24 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
return nil, err
}
+ // Fiat Shamir this
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return nil, err
+ }
+
// compute Z, the permutation accumulator polynomial, in canonical basis
// ll, lr, lo are NOT blinded
- var bz polynomial.Polynomial
+ var blindedZCanonical []fr.Element
chZ := make(chan error, 1)
var alpha fr.Element
go func() {
var err error
- bz, err = computeBlindedZ(ll, lr, lo, pk, gamma)
+ blindedZCanonical, err = computeBlindedZCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ pk, beta, gamma)
if err != nil {
chZ <- err
close(chZ)
@@ -128,7 +141,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
// this may add additional arithmetic operations, but with smaller tasks
// we ensure that this commitment is well parallelized, without having a "unbalanced task" making
// the rest of the code wait too long.
- if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
+ if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
chZ <- err
close(chZ)
return
@@ -141,40 +154,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
}()
// evaluation of the blinded versions of l, r, o and bz
- // on the odd cosets of (Z/8mZ)/(Z/mZ)
- var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial
+ // on the coset of the big domain
+ var (
+ evaluationBlindedLDomainBigBitReversed []fr.Element
+ evaluationBlindedRDomainBigBitReversed []fr.Element
+ evaluationBlindedODomainBigBitReversed []fr.Element
+ evaluationBlindedZDomainBigBitReversed []fr.Element
+ )
chEvalBL := make(chan struct{}, 1)
chEvalBR := make(chan struct{}, 1)
chEvalBO := make(chan struct{}, 1)
go func() {
- evalBL = evaluateHDomain(bcl, &pk.DomainH)
+ evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1])
close(chEvalBL)
}()
go func() {
- evalBR = evaluateHDomain(bcr, &pk.DomainH)
+ evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1])
close(chEvalBR)
}()
go func() {
- evalBO = evaluateHDomain(bco, &pk.DomainH)
+ evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1])
close(chEvalBO)
}()
- var constraintsInd, constraintsOrdering polynomial.Polynomial
+ var constraintsInd, constraintsOrdering []fr.Element
chConstraintInd := make(chan struct{}, 1)
go func() {
// compute qk in canonical basis, completed with the public inputs
- qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality)
- copy(qk, fullWitness[:spr.NbPublicVariables])
- copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
- pk.DomainNum.FFTInverse(qk, fft.DIF, 0)
- fft.BitReverse(qk)
-
- // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ)
- // --> uses the blinded version of l, r, o
+ qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality)
+ copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables])
+ copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
+ pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF)
+ fft.BitReverse(qkCompletedCanonical)
+
+ // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain
+ // → uses the blinded version of l, r, o
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk)
+ constraintsInd = evaluateConstraintsDomainBigBitReversed(
+ pk,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ qkCompletedCanonical)
close(chConstraintInd)
}()
@@ -184,13 +207,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
chConstraintOrdering <- err
return
}
- evalBZ = evaluateHDomain(bz, &pk.DomainH)
- // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ)
+
+ evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1])
+ // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain
// evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o.
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma)
+ constraintsOrdering = evaluateOrderingDomainBigBitReversed(
+ pk,
+ evaluationBlindedZDomainBigBitReversed,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ beta,
+ gamma)
chConstraintOrdering <- nil
close(chConstraintOrdering)
}()
@@ -198,12 +229,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
if err := <-chConstraintOrdering; err != nil {
return nil, err
}
+
<-chConstraintInd
+
// compute h in canonical form
- h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha)
+ h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha)
// compute kzg commitments of h1, h2 and h3
- if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -218,15 +251,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
var wgZetaEvals sync.WaitGroup
wgZetaEvals.Add(3)
go func() {
- blzeta = bcl.Eval(&zeta)
+ blzeta = eval(blindedLCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- brzeta = bcr.Eval(&zeta)
+ brzeta = eval(blindedRCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- bozeta = bco.Eval(&zeta)
+ bozeta = eval(blindedOCanonical, zeta)
wgZetaEvals.Done()
}()
@@ -234,9 +267,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
var zetaShifted fr.Element
zetaShifted.Mul(&zeta, &pk.Vk.Generator)
proof.ZShiftedOpening, err = kzg.Open(
- bz,
- &zetaShifted,
- &pk.DomainH,
+ blindedZCanonical,
+ zetaShifted,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -247,53 +279,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
bzuzeta := proof.ZShiftedOpening.ClaimedValue
var (
- linearizedPolynomial polynomial.Polynomial
- linearizedPolynomialDigest curve.G1Affine
- errLPoly error
+ linearizedPolynomialCanonical []fr.Element
+ linearizedPolynomialDigest curve.G1Affine
+ errLPoly error
)
chLpoly := make(chan struct{}, 1)
go func() {
// compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k)
wgZetaEvals.Wait()
- linearizedPolynomial = computeLinearizedPolynomial(
+ linearizedPolynomialCanonical = computeLinearizedPolynomial(
blzeta,
brzeta,
bozeta,
alpha,
+ beta,
gamma,
zeta,
bzuzeta,
- bz,
+ blindedZCanonical,
pk,
)
// TODO this commitment is only necessary to derive the challenge, we should
// be able to avoid doing it and get the challenge in another way
- linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS)
+ linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS)
close(chLpoly)
}()
- // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3)
+ // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3)
var bZetaPowerm, bSize big.Int
- bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
+ bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
var zetaPowerm fr.Element
zetaPowerm.Exp(zeta, &bSize)
zetaPowerm.ToBigIntRegular(&bZetaPowerm)
foldedHDigest := proof.H[2]
foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm)
- foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3)
- foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2)
- foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3)
+ foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1)
- // foldedH = h1 + zeta*h2 + zeta**2*h3
+ // foldedH = h1 + ζ*h2 + ζ²*h3
foldedH := h3
utils.Parallelize(len(foldedH), func(start, end int) {
for i := start; i < end; i++ {
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3
- foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1)
- foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3
+ foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺²
+ foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1
}
})
@@ -304,14 +337,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
// Batch open the first list of polynomials
proof.BatchedProof, err = kzg.BatchOpenSinglePoint(
- []polynomial.Polynomial{
+ [][]fr.Element{
foldedH,
- linearizedPolynomial,
- bcl,
- bcr,
- bco,
- pk.CS1,
- pk.CS2,
+ linearizedPolynomialCanonical,
+ blindedLCanonical,
+ blindedRCanonical,
+ blindedOCanonical,
+ pk.S1Canonical,
+ pk.S2Canonical,
},
[]kzg.Digest{
foldedHDigest,
@@ -322,9 +355,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
pk.Vk.S[0],
pk.Vk.S[1],
},
- &zeta,
+ zeta,
hFunc,
- &pk.DomainH,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -335,8 +367,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness bw6_761witness.Witnes
}
+// eval evaluates c at p
+func eval(c []fr.Element, p fr.Element) fr.Element {
+ var r fr.Element
+ for i := len(c) - 1; i >= 0; i-- {
+ r.Mul(&r, &p).Add(&r, &c[i])
+ }
+ return r
+}
+
// fills proof.LRO with kzg commits of bcl, bcr and bco
-func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -362,7 +403,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS
return err1
}
-func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -388,20 +429,20 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err
return err1
}
-// computeBlindedLRO l, r, o in canonical basis with blinding
-func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) {
+// computeBlindedLROCanonical l, r, o in canonical basis with blinding
+func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) {
// note that bcl, bcr and bco reuses cl, cr and co memory
- cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
- co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality+2)
+ cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
chDone := make(chan error, 2)
go func() {
var err error
copy(cl, ll)
- domain.FFTInverse(cl, fft.DIF, 0)
+ domain.FFTInverse(cl, fft.DIF)
fft.BitReverse(cl)
bcl, err = blindPoly(cl, domain.Cardinality, 1)
chDone <- err
@@ -409,13 +450,13 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
go func() {
var err error
copy(cr, lr)
- domain.FFTInverse(cr, fft.DIF, 0)
+ domain.FFTInverse(cr, fft.DIF)
fft.BitReverse(cr)
bcr, err = blindPoly(cr, domain.Cardinality, 1)
chDone <- err
}()
copy(co, lo)
- domain.FFTInverse(co, fft.DIF, 0)
+ domain.FFTInverse(co, fft.DIF)
fft.BitReverse(co)
if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil {
return
@@ -436,9 +477,9 @@ func computeBlindedLRO(ll, lr, lo polynomial.Polynomial, domain *fft.Domain) (bc
// * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1)
//
// WARNING:
-// pre condition degree(cp) <= rou + bo
-// pre condition cap(cp) >= int(totalDegree + 1)
-func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) {
+// pre condition degree(cp) ⩽ rou + bo
+// pre condition cap(cp) ⩾ int(totalDegree + 1)
+func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) {
// degree of the blinded polynomial is max(rou+order, cp.Degree)
totalDegree := rou + bo
@@ -447,7 +488,7 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
res := cp[:totalDegree+1]
// random polynomial
- blindingPoly := make(polynomial.Polynomial, bo+1)
+ blindingPoly := make([]fr.Element, bo+1)
for i := uint64(0); i < bo+1; i++ {
if _, err := blindingPoly[i].SetRandom(); err != nil {
return nil, err
@@ -461,15 +502,16 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
}
return res, nil
+
}
-// computeLRO extracts the solution l, r, o, and returns it in lagrange form.
+// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form.
// solution = [ public | secret | internal ]
-func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) {
+func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) {
- s := int(pk.DomainNum.Cardinality)
+ s := int(pk.Domain[0].Cardinality)
- var l, r, o polynomial.Polynomial
+ var l, r, o []fr.Element
l = make([]fr.Element, s)
r = make([]fr.Element, s)
o = make([]fr.Element, s)
@@ -502,47 +544,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly
//
// * Z of degree n (domainNum.Cardinality)
// * Z(1)=1
-// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma)
-// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1
- u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
}
})
@@ -552,43 +590,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele
Mul(&z[i], &gInv[i])
}
- pk.DomainNum.FFTInverse(z, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(z, fft.DIF)
fft.BitReverse(z)
- return blindPoly(z, pk.DomainNum.Cardinality, 2)
+ return blindPoly(z, pk.Domain[0].Cardinality, 2)
}
-// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on
-// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on
+// the big domain coset.
//
// * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets
// * qk is the completed version of qk, in canonical version
-func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
- var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial
+func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
+ var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element
var wg sync.WaitGroup
wg.Add(4)
go func() {
- evalQl = evaluateHDomain(pk.Ql, &pk.DomainH)
+ evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQr = evaluateHDomain(pk.Qr, &pk.DomainH)
+ evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQm = evaluateHDomain(pk.Qm, &pk.DomainH)
+ evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQo = evaluateHDomain(pk.Qo, &pk.DomainH)
+ evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1])
wg.Done()
}()
- evalQk = evaluateHDomain(qk, &pk.DomainH)
+ evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1])
wg.Wait()
- // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets
- // of (Z/8mZ)/(Z/mZ)
+
+ // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain
utils.Parallelize(len(evalQk), func(start, end int) {
var t0, t1 fr.Element
for i := start; i < end; i++ {
@@ -608,211 +646,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.
return evalQk
}
-// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ)
-func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) {
-
- id = make([]fr.Element, pk.DomainH.Cardinality)
-
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
- var acc fr.Element
- acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start)))
- for i := start; i < end; i++ {
- id[i].Mul(&acc, &pk.DomainH.FinerGenerator)
- acc.Mul(&acc, &pk.DomainH.Generator)
- }
- })
-
- return id
-}
-
-// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
-// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
+// cosets of the big domain.
//
-// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets
-// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets
+// * z evaluation of the blinded permutation accumulator polynomial on odd cosets
+// * l, r, o evaluation of the blinded solution vectors on odd cosets
// * gamma randomization
-func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial {
+func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element {
- // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ)
- evalID := evalIDCosets(pk)
+ nbElmts := int(pk.Domain[1].Cardinality)
- // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ)
- var wg sync.WaitGroup
- wg.Add(2)
- var evalS1, evalS2, evalS3 polynomial.Polynomial
- go func() {
- evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH)
- wg.Done()
- }()
- go func() {
- evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH)
- wg.Done()
- }()
- evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH)
- wg.Wait()
+ // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ)
+ // on the big domain (coset).
+ res := make([]fr.Element, pk.Domain[1].Cardinality)
- // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ)
- res := evalS1 // re use allocated memory for evalS1
- s := uint64(len(evalZ))
- nn := uint64(64 - bits.TrailingZeros64(uint64(s)))
+ nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts)))
// needed to shift evalZ
- toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality
+ toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality)
+
+ var cosetShift, cosetShiftSquare fr.Element
+ cosetShift.Set(&pk.Vk.CosetShift)
+ cosetShiftSquare.Square(&pk.Vk.CosetShift)
+
+ utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) {
+
+ var evaluationIDBigDomain fr.Element
+ evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))).
+ Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen)
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
var f [3]fr.Element
var g [3]fr.Element
- var eID fr.Element
for i := start; i < end; i++ {
- // here we want to left shift evalZ by domainH/domainNum
- // however, evalZ is permuted
- // we take the non permuted index
- // compute the corresponding shift position
- // permute it again
- irev := bits.Reverse64(uint64(i)) >> nn
- eID = evalID[irev]
+ _i := bits.Reverse64(uint64(i)) >> nn
+ _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn
- shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn
- //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn
+ // in what follows gⁱ is understood as the generator of the chosen coset of domainBig
+ f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ
+ f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ
+ f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ
- f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma
- f[1].Mul(&eID, &pk.Vk.Shifter[0])
- f[2].Mul(&eID, &pk.Vk.Shifter[1])
- f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma
- f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma
+ g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ
+ g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ
+ g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ
- g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma
- g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma
- g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma
+ f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
+ g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ)
- f[0].Mul(&f[0], &f[1]).
- Mul(&f[0], &f[2]).
- Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma)
+ res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
- g[0].Mul(&g[0], &g[1]).
- Mul(&g[0], &g[2]).
- Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma)
-
- res[i].Sub(&g[0], &f[0])
+ evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g
}
})
return res
}
-// evaluateHDomain evaluates poly (canonical form) of degree m> nn
- // h[i].Mul(&h[i], &_u[irev%4])
- h[i].Mul(&h[i], &_u[irev%toShift])
+
+ _i := bits.Reverse64(i) >> nn
+
+ t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain
+ h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t).
+ Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]).
+ Mul(&h[_i], &alpha).
+ Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]).
+ Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio])
}
})
// put h in canonical form. h is of degree 3*(n+1)+2.
// using fft.DIT put h revert bit reverse
- pk.DomainH.FFTInverse(h, fft.DIT, 1)
- // fmt.Println("h:")
- // for i := 0; i < len(h); i++ {
- // fmt.Printf("%s\n", h[i].String())
- // }
- // fmt.Println("")
+ pk.Domain[1].FFTInverse(h, fft.DIT, true)
// degree of hi is n+2 because of the blinding
- h1 := h[:pk.DomainNum.Cardinality+2]
- h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)]
- h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)]
+ h1 := h[:pk.Domain[0].Cardinality+2]
+ h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)]
+ h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)]
return h1, h2, h3
@@ -820,78 +801,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom
// computeLinearizedPolynomial computes the linearized polynomial in canonical basis.
// The purpose is to commit and open all in one ql, qr, qm, qo, qk.
-// * a, b, c are the evaluation of l, r, o at zeta
-// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z
+// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta
+// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z
// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk.
-func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial {
+//
+// The Linearized polynomial is:
+//
+// α²*L₁(ζ)*Z(X)
+// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ))
+// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X)
+func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element {
// first part: individual constraints
var rl fr.Element
- rl.Mul(&r, &l)
+ rl.Mul(&rZeta, &lZeta)
- // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
+ // second part:
+ // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)
var s1, s2 fr.Element
chS1 := make(chan struct{}, 1)
go func() {
- s1 = pk.CS1.Eval(&zeta)
- s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma)
+ s1 = eval(pk.S1Canonical, zeta) // s1(ζ)
+ s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ)
close(chS1)
}()
- t := pk.CS2.Eval(&zeta)
- t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma)
+ tmp := eval(pk.S2Canonical, zeta) // s2(ζ)
+ tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ)
<-chS1
- s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma)
- Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)
-
- s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma)
- t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)
- t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
- s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
-
- // third part L1(zeta)*alpha**2**Z
- var lagrange, one, den, frNbElmt fr.Element
+ s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ)
+
+ var uzeta, uuzeta fr.Element
+ uzeta.Mul(&zeta, &pk.Vk.CosetShift)
+ uuzeta.Mul(&uzeta, &pk.Vk.CosetShift)
+
+ s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ)
+ tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)
+ tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+ s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ // third part L₁(ζ)*α²*Z
+ var lagrangeZeta, one, den, frNbElmt fr.Element
one.SetOne()
- nbElmt := int64(pk.DomainNum.Cardinality)
- lagrange.Set(&zeta).
- Exp(lagrange, big.NewInt(nbElmt)).
- Sub(&lagrange, &one)
+ nbElmt := int64(pk.Domain[0].Cardinality)
+ lagrangeZeta.Set(&zeta).
+ Exp(lagrangeZeta, big.NewInt(nbElmt)).
+ Sub(&lagrangeZeta, &one)
frNbElmt.SetUint64(uint64(nbElmt))
den.Sub(&zeta, &one).
- Mul(&den, &frNbElmt).
Inverse(&den)
- lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1)
- Mul(&lagrange, &alpha).
- Mul(&lagrange, &alpha) // alpha**2*L_0
+ lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1)
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ)
- linPol := z.Clone()
+ linPol := make([]fr.Element, len(blindedZCanonical))
+ copy(linPol, blindedZCanonical)
utils.Parallelize(len(linPol), func(start, end int) {
+
var t0, t1 fr.Element
+
for i := start; i < end; i++ {
- linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
- if i < len(pk.CS3) {
- t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X)
+
+ linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ if i < len(pk.S3Canonical) {
+
+ t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X)
+
linPol[i].Add(&linPol[i], &t0)
}
- linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) )
+ linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ))
if i < len(pk.Qm) {
- t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm
- t0.Mul(&pk.Ql[i], &l)
+
+ t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X)
+ t0.Mul(&pk.Ql[i], &lZeta)
t0.Add(&t0, &t1)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X)
- t0.Mul(&pk.Qr[i], &r)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr
+ t0.Mul(&pk.Qr[i], &rZeta)
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X)
- t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i])
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk
+ t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i])
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X)
}
- t0.Mul(&z[i], &lagrange)
+ t0.Mul(&blindedZCanonical[i], &lagrangeZeta)
linPol[i].Add(&linPol[i], &t0) // finish the computation
}
})
diff --git a/internal/backend/bw6-761/plonk/setup.go b/internal/backend/bw6-761/plonk/setup.go
index 5a0741b31e..946d153c3f 100644
--- a/internal/backend/bw6-761/plonk/setup.go
+++ b/internal/backend/bw6-761/plonk/setup.go
@@ -21,7 +21,6 @@ import (
"github.com/consensys/gnark-crypto/ecc/bw6-761/fr"
"github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft"
"github.com/consensys/gnark-crypto/ecc/bw6-761/fr/kzg"
- "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial"
"github.com/consensys/gnark/internal/backend/bw6-761/cs"
kzgg "github.com/consensys/gnark-crypto/kzg"
@@ -40,18 +39,21 @@ type ProvingKey struct {
Vk *VerifyingKey
// qr,ql,qm,qo (in canonical basis).
- Ql, Qr, Qm, Qo polynomial.Polynomial
+ Ql, Qr, Qm, Qo []fr.Element
// LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs.
// Storing LQk in Lagrange basis saves a fft...
- CQk, LQk polynomial.Polynomial
+ CQk, LQk []fr.Element
- // Domains used for the FFTs
- DomainNum, DomainH fft.Domain
+ // Domains used for the FFTs.
+ // Domain[0] = small Domain
+ // Domain[1] = big Domain
+ Domain [2]fft.Domain
+ // Domain[0], Domain[1] fft.Domain
- // s1, s2, s3 (L=Lagrange basis, C=canonical basis)
- LS1, LS2, LS3 polynomial.Polynomial
- CS1, CS2, CS3 polynomial.Polynomial
+ // Permutation polynomials
+ EvaluationPermutationBigDomainBitReversed []fr.Element
+ S1Canonical, S2Canonical, S3Canonical []fr.Element
// position -> permuted position (position in [0,3*sizeSystem-1])
Permutation []int64
@@ -69,13 +71,12 @@ type VerifyingKey struct {
Generator fr.Element
NbPublicVariables uint64
- // shifters for extending the permutation set: from s=<1,z,..,z**n-1>,
- // extended domain = s || shifter[0].s || shifter[1].s
- Shifter [2]fr.Element
-
// Commitment scheme that is used for an instantiation of PLONK
KZGSRS *kzg.SRS
+ // cosetShift generator of the coset on the small domain
+ CosetShift fr.Element
+
// S commitments to S1, S2, S3
S [3]kzg.Digest
@@ -96,37 +97,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
// fft domains
sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints
- pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false)
+ pk.Domain[0] = *fft.NewDomain(sizeSystem)
+ pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen)
// h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space,
// the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases
// except when n<6.
if sizeSystem < 6 {
- pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(8 * sizeSystem)
} else {
- pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(4 * sizeSystem)
}
- vk.Size = pk.DomainNum.Cardinality
+ vk.Size = pk.Domain[0].Cardinality
vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv)
- vk.Generator.Set(&pk.DomainNum.Generator)
+ vk.Generator.Set(&pk.Domain[0].Generator)
vk.NbPublicVariables = uint64(spr.NbPublicVariables)
- // shifters
- vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator)
- vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator)
-
if err := pk.InitKZG(srs); err != nil {
return nil, nil, err
}
// public polynomials corresponding to constraints: [ placholders | constraints | assertions ]
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -134,7 +132,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.Qm[i].SetZero()
pk.Qo[i].SetZero()
pk.CQk[i].SetZero()
- pk.LQk[i].SetZero() // --> to be completed by the prover
+ pk.LQk[i].SetZero() // → to be completed by the prover
}
offset := spr.NbPublicVariables
for i := 0; i < nbConstraints; i++ { // constraints
@@ -148,11 +146,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K])
}
- pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(pk.Ql, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qr, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qm, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qo, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.CQk, fft.DIF)
fft.BitReverse(pk.Ql)
fft.BitReverse(pk.Qr)
fft.BitReverse(pk.Qm)
@@ -163,7 +161,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
buildPermutation(spr, &pk)
// set s1, s2, s3
- computeLDE(&pk)
+ ccomputePermutationPolynomials(&pk)
// Commit to the polynomials to set up the verifying key
var err error
@@ -182,13 +180,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil {
+ if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil {
+ if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil {
+ if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
@@ -200,18 +198,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
//
// The permutation s is composed of cycles of maximum length such that
//
-// s. (l||r||o) = (l||r||o)
+// s. (l∥r∥o) = (l∥r∥o)
//
-//, where l||r||o is the concatenation of the indices of l, r, o in
+//, where l∥r∥o is the concatenation of the indices of l, r, o in
// ql.l+qr.r+qm.l.r+qo.O+k = 0.
//
// The permutation is encoded as a slice s of size 3*size(l), where the
-// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab
+// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab
// like this: for i in tab: tab[i] = tab[permutation[i]]
func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables
- sizeSolution := int(pk.DomainNum.Cardinality)
+ sizeSolution := int(pk.Domain[0].Cardinality)
// init permutation
pk.Permutation = make([]int64, 3*sizeSolution)
@@ -256,60 +254,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
}
}
-// computeLDE computes the LDE (Lagrange basis) of the permutations
+// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations
// s1, s2, s3.
//
-// ex: z gen of Z/mZ, u gen of Z/8mZ, then
-//
// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 |
// |
// | Permutation
// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v
// \---------------/ \--------------------/ \------------------------/
// s1 (LDE) s2 (LDE) s3 (LDE)
-func computeLDE(pk *ProvingKey) {
+func ccomputePermutationPolynomials(pk *ProvingKey) {
- nbElmt := int(pk.DomainNum.Cardinality)
+ nbElmts := int(pk.Domain[0].Cardinality)
- // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1]
- sID := make([]fr.Element, 3*nbElmt)
- sID[0].SetOne()
- sID[nbElmt].Set(&pk.DomainNum.FinerGenerator)
- sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator)
-
- for i := 1; i < nbElmt; i++ {
- sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1
- sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
- }
+ // Lagrange form of ID
+ evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0])
// Lagrange form of S1, S2, S3
- pk.LS1 = make(polynomial.Polynomial, nbElmt)
- pk.LS2 = make(polynomial.Polynomial, nbElmt)
- pk.LS3 = make(polynomial.Polynomial, nbElmt)
- for i := 0; i < nbElmt; i++ {
- pk.LS1[i].Set(&sID[pk.Permutation[i]])
- pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]])
- pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]])
+ pk.S1Canonical = make([]fr.Element, nbElmts)
+ pk.S2Canonical = make([]fr.Element, nbElmts)
+ pk.S3Canonical = make([]fr.Element, nbElmts)
+ for i := 0; i < nbElmts; i++ {
+ pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]])
+ pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]])
+ pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]])
}
// Canonical form of S1, S2, S3
- pk.CS1 = make(polynomial.Polynomial, nbElmt)
- pk.CS2 = make(polynomial.Polynomial, nbElmt)
- pk.CS3 = make(polynomial.Polynomial, nbElmt)
- copy(pk.CS1, pk.LS1)
- copy(pk.CS2, pk.LS2)
- copy(pk.CS3, pk.LS3)
- pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0)
- fft.BitReverse(pk.CS1)
- fft.BitReverse(pk.CS2)
- fft.BitReverse(pk.CS3)
+ pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF)
+ fft.BitReverse(pk.S1Canonical)
+ fft.BitReverse(pk.S2Canonical)
+ fft.BitReverse(pk.S3Canonical)
+
+ // evaluation of permutation on the big domain
+ pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality)
+ copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true)
+
+}
+
+// getIDSmallDomain returns the Lagrange form of ID on the small domain
+func getIDSmallDomain(domain *fft.Domain) []fr.Element {
+
+ res := make([]fr.Element, 3*domain.Cardinality)
+
+ res[0].SetOne()
+ res[domain.Cardinality].Set(&domain.FrMultiplicativeGen)
+ res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen)
+
+ for i := uint64(1); i < domain.Cardinality; i++ {
+ res[i].Mul(&res[i-1], &domain.Generator)
+ res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator)
+ res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator)
+ }
+ return res
}
-// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS
+// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS
//
// This should be used after deserializing a ProvingKey
// as pk.Vk.KZG is NOT serialized
diff --git a/internal/backend/bw6-761/plonk/verify.go b/internal/backend/bw6-761/plonk/verify.go
index 9cba6efb9a..266c8859bc 100644
--- a/internal/backend/bw6-761/plonk/verify.go
+++ b/internal/backend/bw6-761/plonk/verify.go
@@ -43,7 +43,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness
hFunc := sha256.New()
// transcript to derive the challenge
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// derive gamma from Comm(l), Comm(r), Comm(o)
gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2])
@@ -51,6 +51,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness
return err
}
+ // derive beta from Comm(l), Comm(r), Comm(o)
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return err
+ }
+
// derive alpha from Comm(l), Comm(r), Comm(o), Com(Z)
alpha, err := deriveRandomness(&fs, "alpha", &proof.Z)
if err != nil {
@@ -63,7 +69,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness
return err
}
- // evaluation of Z=X**m-1 at zeta
+ // evaluation of Z=Xⁿ⁻¹ at ζ
var zetaPowerM, zzeta fr.Element
var bExpo big.Int
one := fr.One()
@@ -71,20 +77,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness bw6_761witness.Witness
zetaPowerM.Exp(zeta, &bExpo)
zzeta.Sub(&zetaPowerM, &one)
- // ccompute PI = Sum_i maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
}
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
}
- }
+
+ // wait for the level to be done
+ wg.Wait()
- // sanity check; ensure all wires are marked as "instantiated"
- if !solution.isValid() {
- panic("solver didn't instantiate all wires")
+ if len(chError) > 0 {
+ return <-chError
+ }
}
- return solution.values, nil
+ return nil
}
// IsSolved returns nil if given witness solves the R1CS and error otherwise
@@ -167,7 +255,7 @@ func (cs *R1CS) divByCoeff(res *fr.Element, t compiled.Term) {
// returns false, nil if there was no wire to solve
// returns true, nil if exactly one wire was solved. In that case, it is redundant to check that
// the constraint is satisfied later.
-func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool, a,b,c fr.Element, err error) {
+func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution, a,b,c *fr.Element) error {
// the index of the non zero entry shows if L, R or O has an uninstantiated wire
// the content is the ID of the wire non instantiated
@@ -204,28 +292,31 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
return nil
}
- if err = processLExp(r.L.LinExp, &a, 1); err != nil {
- return
+ if err := processLExp(r.L.LinExp, a, 1); err != nil {
+ return err
}
- if err = processLExp(r.R.LinExp, &b, 2); err != nil {
- return
+ if err := processLExp(r.R.LinExp, b, 2); err != nil {
+ return err
}
- if err = processLExp(r.O.LinExp, &c, 3); err != nil {
- return
+ if err := processLExp(r.O.LinExp, c, 3); err != nil {
+ return err
}
if loc == 0 {
// there is nothing to solve, may happen if we have an assertion
// (ie a constraints that doesn't yield any output)
// or if we solved the unsolved wires with hint functions
- return
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
+ return nil
}
// we compute the wire value and instantiate it
- solved = true
- vID := termToCompute.WireID()
+ wID := termToCompute.WireID()
// solver result
var wire fr.Element
@@ -234,36 +325,42 @@ func (cs *R1CS) solveConstraint(r compiled.R1C, solution *solution) (solved bool
switch loc {
case 1:
if !b.IsZero() {
- wire.Div(&c, &b).
- Sub(&wire, &a)
- a.Add(&a, &wire)
+ wire.Div(c, b).
+ Sub(&wire, a)
+ a.Add(a, &wire)
} else {
// we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 2:
if !a.IsZero() {
- wire.Div(&c, &a).
- Sub(&wire, &b)
- b.Add(&b, &wire)
+ wire.Div(c, a).
+ Sub(&wire, b)
+ b.Add(b, &wire)
} else {
- // we didn't actually ensure that a * b == c
- solved = false
+ var check fr.Element
+ if !check.Mul(a, b).Equal(c) {
+ return errUnsatisfiedConstraint
+ }
}
case 3:
- wire.Mul(&a, &b).
- Sub(&wire, &c)
+ wire.Mul(a, b).
+ Sub(&wire, c)
- c.Add(&c, &wire)
+ c.Add(c, &wire)
}
// wire is the term (coeff * value)
// but in the solution we want to store the value only
// note that in gnark frontend, coeff here is always 1 or -1
cs.divByCoeff(&wire, termToCompute)
- solution.set(vID, wire)
+ solution.set(wID, wire)
+
- return
+ return nil
}
// GetConstraints return a list of constraint formatted as L⋅R == O
diff --git a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl
index c572bfa1d4..dbc9915d2c 100644
--- a/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl
+++ b/internal/generator/backend/template/representations/r1cs.sparse.go.tmpl
@@ -6,6 +6,9 @@ import (
"github.com/consensys/gnark-crypto/ecc"
"strings"
"os"
+ "sync"
+ "runtime"
+ "math"
"github.com/consensys/gnark/internal/backend/ioutils"
"github.com/consensys/gnark/internal/backend/compiled"
@@ -71,11 +74,6 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
}
- defer func() {
- // release memory
- solution.tmpHintsIO = nil
- }()
-
// solution.values = [publicInputs | secretInputs | internalVariables ] -> we fill publicInputs | secretInputs
copy(solution.values, witness)
for i := 0; i < len(witness); i++ {
@@ -84,7 +82,7 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
// keep track of the number of wire instantiations we do, for a sanity check to ensure
// we instantiated all wires
- solution.nbSolved += len(witness)
+ solution.nbSolved += uint64(len(witness))
// defer log printing once all solution.values are computed
defer solution.printLogs(opt.LoggerOut, cs.Logs)
@@ -95,19 +93,8 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
coefficientsNegInv[i].Neg(&coefficientsNegInv[i])
}
-
- // loop through the constraints to solve the variables
- for i := 0; i < len(cs.Constraints); i++ {
- if err := cs.solveConstraint(cs.Constraints[i], &solution, coefficientsNegInv); err != nil {
- return solution.values, fmt.Errorf("constraint %d: %w", i, err)
- }
- if err := cs.checkConstraint(cs.Constraints[i], &solution); err != nil {
- errMsg := err.Error()
- if dID, ok := cs.MDebug[i]; ok {
- errMsg = solution.logValue(cs.DebugInfo[dID])
- }
- return solution.values, fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
- }
+ if err := cs.parallelSolve(&solution, coefficientsNegInv); err != nil {
+ return solution.values, err
}
// sanity check; ensure all wires are marked as "instantiated"
@@ -120,6 +107,122 @@ func (cs *SparseR1CS) Solve(witness []fr.Element, opt backend.ProverConfig) ([]f
}
+func (cs *SparseR1CS) parallelSolve(solution *solution, coefficientsNegInv []fr.Element) error {
+ // minWorkPerCPU is the minimum target number of constraint a task should hold
+ // in other words, if a level has less than minWorkPerCPU, it will not be parallelized and executed
+ // sequentially without sync.
+ const minWorkPerCPU = 50.0
+
+ // cs.Levels has a list of levels, where all constraints in a level l(n) are independent
+ // and may only have dependencies on previous levels
+
+ var wg sync.WaitGroup
+ chTasks := make(chan []int, runtime.NumCPU())
+ chError := make(chan error, runtime.NumCPU())
+
+ // start a worker pool
+ // each worker wait on chTasks
+ // a task is a slice of constraint indexes to be solved
+ for i := 0; i < runtime.NumCPU(); i++ {
+ go func() {
+ for t := range chTasks {
+ for _, i := range t {
+ // for each constraint in the task, solve it.
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ wg.Done()
+ return
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ chError <- fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ wg.Done()
+ return
+ }
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ // clean up pool go routines
+ defer func() {
+ close(chTasks)
+ close(chError)
+ }()
+
+ // for each level, we push the tasks
+ for _, level := range cs.Levels {
+
+ // max CPU to use
+ maxCPU := float64(len(level)) / minWorkPerCPU
+
+ if maxCPU <= 1.0 {
+ // we do it sequentially
+ for _, i := range level {
+ if err := cs.solveConstraint(cs.Constraints[i], solution, coefficientsNegInv); err != nil {
+ return fmt.Errorf("constraint #%d is not satisfied: %w", i, err)
+ }
+ if err := cs.checkConstraint(cs.Constraints[i], solution); err != nil {
+ errMsg := err.Error()
+ if dID, ok := cs.MDebug[i]; ok {
+ errMsg = solution.logValue(cs.DebugInfo[dID])
+ }
+ return fmt.Errorf("constraint #%d is not satisfied: %s", i, errMsg)
+ }
+ }
+ continue
+ }
+
+ // number of tasks for this level is set to num cpus
+ // but if we don't have enough work for all our CPUS, it can be lower.
+ nbTasks := runtime.NumCPU()
+ maxTasks := int(math.Ceil(maxCPU))
+ if nbTasks > maxTasks {
+ nbTasks = maxTasks
+ }
+ nbIterationsPerCpus := len(level) / nbTasks
+
+ // more CPUs than tasks: a CPU will work on exactly one iteration
+ // note: this depends on minWorkPerCPU constant
+ if nbIterationsPerCpus < 1 {
+ nbIterationsPerCpus = 1
+ nbTasks = len(level)
+ }
+
+
+ extraTasks := len(level) - (nbTasks * nbIterationsPerCpus)
+ extraTasksOffset := 0
+
+ for i := 0; i < nbTasks; i++ {
+ wg.Add(1)
+ _start := i*nbIterationsPerCpus + extraTasksOffset
+ _end := _start + nbIterationsPerCpus
+ if extraTasks > 0 {
+ _end++
+ extraTasks--
+ extraTasksOffset++
+ }
+ // since we're never pushing more than num CPU tasks
+ // we will never be blocked here
+ chTasks <- level[_start:_end]
+ }
+
+ // wait for the level to be done
+ wg.Wait()
+
+ if len(chError) > 0 {
+ return <-chError
+ }
+ }
+
+ return nil
+}
+
+
// computeHints computes wires associated with a hint function, if any
// if there is no remaining wire to solve, returns -1
diff --git a/internal/generator/backend/template/representations/solution.go.tmpl b/internal/generator/backend/template/representations/solution.go.tmpl
index 43aaca6ee7..dba02f7cb6 100644
--- a/internal/generator/backend/template/representations/solution.go.tmpl
+++ b/internal/generator/backend/template/representations/solution.go.tmpl
@@ -3,6 +3,7 @@ import (
"errors"
"fmt"
"math/big"
+ "sync/atomic"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/internal/backend/compiled"
@@ -13,14 +14,15 @@ import (
{{ template "import_curve" . }}
)
+var errUnsatisfiedConstraint = errors.New("unsatisfied")
+
// solution represents elements needed to compute
// a solution to a R1CS or SparseR1CS
type solution struct {
values, coefficients []fr.Element
solved []bool
- nbSolved int
+ nbSolved uint64
mHintsFunctions map[hint.ID]hint.Function
- tmpHintsIO []*big.Int
}
func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.Element) (solution, error) {
@@ -30,7 +32,6 @@ func newSolution(nbWires int, hintFunctions []hint.Function, coefficients []fr.E
coefficients: coefficients,
solved: make([]bool, nbWires),
mHintsFunctions: make(map[hint.ID]hint.Function, len(hintFunctions)),
- tmpHintsIO: make([]*big.Int, 0),
}
for _, h := range hintFunctions {
@@ -49,11 +50,12 @@ func (s *solution) set(id int, value fr.Element) {
}
s.values[id] = value
s.solved[id] = true
- s.nbSolved++
+ atomic.AddUint64(&s.nbSolved, 1)
+ // s.nbSolved++
}
func (s *solution) isValid() bool {
- return s.nbSolved == len(s.values)
+ return int(s.nbSolved) == len(s.values)
}
// computeTerm computes coef*variable
@@ -128,15 +130,21 @@ func (s *solution) solveWithHint(vID int, h *compiled.Hint) error {
// tmp IO big int memory
nbInputs := len(h.Inputs)
nbOutputs := f.NbOutputs(curve.ID, len(h.Inputs))
- m := len(s.tmpHintsIO)
- if m < (nbInputs + nbOutputs) {
- s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...)
- for i := m; i < len(s.tmpHintsIO); i++ {
- s.tmpHintsIO[i] = big.NewInt(0)
- }
+ // m := len(s.tmpHintsIO)
+ // if m < (nbInputs + nbOutputs) {
+ // s.tmpHintsIO = append(s.tmpHintsIO, make([]*big.Int, (nbOutputs + nbInputs) - m)...)
+ // for i := m; i < len(s.tmpHintsIO); i++ {
+ // s.tmpHintsIO[i] = big.NewInt(0)
+ // }
+ // }
+ inputs := make([]*big.Int, nbInputs)
+ outputs := make([]*big.Int, nbOutputs)
+ for i :=0; i < nbInputs; i++ {
+ inputs[i] = big.NewInt(0)
+ }
+ for i :=0; i < nbOutputs; i++ {
+ outputs[i] = big.NewInt(0)
}
- inputs := s.tmpHintsIO[:nbInputs]
- outputs := s.tmpHintsIO[nbInputs:nbInputs+nbOutputs]
q := fr.Modulus()
diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl
index 58f9cf6dd7..485fcf437a 100644
--- a/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl
+++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl
@@ -259,18 +259,18 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
c = append(c, padding...)
n = len(a)
- domain.FFTInverse(a, fft.DIF, 0)
- domain.FFTInverse(b, fft.DIF, 0)
- domain.FFTInverse(c, fft.DIF, 0)
+ domain.FFTInverse(a, fft.DIF)
+ domain.FFTInverse(b, fft.DIF)
+ domain.FFTInverse(c, fft.DIF)
- domain.FFT(a, fft.DIT, 1)
- domain.FFT(b, fft.DIT, 1)
- domain.FFT(c, fft.DIT, 1)
+ domain.FFT(a, fft.DIT, true)
+ domain.FFT(b, fft.DIT, true)
+ domain.FFT(c, fft.DIT, true)
- var minusTwoInv fr.Element
- minusTwoInv.SetUint64(2)
- minusTwoInv.Neg(&minusTwoInv).
- Inverse(&minusTwoInv)
+ var den, one fr.Element
+ one.SetOne()
+ den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
+ den.Sub(&den, &one).Inverse(&den)
// h = ifft_coset(ca o cb - cc)
// reusing a to avoid unecessary memalloc
@@ -278,12 +278,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
for i := start; i < end; i++ {
a[i].Mul(&a[i], &b[i]).
Sub(&a[i], &c[i]).
- Mul(&a[i], &minusTwoInv)
+ Mul(&a[i], &den)
}
})
// ifft_coset
- domain.FFTInverse(a, fft.DIF, 1)
+ domain.FFTInverse(a, fft.DIF, true)
utils.Parallelize(len(a), func(start, end int) {
for i := start; i < end; i++ {
diff --git a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl
index 867e2aee83..14933f3931 100644
--- a/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl
+++ b/internal/generator/backend/template/zkpschemes/groth16/groth16.setup.go.tmpl
@@ -74,7 +74,7 @@ func Setup(r1cs *cs.R1CS, pk *ProvingKey, vk *VerifyingKey) error {
nbPrivateWires := r1cs.NbSecretVariables + r1cs.NbInternalVariables
// Setting group for fft
- domain := fft.NewDomain(uint64(len(r1cs.Constraints)), 1, true)
+ domain := fft.NewDomain(uint64(len(r1cs.Constraints)))
// samples toxic waste
toxicWaste, err := sampleToxicWaste()
@@ -394,7 +394,7 @@ func DummySetup(r1cs *cs.R1CS, pk *ProvingKey) error {
nbConstraints := len(r1cs.Constraints)
// Setting group for fft
- domain := fft.NewDomain(uint64(nbConstraints), 1, true)
+ domain := fft.NewDomain(uint64(nbConstraints))
// count number of infinity points we would have had we a normal setup
// in pk.G1.A, pk.G1.B, and pk.G2.B
diff --git a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl
index 67ed2b1e93..e5ad79d370 100644
--- a/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl
+++ b/internal/generator/backend/template/zkpschemes/groth16/tests/groth16.marshal.go.tmpl
@@ -166,7 +166,7 @@ func TestProvingKeySerialization(t *testing.T) {
var pk, pkCompressed, pkRaw ProvingKey
// create a random pk
- domain := fft.NewDomain(8, 1, true)
+ domain := fft.NewDomain(8)
pk.Domain = *domain
nbWires := 6
diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl
index ebf5f20ae1..5e03ff6fb2 100644
--- a/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl
+++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.marshal.go.tmpl
@@ -5,11 +5,11 @@ import (
"errors"
)
-// WriteTo writes binary encoding of Proof to w
+// WriteTo writes binary encoding of Proof to w
func (proof *Proof) WriteTo(w io.Writer) (int64, error) {
enc := curve.NewEncoder(w)
- toEncode := []interface{} {
+ toEncode := []interface{}{
&proof.LRO[0],
&proof.LRO[1],
&proof.LRO[2],
@@ -27,10 +27,10 @@ func (proof *Proof) WriteTo(w io.Writer) (int64, error) {
n, err := proof.BatchedProof.WriteTo(w)
if err != nil {
- return n+enc.BytesWritten(), err
+ return n + enc.BytesWritten(), err
}
n2, err := proof.ZShiftedOpening.WriteTo(w)
-
+
return n + n2 + enc.BytesWritten(), err
}
@@ -55,13 +55,11 @@ func (proof *Proof) ReadFrom(r io.Reader) (int64, error) {
n, err := proof.BatchedProof.ReadFrom(r)
if err != nil {
- return n+dec.BytesRead(), err
+ return n + dec.BytesRead(), err
}
n2, err := proof.ZShiftedOpening.ReadFrom(r)
- return n+n2+dec.BytesRead(), err
-}
-
-
+ return n + n2 + dec.BytesRead(), err
+}
// WriteTo writes binary encoding of ProvingKey to w
func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
@@ -72,40 +70,37 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
}
// fft domains
- n2, err := pk.DomainNum.WriteTo(w)
+ n2, err := pk.Domain[0].WriteTo(w)
if err != nil {
- return
+ return
}
- n+=n2
+ n += n2
- n2, err = pk.DomainH.WriteTo(w)
+ n2, err = pk.Domain[1].WriteTo(w)
if err != nil {
- return
+ return
}
- n+=n2
+ n += n2
- // sanity check len(Permutation) == 3*int(pk.DomainNum.Cardinality)
- if len(pk.Permutation) != (3*int(pk.DomainNum.Cardinality)) {
+ // sanity check len(Permutation) == 3*int(pk.Domain[0].Cardinality)
+ if len(pk.Permutation) != (3 * int(pk.Domain[0].Cardinality)) {
return n, errors.New("invalid permutation size, expected 3*domain cardinality")
}
enc := curve.NewEncoder(w)
- // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't
- // encode the size (nor does it convert from Montgomery to Regular form)
+ // note: type Polynomial, which is handled by default binary.Write(...) op and doesn't
+ // encode the size (nor does it convert from Montgomery to Regular form)
// so we explicitly transmit []fr.Element
- toEncode := []interface{} {
- ([]fr.Element)(pk.Ql),
+ toEncode := []interface{}{
+ ([]fr.Element)(pk.Ql),
([]fr.Element)(pk.Qr),
([]fr.Element)(pk.Qm),
([]fr.Element)(pk.Qo),
([]fr.Element)(pk.CQk),
([]fr.Element)(pk.LQk),
- ([]fr.Element)(pk.LS1),
- ([]fr.Element)(pk.LS2),
- ([]fr.Element)(pk.LS3),
- ([]fr.Element)(pk.CS1),
- ([]fr.Element)(pk.CS2),
- ([]fr.Element)(pk.CS3),
+ ([]fr.Element)(pk.S1Canonical),
+ ([]fr.Element)(pk.S2Canonical),
+ ([]fr.Element)(pk.S3Canonical),
pk.Permutation,
}
@@ -118,7 +113,6 @@ func (pk *ProvingKey) WriteTo(w io.Writer) (n int64, err error) {
return n + enc.BytesWritten(), nil
}
-
// ReadFrom reads from binary representation in r into ProvingKey
func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
pk.Vk = &VerifyingKey{}
@@ -127,59 +121,53 @@ func (pk *ProvingKey) ReadFrom(r io.Reader) (int64, error) {
return n, err
}
- n2, err := pk.DomainNum.ReadFrom(r)
- n+=n2
- if err != nil {
- return n, err
+ n2, err := pk.Domain[0].ReadFrom(r)
+ n += n2
+ if err != nil {
+ return n, err
}
- n2, err = pk.DomainH.ReadFrom(r)
- n+=n2
- if err != nil {
- return n, err
+ n2, err = pk.Domain[1].ReadFrom(r)
+ n += n2
+ if err != nil {
+ return n, err
}
- pk.Permutation = make([]int64, 3*pk.DomainNum.Cardinality)
+ pk.Permutation = make([]int64, 3*pk.Domain[0].Cardinality)
dec := curve.NewDecoder(r)
toDecode := []interface{}{
- (*[]fr.Element)(&pk.Ql),
+ (*[]fr.Element)(&pk.Ql),
(*[]fr.Element)(&pk.Qr),
(*[]fr.Element)(&pk.Qm),
(*[]fr.Element)(&pk.Qo),
(*[]fr.Element)(&pk.CQk),
(*[]fr.Element)(&pk.LQk),
- (*[]fr.Element)(&pk.LS1),
- (*[]fr.Element)(&pk.LS2),
- (*[]fr.Element)(&pk.LS3),
- (*[]fr.Element)(&pk.CS1),
- (*[]fr.Element)(&pk.CS2),
- (*[]fr.Element)(&pk.CS3),
+ (*[]fr.Element)(&pk.S1Canonical),
+ (*[]fr.Element)(&pk.S2Canonical),
+ (*[]fr.Element)(&pk.S3Canonical),
&pk.Permutation,
}
for _, v := range toDecode {
if err := dec.Decode(v); err != nil {
- return n +dec.BytesRead(), err
+ return n + dec.BytesRead(), err
}
}
- return n +dec.BytesRead(), nil
+ return n + dec.BytesRead(), nil
}
-
// WriteTo writes binary encoding of VerifyingKey to w
func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) {
enc := curve.NewEncoder(w)
- toEncode := []interface{} {
- vk.Size,
+ toEncode := []interface{}{
+ vk.Size,
&vk.SizeInv,
&vk.Generator,
vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
@@ -196,7 +184,6 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) {
}
}
-
return enc.BytesWritten(), nil
}
@@ -204,12 +191,10 @@ func (vk *VerifyingKey) WriteTo(w io.Writer) (n int64, err error) {
func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) {
dec := curve.NewDecoder(r)
toDecode := []interface{}{
- &vk.Size,
+ &vk.Size,
&vk.SizeInv,
&vk.Generator,
&vk.NbPublicVariables,
- &vk.Shifter[0],
- &vk.Shifter[1],
&vk.S[0],
&vk.S[1],
&vk.S[2],
@@ -227,4 +212,4 @@ func (vk *VerifyingKey) ReadFrom(r io.Reader) (int64, error) {
}
return dec.BytesRead(), nil
-}
\ No newline at end of file
+}
\ No newline at end of file
diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl
index a48210d151..d56a8c8b24 100644
--- a/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl
+++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.prove.go.tmpl
@@ -7,7 +7,6 @@ import (
{{ template "import_fr" . }}
{{ template "import_curve" . }}
- {{ template "import_polynomial" . }}
{{ template "import_kzg" . }}
{{ template "import_fft" . }}
{{ template "import_witness" . }}
@@ -18,7 +17,9 @@ import (
"github.com/consensys/gnark-crypto/fiat-shamir"
)
+
type Proof struct {
+
// Commitments to the solution vectors
LRO [3]kzg.Digest
@@ -42,14 +43,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
hFunc := sha256.New()
// create a transcript manager to apply Fiat Shamir
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// result
proof := &Proof{}
// compute the constraint system solution
var solution []fr.Element
- var err error
+ var err error
if solution, err = spr.Solve(fullWitness, opt); err != nil {
if !opt.Force {
return nil, err
@@ -65,17 +66,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
}
// query l, r, o in Lagrange basis, not blinded
- ll, lr, lo := computeLRO(spr, pk, solution)
+ evaluationLDomainSmall, evaluationRDomainSmall, evaluationODomainSmall := evaluateLROSmallDomain(spr, pk, solution)
// save ll, lr, lo, and make a copy of them in canonical basis.
// note that we allocate more capacity to reuse for blinded polynomials
- bcl, bcr, bco, err := computeBlindedLRO(ll, lr, lo, &pk.DomainNum)
+ blindedLCanonical, blindedRCanonical, blindedOCanonical, err := computeBlindedLROCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ &pk.Domain[0])
if err != nil {
return nil, err
}
// compute kzg commitments of bcl, bcr and bco
- if err := commitToLRO(bcl, bcr, bco, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToLRO(blindedLCanonical, blindedRCanonical, blindedOCanonical, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -85,18 +90,28 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
return nil, err
}
+ // Fiat Shamir this
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return nil, err
+ }
+
// compute Z, the permutation accumulator polynomial, in canonical basis
// ll, lr, lo are NOT blinded
- var bz polynomial.Polynomial
+ var blindedZCanonical []fr.Element
chZ := make(chan error, 1)
var alpha fr.Element
go func() {
- var err error
- bz, err = computeBlindedZ(ll, lr, lo, pk, gamma)
+ var err error
+ blindedZCanonical, err = computeBlindedZCanonical(
+ evaluationLDomainSmall,
+ evaluationRDomainSmall,
+ evaluationODomainSmall,
+ pk, beta, gamma)
if err != nil {
- chZ <- err
+ chZ <- err
close(chZ)
- return
+ return
}
// commit to the blinded version of z
@@ -104,7 +119,7 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
// this may add additional arithmetic operations, but with smaller tasks
// we ensure that this commitment is well parallelized, without having a "unbalanced task" making
// the rest of the code wait too long.
- if proof.Z, err = kzg.Commit(bz, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
+ if proof.Z, err = kzg.Commit(blindedZCanonical, pk.Vk.KZGSRS, runtime.NumCPU()*2); err != nil {
chZ <- err
close(chZ)
return
@@ -117,40 +132,50 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
}()
// evaluation of the blinded versions of l, r, o and bz
- // on the odd cosets of (Z/8mZ)/(Z/mZ)
- var evalBL, evalBR, evalBO, evalBZ polynomial.Polynomial
+ // on the coset of the big domain
+ var (
+ evaluationBlindedLDomainBigBitReversed []fr.Element
+ evaluationBlindedRDomainBigBitReversed []fr.Element
+ evaluationBlindedODomainBigBitReversed []fr.Element
+ evaluationBlindedZDomainBigBitReversed []fr.Element
+ )
chEvalBL := make(chan struct{}, 1)
chEvalBR := make(chan struct{}, 1)
chEvalBO := make(chan struct{}, 1)
go func() {
- evalBL = evaluateHDomain(bcl, &pk.DomainH)
+ evaluationBlindedLDomainBigBitReversed = evaluateDomainBigBitReversed(blindedLCanonical, &pk.Domain[1])
close(chEvalBL)
}()
go func() {
- evalBR = evaluateHDomain(bcr, &pk.DomainH)
+ evaluationBlindedRDomainBigBitReversed = evaluateDomainBigBitReversed(blindedRCanonical, &pk.Domain[1])
close(chEvalBR)
}()
go func() {
- evalBO = evaluateHDomain(bco, &pk.DomainH)
+ evaluationBlindedODomainBigBitReversed = evaluateDomainBigBitReversed(blindedOCanonical, &pk.Domain[1])
close(chEvalBO)
}()
- var constraintsInd, constraintsOrdering polynomial.Polynomial
+ var constraintsInd, constraintsOrdering []fr.Element
chConstraintInd := make(chan struct{}, 1)
go func() {
// compute qk in canonical basis, completed with the public inputs
- qk := make(polynomial.Polynomial, pk.DomainNum.Cardinality)
- copy(qk, fullWitness[:spr.NbPublicVariables])
- copy(qk[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
- pk.DomainNum.FFTInverse(qk, fft.DIF, 0)
- fft.BitReverse(qk)
-
- // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the odd cosets of (Z/8mZ)/(Z/mZ)
- // --> uses the blinded version of l, r, o
+ qkCompletedCanonical := make([]fr.Element, pk.Domain[0].Cardinality)
+ copy(qkCompletedCanonical, fullWitness[:spr.NbPublicVariables])
+ copy(qkCompletedCanonical[spr.NbPublicVariables:], pk.LQk[spr.NbPublicVariables:])
+ pk.Domain[0].FFTInverse(qkCompletedCanonical, fft.DIF)
+ fft.BitReverse(qkCompletedCanonical)
+
+ // compute the evaluation of qlL+qrR+qmL.R+qoO+k on the coset of the big domain
+ // → uses the blinded version of l, r, o
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsInd = evalConstraints(pk, evalBL, evalBR, evalBO, qk)
+ constraintsInd = evaluateConstraintsDomainBigBitReversed(
+ pk,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ qkCompletedCanonical)
close(chConstraintInd)
}()
@@ -160,13 +185,21 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
chConstraintOrdering <- err
return
}
- evalBZ = evaluateHDomain(bz, &pk.DomainH)
- // compute zu*g1*g2*g3-z*f1*f2*f3 on the odd cosets of (Z/8mZ)/(Z/mZ)
+
+ evaluationBlindedZDomainBigBitReversed = evaluateDomainBigBitReversed(blindedZCanonical, &pk.Domain[1])
+ // compute zu*g1*g2*g3-z*f1*f2*f3 on the coset of the big domain
// evalL, evalO, evalR are the evaluations of the blinded versions of l, r, o.
<-chEvalBL
<-chEvalBR
<-chEvalBO
- constraintsOrdering = evalConstraintOrdering(pk, evalBZ, evalBL, evalBR, evalBO, gamma)
+ constraintsOrdering = evaluateOrderingDomainBigBitReversed(
+ pk,
+ evaluationBlindedZDomainBigBitReversed,
+ evaluationBlindedLDomainBigBitReversed,
+ evaluationBlindedRDomainBigBitReversed,
+ evaluationBlindedODomainBigBitReversed,
+ beta,
+ gamma)
chConstraintOrdering <- nil
close(chConstraintOrdering)
}()
@@ -174,12 +207,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
if err := <-chConstraintOrdering; err != nil {
return nil, err
}
+
<-chConstraintInd
+
// compute h in canonical form
- h1, h2, h3 := computeH(pk, constraintsInd, constraintsOrdering, evalBZ, alpha)
+ h1, h2, h3 := computeQuotientCanonical(pk, constraintsInd, constraintsOrdering, evaluationBlindedZDomainBigBitReversed, alpha)
// compute kzg commitments of h1, h2 and h3
- if err := commitToH(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
+ if err := commitToQuotient(h1, h2, h3, proof, pk.Vk.KZGSRS); err != nil {
return nil, err
}
@@ -194,15 +229,15 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
var wgZetaEvals sync.WaitGroup
wgZetaEvals.Add(3)
go func() {
- blzeta = bcl.Eval(&zeta)
+ blzeta = eval(blindedLCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- brzeta = bcr.Eval(&zeta)
+ brzeta = eval(blindedRCanonical, zeta)
wgZetaEvals.Done()
}()
go func() {
- bozeta = bco.Eval(&zeta)
+ bozeta = eval(blindedOCanonical, zeta)
wgZetaEvals.Done()
}()
@@ -210,9 +245,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
var zetaShifted fr.Element
zetaShifted.Mul(&zeta, &pk.Vk.Generator)
proof.ZShiftedOpening, err = kzg.Open(
- bz,
- &zetaShifted,
- &pk.DomainH,
+ blindedZCanonical,
+ zetaShifted,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -223,53 +257,54 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
bzuzeta := proof.ZShiftedOpening.ClaimedValue
var (
- linearizedPolynomial polynomial.Polynomial
- linearizedPolynomialDigest curve.G1Affine
- errLPoly error
+ linearizedPolynomialCanonical []fr.Element
+ linearizedPolynomialDigest curve.G1Affine
+ errLPoly error
)
chLpoly := make(chan struct{}, 1)
go func() {
// compute the linearization polynomial r at zeta (goal: save committing separately to z, ql, qr, qm, qo, k)
wgZetaEvals.Wait()
- linearizedPolynomial = computeLinearizedPolynomial(
+ linearizedPolynomialCanonical = computeLinearizedPolynomial(
blzeta,
brzeta,
bozeta,
alpha,
+ beta,
gamma,
zeta,
bzuzeta,
- bz,
+ blindedZCanonical,
pk,
)
// TODO this commitment is only necessary to derive the challenge, we should
// be able to avoid doing it and get the challenge in another way
- linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomial, pk.Vk.KZGSRS)
+ linearizedPolynomialDigest, errLPoly = kzg.Commit(linearizedPolynomialCanonical, pk.Vk.KZGSRS)
close(chLpoly)
}()
- // foldedHDigest = Comm(h1) + zeta**m*Comm(h2) + zeta**2m*Comm(h3)
+ // foldedHDigest = Comm(h1) + ζᵐ⁺²*Comm(h2) + ζ²⁽ᵐ⁺²⁾*Comm(h3)
var bZetaPowerm, bSize big.Int
- bSize.SetUint64(pk.DomainNum.Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
+ bSize.SetUint64(pk.Domain[0].Cardinality + 2) // +2 because of the masking (h of degree 3(n+2)-1)
var zetaPowerm fr.Element
zetaPowerm.Exp(zeta, &bSize)
zetaPowerm.ToBigIntRegular(&bZetaPowerm)
foldedHDigest := proof.H[2]
foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm)
- foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // zeta**(m+1)*Comm(h3)
- foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2)
- foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // zeta**2(m+1)*Comm(h3) + zeta**(m+1)*Comm(h2) + Comm(h1)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[1]) // ζᵐ⁺²*Comm(h3)
+ foldedHDigest.ScalarMultiplication(&foldedHDigest, &bZetaPowerm) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2)
+ foldedHDigest.Add(&foldedHDigest, &proof.H[0]) // ζ²⁽ᵐ⁺²⁾*Comm(h3) + ζᵐ⁺²*Comm(h2) + Comm(h1)
- // foldedH = h1 + zeta*h2 + zeta**2*h3
+ // foldedH = h1 + ζ*h2 + ζ²*h3
foldedH := h3
utils.Parallelize(len(foldedH), func(start, end int) {
for i := start; i < end; i++ {
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**(m+1)*h3
- foldedH[i].Add(&foldedH[i], &h2[i]) // zeta**(m+1)*h3
- foldedH[i].Mul(&foldedH[i], &zetaPowerm) // zeta**2(m+1)*h3+h2*zeta**(m+1)
- foldedH[i].Add(&foldedH[i], &h1[i]) // zeta**2(m+1)*h3+zeta**(m+1)*h2 + h1
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζᵐ⁺²*h3
+ foldedH[i].Add(&foldedH[i], &h2[i]) // ζ^{m+2)*h3+h2
+ foldedH[i].Mul(&foldedH[i], &zetaPowerm) // ζ²⁽ᵐ⁺²⁾*h3+h2*ζᵐ⁺²
+ foldedH[i].Add(&foldedH[i], &h1[i]) // ζ^{2(m+2)*h3+ζᵐ⁺²*h2 + h1
}
})
@@ -280,14 +315,14 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
// Batch open the first list of polynomials
proof.BatchedProof, err = kzg.BatchOpenSinglePoint(
- []polynomial.Polynomial{
+ [][]fr.Element{
foldedH,
- linearizedPolynomial,
- bcl,
- bcr,
- bco,
- pk.CS1,
- pk.CS2,
+ linearizedPolynomialCanonical,
+ blindedLCanonical,
+ blindedRCanonical,
+ blindedOCanonical,
+ pk.S1Canonical,
+ pk.S2Canonical,
},
[]kzg.Digest{
foldedHDigest,
@@ -298,9 +333,8 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
pk.Vk.S[0],
pk.Vk.S[1],
},
- &zeta,
+ zeta,
hFunc,
- &pk.DomainH,
pk.Vk.KZGSRS,
)
if err != nil {
@@ -311,8 +345,17 @@ func Prove(spr *cs.SparseR1CS, pk *ProvingKey, fullWitness {{ toLower .CurveID }
}
+// eval evaluates c at p
+func eval(c []fr.Element, p fr.Element) fr.Element {
+ var r fr.Element
+ for i := len(c) - 1; i >= 0; i-- {
+ r.Mul(&r, &p).Add(&r, &c[i])
+ }
+ return r
+}
+
// fills proof.LRO with kzg commits of bcl, bcr and bco
-func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToLRO(bcl, bcr, bco []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -338,7 +381,7 @@ func commitToLRO(bcl, bcr, bco polynomial.Polynomial, proof *Proof, srs *kzg.SRS
return err1
}
-func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) error {
+func commitToQuotient(h1, h2, h3 []fr.Element, proof *Proof, srs *kzg.SRS) error {
n := runtime.NumCPU() / 2
var err0, err1, err2 error
chCommit0 := make(chan struct{}, 1)
@@ -364,44 +407,44 @@ func commitToH(h1, h2, h3 polynomial.Polynomial, proof *Proof, srs *kzg.SRS) err
return err1
}
-// computeBlindedLRO l, r, o in canonical basis with blinding
-func computeBlindedLRO(ll,lr,lo polynomial.Polynomial, domain *fft.Domain) (bcl, bcr, bco polynomial.Polynomial, err error) {
-
+// computeBlindedLROCanonical l, r, o in canonical basis with blinding
+func computeBlindedLROCanonical(ll, lr, lo []fr.Element, domain *fft.Domain) (bcl, bcr, bco []fr.Element, err error) {
+
// note that bcl, bcr and bco reuses cl, cr and co memory
- cl := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality + 2)
- cr := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality + 2)
- co := make(polynomial.Polynomial, domain.Cardinality, domain.Cardinality + 2)
-
+ cl := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ cr := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+ co := make([]fr.Element, domain.Cardinality, domain.Cardinality+2)
+
chDone := make(chan error, 2)
go func() {
- var err error
+ var err error
copy(cl, ll)
- domain.FFTInverse(cl, fft.DIF, 0)
+ domain.FFTInverse(cl, fft.DIF)
fft.BitReverse(cl)
bcl, err = blindPoly(cl, domain.Cardinality, 1)
- chDone <- err
+ chDone <- err
}()
go func() {
var err error
copy(cr, lr)
- domain.FFTInverse(cr, fft.DIF, 0)
+ domain.FFTInverse(cr, fft.DIF)
fft.BitReverse(cr)
bcr, err = blindPoly(cr, domain.Cardinality, 1)
- chDone <- err
+ chDone <- err
}()
copy(co, lo)
- domain.FFTInverse(co, fft.DIF, 0)
+ domain.FFTInverse(co, fft.DIF)
fft.BitReverse(co)
if bco, err = blindPoly(co, domain.Cardinality, 1); err != nil {
- return
+ return
}
- err = <-chDone
+ err = <-chDone
if err != nil {
return
}
- err = <-chDone
- return
+ err = <-chDone
+ return
}
@@ -412,9 +455,9 @@ func computeBlindedLRO(ll,lr,lo polynomial.Polynomial, domain *fft.Domain) (bcl,
// * bo blinding order, it's the degree of Q, where the blinding is Q(X)*(X**degree-1)
//
// WARNING:
-// pre condition degree(cp) <= rou + bo
-// pre condition cap(cp) >= int(totalDegree + 1)
-func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial, error) {
+// pre condition degree(cp) ⩽ rou + bo
+// pre condition cap(cp) ⩾ int(totalDegree + 1)
+func blindPoly(cp []fr.Element, rou, bo uint64) ([]fr.Element, error) {
// degree of the blinded polynomial is max(rou+order, cp.Degree)
totalDegree := rou + bo
@@ -423,10 +466,10 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
res := cp[:totalDegree+1]
// random polynomial
- blindingPoly := make(polynomial.Polynomial, bo+1)
+ blindingPoly := make([]fr.Element, bo+1)
for i := uint64(0); i < bo+1; i++ {
if _, err := blindingPoly[i].SetRandom(); err != nil {
- return nil, err
+ return nil, err
}
}
@@ -436,16 +479,17 @@ func blindPoly(cp polynomial.Polynomial, rou, bo uint64) (polynomial.Polynomial,
res[rou+i].Add(&res[rou+i], &blindingPoly[i])
}
- return res, nil
+ return res, nil
+
}
-// computeLRO extracts the solution l, r, o, and returns it in lagrange form.
+// evaluateLROSmallDomain extracts the solution l, r, o, and returns it in lagrange form.
// solution = [ public | secret | internal ]
-func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (polynomial.Polynomial, polynomial.Polynomial, polynomial.Polynomial) {
+func evaluateLROSmallDomain(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) ([]fr.Element, []fr.Element, []fr.Element) {
- s := int(pk.DomainNum.Cardinality)
+ s := int(pk.Domain[0].Cardinality)
- var l, r, o polynomial.Polynomial
+ var l, r, o []fr.Element
l = make([]fr.Element, s)
r = make([]fr.Element, s)
o = make([]fr.Element, s)
@@ -478,47 +522,43 @@ func computeLRO(spr *cs.SparseR1CS, pk *ProvingKey, solution []fr.Element) (poly
//
// * Z of degree n (domainNum.Cardinality)
// * Z(1)=1
-// (l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2z**i+gamma)
-// * for i>0: Z(u**i) = Pi_{k0: Z(gⁱ) = Π_{k z**i+1
- u[1].Mul(&u[1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- u[2].Mul(&u[2], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
}
})
@@ -528,43 +568,43 @@ func computeBlindedZ(l, r, o polynomial.Polynomial, pk *ProvingKey, gamma fr.Ele
Mul(&z[i], &gInv[i])
}
- pk.DomainNum.FFTInverse(z, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(z, fft.DIF)
fft.BitReverse(z)
- return blindPoly(z, pk.DomainNum.Cardinality, 2)
+ return blindPoly(z, pk.Domain[0].Cardinality, 2)
}
-// evalConstraints computes the evaluation of lL+qrR+qqmL.R+qoO+k on
-// the odd cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateConstraintsDomainBigBitReversed computes the evaluation of lL+qrR+qqmL.R+qoO+k on
+// the big domain coset.
//
// * evalL, evalR, evalO are the evaluation of the blinded solution vectors on odd cosets
// * qk is the completed version of qk, in canonical version
-func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
- var evalQl, evalQr, evalQm, evalQo, evalQk polynomial.Polynomial
+func evaluateConstraintsDomainBigBitReversed(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.Element {
+ var evalQl, evalQr, evalQm, evalQo, evalQk []fr.Element
var wg sync.WaitGroup
wg.Add(4)
go func() {
- evalQl = evaluateHDomain(pk.Ql, &pk.DomainH)
+ evalQl = evaluateDomainBigBitReversed(pk.Ql, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQr = evaluateHDomain(pk.Qr, &pk.DomainH)
+ evalQr = evaluateDomainBigBitReversed(pk.Qr, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQm = evaluateHDomain(pk.Qm, &pk.DomainH)
+ evalQm = evaluateDomainBigBitReversed(pk.Qm, &pk.Domain[1])
wg.Done()
}()
go func() {
- evalQo = evaluateHDomain(pk.Qo, &pk.DomainH)
+ evalQo = evaluateDomainBigBitReversed(pk.Qo, &pk.Domain[1])
wg.Done()
}()
- evalQk = evaluateHDomain(qk, &pk.DomainH)
+ evalQk = evaluateDomainBigBitReversed(qk, &pk.Domain[1])
wg.Wait()
- // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the odd cosets
- // of (Z/8mZ)/(Z/mZ)
+
+ // computes the evaluation of qrR+qlL+qmL.R+qoO+k on the coset of the big domain
utils.Parallelize(len(evalQk), func(start, end int) {
var t0, t1 fr.Element
for i := start; i < end; i++ {
@@ -584,211 +624,154 @@ func evalConstraints(pk *ProvingKey, evalL, evalR, evalO, qk []fr.Element) []fr.
return evalQk
}
-// evalIDCosets id, uid, u**2id on the odd cosets of (Z/8mZ)/(Z/mZ)
-func evalIDCosets(pk *ProvingKey) (id polynomial.Polynomial) {
-
- id = make([]fr.Element, pk.DomainH.Cardinality)
-
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
- var acc fr.Element
- acc.Exp(pk.DomainH.Generator, new(big.Int).SetInt64(int64(start)))
- for i := start; i < end; i++ {
- id[i].Mul(&acc, &pk.DomainH.FinerGenerator)
- acc.Mul(&acc, &pk.DomainH.Generator)
- }
- })
-
- return id
-}
-
-// evalConstraintOrdering computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
-// cosets of (Z/8mZ)/(Z/mZ), where m=nbConstraints+nbAssertions.
+// evaluateOrderingDomainBigBitReversed computes the evaluation of Z(uX)g1g2g3-Z(X)f1f2f3 on the odd
+// cosets of the big domain.
//
-// * evalZ evaluation of the blinded permutation accumulator polynomial on odd cosets
-// * evalL, evalR, evalO evaluation of the blinded solution vectors on odd cosets
+// * z evaluation of the blinded permutation accumulator polynomial on odd cosets
+// * l, r, o evaluation of the blinded solution vectors on odd cosets
// * gamma randomization
-func evalConstraintOrdering(pk *ProvingKey, evalZ, evalL, evalR, evalO polynomial.Polynomial, gamma fr.Element) polynomial.Polynomial {
+func evaluateOrderingDomainBigBitReversed(pk *ProvingKey, z, l, r, o []fr.Element, beta, gamma fr.Element) []fr.Element {
- // evalutation of ID the odd cosets of (Z/8mZ)/(Z/mZ)
- evalID := evalIDCosets(pk)
+ nbElmts := int(pk.Domain[1].Cardinality)
- // evaluation of z, zu, s1, s2, s3, on the odd cosets of (Z/8mZ)/(Z/mZ)
- var wg sync.WaitGroup
- wg.Add(2)
- var evalS1, evalS2, evalS3 polynomial.Polynomial
- go func() {
- evalS1 = evaluateHDomain(pk.CS1, &pk.DomainH)
- wg.Done()
- }()
- go func() {
- evalS2 = evaluateHDomain(pk.CS2, &pk.DomainH)
- wg.Done()
- }()
- evalS3 = evaluateHDomain(pk.CS3, &pk.DomainH)
- wg.Wait()
+ // computes z_(uX)*(l(X)+s₁(X)*β+γ)*(r(X))+s₂(gⁱ)*β+γ)*(o(X))+s₃(X)*β+γ) - z(X)*(l(X)+X*β+γ)*(r(X)+u*X*β+γ)*(o(X)+u²*X*β+γ)
+ // on the big domain (coset).
+ res := make([]fr.Element, pk.Domain[1].Cardinality)
- // computes Z(uX)g1g2g3l-Z(X)f1f2f3l on the odd cosets of (Z/8mZ)/(Z/mZ)
- res := evalS1 // re use allocated memory for evalS1
- s := uint64(len(evalZ))
- nn := uint64(64 - bits.TrailingZeros64(uint64(s)))
+ nn := uint64(64 - bits.TrailingZeros64(uint64(nbElmts)))
// needed to shift evalZ
- toShift := pk.DomainH.Cardinality / pk.DomainNum.Cardinality
+ toShift := int(pk.Domain[1].Cardinality / pk.Domain[0].Cardinality)
+
+ var cosetShift, cosetShiftSquare fr.Element
+ cosetShift.Set(&pk.Vk.CosetShift)
+ cosetShiftSquare.Square(&pk.Vk.CosetShift)
+
+ utils.Parallelize(int(pk.Domain[1].Cardinality), func(start, end int) {
+
+ var evaluationIDBigDomain fr.Element
+ evaluationIDBigDomain.Exp(pk.Domain[1].Generator, big.NewInt(int64(start))).
+ Mul(&evaluationIDBigDomain, &pk.Domain[1].FrMultiplicativeGen)
- utils.Parallelize(int(pk.DomainH.Cardinality), func(start, end int) {
var f [3]fr.Element
var g [3]fr.Element
- var eID fr.Element
for i := start; i < end; i++ {
- // here we want to left shift evalZ by domainH/domainNum
- // however, evalZ is permuted
- // we take the non permuted index
- // compute the corresponding shift position
- // permute it again
- irev := bits.Reverse64(uint64(i)) >> nn
- eID = evalID[irev]
-
- shiftedZ := bits.Reverse64(uint64((irev+toShift)%s)) >> nn
- //shiftedZ := bits.Reverse64(uint64((irev+4)%s)) >> nn
+ _i := bits.Reverse64(uint64(i)) >> nn
+ _is := bits.Reverse64(uint64((i+toShift)%nbElmts)) >> nn
- f[0].Add(&eID, &evalL[i]).Add(&f[0], &gamma) //l_i+z**i+gamma
- f[1].Mul(&eID, &pk.Vk.Shifter[0])
- f[2].Mul(&eID, &pk.Vk.Shifter[1])
- f[1].Add(&f[1], &evalR[i]).Add(&f[1], &gamma) //r_i+u*z**i+gamma
- f[2].Add(&f[2], &evalO[i]).Add(&f[2], &gamma) //o_i+u**2*z**i+gamma
+ // in what follows gⁱ is understood as the generator of the chosen coset of domainBig
+ f[0].Mul(&evaluationIDBigDomain, &beta).Add(&f[0], &l[_i]).Add(&f[0], &gamma) //l(gⁱ)+gⁱ*β+γ
+ f[1].Mul(&evaluationIDBigDomain, &cosetShift).Mul(&f[1], &beta).Add(&f[1], &r[_i]).Add(&f[1], &gamma) //r(gⁱ)+u*gⁱ*β+γ
+ f[2].Mul(&evaluationIDBigDomain, &cosetShiftSquare).Mul(&f[2], &beta).Add(&f[2], &o[_i]).Add(&f[2], &gamma) //o(gⁱ)+u²*gⁱ*β+γ
- g[0].Add(&evalL[i], &evalS1[i]).Add(&g[0], &gamma) //l_i+s1+gamma
- g[1].Add(&evalR[i], &evalS2[i]).Add(&g[1], &gamma) //r_i+s2+gamma
- g[2].Add(&evalO[i], &evalS3[i]).Add(&g[2], &gamma) //o_i+s3+gamma
+ g[0].Mul(&pk.EvaluationPermutationBigDomainBitReversed[_i], &beta).Add(&g[0], &l[_i]).Add(&g[0], &gamma) //l(gⁱ))+s1(gⁱ)*β+γ
+ g[1].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+nbElmts], &beta).Add(&g[1], &r[_i]).Add(&g[1], &gamma) //r(gⁱ))+s2(gⁱ)*β+γ
+ g[2].Mul(&pk.EvaluationPermutationBigDomainBitReversed[int(_i)+2*nbElmts], &beta).Add(&g[2], &o[_i]).Add(&g[2], &gamma) //o(gⁱ))+s3(gⁱ)*β+γ
- f[0].Mul(&f[0], &f[1]).
- Mul(&f[0], &f[2]).
- Mul(&f[0], &evalZ[i]) // z_i*(l_i+z**i+gamma)*(r_i+u*z**i+gamma)*(o_i+u**2*z**i+gamma)
+ f[0].Mul(&f[0], &f[1]).Mul(&f[0], &f[2]).Mul(&f[0], &z[_i]) // z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
+ g[0].Mul(&g[0], &g[1]).Mul(&g[0], &g[2]).Mul(&g[0], &z[_is]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ)
- g[0].Mul(&g[0], &g[1]).
- Mul(&g[0], &g[2]).
- Mul(&g[0], &evalZ[shiftedZ]) // u*z_i*(l_i+s1+gamma)*(r_i+s2+gamma)*(o_i+s3+gamma)
+ res[_i].Sub(&g[0], &f[0]) // z_(ugⁱ)*(l(gⁱ))+s₁(gⁱ)*β+γ)*(r(gⁱ))+s₂(gⁱ)*β+γ)*(o(gⁱ))+s₃(gⁱ)*β+γ) - z(gⁱ)*(l(gⁱ)+g^i*β+γ)*(r(g^i)+u*g^i*β+γ)*(o(g^i)+u²*g^i*β+γ)
- res[i].Sub(&g[0], &f[0])
+ evaluationIDBigDomain.Mul(&evaluationIDBigDomain, &pk.Domain[1].Generator) // gⁱ*g
}
})
return res
}
-// evaluateHDomain evaluates poly (canonical form) of degree m> nn
- // h[i].Mul(&h[i], &_u[irev%4])
- h[i].Mul(&h[i], &_u[irev%toShift])
+
+ _i := bits.Reverse64(i) >> nn
+
+ t.Sub(&evaluationBlindedZDomainBigBitReversed[_i], &one) // evaluates L₁(X)*(Z(X)-1) on a coset of the big domain
+ h[_i].Mul(&startsAtOne[_i], &alpha).Mul(&h[_i], &t).
+ Add(&h[_i], &evaluationConstraintOrderingBitReversed[_i]).
+ Mul(&h[_i], &alpha).
+ Add(&h[_i], &evaluationConstraintsIndBitReversed[_i]).
+ Mul(&h[_i], &evaluationXnMinusOneInverse[i%ratio])
}
})
// put h in canonical form. h is of degree 3*(n+1)+2.
// using fft.DIT put h revert bit reverse
- pk.DomainH.FFTInverse(h, fft.DIT, 1)
- // fmt.Println("h:")
- // for i := 0; i < len(h); i++ {
- // fmt.Printf("%s\n", h[i].String())
- // }
- // fmt.Println("")
+ pk.Domain[1].FFTInverse(h, fft.DIT, true)
// degree of hi is n+2 because of the blinding
- h1 := h[:pk.DomainNum.Cardinality+2]
- h2 := h[pk.DomainNum.Cardinality+2 : 2*(pk.DomainNum.Cardinality+2)]
- h3 := h[2*(pk.DomainNum.Cardinality+2) : 3*(pk.DomainNum.Cardinality+2)]
+ h1 := h[:pk.Domain[0].Cardinality+2]
+ h2 := h[pk.Domain[0].Cardinality+2 : 2*(pk.Domain[0].Cardinality+2)]
+ h3 := h[2*(pk.Domain[0].Cardinality+2) : 3*(pk.Domain[0].Cardinality+2)]
return h1, h2, h3
@@ -796,78 +779,96 @@ func computeH(pk *ProvingKey, constraintsInd, constraintOrdering, evalBZ polynom
// computeLinearizedPolynomial computes the linearized polynomial in canonical basis.
// The purpose is to commit and open all in one ql, qr, qm, qo, qk.
-// * a, b, c are the evaluation of l, r, o at zeta
-// * z is the permutation polynomial, zu is Z(uX), the shifted version of Z
+// * lZeta, rZeta, oZeta are the evaluation of l, r, o at zeta
+// * z is the permutation polynomial, zu is Z(μX), the shifted version of Z
// * pk is the proving key: the linearized polynomial is a linear combination of ql, qr, qm, qo, qk.
-func computeLinearizedPolynomial(l, r, o, alpha, gamma, zeta, zu fr.Element, z polynomial.Polynomial, pk *ProvingKey) polynomial.Polynomial {
+//
+// The Linearized polynomial is:
+//
+// α²*L₁(ζ)*Z(X)
+// + α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ))
+// + l(ζ)*Ql(X) + l(ζ)r(ζ)*Qm(X) + r(ζ)*Qr(X) + o(ζ)*Qo(X) + Qk(X)
+func computeLinearizedPolynomial(lZeta, rZeta, oZeta, alpha, beta, gamma, zeta, zu fr.Element, blindedZCanonical []fr.Element, pk *ProvingKey) []fr.Element {
// first part: individual constraints
var rl fr.Element
- rl.Mul(&r, &l)
+ rl.Mul(&rZeta, &lZeta)
- // second part: Z(uzeta)(a+s1+gamma)*(b+s2+gamma)*s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
+ // second part:
+ // Z(μζ)(l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*β*s3(X)-Z(X)(l(ζ)+β*id1(ζ)+γ)*(r(ζ)+β*id2(ζ)+γ)*(o(ζ)+β*id3(ζ)+γ)
var s1, s2 fr.Element
chS1 := make(chan struct{}, 1)
go func() {
- s1 = pk.CS1.Eval(&zeta)
- s1.Add(&s1, &l).Add(&s1, &gamma) // (a+s1+gamma)
+ s1 = eval(pk.S1Canonical, zeta) // s1(ζ)
+ s1.Mul(&s1, &beta).Add(&s1, &lZeta).Add(&s1, &gamma) // (l(ζ)+β*s1(ζ)+γ)
close(chS1)
}()
- t := pk.CS2.Eval(&zeta)
- t.Add(&t, &r).Add(&t, &gamma) // (b+s2+gamma)
+ tmp := eval(pk.S2Canonical, zeta) // s2(ζ)
+ tmp.Mul(&tmp, &beta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*s2(ζ)+γ)
<-chS1
- s1.Mul(&s1, &t). // (a+s1+gamma)*(b+s2+gamma)
- Mul(&s1, &zu) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)
-
- s2.Add(&l, &zeta).Add(&s2, &gamma) // (a+z+gamma)
- t.Mul(&pk.Vk.Shifter[0], &zeta).Add(&t, &r).Add(&t, &gamma) // (b+uz+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)
- t.Mul(&pk.Vk.Shifter[1], &zeta).Add(&t, &o).Add(&t, &gamma) // (o+u**2z+gamma)
- s2.Mul(&s2, &t) // (a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
- s2.Neg(&s2) // -(a+z+gamma)*(b+uz+gamma)*(c+u**2*z+gamma)
-
- // third part L1(zeta)*alpha**2**Z
- var lagrange, one, den, frNbElmt fr.Element
+ s1.Mul(&s1, &tmp).Mul(&s1, &zu).Mul(&s1, &beta) // (l(ζ)+β*s1(β)+γ)*(r(ζ)+β*s2(β)+γ)*β*Z(μζ)
+
+ var uzeta, uuzeta fr.Element
+ uzeta.Mul(&zeta, &pk.Vk.CosetShift)
+ uuzeta.Mul(&uzeta, &pk.Vk.CosetShift)
+
+ s2.Mul(&beta, &zeta).Add(&s2, &lZeta).Add(&s2, &gamma) // (l(ζ)+β*ζ+γ)
+ tmp.Mul(&beta, &uzeta).Add(&tmp, &rZeta).Add(&tmp, &gamma) // (r(ζ)+β*u*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)
+ tmp.Mul(&beta, &uuzeta).Add(&tmp, &oZeta).Add(&tmp, &gamma) // (o(ζ)+β*u²*ζ+γ)
+ s2.Mul(&s2, &tmp) // (l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+ s2.Neg(&s2) // -(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ // third part L₁(ζ)*α²*Z
+ var lagrangeZeta, one, den, frNbElmt fr.Element
one.SetOne()
- nbElmt := int64(pk.DomainNum.Cardinality)
- lagrange.Set(&zeta).
- Exp(lagrange, big.NewInt(nbElmt)).
- Sub(&lagrange, &one)
+ nbElmt := int64(pk.Domain[0].Cardinality)
+ lagrangeZeta.Set(&zeta).
+ Exp(lagrangeZeta, big.NewInt(nbElmt)).
+ Sub(&lagrangeZeta, &one)
frNbElmt.SetUint64(uint64(nbElmt))
den.Sub(&zeta, &one).
- Mul(&den, &frNbElmt).
Inverse(&den)
- lagrange.Mul(&lagrange, &den). // L_0 = 1/m*(zeta**n-1)/(zeta-1)
- Mul(&lagrange, &alpha).
- Mul(&lagrange, &alpha) // alpha**2*L_0
+ lagrangeZeta.Mul(&lagrangeZeta, &den). // L₁ = (ζⁿ⁻¹)/(ζ-1)
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &alpha).
+ Mul(&lagrangeZeta, &pk.Domain[0].CardinalityInv) // (1/n)*α²*L₁(ζ)
- linPol := z.Clone()
+ linPol := make([]fr.Element, len(blindedZCanonical))
+ copy(linPol, blindedZCanonical)
utils.Parallelize(len(linPol), func(start, end int) {
+
var t0, t1 fr.Element
+
for i := start; i < end; i++ {
- linPol[i].Mul(&linPol[i], &s2) // -Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma)
- if i < len(pk.CS3) {
- t0.Mul(&pk.CS3[i], &s1) // (a+s1+gamma)*(b+s2+gamma)*Z(uzeta)*s3(X)
+
+ linPol[i].Mul(&linPol[i], &s2) // -Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ)
+
+ if i < len(pk.S3Canonical) {
+
+ t0.Mul(&pk.S3Canonical[i], &s1) // (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*β*s3(X)
+
linPol[i].Add(&linPol[i], &t0)
}
- linPol[i].Mul(&linPol[i], &alpha) // alpha*( Z(uzeta)*(a+s1+gamma)*(b+s2+gamma)s3(X)-Z(X)(a+zeta+gamma)*(b+uzeta+gamma)*(c+u**2*zeta+gamma) )
+ linPol[i].Mul(&linPol[i], &alpha) // α*( (l(ζ)+β*s1(ζ)+γ)*(r(ζ)+β*s2(ζ)+γ)*Z(μζ)*s3(X) - Z(X)*(l(ζ)+β*ζ+γ)*(r(ζ)+β*u*ζ+γ)*(o(ζ)+β*u²*ζ+γ))
if i < len(pk.Qm) {
- t1.Mul(&pk.Qm[i], &rl) // linPol = lr*Qm
- t0.Mul(&pk.Ql[i], &l)
+
+ t1.Mul(&pk.Qm[i], &rl) // linPol = linPol + l(ζ)r(ζ)*Qm(X)
+ t0.Mul(&pk.Ql[i], &lZeta)
t0.Add(&t0, &t1)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + l(ζ)*Ql(X)
- t0.Mul(&pk.Qr[i], &r)
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr
+ t0.Mul(&pk.Qr[i], &rZeta)
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + r(ζ)*Qr(X)
- t0.Mul(&pk.Qo[i], &o).Add(&t0, &pk.CQk[i])
- linPol[i].Add(&linPol[i], &t0) // linPol = lr*Qm + l*Ql + r*Qr + o*Qo + Qk
+ t0.Mul(&pk.Qo[i], &oZeta).Add(&t0, &pk.CQk[i])
+ linPol[i].Add(&linPol[i], &t0) // linPol = linPol + o(ζ)*Qo(X) + Qk(X)
}
- t0.Mul(&z[i], &lagrange)
+ t0.Mul(&blindedZCanonical[i], &lagrangeZeta)
linPol[i].Add(&linPol[i], &t0) // finish the computation
}
})
diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl
index 0c326d73a8..25458538e9 100644
--- a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl
+++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl
@@ -1,6 +1,5 @@
import (
"errors"
- {{- template "import_polynomial" . }}
{{- template "import_kzg" . }}
{{- template "import_fr" . }}
{{- template "import_fft" . }}
@@ -22,18 +21,21 @@ type ProvingKey struct {
Vk *VerifyingKey
// qr,ql,qm,qo (in canonical basis).
- Ql, Qr, Qm, Qo polynomial.Polynomial
+ Ql, Qr, Qm, Qo []fr.Element
// LQk (CQk) qk in Lagrange basis (canonical basis), prepended with as many zeroes as public inputs.
// Storing LQk in Lagrange basis saves a fft...
- CQk, LQk polynomial.Polynomial
+ CQk, LQk []fr.Element
- // Domains used for the FFTs
- DomainNum, DomainH fft.Domain
+ // Domains used for the FFTs.
+ // Domain[0] = small Domain
+ // Domain[1] = big Domain
+ Domain [2]fft.Domain
+ // Domain[0], Domain[1] fft.Domain
- // s1, s2, s3 (L=Lagrange basis, C=canonical basis)
- LS1, LS2, LS3 polynomial.Polynomial
- CS1, CS2, CS3 polynomial.Polynomial
+ // Permutation polynomials
+ EvaluationPermutationBigDomainBitReversed []fr.Element
+ S1Canonical, S2Canonical, S3Canonical []fr.Element
// position -> permuted position (position in [0,3*sizeSystem-1])
Permutation []int64
@@ -51,13 +53,12 @@ type VerifyingKey struct {
Generator fr.Element
NbPublicVariables uint64
- // shifters for extending the permutation set: from s=<1,z,..,z**n-1>,
- // extended domain = s || shifter[0].s || shifter[1].s
- Shifter [2]fr.Element
-
// Commitment scheme that is used for an instantiation of PLONK
KZGSRS *kzg.SRS
+ // cosetShift generator of the coset on the small domain
+ CosetShift fr.Element
+
// S commitments to S1, S2, S3
S [3]kzg.Digest
@@ -78,37 +79,34 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
// fft domains
sizeSystem := uint64(nbConstraints + spr.NbPublicVariables) // spr.NbPublicVariables is for the placeholder constraints
- pk.DomainNum = *fft.NewDomain(sizeSystem, 0, false)
+ pk.Domain[0] = *fft.NewDomain(sizeSystem)
+ pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen)
// h, the quotient polynomial is of degree 3(n+1)+2, so it's in a 3(n+2) dim vector space,
// the domain is the next power of 2 superior to 3(n+2). 4*domainNum is enough in all cases
// except when n<6.
if sizeSystem < 6 {
- pk.DomainH = *fft.NewDomain(8*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(8 * sizeSystem)
} else {
- pk.DomainH = *fft.NewDomain(4*sizeSystem, 1, false)
+ pk.Domain[1] = *fft.NewDomain(4 * sizeSystem)
}
- vk.Size = pk.DomainNum.Cardinality
+ vk.Size = pk.Domain[0].Cardinality
vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv)
- vk.Generator.Set(&pk.DomainNum.Generator)
+ vk.Generator.Set(&pk.Domain[0].Generator)
vk.NbPublicVariables = uint64(spr.NbPublicVariables)
- // shifters
- vk.Shifter[0].Set(&pk.DomainNum.FinerGenerator)
- vk.Shifter[1].Square(&pk.DomainNum.FinerGenerator)
-
if err := pk.InitKZG(srs); err != nil {
return nil, nil, err
}
// public polynomials corresponding to constraints: [ placholders | constraints | assertions ]
- pk.Ql = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qr = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qm = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.Qo = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.CQk = make([]fr.Element, pk.DomainNum.Cardinality)
- pk.LQk = make([]fr.Element, pk.DomainNum.Cardinality)
+ pk.Ql = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qr = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qm = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.Qo = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.CQk = make([]fr.Element, pk.Domain[0].Cardinality)
+ pk.LQk = make([]fr.Element, pk.Domain[0].Cardinality)
for i := 0; i < spr.NbPublicVariables; i++ { // placeholders (-PUB_INPUT_i + qk_i = 0) TODO should return error is size is inconsistant
pk.Ql[i].SetOne().Neg(&pk.Ql[i])
@@ -116,7 +114,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.Qm[i].SetZero()
pk.Qo[i].SetZero()
pk.CQk[i].SetZero()
- pk.LQk[i].SetZero() // --> to be completed by the prover
+ pk.LQk[i].SetZero() // → to be completed by the prover
}
offset := spr.NbPublicVariables
for i := 0; i < nbConstraints; i++ { // constraints
@@ -130,11 +128,11 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
pk.LQk[offset+i].Set(&spr.Coefficients[spr.Constraints[i].K])
}
- pk.DomainNum.FFTInverse(pk.Ql, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qr, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qm, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.Qo, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CQk, fft.DIF, 0)
+ pk.Domain[0].FFTInverse(pk.Ql, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qr, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qm, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.Qo, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.CQk, fft.DIF)
fft.BitReverse(pk.Ql)
fft.BitReverse(pk.Qr)
fft.BitReverse(pk.Qm)
@@ -145,7 +143,7 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
buildPermutation(spr, &pk)
// set s1, s2, s3
- computeLDE(&pk)
+ ccomputePermutationPolynomials(&pk)
// Commit to the polynomials to set up the verifying key
var err error
@@ -164,13 +162,13 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
if vk.Qk, err = kzg.Commit(pk.CQk, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[0], err = kzg.Commit(pk.CS1, vk.KZGSRS); err != nil {
+ if vk.S[0], err = kzg.Commit(pk.S1Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[1], err = kzg.Commit(pk.CS2, vk.KZGSRS); err != nil {
+ if vk.S[1], err = kzg.Commit(pk.S2Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
- if vk.S[2], err = kzg.Commit(pk.CS3, vk.KZGSRS); err != nil {
+ if vk.S[2], err = kzg.Commit(pk.S3Canonical, vk.KZGSRS); err != nil {
return nil, nil, err
}
@@ -182,18 +180,18 @@ func Setup(spr *cs.SparseR1CS, srs *kzg.SRS) (*ProvingKey, *VerifyingKey, error)
//
// The permutation s is composed of cycles of maximum length such that
//
-// s. (l||r||o) = (l||r||o)
+// s. (l∥r∥o) = (l∥r∥o)
//
-//, where l||r||o is the concatenation of the indices of l, r, o in
+//, where l∥r∥o is the concatenation of the indices of l, r, o in
// ql.l+qr.r+qm.l.r+qo.O+k = 0.
//
// The permutation is encoded as a slice s of size 3*size(l), where the
-// i-th entry of l||r||o is sent to the s[i]-th entry, so it acts on a tab
+// i-th entry of l∥r∥o is sent to the s[i]-th entry, so it acts on a tab
// like this: for i in tab: tab[i] = tab[permutation[i]]
func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
nbVariables := spr.NbInternalVariables + spr.NbPublicVariables + spr.NbSecretVariables
- sizeSolution := int(pk.DomainNum.Cardinality)
+ sizeSolution := int(pk.Domain[0].Cardinality)
// init permutation
pk.Permutation = make([]int64, 3*sizeSolution)
@@ -238,60 +236,70 @@ func buildPermutation(spr *cs.SparseR1CS, pk *ProvingKey) {
}
}
-// computeLDE computes the LDE (Lagrange basis) of the permutations
+// ccomputePermutationPolynomials computes the LDE (Lagrange basis) of the permutations
// s1, s2, s3.
//
-// ex: z gen of Z/mZ, u gen of Z/8mZ, then
-//
// 1 z .. z**n-1 | u uz .. u*z**n-1 | u**2 u**2*z .. u**2*z**n-1 |
// |
// | Permutation
// s11 s12 .. s1n s21 s22 .. s2n s31 s32 .. s3n v
// \---------------/ \--------------------/ \------------------------/
// s1 (LDE) s2 (LDE) s3 (LDE)
-func computeLDE(pk *ProvingKey) {
+func ccomputePermutationPolynomials(pk *ProvingKey) {
- nbElmt := int(pk.DomainNum.Cardinality)
+ nbElmts := int(pk.Domain[0].Cardinality)
- // sID = [1,z,..,z**n-1,u,uz,..,uz**n-1,u**2,u**2.z,..,u**2.z**n-1]
- sID := make([]fr.Element, 3*nbElmt)
- sID[0].SetOne()
- sID[nbElmt].Set(&pk.DomainNum.FinerGenerator)
- sID[2*nbElmt].Square(&pk.DomainNum.FinerGenerator)
-
- for i := 1; i < nbElmt; i++ {
- sID[i].Mul(&sID[i-1], &pk.DomainNum.Generator) // z**i -> z**i+1
- sID[i+nbElmt].Mul(&sID[nbElmt+i-1], &pk.DomainNum.Generator) // u*z**i -> u*z**i+1
- sID[i+2*nbElmt].Mul(&sID[2*nbElmt+i-1], &pk.DomainNum.Generator) // u**2*z**i -> u**2*z**i+1
- }
+ // Lagrange form of ID
+ evaluationIDSmallDomain := getIDSmallDomain(&pk.Domain[0])
// Lagrange form of S1, S2, S3
- pk.LS1 = make(polynomial.Polynomial, nbElmt)
- pk.LS2 = make(polynomial.Polynomial, nbElmt)
- pk.LS3 = make(polynomial.Polynomial, nbElmt)
- for i := 0; i < nbElmt; i++ {
- pk.LS1[i].Set(&sID[pk.Permutation[i]])
- pk.LS2[i].Set(&sID[pk.Permutation[nbElmt+i]])
- pk.LS3[i].Set(&sID[pk.Permutation[2*nbElmt+i]])
+ pk.S1Canonical = make([]fr.Element, nbElmts)
+ pk.S2Canonical = make([]fr.Element, nbElmts)
+ pk.S3Canonical = make([]fr.Element, nbElmts)
+ for i := 0; i < nbElmts; i++ {
+ pk.S1Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[i]])
+ pk.S2Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[nbElmts+i]])
+ pk.S3Canonical[i].Set(&evaluationIDSmallDomain[pk.Permutation[2*nbElmts+i]])
}
// Canonical form of S1, S2, S3
- pk.CS1 = make(polynomial.Polynomial, nbElmt)
- pk.CS2 = make(polynomial.Polynomial, nbElmt)
- pk.CS3 = make(polynomial.Polynomial, nbElmt)
- copy(pk.CS1, pk.LS1)
- copy(pk.CS2, pk.LS2)
- copy(pk.CS3, pk.LS3)
- pk.DomainNum.FFTInverse(pk.CS1, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS2, fft.DIF, 0)
- pk.DomainNum.FFTInverse(pk.CS3, fft.DIF, 0)
- fft.BitReverse(pk.CS1)
- fft.BitReverse(pk.CS2)
- fft.BitReverse(pk.CS3)
+ pk.Domain[0].FFTInverse(pk.S1Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S2Canonical, fft.DIF)
+ pk.Domain[0].FFTInverse(pk.S3Canonical, fft.DIF)
+ fft.BitReverse(pk.S1Canonical)
+ fft.BitReverse(pk.S2Canonical)
+ fft.BitReverse(pk.S3Canonical)
+
+ // evaluation of permutation on the big domain
+ pk.EvaluationPermutationBigDomainBitReversed = make([]fr.Element, 3*pk.Domain[1].Cardinality)
+ copy(pk.EvaluationPermutationBigDomainBitReversed, pk.S1Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:], pk.S2Canonical)
+ copy(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], pk.S3Canonical)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[:pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[pk.Domain[1].Cardinality:2*pk.Domain[1].Cardinality], fft.DIF, true)
+ pk.Domain[1].FFT(pk.EvaluationPermutationBigDomainBitReversed[2*pk.Domain[1].Cardinality:], fft.DIF, true)
+
+}
+
+// getIDSmallDomain returns the Lagrange form of ID on the small domain
+func getIDSmallDomain(domain *fft.Domain) []fr.Element {
+
+ res := make([]fr.Element, 3*domain.Cardinality)
+
+ res[0].SetOne()
+ res[domain.Cardinality].Set(&domain.FrMultiplicativeGen)
+ res[2*domain.Cardinality].Square(&domain.FrMultiplicativeGen)
+ for i := uint64(1); i < domain.Cardinality; i++ {
+ res[i].Mul(&res[i-1], &domain.Generator)
+ res[domain.Cardinality+i].Mul(&res[domain.Cardinality+i-1], &domain.Generator)
+ res[2*domain.Cardinality+i].Mul(&res[2*domain.Cardinality+i-1], &domain.Generator)
+ }
+
+ return res
}
-// InitKZG inits pk.Vk.KZG using pk.DomainNum cardinality and provided SRS
+// InitKZG inits pk.Vk.KZG using pk.Domain[0] cardinality and provided SRS
//
// This should be used after deserializing a ProvingKey
// as pk.Vk.KZG is NOT serialized
@@ -324,4 +332,4 @@ func (vk *VerifyingKey) NbPublicWitness() int {
// VerifyingKey returns pk.Vk
func (pk *ProvingKey) VerifyingKey() interface{} {
return pk.Vk
-}
\ No newline at end of file
+}
diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl
index a38645edc3..09880d83ee 100644
--- a/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl
+++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.verify.go.tmpl
@@ -22,7 +22,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }}
hFunc := sha256.New()
// transcript to derive the challenge
- fs := fiatshamir.NewTranscript(hFunc, "gamma", "alpha", "zeta")
+ fs := fiatshamir.NewTranscript(hFunc, "gamma", "beta", "alpha", "zeta")
// derive gamma from Comm(l), Comm(r), Comm(o)
gamma, err := deriveRandomness(&fs, "gamma", &proof.LRO[0], &proof.LRO[1], &proof.LRO[2])
@@ -30,6 +30,12 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }}
return err
}
+ // derive beta from Comm(l), Comm(r), Comm(o)
+ beta, err := deriveRandomness(&fs, "beta")
+ if err != nil {
+ return err
+ }
+
// derive alpha from Comm(l), Comm(r), Comm(o), Com(Z)
alpha, err := deriveRandomness(&fs, "alpha", &proof.Z)
if err != nil {
@@ -42,7 +48,7 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }}
return err
}
- // evaluation of Z=X**m-1 at zeta
+ // evaluation of Z=Xⁿ⁻¹ at ζ
var zetaPowerM, zzeta fr.Element
var bExpo big.Int
one := fr.One()
@@ -50,20 +56,20 @@ func Verify(proof *Proof, vk *VerifyingKey, publicWitness {{ toLower .CurveID }}
zetaPowerM.Exp(zeta, &bExpo)
zzeta.Sub(&zetaPowerM, &one)
- // ccompute PI = Sum_i if the circuits contains maps or slices
+// this is actually a shallow copy → if the circuits contains maps or slices
// only the reference is copied.
func ShallowClone(circuit frontend.Circuit) frontend.Circuit {
diff --git a/std/algebra/fields_bls12377/e12.go b/std/algebra/fields_bls12377/e12.go
index d63f963f64..af387437e3 100644
--- a/std/algebra/fields_bls12377/e12.go
+++ b/std/algebra/fields_bls12377/e12.go
@@ -163,25 +163,25 @@ func (e *E12) CyclotomicSquareCompressed(api frontend.API, x E12, ext Extension)
var t [7]E2
- // t0 = g1^2
+ // t0 = g1²
t[0].Square(api, x.C0.B1, ext)
- // t1 = g5^2
+ // t1 = g5²
t[1].Square(api, x.C1.B2, ext)
// t5 = g1 + g5
t[5].Add(api, x.C0.B1, x.C1.B2)
- // t2 = (g1 + g5)^2
+ // t2 = (g1 + g5)²
t[2].Square(api, t[5], ext)
- // t3 = g1^2 + g5^2
+ // t3 = g1² + g5²
t[3].Add(api, t[0], t[1])
// t5 = 2 * g1 * g5
t[5].Sub(api, t[2], t[3])
// t6 = g3 + g2
t[6].Add(api, x.C1.B0, x.C0.B2)
- // t3 = (g3 + g2)^2
+ // t3 = (g3 + g2)²
t[3].Square(api, t[6], ext)
- // t2 = g3^2
+ // t2 = g3²
t[2].Square(api, x.C1.B0, ext)
// t6 = 2 * nr * g1 * g5
@@ -192,33 +192,33 @@ func (e *E12) CyclotomicSquareCompressed(api frontend.API, x E12, ext Extension)
// z3 = 6 * nr * g1 * g5 + 2 * g3
e.C1.B0.Add(api, t[5], t[6])
- // t4 = nr * g5^2
+ // t4 = nr * g5²
t[4].MulByNonResidue(api, t[1], ext)
- // t5 = nr * g5^2 + g1^2
+ // t5 = nr * g5² + g1²
t[5].Add(api, t[0], t[4])
- // t6 = nr * g5^2 + g1^2 - g2
+ // t6 = nr * g5² + g1² - g2
t[6].Sub(api, t[5], x.C0.B2)
- // t1 = g2^2
+ // t1 = g2²
t[1].Square(api, x.C0.B2, ext)
- // t6 = 2 * nr * g5^2 + 2 * g1^2 - 2*g2
+ // t6 = 2 * nr * g5² + 2 * g1² - 2*g2
t[6].Double(api, t[6])
- // z2 = 3 * nr * g5^2 + 3 * g1^2 - 2*g2
+ // z2 = 3 * nr * g5² + 3 * g1² - 2*g2
e.C0.B2.Add(api, t[6], t[5])
- // t4 = nr * g2^2
+ // t4 = nr * g2²
t[4].MulByNonResidue(api, t[1], ext)
- // t5 = g3^2 + nr * g2^2
+ // t5 = g3² + nr * g2²
t[5].Add(api, t[2], t[4])
- // t6 = g3^2 + nr * g2^2 - g1
+ // t6 = g3² + nr * g2² - g1
t[6].Sub(api, t[5], x.C0.B1)
- // t6 = 2 * g3^2 + 2 * nr * g2^2 - 2 * g1
+ // t6 = 2 * g3² + 2 * nr * g2² - 2 * g1
t[6].Double(api, t[6])
- // z1 = 3 * g3^2 + 3 * nr * g2^2 - 2 * g1
+ // z1 = 3 * g3² + 3 * nr * g2² - 2 * g1
e.C0.B1.Add(api, t[6], t[5])
- // t0 = g2^2 + g3^2
+ // t0 = g2² + g3²
t[0].Add(api, t[2], t[1])
// t5 = 2 * g3 * g2
t[5].Sub(api, t[3], t[0])
@@ -239,13 +239,13 @@ func (e *E12) Decompress(api frontend.API, x E12, ext Extension) *E12 {
var one E2
one.SetOne(api)
- // t0 = g1^2
+ // t0 = g1²
t[0].Square(api, x.C0.B1, ext)
- // t1 = 3 * g1^2 - 2 * g2
+ // t1 = 3 * g1² - 2 * g2
t[1].Sub(api, t[0], x.C0.B2).
Double(api, t[1]).
Add(api, t[1], t[0])
- // t0 = E * g5^2 + t1
+ // t0 = E * g5² + t1
t[2].Square(api, x.C1.B2, ext)
t[0].MulByNonResidue(api, t[2], ext).
Add(api, t[0], t[1])
@@ -258,14 +258,14 @@ func (e *E12) Decompress(api frontend.API, x E12, ext Extension) *E12 {
// t1 = g2 * g1
t[1].Mul(api, x.C0.B2, x.C0.B1, ext)
- // t2 = 2 * g4^2 - 3 * g2 * g1
+ // t2 = 2 * g4² - 3 * g2 * g1
t[2].Square(api, e.C1.B1, ext).
Sub(api, t[2], t[1]).
Double(api, t[2]).
Sub(api, t[2], t[1])
// t1 = g3 * g5
t[1].Mul(api, x.C1.B0, x.C1.B2, ext)
- // c_0 = E * (2 * g4^2 + g3 * g5 - 3 * g2 * g1) + 1
+ // c₀ = E * (2 * g4² + g3 * g5 - 3 * g2 * g1) + 1
t[2].Add(api, t[2], t[1])
e.C0.B0.MulByNonResidue(api, t[2], ext).
Add(api, e.C0.B0, one)
@@ -296,9 +296,9 @@ func (e *E12) CyclotomicSquare(api frontend.API, x E12, ext Extension) *E12 {
t[5].Square(api, x.C0.B1, ext)
t[8].Add(api, x.C1.B2, x.C0.B1).Square(api, t[8], ext).Sub(api, t[8], t[4]).Sub(api, t[8], t[5]).MulByNonResidue(api, t[8], ext) // 2*x5*x1*u
- t[0].MulByNonResidue(api, t[0], ext).Add(api, t[0], t[1]) // x4^2*u + x0^2
- t[2].MulByNonResidue(api, t[2], ext).Add(api, t[2], t[3]) // x2^2*u + x3^2
- t[4].MulByNonResidue(api, t[4], ext).Add(api, t[4], t[5]) // x5^2*u + x1^2
+ t[0].MulByNonResidue(api, t[0], ext).Add(api, t[0], t[1]) // x4²*u + x0²
+ t[2].MulByNonResidue(api, t[2], ext).Add(api, t[2], t[3]) // x2²*u + x3²
+ t[4].MulByNonResidue(api, t[4], ext).Add(api, t[4], t[5]) // x5²*u + x1²
e.C0.B0.Sub(api, t[0], x.C0.B0).Add(api, e.C0.B0, e.C0.B0).Add(api, e.C0.B0, t[0])
e.C0.B1.Sub(api, t[2], x.C0.B1).Add(api, e.C0.B1, e.C0.B1).Add(api, e.C0.B1, t[2])
diff --git a/std/algebra/fields_bls12377/e6.go b/std/algebra/fields_bls12377/e6.go
index 610caec82e..caa171bbbb 100644
--- a/std/algebra/fields_bls12377/e6.go
+++ b/std/algebra/fields_bls12377/e6.go
@@ -109,7 +109,7 @@ func (e *E6) MulByFp2(api frontend.API, e1 E6, e2 E2, ext Extension) *E6 {
return e
}
-// MulByNonResidue multiplies e by the imaginary elmt of Fp6 (noted a+bV+cV where V**3 in F^2)
+// MulByNonResidue multiplies e by the imaginary elmt of Fp6 (noted a+bV+cV where V**3 in F²)
func (e *E6) MulByNonResidue(api frontend.API, e1 E6, ext Extension) *E6 {
res := E6{}
res.B0.MulByNonResidue(api, e1.B2, ext)
diff --git a/std/algebra/fields_bls24315/e12.go b/std/algebra/fields_bls24315/e12.go
index a0d087eafe..08b8ea75a7 100644
--- a/std/algebra/fields_bls24315/e12.go
+++ b/std/algebra/fields_bls24315/e12.go
@@ -109,7 +109,7 @@ func (e *E12) MulByFp2(api frontend.API, e1 E12, e2 E4, ext Extension) *E12 {
return e
}
-// MulByNonResidue multiplies e by the imaginary elmt of Fp12 (noted a+bV+cV where V**3 in F^2)
+// MulByNonResidue multiplies e by the imaginary elmt of Fp12 (noted a+bV+cV where V**3 in F²)
func (e *E12) MulByNonResidue(api frontend.API, e1 E12, ext Extension) *E12 {
res := E12{}
res.C0.MulByNonResidue(api, e1.C2, ext)
diff --git a/std/algebra/fields_bls24315/e24.go b/std/algebra/fields_bls24315/e24.go
index 6b12d58c6b..aa13c3ddde 100644
--- a/std/algebra/fields_bls24315/e24.go
+++ b/std/algebra/fields_bls24315/e24.go
@@ -163,25 +163,25 @@ func (e *E24) Square(api frontend.API, x E24, ext Extension) *E24 {
func (e *E24) CyclotomicSquareCompressed(api frontend.API, x E24, ext Extension) *E24 {
var t [7]E4
- // t0 = g1^2
+ // t0 = g1²
t[0].Square(api, x.D0.C1, ext)
- // t1 = g5^2
+ // t1 = g5²
t[1].Square(api, x.D1.C2, ext)
// t5 = g1 + g5
t[5].Add(api, x.D0.C1, x.D1.C2)
- // t2 = (g1 + g5)^2
+ // t2 = (g1 + g5)²
t[2].Square(api, t[5], ext)
- // t3 = g1^2 + g5^2
+ // t3 = g1² + g5²
t[3].Add(api, t[0], t[1])
// t5 = 2 * g1 * g5
t[5].Sub(api, t[2], t[3])
// t6 = g3 + g2
t[6].Add(api, x.D1.C0, x.D0.C2)
- // t3 = (g3 + g2)^2
+ // t3 = (g3 + g2)²
t[3].Square(api, t[6], ext)
- // t2 = g3^2
+ // t2 = g3²
t[2].Square(api, x.D1.C0, ext)
// t6 = 2 * nr * g1 * g5
@@ -192,33 +192,33 @@ func (e *E24) CyclotomicSquareCompressed(api frontend.API, x E24, ext Extension)
// z3 = 6 * nr * g1 * g5 + 2 * g3
e.D1.C0.Add(api, t[5], t[6])
- // t4 = nr * g5^2
+ // t4 = nr * g5²
t[4].MulByNonResidue(api, t[1], ext)
- // t5 = nr * g5^2 + g1^2
+ // t5 = nr * g5² + g1²
t[5].Add(api, t[0], t[4])
- // t6 = nr * g5^2 + g1^2 - g2
+ // t6 = nr * g5² + g1² - g2
t[6].Sub(api, t[5], x.D0.C2)
- // t1 = g2^2
+ // t1 = g2²
t[1].Square(api, x.D0.C2, ext)
- // t6 = 2 * nr * g5^2 + 2 * g1^2 - 2*g2
+ // t6 = 2 * nr * g5² + 2 * g1² - 2*g2
t[6].Double(api, t[6])
- // z2 = 3 * nr * g5^2 + 3 * g1^2 - 2*g2
+ // z2 = 3 * nr * g5² + 3 * g1² - 2*g2
e.D0.C2.Add(api, t[6], t[5])
- // t4 = nr * g2^2
+ // t4 = nr * g2²
t[4].MulByNonResidue(api, t[1], ext)
- // t5 = g3^2 + nr * g2^2
+ // t5 = g3² + nr * g2²
t[5].Add(api, t[2], t[4])
- // t6 = g3^2 + nr * g2^2 - g1
+ // t6 = g3² + nr * g2² - g1
t[6].Sub(api, t[5], x.D0.C1)
- // t6 = 2 * g3^2 + 2 * nr * g2^2 - 2 * g1
+ // t6 = 2 * g3² + 2 * nr * g2² - 2 * g1
t[6].Double(api, t[6])
- // z1 = 3 * g3^2 + 3 * nr * g2^2 - 2 * g1
+ // z1 = 3 * g3² + 3 * nr * g2² - 2 * g1
e.D0.C1.Add(api, t[6], t[5])
- // t0 = g2^2 + g3^2
+ // t0 = g2² + g3²
t[0].Add(api, t[2], t[1])
// t5 = 2 * g3 * g2
t[5].Sub(api, t[3], t[0])
@@ -239,13 +239,13 @@ func (e *E24) Decompress(api frontend.API, x E24, ext Extension) *E24 {
var one E4
one.SetOne(api)
- // t0 = g1^2
+ // t0 = g1²
t[0].Square(api, x.D0.C1, ext)
- // t1 = 3 * g1^2 - 2 * g2
+ // t1 = 3 * g1² - 2 * g2
t[1].Sub(api, t[0], x.D0.C2).
Double(api, t[1]).
Add(api, t[1], t[0])
- // t0 = E * g5^2 + t1
+ // t0 = E * g5² + t1
t[2].Square(api, x.D1.C2, ext)
t[0].MulByNonResidue(api, t[2], ext).
Add(api, t[0], t[1])
@@ -258,14 +258,14 @@ func (e *E24) Decompress(api frontend.API, x E24, ext Extension) *E24 {
// t1 = g2 * g1
t[1].Mul(api, x.D0.C2, x.D0.C1, ext)
- // t2 = 2 * g4^2 - 3 * g2 * g1
+ // t2 = 2 * g4² - 3 * g2 * g1
t[2].Square(api, e.D1.C1, ext).
Sub(api, t[2], t[1]).
Double(api, t[2]).
Sub(api, t[2], t[1])
// t1 = g3 * g5
t[1].Mul(api, x.D1.C0, x.D1.C2, ext)
- // c_0 = E * (2 * g4^2 + g3 * g5 - 3 * g2 * g1) + 1
+ // c₀ = E * (2 * g4² + g3 * g5 - 3 * g2 * g1) + 1
t[2].Add(api, t[2], t[1])
e.D0.C0.MulByNonResidue(api, t[2], ext).
Add(api, e.D0.C0, one)
@@ -474,7 +474,7 @@ func (e *E24) FinalExponentiation(api frontend.API, e1 E24, genT uint64, ext Ext
// Daiki Hayashida and Kenichiro Hayasaka
// and Tadanori Teruya
// https://eprint.iacr.org/2020/875.pdf
- // 3*Phi_24(api, p)/r = (api, u-1)^2 * (api, u+p) * (api, u^2+p^2) * (api, u^4+p^4-1) + 3
+ // 3*Phi_24(api, p)/r = (api, u-1)² * (api, u+p) * (api, u²+p²) * (api, u⁴+p⁴-1) + 3
t[0].CyclotomicSquare(api, result, ext)
t[1].Expt(api, result, genT, ext)
t[2].Conjugate(api, result)
diff --git a/std/algebra/sw_bls12377/g1.go b/std/algebra/sw_bls12377/g1.go
index b38735f94d..0e33212b96 100644
--- a/std/algebra/sw_bls12377/g1.go
+++ b/std/algebra/sw_bls12377/g1.go
@@ -137,7 +137,7 @@ func (p *G1Jac) DoubleAssign(api frontend.API) *G1Jac {
S = api.Sub(S, XX)
S = api.Sub(S, YYYY)
S = api.Add(S, S)
- M = api.Mul(XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0)
+ M = api.Mul(XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0)
p.Z = api.Add(p.Z, p.Y)
p.Z = api.Mul(p.Z, p.Z)
p.Z = api.Sub(p.Z, YY)
diff --git a/std/algebra/sw_bls12377/g2.go b/std/algebra/sw_bls12377/g2.go
index 0aaa4ac03f..8014f9494c 100644
--- a/std/algebra/sw_bls12377/g2.go
+++ b/std/algebra/sw_bls12377/g2.go
@@ -183,7 +183,7 @@ func (p *G2Jac) Double(api frontend.API, p1 *G2Jac, ext fields_bls12377.Extensio
S.Sub(api, S, XX)
S.Sub(api, S, YYYY)
S.Add(api, S, S)
- M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0)
+ M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0)
p.Z.Add(api, p.Z, p.Y)
p.Z.Square(api, p.Z, ext)
p.Z.Sub(api, p.Z, YY)
diff --git a/std/algebra/sw_bls24315/g1.go b/std/algebra/sw_bls24315/g1.go
index 10bb09cdab..27560ee675 100644
--- a/std/algebra/sw_bls24315/g1.go
+++ b/std/algebra/sw_bls24315/g1.go
@@ -137,7 +137,7 @@ func (p *G1Jac) DoubleAssign(api frontend.API) *G1Jac {
S = api.Sub(S, XX)
S = api.Sub(S, YYYY)
S = api.Add(S, S)
- M = api.Mul(XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0)
+ M = api.Mul(XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0)
p.Z = api.Add(p.Z, p.Y)
p.Z = api.Mul(p.Z, p.Z)
p.Z = api.Sub(p.Z, YY)
diff --git a/std/algebra/sw_bls24315/g2.go b/std/algebra/sw_bls24315/g2.go
index 1dd7154686..07e020e1b4 100644
--- a/std/algebra/sw_bls24315/g2.go
+++ b/std/algebra/sw_bls24315/g2.go
@@ -183,7 +183,7 @@ func (p *G2Jac) Double(api frontend.API, p1 *G2Jac, ext fields_bls24315.Extensio
S.Sub(api, S, XX)
S.Sub(api, S, YYYY)
S.Add(api, S, S)
- M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ^2, here a=0 (we suppose sw has j invariant 0)
+ M.MulByFp(api, XX, 3) // M = 3*XX+a*ZZ², here a=0 (we suppose sw has j invariant 0)
p.Z.Add(api, p.Z, p.Y)
p.Z.Square(api, p.Z, ext)
p.Z.Sub(api, p.Z, YY)
diff --git a/std/algebra/twistededwards/bandersnatch/curve.go b/std/algebra/twistededwards/bandersnatch/curve.go
index 15caf30980..0cfb3c071e 100644
--- a/std/algebra/twistededwards/bandersnatch/curve.go
+++ b/std/algebra/twistededwards/bandersnatch/curve.go
@@ -25,10 +25,16 @@ import (
"github.com/consensys/gnark/internal/utils"
)
+// Coordinates of a point on a twisted Edwards curve
+type Coord struct {
+ X, Y big.Int
+}
+
// EdCurve stores the info on the chosen edwards curve
type EdCurve struct {
- A, D, Cofactor, Order, BaseX, BaseY big.Int
- ID ecc.ID
+ A, D, Cofactor, Order big.Int
+ Base Coord
+ ID ecc.ID
}
var constructors map[ecc.ID]func() EdCurve
@@ -60,9 +66,11 @@ func newBandersnatch() EdCurve {
D: utils.FromInterface(edcurve.D),
Cofactor: utils.FromInterface(edcurve.Cofactor),
Order: utils.FromInterface(edcurve.Order),
- BaseX: utils.FromInterface(edcurve.Base.X),
- BaseY: utils.FromInterface(edcurve.Base.Y),
- ID: ecc.BLS12_381,
+ Base: Coord{
+ X: utils.FromInterface(edcurve.Base.X),
+ Y: utils.FromInterface(edcurve.Base.Y),
+ },
+ ID: ecc.BLS12_381,
}
}
diff --git a/std/algebra/twistededwards/bandersnatch/point.go b/std/algebra/twistededwards/bandersnatch/point.go
index ddfb8ac52a..071f2762de 100644
--- a/std/algebra/twistededwards/bandersnatch/point.go
+++ b/std/algebra/twistededwards/bandersnatch/point.go
@@ -27,8 +27,15 @@ type Point struct {
X, Y frontend.Variable
}
+// Neg computes the negative of a point in SNARK coordinates
+func (p *Point) Neg(api frontend.API, p1 *Point) *Point {
+ p.X = api.Neg(p1.X)
+ p.Y = p1.Y
+ return p
+}
+
// MustBeOnCurve checks if a point is on the reduced twisted Edwards curve
-// a*x^2 + y^2 = 1 + d*x^2*y^2.
+// a*x² + y² = 1 + d*x²*y².
func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) {
one := big.NewInt(1)
@@ -46,34 +53,9 @@ func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) {
}
-// AddFixedPoint Adds two points, among which is one fixed point (the base), on a twisted edwards curve (eg jubjub)
-// p1, base, ecurve are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve
-func (p *Point) AddFixedPoint(api frontend.API, p1 *Point /*basex*/, x /*basey*/, y interface{}, curve EdCurve) *Point {
-
- // https://eprint.iacr.org/2008/013.pdf
-
- n11 := api.Mul(p1.X, y)
- n12 := api.Mul(p1.Y, x)
- n1 := api.Add(n11, n12)
-
- n21 := api.Mul(p1.Y, y)
- n22 := api.Mul(p1.X, x)
- an22 := api.Mul(n22, &curve.A)
- n2 := api.Sub(n21, an22)
-
- d11 := api.Mul(curve.D, n11, n12)
- d1 := api.Add(1, d11)
- d2 := api.Sub(1, d11)
-
- p.X = api.DivUnchecked(n1, d1)
- p.Y = api.DivUnchecked(n2, d2)
-
- return p
-}
-
-// AddGeneric Adds two points on a twisted edwards curve (eg jubjub)
+// Add Adds two points on a twisted edwards curve (eg jubjub)
// p1, p2, c are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve
-func (p *Point) AddGeneric(api frontend.API, p1, p2 *Point, curve EdCurve) *Point {
+func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point {
// https://eprint.iacr.org/2008/013.pdf
@@ -103,14 +85,12 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point {
u := api.Mul(p1.X, p1.Y)
v := api.Mul(p1.X, p1.X)
w := api.Mul(p1.Y, p1.Y)
- z := api.Mul(v, w)
n1 := api.Mul(2, u)
av := api.Mul(v, &curve.A)
n2 := api.Sub(w, av)
- d := api.Mul(z, curve.D)
- d1 := api.Add(1, d)
- d2 := api.Sub(1, d)
+ d1 := api.Add(w, av)
+ d2 := api.Sub(2, d1)
p.X = api.DivUnchecked(n1, d1)
p.Y = api.DivUnchecked(n2, d2)
@@ -118,55 +98,72 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point {
return p
}
-// ScalarMulNonFixedBase computes the scalar multiplication of a point on a twisted Edwards curve
+// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve
// p1: base point (as snark point)
// curve: parameters of the Edwards curve
// scal: scalar as a SNARK constraint
// Standard left to right double and add
-func (p *Point) ScalarMulNonFixedBase(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point {
+func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point {
// first unpack the scalar
b := api.ToBinary(scalar)
- res := Point{
- 0,
- 1,
+ res := Point{}
+ tmp := Point{}
+ A := Point{}
+ B := Point{}
+
+ A.Double(api, p1, curve)
+ B.Add(api, &A, p1, curve)
+
+ n := len(b) - 1
+ res.X = api.Lookup2(b[n], b[n-1], 0, A.X, p1.X, B.X)
+ res.Y = api.Lookup2(b[n], b[n-1], 1, A.Y, p1.Y, B.Y)
+
+ for i := n - 2; i >= 1; i -= 2 {
+ res.Double(api, &res, curve).
+ Double(api, &res, curve)
+ tmp.X = api.Lookup2(b[i], b[i-1], 0, A.X, p1.X, B.X)
+ tmp.Y = api.Lookup2(b[i], b[i-1], 1, A.Y, p1.Y, B.Y)
+ res.Add(api, &res, &tmp, curve)
}
- for i := len(b) - 1; i >= 0; i-- {
+ if n%2 == 0 {
res.Double(api, &res, curve)
- tmp := Point{}
- tmp.AddGeneric(api, &res, p1, curve)
- res.X = api.Select(b[i], tmp.X, res.X)
- res.Y = api.Select(b[i], tmp.Y, res.Y)
+ tmp.Add(api, &res, p1, curve)
+ res.X = api.Select(b[0], tmp.X, res.X)
+ res.Y = api.Select(b[0], tmp.Y, res.Y)
}
p.X = res.X
p.Y = res.Y
+
return p
}
-// ScalarMulFixedBase computes the scalar multiplication of a point on a twisted Edwards curve
-// x, y: coordinates of the base point
-// curve: parameters of the Edwards curve
-// scal: scalar as a SNARK constraint
-// Standard left to right double and add
-func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar frontend.Variable, curve EdCurve) *Point {
+// DoubleBaseScalarMul computes s1*P1+s2*P2
+// where P1 and P2 are points on a twisted Edwards curve
+// and s1, s2 scalars.
+func (p *Point) DoubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve EdCurve) *Point {
- // first unpack the scalar
- b := api.ToBinary(scalar)
+ // first unpack the scalars
+ b1 := api.ToBinary(s1)
+ b2 := api.ToBinary(s2)
- res := Point{
- 0,
- 1,
- }
+ res := Point{}
+ tmp := Point{}
+ sum := Point{}
+ sum.Add(api, p1, p2, curve)
+
+ n := len(b1)
+ res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, sum.X)
+ res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, sum.Y)
- for i := len(b) - 1; i >= 0; i-- {
+ for i := n - 2; i >= 0; i-- {
res.Double(api, &res, curve)
- tmp := Point{}
- tmp.AddFixedPoint(api, &res, x, y, curve)
- res.X = api.Select(b[i], tmp.X, res.X)
- res.Y = api.Select(b[i], tmp.Y, res.Y)
+ tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, sum.X)
+ tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, sum.Y)
+ res.Add(api, &res, &tmp, curve)
}
p.X = res.X
@@ -174,10 +171,3 @@ func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar fr
return p
}
-
-// Neg computes the negative of a point in SNARK coordinates
-func (p *Point) Neg(api frontend.API, p1 *Point) *Point {
- p.X = api.Neg(p1.X)
- p.Y = p1.Y
- return p
-}
diff --git a/std/algebra/twistededwards/bandersnatch/point_test.go b/std/algebra/twistededwards/bandersnatch/point_test.go
index ba883d8930..ff879c2fa2 100644
--- a/std/algebra/twistededwards/bandersnatch/point_test.go
+++ b/std/algebra/twistededwards/bandersnatch/point_test.go
@@ -55,8 +55,8 @@ func TestIsOnCurve(t *testing.T) {
t.Fatal(err)
}
- witness.P.X = (params.BaseX)
- witness.P.Y = (params.BaseY)
+ witness.P.X = (params.Base.X)
+ witness.P.Y = (params.Base.Y)
assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BLS12_381))
@@ -74,7 +74,10 @@ func (circuit *add) Define(api frontend.API) error {
return err
}
- res := circuit.P.AddFixedPoint(api, &circuit.P, params.BaseX, params.BaseY, params)
+ p := Point{}
+ p.X = params.Base.X
+ p.Y = params.Base.Y
+ res := circuit.P.Add(api, &circuit.P, &p, params)
api.AssertIsEqual(res.X, circuit.E.X)
api.AssertIsEqual(res.Y, circuit.E.Y)
@@ -94,8 +97,8 @@ func TestAddFixedPoint(t *testing.T) {
t.Fatal(err)
}
var base, point, expected bandersnatch.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
point.Set(&base)
r := big.NewInt(5)
point.ScalarMul(&point, r)
@@ -112,6 +115,9 @@ func TestAddFixedPoint(t *testing.T) {
}
+//-------------------------------------------------------------
+// addGeneric
+
type addGeneric struct {
P1, P2, E Point
}
@@ -124,7 +130,7 @@ func (circuit *addGeneric) Define(api frontend.API) error {
return err
}
- res := circuit.P1.AddGeneric(api, &circuit.P1, &circuit.P2, params)
+ res := circuit.P1.Add(api, &circuit.P1, &circuit.P2, params)
api.AssertIsEqual(res.X, circuit.E.X)
api.AssertIsEqual(res.Y, circuit.E.Y)
@@ -143,8 +149,8 @@ func TestAddGeneric(t *testing.T) {
t.Fatal(err)
}
var point1, point2, expected bandersnatch.PointAffine
- point1.X.SetBigInt(¶ms.BaseX)
- point1.Y.SetBigInt(¶ms.BaseY)
+ point1.X.SetBigInt(¶ms.Base.X)
+ point1.Y.SetBigInt(¶ms.Base.Y)
point2.Set(&point1)
r1 := big.NewInt(5)
r2 := big.NewInt(12)
@@ -165,6 +171,8 @@ func TestAddGeneric(t *testing.T) {
}
+//-------------------------------------------------------------
+// Double
type double struct {
P, E Point
}
@@ -197,8 +205,8 @@ func TestDouble(t *testing.T) {
t.Fatal(err)
}
var base, expected bandersnatch.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
expected.Double(&base)
// populate witness
@@ -225,8 +233,10 @@ func (circuit *scalarMulFixed) Define(api frontend.API) error {
return err
}
- var resFixed Point
- resFixed.ScalarMulFixedBase(api, params.BaseX, params.BaseY, circuit.S, params)
+ var resFixed, p Point
+ p.X = params.Base.X
+ p.Y = params.Base.Y
+ resFixed.ScalarMul(api, &p, circuit.S, params)
api.AssertIsEqual(resFixed.X, circuit.E.X)
api.AssertIsEqual(resFixed.Y, circuit.E.Y)
@@ -246,8 +256,8 @@ func TestScalarMulFixed(t *testing.T) {
t.Fatal(err)
}
var base, expected bandersnatch.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
r := big.NewInt(928323002)
expected.ScalarMul(&base, r)
@@ -274,7 +284,7 @@ func (circuit *scalarMulGeneric) Define(api frontend.API) error {
return err
}
- resGeneric := circuit.P.ScalarMulNonFixedBase(api, &circuit.P, circuit.S, params)
+ resGeneric := circuit.P.ScalarMul(api, &circuit.P, circuit.S, params)
api.AssertIsEqual(resGeneric.X, circuit.E.X)
api.AssertIsEqual(resGeneric.Y, circuit.E.Y)
@@ -294,8 +304,8 @@ func TestScalarMulGeneric(t *testing.T) {
t.Fatal(err)
}
var base, point, expected bandersnatch.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
s := big.NewInt(902)
point.ScalarMul(&base, s) // random point
r := big.NewInt(230928302)
@@ -336,8 +346,8 @@ func TestNeg(t *testing.T) {
t.Fatal(err)
}
var base, expected bandersnatch.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
expected.Neg(&base)
// generate witness
@@ -351,25 +361,39 @@ func TestNeg(t *testing.T) {
}
-// benches
+// Bench
+func BenchmarkDouble(b *testing.B) {
+ var c double
+ ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c)
+ b.Log("groth16", ccsBench.GetNbConstraints())
+}
-var ccsBench frontend.CompiledConstraintSystem
+func BenchmarkAddGeneric(b *testing.B) {
+ var c addGeneric
+ ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c)
+ b.Log("groth16", ccsBench.GetNbConstraints())
+}
-func BenchmarkScalarMulG1(b *testing.B) {
- var c scalarMulGeneric
- b.Run("groth16", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- ccsBench, _ = frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c)
- }
+func BenchmarkAddFixedPoint(b *testing.B) {
+ var c add
+ ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c)
+ b.Log("groth16", ccsBench.GetNbConstraints())
+}
- })
+func BenchmarkMustBeOnCurve(b *testing.B) {
+ var c mustBeOnCurve
+ ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c)
b.Log("groth16", ccsBench.GetNbConstraints())
- b.Run("plonk", func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- ccsBench, _ = frontend.Compile(ecc.BLS12_381, backend.PLONK, &c)
- }
+}
- })
- b.Log("plonk", ccsBench.GetNbConstraints())
+func BenchmarkScalarMulGeneric(b *testing.B) {
+ var c scalarMulGeneric
+ ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c)
+ b.Log("groth16", ccsBench.GetNbConstraints())
+}
+func BenchmarkScalarMulFixed(b *testing.B) {
+ var c scalarMulFixed
+ ccsBench, _ := frontend.Compile(ecc.BLS12_381, backend.GROTH16, &c)
+ b.Log("groth16", ccsBench.GetNbConstraints())
}
diff --git a/std/algebra/twistededwards/curve.go b/std/algebra/twistededwards/curve.go
index 798b01271f..432210c588 100644
--- a/std/algebra/twistededwards/curve.go
+++ b/std/algebra/twistededwards/curve.go
@@ -30,11 +30,17 @@ import (
"github.com/consensys/gnark/internal/utils"
)
+// Coordinates of a point on a twisted Edwards curve
+type Coord struct {
+ X, Y big.Int
+}
+
// EdCurve stores the info on the chosen edwards curve
// note that all curves implemented in gnark-crypto have A = -1
type EdCurve struct {
- A, D, Cofactor, Order, BaseX, BaseY big.Int
- ID ecc.ID
+ A, D, Cofactor, Order big.Int
+ Base Coord
+ ID ecc.ID
}
var constructors map[ecc.ID]func() EdCurve
@@ -71,9 +77,11 @@ func newEdBN254() EdCurve {
D: utils.FromInterface(edcurve.D),
Cofactor: utils.FromInterface(edcurve.Cofactor),
Order: utils.FromInterface(edcurve.Order),
- BaseX: utils.FromInterface(edcurve.Base.X),
- BaseY: utils.FromInterface(edcurve.Base.Y),
- ID: ecc.BN254,
+ Base: Coord{
+ X: utils.FromInterface(edcurve.Base.X),
+ Y: utils.FromInterface(edcurve.Base.Y),
+ },
+ ID: ecc.BN254,
}
}
@@ -88,9 +96,11 @@ func newEdBLS381() EdCurve {
D: utils.FromInterface(edcurve.D),
Cofactor: utils.FromInterface(edcurve.Cofactor),
Order: utils.FromInterface(edcurve.Order),
- BaseX: utils.FromInterface(edcurve.Base.X),
- BaseY: utils.FromInterface(edcurve.Base.Y),
- ID: ecc.BLS12_381,
+ Base: Coord{
+ X: utils.FromInterface(edcurve.Base.X),
+ Y: utils.FromInterface(edcurve.Base.Y),
+ },
+ ID: ecc.BLS12_381,
}
}
@@ -105,9 +115,11 @@ func newEdBLS377() EdCurve {
D: utils.FromInterface(edcurve.D),
Cofactor: utils.FromInterface(edcurve.Cofactor),
Order: utils.FromInterface(edcurve.Order),
- BaseX: utils.FromInterface(edcurve.Base.X),
- BaseY: utils.FromInterface(edcurve.Base.Y),
- ID: ecc.BLS12_377,
+ Base: Coord{
+ X: utils.FromInterface(edcurve.Base.X),
+ Y: utils.FromInterface(edcurve.Base.Y),
+ },
+ ID: ecc.BLS12_377,
}
}
@@ -122,9 +134,11 @@ func newEdBW633() EdCurve {
D: utils.FromInterface(edcurve.D),
Cofactor: utils.FromInterface(edcurve.Cofactor),
Order: utils.FromInterface(edcurve.Order),
- BaseX: utils.FromInterface(edcurve.Base.X),
- BaseY: utils.FromInterface(edcurve.Base.Y),
- ID: ecc.BW6_633,
+ Base: Coord{
+ X: utils.FromInterface(edcurve.Base.X),
+ Y: utils.FromInterface(edcurve.Base.Y),
+ },
+ ID: ecc.BW6_633,
}
}
@@ -139,9 +153,11 @@ func newEdBW761() EdCurve {
D: utils.FromInterface(edcurve.D),
Cofactor: utils.FromInterface(edcurve.Cofactor),
Order: utils.FromInterface(edcurve.Order),
- BaseX: utils.FromInterface(edcurve.Base.X),
- BaseY: utils.FromInterface(edcurve.Base.Y),
- ID: ecc.BW6_761,
+ Base: Coord{
+ X: utils.FromInterface(edcurve.Base.X),
+ Y: utils.FromInterface(edcurve.Base.Y),
+ },
+ ID: ecc.BW6_761,
}
}
@@ -156,9 +172,11 @@ func newEdBLS315() EdCurve {
D: utils.FromInterface(edcurve.D),
Cofactor: utils.FromInterface(edcurve.Cofactor),
Order: utils.FromInterface(edcurve.Order),
- BaseX: utils.FromInterface(edcurve.Base.X),
- BaseY: utils.FromInterface(edcurve.Base.Y),
- ID: ecc.BLS24_315,
+ Base: Coord{
+ X: utils.FromInterface(edcurve.Base.X),
+ Y: utils.FromInterface(edcurve.Base.Y),
+ },
+ ID: ecc.BLS24_315,
}
}
diff --git a/std/algebra/twistededwards/point.go b/std/algebra/twistededwards/point.go
index 7faf85faf4..72a712e517 100644
--- a/std/algebra/twistededwards/point.go
+++ b/std/algebra/twistededwards/point.go
@@ -27,8 +27,15 @@ type Point struct {
X, Y frontend.Variable
}
+// Neg computes the negative of a point in SNARK coordinates
+func (p *Point) Neg(api frontend.API, p1 *Point) *Point {
+ p.X = api.Neg(p1.X)
+ p.Y = p1.Y
+ return p
+}
+
// MustBeOnCurve checks if a point is on the reduced twisted Edwards curve
-// a*x^2 + y^2 = 1 + d*x^2*y^2.
+// a*x² + y² = 1 + d*x²*y².
func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) {
one := big.NewInt(1)
@@ -46,34 +53,9 @@ func (p *Point) MustBeOnCurve(api frontend.API, curve EdCurve) {
}
-// AddFixedPoint Adds two points, among which is one fixed point (the base), on a twisted edwards curve (eg jubjub)
-// p1, base, ecurve are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve
-func (p *Point) AddFixedPoint(api frontend.API, p1 *Point /*basex*/, x /*basey*/, y interface{}, curve EdCurve) *Point {
-
- // https://eprint.iacr.org/2008/013.pdf
-
- n11 := api.Mul(p1.X, y)
- n12 := api.Mul(p1.Y, x)
- n1 := api.Add(n11, n12)
-
- n21 := api.Mul(p1.Y, y)
- n22 := api.Mul(p1.X, x)
- an22 := api.Mul(n22, &curve.A)
- n2 := api.Sub(n21, an22)
-
- d11 := api.Mul(curve.D, n11, n12)
- d1 := api.Add(1, d11)
- d2 := api.Sub(1, d11)
-
- p.X = api.DivUnchecked(n1, d1)
- p.Y = api.DivUnchecked(n2, d2)
-
- return p
-}
-
-// AddGeneric Adds two points on a twisted edwards curve (eg jubjub)
+// Add Adds two points on a twisted edwards curve (eg jubjub)
// p1, p2, c are respectively: the point to add, a known base point, and the parameters of the twisted edwards curve
-func (p *Point) AddGeneric(api frontend.API, p1, p2 *Point, curve EdCurve) *Point {
+func (p *Point) Add(api frontend.API, p1, p2 *Point, curve EdCurve) *Point {
// https://eprint.iacr.org/2008/013.pdf
@@ -103,14 +85,12 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point {
u := api.Mul(p1.X, p1.Y)
v := api.Mul(p1.X, p1.X)
w := api.Mul(p1.Y, p1.Y)
- z := api.Mul(v, w)
n1 := api.Mul(2, u)
av := api.Mul(v, &curve.A)
n2 := api.Sub(w, av)
- d := api.Mul(z, curve.D)
- d1 := api.Add(1, d)
- d2 := api.Sub(1, d)
+ d1 := api.Add(w, av)
+ d2 := api.Sub(2, d1)
p.X = api.DivUnchecked(n1, d1)
p.Y = api.DivUnchecked(n2, d2)
@@ -118,55 +98,72 @@ func (p *Point) Double(api frontend.API, p1 *Point, curve EdCurve) *Point {
return p
}
-// ScalarMulNonFixedBase computes the scalar multiplication of a point on a twisted Edwards curve
+// ScalarMul computes the scalar multiplication of a point on a twisted Edwards curve
// p1: base point (as snark point)
// curve: parameters of the Edwards curve
// scal: scalar as a SNARK constraint
// Standard left to right double and add
-func (p *Point) ScalarMulNonFixedBase(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point {
+func (p *Point) ScalarMul(api frontend.API, p1 *Point, scalar frontend.Variable, curve EdCurve) *Point {
// first unpack the scalar
b := api.ToBinary(scalar)
- res := Point{
- 0,
- 1,
+ res := Point{}
+ tmp := Point{}
+ A := Point{}
+ B := Point{}
+
+ A.Double(api, p1, curve)
+ B.Add(api, &A, p1, curve)
+
+ n := len(b) - 1
+ res.X = api.Lookup2(b[n], b[n-1], 0, A.X, p1.X, B.X)
+ res.Y = api.Lookup2(b[n], b[n-1], 1, A.Y, p1.Y, B.Y)
+
+ for i := n - 2; i >= 1; i -= 2 {
+ res.Double(api, &res, curve).
+ Double(api, &res, curve)
+ tmp.X = api.Lookup2(b[i], b[i-1], 0, A.X, p1.X, B.X)
+ tmp.Y = api.Lookup2(b[i], b[i-1], 1, A.Y, p1.Y, B.Y)
+ res.Add(api, &res, &tmp, curve)
}
- for i := len(b) - 1; i >= 0; i-- {
+ if n%2 == 0 {
res.Double(api, &res, curve)
- tmp := Point{}
- tmp.AddGeneric(api, &res, p1, curve)
- res.X = api.Select(b[i], tmp.X, res.X)
- res.Y = api.Select(b[i], tmp.Y, res.Y)
+ tmp.Add(api, &res, p1, curve)
+ res.X = api.Select(b[0], tmp.X, res.X)
+ res.Y = api.Select(b[0], tmp.Y, res.Y)
}
p.X = res.X
p.Y = res.Y
+
return p
}
-// ScalarMulFixedBase computes the scalar multiplication of a point on a twisted Edwards curve
-// x, y: coordinates of the base point
-// curve: parameters of the Edwards curve
-// scal: scalar as a SNARK constraint
-// Standard left to right double and add
-func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar frontend.Variable, curve EdCurve) *Point {
+// DoubleBaseScalarMul computes s1*P1+s2*P2
+// where P1 and P2 are points on a twisted Edwards curve
+// and s1, s2 scalars.
+func (p *Point) DoubleBaseScalarMul(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve EdCurve) *Point {
- // first unpack the scalar
- b := api.ToBinary(scalar)
+ // first unpack the scalars
+ b1 := api.ToBinary(s1)
+ b2 := api.ToBinary(s2)
- res := Point{
- 0,
- 1,
- }
+ res := Point{}
+ tmp := Point{}
+ sum := Point{}
+ sum.Add(api, p1, p2, curve)
+
+ n := len(b1)
+ res.X = api.Lookup2(b1[n-1], b2[n-1], 0, p1.X, p2.X, sum.X)
+ res.Y = api.Lookup2(b1[n-1], b2[n-1], 1, p1.Y, p2.Y, sum.Y)
- for i := len(b) - 1; i >= 0; i-- {
+ for i := n - 2; i >= 0; i-- {
res.Double(api, &res, curve)
- tmp := Point{}
- tmp.AddFixedPoint(api, &res, x, y, curve)
- res.X = api.Select(b[i], tmp.X, res.X)
- res.Y = api.Select(b[i], tmp.Y, res.Y)
+ tmp.X = api.Lookup2(b1[i], b2[i], 0, p1.X, p2.X, sum.X)
+ tmp.Y = api.Lookup2(b1[i], b2[i], 1, p1.Y, p2.Y, sum.Y)
+ res.Add(api, &res, &tmp, curve)
}
p.X = res.X
@@ -174,10 +171,3 @@ func (p *Point) ScalarMulFixedBase(api frontend.API, x, y interface{}, scalar fr
return p
}
-
-// Neg computes the negative of a point in SNARK coordinates
-func (p *Point) Neg(api frontend.API, p1 *Point) *Point {
- p.X = api.Neg(p1.X)
- p.Y = p1.Y
- return p
-}
diff --git a/std/algebra/twistededwards/point_test.go b/std/algebra/twistededwards/point_test.go
index b8f9d1959b..d9387e94fd 100644
--- a/std/algebra/twistededwards/point_test.go
+++ b/std/algebra/twistededwards/point_test.go
@@ -28,7 +28,6 @@ import (
tbw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/twistededwards"
tbw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/twistededwards"
- "github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/test"
)
@@ -61,8 +60,8 @@ func TestIsOnCurve(t *testing.T) {
t.Fatal(err)
}
- witness.P.X = (params.BaseX)
- witness.P.Y = (params.BaseY)
+ witness.P.X = (params.Base.X)
+ witness.P.Y = (params.Base.Y)
assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254))
@@ -80,7 +79,10 @@ func (circuit *add) Define(api frontend.API) error {
return err
}
- res := circuit.P.AddFixedPoint(api, &circuit.P, params.BaseX, params.BaseY, params)
+ p := Point{}
+ p.X = params.Base.X
+ p.Y = params.Base.Y
+ res := circuit.P.Add(api, &circuit.P, &p, params)
api.AssertIsEqual(res.X, circuit.E.X)
api.AssertIsEqual(res.Y, circuit.E.Y)
@@ -100,8 +102,8 @@ func TestAddFixedPoint(t *testing.T) {
t.Fatal(err)
}
var base, point, expected tbn254.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
point.Set(&base)
r := big.NewInt(5)
point.ScalarMul(&point, r)
@@ -133,7 +135,7 @@ func (circuit *addGeneric) Define(api frontend.API) error {
return err
}
- res := circuit.P1.AddGeneric(api, &circuit.P1, &circuit.P2, params)
+ res := circuit.P1.Add(api, &circuit.P1, &circuit.P2, params)
api.AssertIsEqual(res.X, circuit.E.X)
api.AssertIsEqual(res.Y, circuit.E.Y)
@@ -157,8 +159,8 @@ func TestAddGeneric(t *testing.T) {
switch id {
case ecc.BN254:
var op1, op2, expected tbn254.PointAffine
- op1.X.SetBigInt(¶ms.BaseX)
- op1.Y.SetBigInt(¶ms.BaseY)
+ op1.X.SetBigInt(¶ms.Base.X)
+ op1.Y.SetBigInt(¶ms.Base.Y)
op2.Set(&op1)
r1 := big.NewInt(5)
r2 := big.NewInt(12)
@@ -173,8 +175,8 @@ func TestAddGeneric(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BLS12_381:
var op1, op2, expected tbls12381.PointAffine
- op1.X.SetBigInt(¶ms.BaseX)
- op1.Y.SetBigInt(¶ms.BaseY)
+ op1.X.SetBigInt(¶ms.Base.X)
+ op1.Y.SetBigInt(¶ms.Base.Y)
op2.Set(&op1)
r1 := big.NewInt(5)
r2 := big.NewInt(12)
@@ -189,8 +191,8 @@ func TestAddGeneric(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BLS12_377:
var op1, op2, expected tbls12377.PointAffine
- op1.X.SetBigInt(¶ms.BaseX)
- op1.Y.SetBigInt(¶ms.BaseY)
+ op1.X.SetBigInt(¶ms.Base.X)
+ op1.Y.SetBigInt(¶ms.Base.Y)
op2.Set(&op1)
r1 := big.NewInt(5)
r2 := big.NewInt(12)
@@ -205,8 +207,8 @@ func TestAddGeneric(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BLS24_315:
var op1, op2, expected tbls24315.PointAffine
- op1.X.SetBigInt(¶ms.BaseX)
- op1.Y.SetBigInt(¶ms.BaseY)
+ op1.X.SetBigInt(¶ms.Base.X)
+ op1.Y.SetBigInt(¶ms.Base.Y)
op2.Set(&op1)
r1 := big.NewInt(5)
r2 := big.NewInt(12)
@@ -221,8 +223,8 @@ func TestAddGeneric(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BW6_633:
var op1, op2, expected tbw6633.PointAffine
- op1.X.SetBigInt(¶ms.BaseX)
- op1.Y.SetBigInt(¶ms.BaseY)
+ op1.X.SetBigInt(¶ms.Base.X)
+ op1.Y.SetBigInt(¶ms.Base.Y)
op2.Set(&op1)
r1 := big.NewInt(5)
r2 := big.NewInt(12)
@@ -237,8 +239,8 @@ func TestAddGeneric(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BW6_761:
var op1, op2, expected tbw6761.PointAffine
- op1.X.SetBigInt(¶ms.BaseX)
- op1.Y.SetBigInt(¶ms.BaseY)
+ op1.X.SetBigInt(¶ms.Base.X)
+ op1.Y.SetBigInt(¶ms.Base.Y)
op2.Set(&op1)
r1 := big.NewInt(5)
r2 := big.NewInt(12)
@@ -299,8 +301,8 @@ func TestDouble(t *testing.T) {
switch id {
case ecc.BN254:
var base, expected tbn254.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
expected.Double(&base)
witness.P.X = (base.X.String())
witness.P.Y = (base.Y.String())
@@ -308,8 +310,8 @@ func TestDouble(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BLS12_381:
var base, expected tbls12381.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
expected.Double(&base)
witness.P.X = (base.X.String())
witness.P.Y = (base.Y.String())
@@ -317,8 +319,8 @@ func TestDouble(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BLS12_377:
var base, expected tbls12377.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
expected.Double(&base)
witness.P.X = (base.X.String())
witness.P.Y = (base.Y.String())
@@ -326,8 +328,8 @@ func TestDouble(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BLS24_315:
var base, expected tbls24315.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
expected.Double(&base)
witness.P.X = (base.X.String())
witness.P.Y = (base.Y.String())
@@ -335,8 +337,8 @@ func TestDouble(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BW6_633:
var base, expected tbw6633.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
expected.Double(&base)
witness.P.X = (base.X.String())
witness.P.Y = (base.Y.String())
@@ -344,8 +346,8 @@ func TestDouble(t *testing.T) {
witness.E.Y = (expected.Y.String())
case ecc.BW6_761:
var base, expected tbw6761.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
expected.Double(&base)
witness.P.X = (base.X.String())
witness.P.Y = (base.Y.String())
@@ -375,8 +377,10 @@ func (circuit *scalarMulFixed) Define(api frontend.API) error {
return err
}
- var resFixed Point
- resFixed.ScalarMulFixedBase(api, params.BaseX, params.BaseY, circuit.S, params)
+ var resFixed, p Point
+ p.X = params.Base.X
+ p.Y = params.Base.Y
+ resFixed.ScalarMul(api, &p, circuit.S, params)
api.AssertIsEqual(resFixed.X, circuit.E.X)
api.AssertIsEqual(resFixed.Y, circuit.E.Y)
@@ -391,8 +395,7 @@ func TestScalarMulFixed(t *testing.T) {
var circuit, witness scalarMulFixed
// generate witness data
- //for _, id := range ecc.Implemented() {
- for _, id := range []ecc.ID{ecc.BLS12_377} {
+ for _, id := range ecc.Implemented() {
params, err := NewEdCurve(id)
if err != nil {
@@ -402,8 +405,8 @@ func TestScalarMulFixed(t *testing.T) {
switch id {
case ecc.BN254:
var base, expected tbn254.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
r := big.NewInt(928323002)
expected.ScalarMul(&base, r)
witness.E.X = (expected.X.String())
@@ -411,8 +414,8 @@ func TestScalarMulFixed(t *testing.T) {
witness.S = (r)
case ecc.BLS12_381:
var base, expected tbls12381.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
r := big.NewInt(928323002)
expected.ScalarMul(&base, r)
witness.E.X = (expected.X.String())
@@ -420,8 +423,8 @@ func TestScalarMulFixed(t *testing.T) {
witness.S = (r)
case ecc.BLS12_377:
var base, expected tbls12377.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
r := big.NewInt(928323002)
expected.ScalarMul(&base, r)
witness.E.X = (expected.X.String())
@@ -429,8 +432,8 @@ func TestScalarMulFixed(t *testing.T) {
witness.S = (r)
case ecc.BLS24_315:
var base, expected tbls24315.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
r := big.NewInt(928323002)
expected.ScalarMul(&base, r)
witness.E.X = (expected.X.String())
@@ -438,8 +441,8 @@ func TestScalarMulFixed(t *testing.T) {
witness.S = (r)
case ecc.BW6_633:
var base, expected tbw6633.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
r := big.NewInt(928323002)
expected.ScalarMul(&base, r)
witness.E.X = (expected.X.String())
@@ -447,8 +450,8 @@ func TestScalarMulFixed(t *testing.T) {
witness.S = (r)
case ecc.BW6_761:
var base, expected tbw6761.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
r := big.NewInt(928323002)
expected.ScalarMul(&base, r)
witness.E.X = (expected.X.String())
@@ -457,7 +460,7 @@ func TestScalarMulFixed(t *testing.T) {
}
// creates r1cs
- assert.SolvingSucceeded(&circuit, &witness, test.WithBackends(backend.PLONK), test.WithCurves(id))
+ assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id))
}
}
@@ -475,7 +478,7 @@ func (circuit *scalarMulGeneric) Define(api frontend.API) error {
return err
}
- resGeneric := circuit.P.ScalarMulNonFixedBase(api, &circuit.P, circuit.S, params)
+ resGeneric := circuit.P.ScalarMul(api, &circuit.P, circuit.S, params)
api.AssertIsEqual(resGeneric.X, circuit.E.X)
api.AssertIsEqual(resGeneric.Y, circuit.E.Y)
@@ -490,28 +493,292 @@ func TestScalarMulGeneric(t *testing.T) {
var circuit, witness scalarMulGeneric
// generate witness data
- params, err := NewEdCurve(ecc.BN254)
+ for _, id := range ecc.Implemented() {
+
+ params, err := NewEdCurve(id)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ switch id {
+ case ecc.BN254:
+ var base, point, expected tbn254.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s := big.NewInt(902)
+ point.ScalarMul(&base, s) // random point
+ r := big.NewInt(230928302)
+ expected.ScalarMul(&point, r)
+
+ // populate witness
+ witness.P.X = (point.X.String())
+ witness.P.Y = (point.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S = (r)
+ case ecc.BLS12_377:
+ var base, point, expected tbls12377.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s := big.NewInt(902)
+ point.ScalarMul(&base, s) // random point
+ r := big.NewInt(230928302)
+ expected.ScalarMul(&point, r)
+
+ // populate witness
+ witness.P.X = (point.X.String())
+ witness.P.Y = (point.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S = (r)
+ case ecc.BLS12_381:
+ var base, point, expected tbls12381.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s := big.NewInt(902)
+ point.ScalarMul(&base, s) // random point
+ r := big.NewInt(230928302)
+ expected.ScalarMul(&point, r)
+
+ // populate witness
+ witness.P.X = (point.X.String())
+ witness.P.Y = (point.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S = (r)
+ case ecc.BLS24_315:
+ var base, point, expected tbls24315.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s := big.NewInt(902)
+ point.ScalarMul(&base, s) // random point
+ r := big.NewInt(230928302)
+ expected.ScalarMul(&point, r)
+
+ // populate witness
+ witness.P.X = (point.X.String())
+ witness.P.Y = (point.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S = (r)
+ case ecc.BW6_761:
+ var base, point, expected tbw6761.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s := big.NewInt(902)
+ point.ScalarMul(&base, s) // random point
+ r := big.NewInt(230928302)
+ expected.ScalarMul(&point, r)
+
+ // populate witness
+ witness.P.X = (point.X.String())
+ witness.P.Y = (point.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S = (r)
+ case ecc.BW6_633:
+ var base, point, expected tbw6633.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s := big.NewInt(902)
+ point.ScalarMul(&base, s) // random point
+ r := big.NewInt(230928302)
+ expected.ScalarMul(&point, r)
+
+ // populate witness
+ witness.P.X = (point.X.String())
+ witness.P.Y = (point.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S = (r)
+ }
+
+ // creates r1cs
+ assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id))
+ }
+}
+
+//
+
+type doubleScalarMulGeneric struct {
+ P1, P2, E Point
+ S1, S2 frontend.Variable
+}
+
+func (circuit *doubleScalarMulGeneric) Define(api frontend.API) error {
+
+ // get edwards curve params
+ params, err := NewEdCurve(api.Curve())
if err != nil {
- t.Fatal(err)
+ return err
}
- var base, point, expected tbn254.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
- s := big.NewInt(902)
- point.ScalarMul(&base, s) // random point
- r := big.NewInt(230928302)
- expected.ScalarMul(&point, r)
- // populate witness
- witness.P.X = (point.X.String())
- witness.P.Y = (point.Y.String())
- witness.E.X = (expected.X.String())
- witness.E.Y = (expected.Y.String())
- witness.S = (r)
+ resGeneric := circuit.P1.DoubleBaseScalarMul(api, &circuit.P1, &circuit.P2, circuit.S1, circuit.S2, params)
- // creates r1cs
- assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254))
+ api.AssertIsEqual(resGeneric.X, circuit.E.X)
+ api.AssertIsEqual(resGeneric.Y, circuit.E.Y)
+
+ return nil
+}
+
+func TestDoubleScalarMulGeneric(t *testing.T) {
+
+ assert := test.NewAssert(t)
+
+ var circuit, witness doubleScalarMulGeneric
+
+ // generate witness data
+ for _, id := range ecc.Implemented() {
+
+ params, err := NewEdCurve(id)
+ if err != nil {
+ t.Fatal(err)
+ }
+ switch id {
+ case ecc.BN254:
+ var base, point1, point2, tmp, expected tbn254.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s1 := big.NewInt(902)
+ s2 := big.NewInt(891)
+ point1.ScalarMul(&base, s1) // random point
+ point2.ScalarMul(&base, s2) // random point
+ r1 := big.NewInt(230928303)
+ r2 := big.NewInt(2830309)
+ tmp.ScalarMul(&point1, r1)
+ expected.ScalarMul(&point2, r2).
+ Add(&expected, &tmp)
+
+ // populate witness
+ witness.P1.X = (point1.X.String())
+ witness.P1.Y = (point1.Y.String())
+ witness.P2.X = (point2.X.String())
+ witness.P2.Y = (point2.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S1 = (r1)
+ witness.S2 = (r2)
+ case ecc.BLS12_377:
+ var base, point1, point2, tmp, expected tbls12377.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s1 := big.NewInt(902)
+ s2 := big.NewInt(891)
+ point1.ScalarMul(&base, s1) // random point
+ point2.ScalarMul(&base, s2) // random point
+ r1 := big.NewInt(230928303)
+ r2 := big.NewInt(2830309)
+ tmp.ScalarMul(&point1, r1)
+ expected.ScalarMul(&point2, r2).
+ Add(&expected, &tmp)
+
+ // populate witness
+ witness.P1.X = (point1.X.String())
+ witness.P1.Y = (point1.Y.String())
+ witness.P2.X = (point2.X.String())
+ witness.P2.Y = (point2.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S1 = (r1)
+ witness.S2 = (r2)
+ case ecc.BLS12_381:
+ var base, point1, point2, tmp, expected tbls12381.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s1 := big.NewInt(902)
+ s2 := big.NewInt(891)
+ point1.ScalarMul(&base, s1) // random point
+ point2.ScalarMul(&base, s2) // random point
+ r1 := big.NewInt(230928303)
+ r2 := big.NewInt(2830309)
+ tmp.ScalarMul(&point1, r1)
+ expected.ScalarMul(&point2, r2).
+ Add(&expected, &tmp)
+
+ // populate witness
+ witness.P1.X = (point1.X.String())
+ witness.P1.Y = (point1.Y.String())
+ witness.P2.X = (point2.X.String())
+ witness.P2.Y = (point2.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S1 = (r1)
+ witness.S2 = (r2)
+ case ecc.BLS24_315:
+ var base, point1, point2, tmp, expected tbls24315.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s1 := big.NewInt(902)
+ s2 := big.NewInt(891)
+ point1.ScalarMul(&base, s1) // random point
+ point2.ScalarMul(&base, s2) // random point
+ r1 := big.NewInt(230928303)
+ r2 := big.NewInt(2830309)
+ tmp.ScalarMul(&point1, r1)
+ expected.ScalarMul(&point2, r2).
+ Add(&expected, &tmp)
+
+ // populate witness
+ witness.P1.X = (point1.X.String())
+ witness.P1.Y = (point1.Y.String())
+ witness.P2.X = (point2.X.String())
+ witness.P2.Y = (point2.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S1 = (r1)
+ witness.S2 = (r2)
+ case ecc.BW6_761:
+ var base, point1, point2, tmp, expected tbw6761.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s1 := big.NewInt(902)
+ s2 := big.NewInt(891)
+ point1.ScalarMul(&base, s1) // random point
+ point2.ScalarMul(&base, s2) // random point
+ r1 := big.NewInt(230928303)
+ r2 := big.NewInt(2830309)
+ tmp.ScalarMul(&point1, r1)
+ expected.ScalarMul(&point2, r2).
+ Add(&expected, &tmp)
+
+ // populate witness
+ witness.P1.X = (point1.X.String())
+ witness.P1.Y = (point1.Y.String())
+ witness.P2.X = (point2.X.String())
+ witness.P2.Y = (point2.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S1 = (r1)
+ witness.S2 = (r2)
+ case ecc.BW6_633:
+ var base, point1, point2, tmp, expected tbw6633.PointAffine
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
+ s1 := big.NewInt(902)
+ s2 := big.NewInt(891)
+ point1.ScalarMul(&base, s1) // random point
+ point2.ScalarMul(&base, s2) // random point
+ r1 := big.NewInt(230928303)
+ r2 := big.NewInt(2830309)
+ tmp.ScalarMul(&point1, r1)
+ expected.ScalarMul(&point2, r2).
+ Add(&expected, &tmp)
+
+ // populate witness
+ witness.P1.X = (point1.X.String())
+ witness.P1.Y = (point1.Y.String())
+ witness.P2.X = (point2.X.String())
+ witness.P2.Y = (point2.Y.String())
+ witness.E.X = (expected.X.String())
+ witness.E.Y = (expected.Y.String())
+ witness.S1 = (r1)
+ witness.S2 = (r2)
+ }
+
+ // creates r1cs
+ assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(id))
+ }
}
type neg struct {
@@ -537,8 +804,8 @@ func TestNeg(t *testing.T) {
t.Fatal(err)
}
var base, expected tbn254.PointAffine
- base.X.SetBigInt(¶ms.BaseX)
- base.Y.SetBigInt(¶ms.BaseY)
+ base.X.SetBigInt(¶ms.Base.X)
+ base.Y.SetBigInt(¶ms.Base.Y)
expected.Neg(&base)
// generate witness
diff --git a/std/fiat-shamir/transcript.go b/std/fiat-shamir/transcript.go
index 3491a548a5..97a35ca62a 100644
--- a/std/fiat-shamir/transcript.go
+++ b/std/fiat-shamir/transcript.go
@@ -91,8 +91,8 @@ func (t *Transcript) Bind(challengeID string, values []frontend.Variable) error
// ComputeChallenge computes the challenge corresponding to the given name.
// The resulting variable is:
-// * H(name || previous_challenge || binded_values...) if the challenge is not the first one
-// * H(name || binded_values... ) if it's is the first challenge
+// * H(name ∥ previous_challenge ∥ binded_values...) if the challenge is not the first one
+// * H(name ∥ binded_values... ) if it's is the first challenge
func (t *Transcript) ComputeChallenge(challengeID string) (frontend.Variable, error) {
challenge, ok := t.challenges[challengeID]
diff --git a/std/signature/eddsa/eddsa.go b/std/signature/eddsa/eddsa.go
index c46b0273c3..c266b33a7e 100644
--- a/std/signature/eddsa/eddsa.go
+++ b/std/signature/eddsa/eddsa.go
@@ -58,37 +58,35 @@ func Verify(api frontend.API, sig Signature, msg frontend.Variable, pubKey Publi
return err
}
hash.Write(data...)
- //hramConstant := hash.Sum(data...)
hramConstant := hash.Sum()
- // lhs = [S]G
- cofactor := pubKey.Curve.Cofactor.Uint64()
- lhs := twistededwards.Point{}
- lhs.ScalarMulFixedBase(api, pubKey.Curve.BaseX, pubKey.Curve.BaseY, sig.S, pubKey.Curve)
- lhs.MustBeOnCurve(api, pubKey.Curve)
+ base := twistededwards.Point{}
+ base.X = pubKey.Curve.Base.X
+ base.Y = pubKey.Curve.Base.Y
- // rhs = R+[H(R,A,M)]*A
- rhs := twistededwards.Point{}
- rhs.ScalarMulNonFixedBase(api, &pubKey.A, hramConstant, pubKey.Curve).
- AddGeneric(api, &rhs, &sig.R, pubKey.Curve)
- rhs.MustBeOnCurve(api, pubKey.Curve)
+ //[S]G-[H(R,A,M)]*A
+ cofactor := pubKey.Curve.Cofactor.Uint64()
+ Q := twistededwards.Point{}
+ _A := twistededwards.Point{}
+ _A.Neg(api, &pubKey.A)
+ Q.DoubleBaseScalarMul(api, &base, &_A, sig.S, hramConstant, pubKey.Curve)
+ Q.MustBeOnCurve(api, pubKey.Curve)
- // lhs-rhs
- rhs.Neg(api, &rhs).AddGeneric(api, &lhs, &rhs, pubKey.Curve)
+ //[S]G-[H(R,A,M)]*A-R
+ Q.Neg(api, &Q).Add(api, &Q, &sig.R, pubKey.Curve)
- // [cofactor](lhs-rhs)
+ // [cofactor]*(lhs-rhs)
switch cofactor {
case 4:
- rhs.Double(api, &rhs, pubKey.Curve).
- Double(api, &rhs, pubKey.Curve)
+ Q.Double(api, &Q, pubKey.Curve).
+ Double(api, &Q, pubKey.Curve)
case 8:
- rhs.Double(api, &rhs, pubKey.Curve).
- Double(api, &rhs, pubKey.Curve).Double(api, &rhs, pubKey.Curve)
+ Q.Double(api, &Q, pubKey.Curve).
+ Double(api, &Q, pubKey.Curve).Double(api, &Q, pubKey.Curve)
}
- //rhs.MustBeOnCurve(api, pubKey.Curve)
- api.AssertIsEqual(rhs.X, 0)
- api.AssertIsEqual(rhs.Y, 1)
+ api.AssertIsEqual(Q.X, 0)
+ api.AssertIsEqual(Q.Y, 1)
return nil
}
diff --git a/test/kzg_srs.go b/test/kzg_srs.go
index 7f96d5bb03..4e8225ce80 100644
--- a/test/kzg_srs.go
+++ b/test/kzg_srs.go
@@ -34,7 +34,7 @@ import (
const srsCachedSize = (1 << 14) + 3
// NewKZGSRS uses ccs nb variables and nb constraints to initialize a kzg srs
-// for sizes < 2^15, returns a pre-computed cached SRS
+// for sizes < 2¹⁵, returns a pre-computed cached SRS
//
// /!\ warning /!\: this method is here for convenience only: in production, a SRS generated through MPC should be used.
func NewKZGSRS(ccs frontend.CompiledConstraintSystem) (kzg.SRS, error) {