From b08792364d03f7b7c19a9b7fed1235b6ae8d06e8 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 12 Mar 2024 14:35:28 +0000 Subject: [PATCH 1/9] XXX: full sumcheck-scalarmul --- std/recursion/sumcheck/scalarmul_test.go | 183 +++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 std/recursion/sumcheck/scalarmul_test.go diff --git a/std/recursion/sumcheck/scalarmul_test.go b/std/recursion/sumcheck/scalarmul_test.go new file mode 100644 index 0000000000..cb59619c07 --- /dev/null +++ b/std/recursion/sumcheck/scalarmul_test.go @@ -0,0 +1,183 @@ +package sumcheck + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/secp256k1" + fr_secp256k1 "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/test" +) + +type ScalarMulCircuit[Base, Scalars emulated.FieldParams] struct { + Points []sw_emulated.AffinePoint[Base] + Scalars []emulated.Element[Scalars] + + nbScalarBits int +} + +func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { + if len(c.Points) != len(c.Scalars) { + return fmt.Errorf("len(inputs) != len(scalars)") + } + baseApi, err := emulated.NewField[B](api) + if err != nil { + return fmt.Errorf("new base field: %w", err) + } + scalarApi, err := emulated.NewField[S](api) + if err != nil { + return fmt.Errorf("new scalar field: %w", err) + } + for i := range c.Points { + step, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points[i], c.Scalars[i]) + if err != nil { + return fmt.Errorf("hint scalar mul steps: %w", err) + } + _ = step + } + return nil +} + +func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, + baseApi *emulated.Field[B], scalarApi *emulated.Field[S], + nbScalarBits int, + point sw_emulated.AffinePoint[B], scalar emulated.Element[S]) ([][6]*emulated.Element[B], error) { + var fp B + var fr S + inputs := []frontend.Variable{fp.BitsPerLimb(), fp.NbLimbs()} + inputs = append(inputs, baseApi.Modulus().Limbs...) + inputs = append(inputs, point.X.Limbs...) + inputs = append(inputs, point.Y.Limbs...) + inputs = append(inputs, fr.BitsPerLimb(), fr.NbLimbs()) + inputs = append(inputs, scalarApi.Modulus().Limbs...) + inputs = append(inputs, scalar.Limbs...) + nbRes := nbScalarBits * int(fp.NbLimbs()) * 6 + hintRes, err := api.Compiler().NewHint(hintScalarMulSteps, nbRes, inputs...) + if err != nil { + return nil, fmt.Errorf("new hint: %w", err) + } + res := make([][6]*emulated.Element[B], nbScalarBits) + for i := range res { + for j := 0; j < 6; j++ { + limbs := hintRes[i*(6*int(fp.NbLimbs()))+j*int(fp.NbLimbs()) : i*(6*int(fp.NbLimbs()))+(j+1)*int(fp.NbLimbs())] + res[i][j] = baseApi.NewElement(limbs) + } + } + return res, nil +} + +func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { + nbBits := int(inputs[0].Int64()) + nbLimbs := int(inputs[1].Int64()) + fpLimbs := inputs[2 : 2+nbLimbs] + xLimbs := inputs[2+nbLimbs : 2+2*nbLimbs] + yLimbs := inputs[2+2*nbLimbs : 2+3*nbLimbs] + nbScalarBits := int(inputs[2+3*nbLimbs].Int64()) + nbScalarLimbs := int(inputs[3+3*nbLimbs].Int64()) + frLimbs := inputs[4+3*nbLimbs : 4+3*nbLimbs+nbScalarLimbs] + scalarLimbs := inputs[4+3*nbLimbs+nbScalarLimbs : 4+3*nbLimbs+2*nbScalarLimbs] + + x := new(big.Int) + y := new(big.Int) + fp := new(big.Int) + fr := new(big.Int) + scalar := new(big.Int) + if err := recompose(fpLimbs, uint(nbBits), fp); err != nil { + return fmt.Errorf("recompose fp: %w", err) + } + if err := recompose(frLimbs, uint(nbScalarBits), fr); err != nil { + return fmt.Errorf("recompose fr: %w", err) + } + if err := recompose(xLimbs, uint(nbBits), x); err != nil { + return fmt.Errorf("recompose x: %w", err) + } + if err := recompose(yLimbs, uint(nbBits), y); err != nil { + return fmt.Errorf("recompose y: %w", err) + } + if err := recompose(scalarLimbs, uint(nbScalarBits), scalar); err != nil { + return fmt.Errorf("recompose scalar: %w", err) + } + fmt.Println(fp, fr, x, y, scalar) + + scalarLength := len(outputs) / (6 * nbLimbs) + return nil +} + +func recompose(inputs []*big.Int, nbBits uint, res *big.Int) error { + if len(inputs) == 0 { + return fmt.Errorf("zero length slice input") + } + if res == nil { + return fmt.Errorf("result not initialized") + } + res.SetUint64(0) + for i := range inputs { + res.Lsh(res, nbBits) + res.Add(res, inputs[len(inputs)-i-1]) + } + // TODO @gbotrel mod reduce ? + return nil +} + +func decompose(input *big.Int, nbBits uint, res []*big.Int) error { + // limb modulus + if input.BitLen() > len(res)*int(nbBits) { + return fmt.Errorf("decomposed integer does not fit into res") + } + for _, r := range res { + if r == nil { + return fmt.Errorf("result slice element uninitalized") + } + } + base := new(big.Int).Lsh(big.NewInt(1), nbBits) + tmp := new(big.Int).Set(input) + for i := 0; i < len(res); i++ { + res[i].Mod(tmp, base) + tmp.Rsh(tmp, nbBits) + } + return nil +} + +func TestScalarMul(t *testing.T) { + assert := test.NewAssert(t) + type B = emparams.Secp256k1Fp + type S = emparams.Secp256k1Fr + t.Log(B{}.Modulus(), S{}.Modulus()) + var P secp256k1.G1Affine + var s fr_secp256k1.Element + nbInputs := 1 << 0 + nbScalarBits := 2 + scalarBound := new(big.Int).Lsh(big.NewInt(1), uint(nbScalarBits)) + points := make([]sw_emulated.AffinePoint[B], nbInputs) + scalars := make([]emulated.Element[S], nbInputs) + for i := range points { + s.SetRandom() + P.ScalarMultiplicationBase(s.BigInt(new(big.Int))) + sc, _ := rand.Int(rand.Reader, scalarBound) + t.Log(P.X.String(), P.Y.String(), sc.String()) + points[i] = sw_emulated.AffinePoint[B]{ + X: emulated.ValueOf[B](P.X), + Y: emulated.ValueOf[B](P.Y), + } + scalars[i] = emulated.ValueOf[S](sc) + } + circuit := ScalarMulCircuit[B, S]{ + Points: make([]sw_emulated.AffinePoint[B], nbInputs), + Scalars: make([]emulated.Element[S], nbInputs), + nbScalarBits: nbScalarBits, + } + witness := ScalarMulCircuit[B, S]{ + Points: points, + Scalars: scalars, + } + err := test.IsSolved(&circuit, &witness, ecc.BLS12_377.ScalarField()) + assert.NoError(err) +} From 1ed7dd7befb18044a7cbf7e8627c3c3c7667298d Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Tue, 12 Mar 2024 17:52:56 +0000 Subject: [PATCH 2/9] XXX: full sumcheck-scalarmul --- std/recursion/sumcheck/scalarmul_test.go | 79 ++++++++++++++++++++---- 1 file changed, 66 insertions(+), 13 deletions(-) diff --git a/std/recursion/sumcheck/scalarmul_test.go b/std/recursion/sumcheck/scalarmul_test.go index cb59619c07..4a023c0190 100644 --- a/std/recursion/sumcheck/scalarmul_test.go +++ b/std/recursion/sumcheck/scalarmul_test.go @@ -11,12 +11,17 @@ import ( fr_secp256k1 "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/math/emulated/emparams" "github.com/consensys/gnark/test" ) +type ProjectivePoint[Base emulated.FieldParams] struct { + X, Y, Z emulated.Element[Base] +} + type ScalarMulCircuit[Base, Scalars emulated.FieldParams] struct { Points []sw_emulated.AffinePoint[Base] Scalars []emulated.Element[Scalars] @@ -37,11 +42,11 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { return fmt.Errorf("new scalar field: %w", err) } for i := range c.Points { - step, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points[i], c.Scalars[i]) + results, accs, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points[i], c.Scalars[i]) if err != nil { return fmt.Errorf("hint scalar mul steps: %w", err) } - _ = step + _, _ = results, accs } return nil } @@ -49,7 +54,7 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, baseApi *emulated.Field[B], scalarApi *emulated.Field[S], nbScalarBits int, - point sw_emulated.AffinePoint[B], scalar emulated.Element[S]) ([][6]*emulated.Element[B], error) { + point sw_emulated.AffinePoint[B], scalar emulated.Element[S]) (results []ProjectivePoint[B], accumulators []ProjectivePoint[B], err error) { var fp B var fr S inputs := []frontend.Variable{fp.BitsPerLimb(), fp.NbLimbs()} @@ -62,16 +67,28 @@ func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, nbRes := nbScalarBits * int(fp.NbLimbs()) * 6 hintRes, err := api.Compiler().NewHint(hintScalarMulSteps, nbRes, inputs...) if err != nil { - return nil, fmt.Errorf("new hint: %w", err) + return nil, nil, fmt.Errorf("new hint: %w", err) } - res := make([][6]*emulated.Element[B], nbScalarBits) + res := make([]ProjectivePoint[B], nbScalarBits) + acc := make([]ProjectivePoint[B], nbScalarBits) for i := range res { - for j := 0; j < 6; j++ { + coords := make([]*emulated.Element[B], 6) + for j := range coords { limbs := hintRes[i*(6*int(fp.NbLimbs()))+j*int(fp.NbLimbs()) : i*(6*int(fp.NbLimbs()))+(j+1)*int(fp.NbLimbs())] - res[i][j] = baseApi.NewElement(limbs) + coords[j] = baseApi.NewElement(limbs) + } + res[i] = ProjectivePoint[B]{ + X: *coords[0], + Y: *coords[1], + Z: *coords[2], + } + acc[i] = ProjectivePoint[B]{ + X: *coords[3], + Y: *coords[4], + Z: *coords[5], } } - return res, nil + return res, acc, nil } func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { @@ -105,9 +122,44 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err if err := recompose(scalarLimbs, uint(nbScalarBits), scalar); err != nil { return fmt.Errorf("recompose scalar: %w", err) } - fmt.Println(fp, fr, x, y, scalar) scalarLength := len(outputs) / (6 * nbLimbs) + accX := new(big.Int).Set(x) + accY := new(big.Int).Set(y) + accZ := big.NewInt(1) + resultX := big.NewInt(0) + resultY := big.NewInt(1) + resultZ := big.NewInt(0) + api := newBigIntEngine(fp) + selector := new(big.Int) + + for i := 0; i < scalarLength; i++ { + // selector := scalar.And() + selector.And(scalar, big.NewInt(1)) + scalar.Rsh(scalar, 1) + tmpX, tmpY, tmpZ := projAdd(api, accX, accY, accZ, resultX, resultY, resultZ) + resultX, resultY, resultZ = projSelect(api, selector, tmpX, tmpY, tmpZ, resultX, resultY, resultZ) + accX, accY, accZ = projDbl(api, accX, accY, accZ) + if err := decompose(resultX, uint(nbBits), outputs[i*6*nbLimbs:i*6*nbLimbs+nbLimbs]); err != nil { + return fmt.Errorf("decompose resultX: %w", err) + } + if err := decompose(resultY, uint(nbBits), outputs[i*6*nbLimbs+nbLimbs:i*6*nbLimbs+2*nbLimbs]); err != nil { + return fmt.Errorf("decompose resultY: %w", err) + } + if err := decompose(resultZ, uint(nbBits), outputs[i*6*nbLimbs+2*nbLimbs:i*6*nbLimbs+3*nbLimbs]); err != nil { + return fmt.Errorf("decompose resultZ: %w", err) + } + if err := decompose(accX, uint(nbBits), outputs[i*6*nbLimbs+3*nbLimbs:i*6*nbLimbs+4*nbLimbs]); err != nil { + return fmt.Errorf("decompose accX: %w", err) + } + if err := decompose(accY, uint(nbBits), outputs[i*6*nbLimbs+4*nbLimbs:i*6*nbLimbs+5*nbLimbs]); err != nil { + return fmt.Errorf("decompose accY: %w", err) + } + if err := decompose(accZ, uint(nbBits), outputs[i*6*nbLimbs+5*nbLimbs:(i+1)*6*nbLimbs]); err != nil { + return fmt.Errorf("decompose accZ: %w", err) + } + } + return nil } @@ -150,19 +202,19 @@ func TestScalarMul(t *testing.T) { assert := test.NewAssert(t) type B = emparams.Secp256k1Fp type S = emparams.Secp256k1Fr - t.Log(B{}.Modulus(), S{}.Modulus()) var P secp256k1.G1Affine var s fr_secp256k1.Element - nbInputs := 1 << 0 - nbScalarBits := 2 + nbInputs := 1 << 2 + nbScalarBits := 256 scalarBound := new(big.Int).Lsh(big.NewInt(1), uint(nbScalarBits)) points := make([]sw_emulated.AffinePoint[B], nbInputs) scalars := make([]emulated.Element[S], nbInputs) for i := range points { + P.ScalarMultiplicationBase(big.NewInt(1)) s.SetRandom() P.ScalarMultiplicationBase(s.BigInt(new(big.Int))) sc, _ := rand.Int(rand.Reader, scalarBound) - t.Log(P.X.String(), P.Y.String(), sc.String()) + // t.Log(P.X.String(), P.Y.String(), sc.String()) points[i] = sw_emulated.AffinePoint[B]{ X: emulated.ValueOf[B](P.X), Y: emulated.ValueOf[B](P.Y), @@ -180,4 +232,5 @@ func TestScalarMul(t *testing.T) { } err := test.IsSolved(&circuit, &witness, ecc.BLS12_377.ScalarField()) assert.NoError(err) + frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) } From 9d5b61d8bb9c121172c1bde62509831dfd77afa7 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Fri, 15 Mar 2024 12:40:44 +0000 Subject: [PATCH 3/9] XXX: full sumcheck-scalarmul --- std/recursion/sumcheck/scalarmul_test.go | 75 +++++++++++++++++++++++- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/std/recursion/sumcheck/scalarmul_test.go b/std/recursion/sumcheck/scalarmul_test.go index 4a023c0190..3af5918d59 100644 --- a/std/recursion/sumcheck/scalarmul_test.go +++ b/std/recursion/sumcheck/scalarmul_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "fmt" "math/big" + stdbits "math/bits" "testing" "github.com/consensys/gnark-crypto/ecc" @@ -12,9 +13,13 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/std/algebra" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/math/emulated/emparams" + "github.com/consensys/gnark/std/math/polynomial" + "github.com/consensys/gnark/std/recursion" "github.com/consensys/gnark/test" ) @@ -30,6 +35,7 @@ type ScalarMulCircuit[Base, Scalars emulated.FieldParams] struct { } func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { + var fp B if len(c.Points) != len(c.Scalars) { return fmt.Errorf("len(inputs) != len(scalars)") } @@ -41,13 +47,78 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { if err != nil { return fmt.Errorf("new scalar field: %w", err) } + poly, err := polynomial.New[B](api) + if err != nil { + return fmt.Errorf("new polynomial: %w", err) + } + // we use curve for marshaling points and scalars + curve, err := algebra.GetCurve[S, sw_emulated.AffinePoint[B]](api) + if err != nil { + return fmt.Errorf("get curve: %w", err) + } + fs, err := recursion.NewTranscript(api, fp.Modulus(), []string{"alpha", "beta"}) + if err != nil { + return fmt.Errorf("new transcript: %w", err) + } + // compute the all double-and-add steps for each scalar multiplication + var results, accs []ProjectivePoint[B] for i := range c.Points { - results, accs, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points[i], c.Scalars[i]) + if err := fs.Bind("alpha", curve.MarshalScalar(c.Scalars[i])); err != nil { + return fmt.Errorf("bind scalar %d alpha: %w", i, err) + } + if err := fs.Bind("alpha", curve.MarshalG1(c.Points[i])); err != nil { + return fmt.Errorf("bind point %d alpha: %w", i, err) + } + result, acc, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points[i], c.Scalars[i]) if err != nil { return fmt.Errorf("hint scalar mul steps: %w", err) } - _, _ = results, accs + results = append(results, result...) + accs = append(accs, acc...) } + // derive the randomness for random linear combination + alphaNative, err := fs.ComputeChallenge("alpha") + if err != nil { + return fmt.Errorf("compute challenge alpha: %w", err) + } + alphaBts := bits.ToBinary(api, alphaNative, bits.WithNbDigits(fp.Modulus().BitLen())) + alpha1 := baseApi.FromBits(alphaBts...) + alpha2 := baseApi.Mul(alpha1, alpha1) + alpha3 := baseApi.Mul(alpha1, alpha2) + alpha4 := baseApi.Mul(alpha1, alpha3) + alpha5 := baseApi.Mul(alpha1, alpha4) + claimed := make([]*emulated.Element[B], len(results)) + // compute the random linear combinations of the intermediate results provided by the hint + for i := range results { + claimed[i] = baseApi.Sum( + &accs[i].X, + baseApi.MulNoReduce(alpha1, &accs[i].Y), + baseApi.MulNoReduce(alpha2, &accs[i].Z), + baseApi.MulNoReduce(alpha3, &results[i].X), + baseApi.MulNoReduce(alpha4, &results[i].Y), + baseApi.MulNoReduce(alpha5, &results[i].Z), + ) + } + // derive the randomness for folding + betaNative, err := fs.ComputeChallenge("beta") + if err != nil { + return fmt.Errorf("compute challenge alpha: %w", err) + } + betaBts := bits.ToBinary(api, betaNative, bits.WithNbDigits(fp.Modulus().BitLen())) + evalPoints := make([]*emulated.Element[B], stdbits.Len(uint(len(claimed)))-1) + evalPoints[0] = baseApi.FromBits(betaBts...) + for i := 1; i < len(evalPoints); i++ { + evalPoints[i] = baseApi.Mul(evalPoints[i-1], evalPoints[0]) + } + // compute the polynomial evaluation + claimedPoly := polynomial.FromSliceReferences(claimed) + claim, err := poly.EvalMultilinear(evalPoints, claimedPoly) + if err != nil { + return fmt.Errorf("eval multilinear: %w", err) + } + + _ = claim + return nil } From dcb153f4ea380fa1e6143bbb266de86420821d76 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Fri, 15 Mar 2024 14:27:00 +0000 Subject: [PATCH 4/9] XXX: sumcheck scalarmul return all points --- std/recursion/sumcheck/scalarmul_test.go | 229 +++++++++++++---------- 1 file changed, 133 insertions(+), 96 deletions(-) diff --git a/std/recursion/sumcheck/scalarmul_test.go b/std/recursion/sumcheck/scalarmul_test.go index 3af5918d59..2a33689e43 100644 --- a/std/recursion/sumcheck/scalarmul_test.go +++ b/std/recursion/sumcheck/scalarmul_test.go @@ -10,9 +10,8 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/secp256k1" fr_secp256k1 "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" - + cryptofs "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/algebra" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/bits" @@ -36,6 +35,7 @@ type ScalarMulCircuit[Base, Scalars emulated.FieldParams] struct { func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { var fp B + nbInputs := len(c.Points) if len(c.Points) != len(c.Scalars) { return fmt.Errorf("len(inputs) != len(scalars)") } @@ -61,7 +61,7 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { return fmt.Errorf("new transcript: %w", err) } // compute the all double-and-add steps for each scalar multiplication - var results, accs []ProjectivePoint[B] + // var results, accs []ProjectivePoint[B] for i := range c.Points { if err := fs.Bind("alpha", curve.MarshalScalar(c.Scalars[i])); err != nil { return fmt.Errorf("bind scalar %d alpha: %w", i, err) @@ -69,13 +69,12 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { if err := fs.Bind("alpha", curve.MarshalG1(c.Points[i])); err != nil { return fmt.Errorf("bind point %d alpha: %w", i, err) } - result, acc, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points[i], c.Scalars[i]) - if err != nil { - return fmt.Errorf("hint scalar mul steps: %w", err) - } - results = append(results, result...) - accs = append(accs, acc...) } + result, acc, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points, c.Scalars) + if err != nil { + return fmt.Errorf("hint scalar mul steps: %w", err) + } + // derive the randomness for random linear combination alphaNative, err := fs.ComputeChallenge("alpha") if err != nil { @@ -87,17 +86,19 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { alpha3 := baseApi.Mul(alpha1, alpha2) alpha4 := baseApi.Mul(alpha1, alpha3) alpha5 := baseApi.Mul(alpha1, alpha4) - claimed := make([]*emulated.Element[B], len(results)) + claimed := make([]*emulated.Element[B], nbInputs*c.nbScalarBits) // compute the random linear combinations of the intermediate results provided by the hint - for i := range results { - claimed[i] = baseApi.Sum( - &accs[i].X, - baseApi.MulNoReduce(alpha1, &accs[i].Y), - baseApi.MulNoReduce(alpha2, &accs[i].Z), - baseApi.MulNoReduce(alpha3, &results[i].X), - baseApi.MulNoReduce(alpha4, &results[i].Y), - baseApi.MulNoReduce(alpha5, &results[i].Z), - ) + for i := 0; i < nbInputs; i++ { + for j := 0; j < c.nbScalarBits; j++ { + claimed[i*c.nbScalarBits+j] = baseApi.Sum( + &acc[i][j].X, + baseApi.MulNoReduce(alpha1, &acc[i][j].Y), + baseApi.MulNoReduce(alpha2, &acc[i][j].Z), + baseApi.MulNoReduce(alpha3, &result[i][j].X), + baseApi.MulNoReduce(alpha4, &result[i][j].Y), + baseApi.MulNoReduce(alpha5, &result[i][j].Z), + ) + } } // derive the randomness for folding betaNative, err := fs.ComputeChallenge("beta") @@ -125,112 +126,148 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, baseApi *emulated.Field[B], scalarApi *emulated.Field[S], nbScalarBits int, - point sw_emulated.AffinePoint[B], scalar emulated.Element[S]) (results []ProjectivePoint[B], accumulators []ProjectivePoint[B], err error) { + points []sw_emulated.AffinePoint[B], scalars []emulated.Element[S]) (results [][]ProjectivePoint[B], accumulators [][]ProjectivePoint[B], err error) { var fp B var fr S - inputs := []frontend.Variable{fp.BitsPerLimb(), fp.NbLimbs()} + nbInputs := len(points) + inputs := []frontend.Variable{nbInputs, fp.BitsPerLimb(), fp.NbLimbs(), fr.BitsPerLimb(), fr.NbLimbs()} inputs = append(inputs, baseApi.Modulus().Limbs...) - inputs = append(inputs, point.X.Limbs...) - inputs = append(inputs, point.Y.Limbs...) - inputs = append(inputs, fr.BitsPerLimb(), fr.NbLimbs()) inputs = append(inputs, scalarApi.Modulus().Limbs...) - inputs = append(inputs, scalar.Limbs...) - nbRes := nbScalarBits * int(fp.NbLimbs()) * 6 + for i := range points { + inputs = append(inputs, points[i].X.Limbs...) + inputs = append(inputs, points[i].Y.Limbs...) + inputs = append(inputs, scalars[i].Limbs...) + } + nbRes := nbScalarBits * int(fp.NbLimbs()) * 6 * nbInputs hintRes, err := api.Compiler().NewHint(hintScalarMulSteps, nbRes, inputs...) if err != nil { return nil, nil, fmt.Errorf("new hint: %w", err) } - res := make([]ProjectivePoint[B], nbScalarBits) - acc := make([]ProjectivePoint[B], nbScalarBits) - for i := range res { - coords := make([]*emulated.Element[B], 6) - for j := range coords { - limbs := hintRes[i*(6*int(fp.NbLimbs()))+j*int(fp.NbLimbs()) : i*(6*int(fp.NbLimbs()))+(j+1)*int(fp.NbLimbs())] - coords[j] = baseApi.NewElement(limbs) - } - res[i] = ProjectivePoint[B]{ - X: *coords[0], - Y: *coords[1], - Z: *coords[2], - } - acc[i] = ProjectivePoint[B]{ - X: *coords[3], - Y: *coords[4], - Z: *coords[5], + res := make([][]ProjectivePoint[B], nbInputs) + acc := make([][]ProjectivePoint[B], nbInputs) + for i := 0; i < nbInputs; i++ { + res[i] = make([]ProjectivePoint[B], nbScalarBits) + acc[i] = make([]ProjectivePoint[B], nbScalarBits) + } + for i := 0; i < nbInputs; i++ { + inputRes := hintRes[i*(6*int(fp.NbLimbs())*nbScalarBits) : (i+1)*(6*int(fp.NbLimbs())*nbScalarBits)] + for j := 0; j < nbScalarBits; j++ { + coords := make([]*emulated.Element[B], 6) + for k := range coords { + limbs := inputRes[j*(6*int(fp.NbLimbs()))+k*int(fp.NbLimbs()) : j*(6*int(fp.NbLimbs()))+(k+1)*int(fp.NbLimbs())] + coords[k] = baseApi.NewElement(limbs) + } + res[i][j] = ProjectivePoint[B]{ + X: *coords[0], + Y: *coords[1], + Z: *coords[2], + } + acc[i][j] = ProjectivePoint[B]{ + X: *coords[3], + Y: *coords[4], + Z: *coords[5], + } } } return res, acc, nil } func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { - nbBits := int(inputs[0].Int64()) - nbLimbs := int(inputs[1].Int64()) - fpLimbs := inputs[2 : 2+nbLimbs] - xLimbs := inputs[2+nbLimbs : 2+2*nbLimbs] - yLimbs := inputs[2+2*nbLimbs : 2+3*nbLimbs] - nbScalarBits := int(inputs[2+3*nbLimbs].Int64()) - nbScalarLimbs := int(inputs[3+3*nbLimbs].Int64()) - frLimbs := inputs[4+3*nbLimbs : 4+3*nbLimbs+nbScalarLimbs] - scalarLimbs := inputs[4+3*nbLimbs+nbScalarLimbs : 4+3*nbLimbs+2*nbScalarLimbs] - - x := new(big.Int) - y := new(big.Int) + nbInputs := int(inputs[0].Int64()) + nbBits := int(inputs[1].Int64()) + nbLimbs := int(inputs[2].Int64()) + nbScalarBits := int(inputs[3].Int64()) + nbScalarLimbs := int(inputs[4].Int64()) + fpLimbs := inputs[5 : 5+nbLimbs] + frLimbs := inputs[5+nbLimbs : 5+nbLimbs+nbScalarLimbs] fp := new(big.Int) fr := new(big.Int) - scalar := new(big.Int) if err := recompose(fpLimbs, uint(nbBits), fp); err != nil { return fmt.Errorf("recompose fp: %w", err) } if err := recompose(frLimbs, uint(nbScalarBits), fr); err != nil { return fmt.Errorf("recompose fr: %w", err) } - if err := recompose(xLimbs, uint(nbBits), x); err != nil { - return fmt.Errorf("recompose x: %w", err) - } - if err := recompose(yLimbs, uint(nbBits), y); err != nil { - return fmt.Errorf("recompose y: %w", err) - } - if err := recompose(scalarLimbs, uint(nbScalarBits), scalar); err != nil { - return fmt.Errorf("recompose scalar: %w", err) + ptr := 5 + nbLimbs + nbScalarLimbs + xs := make([]*big.Int, nbInputs) + ys := make([]*big.Int, nbInputs) + scalars := make([]*big.Int, nbInputs) + for i := 0; i < nbInputs; i++ { + xLimbs := inputs[ptr : ptr+nbLimbs] + ptr += nbLimbs + yLimbs := inputs[ptr : ptr+nbLimbs] + ptr += nbLimbs + scalarLimbs := inputs[ptr : ptr+nbScalarLimbs] + ptr += nbScalarLimbs + xs[i] = new(big.Int) + ys[i] = new(big.Int) + scalars[i] = new(big.Int) + if err := recompose(xLimbs, uint(nbBits), xs[i]); err != nil { + return fmt.Errorf("recompose x: %w", err) + } + if err := recompose(yLimbs, uint(nbBits), ys[i]); err != nil { + return fmt.Errorf("recompose y: %w", err) + } + if err := recompose(scalarLimbs, uint(nbScalarBits), scalars[i]); err != nil { + return fmt.Errorf("recompose scalar: %w", err) + } } - scalarLength := len(outputs) / (6 * nbLimbs) - accX := new(big.Int).Set(x) - accY := new(big.Int).Set(y) - accZ := big.NewInt(1) - resultX := big.NewInt(0) - resultY := big.NewInt(1) - resultZ := big.NewInt(0) + scalarLength := len(outputs) / (6 * nbLimbs * nbInputs) api := newBigIntEngine(fp) selector := new(big.Int) - - for i := 0; i < scalarLength; i++ { - // selector := scalar.And() - selector.And(scalar, big.NewInt(1)) - scalar.Rsh(scalar, 1) - tmpX, tmpY, tmpZ := projAdd(api, accX, accY, accZ, resultX, resultY, resultZ) - resultX, resultY, resultZ = projSelect(api, selector, tmpX, tmpY, tmpZ, resultX, resultY, resultZ) - accX, accY, accZ = projDbl(api, accX, accY, accZ) - if err := decompose(resultX, uint(nbBits), outputs[i*6*nbLimbs:i*6*nbLimbs+nbLimbs]); err != nil { - return fmt.Errorf("decompose resultX: %w", err) - } - if err := decompose(resultY, uint(nbBits), outputs[i*6*nbLimbs+nbLimbs:i*6*nbLimbs+2*nbLimbs]); err != nil { - return fmt.Errorf("decompose resultY: %w", err) - } - if err := decompose(resultZ, uint(nbBits), outputs[i*6*nbLimbs+2*nbLimbs:i*6*nbLimbs+3*nbLimbs]); err != nil { - return fmt.Errorf("decompose resultZ: %w", err) - } - if err := decompose(accX, uint(nbBits), outputs[i*6*nbLimbs+3*nbLimbs:i*6*nbLimbs+4*nbLimbs]); err != nil { - return fmt.Errorf("decompose accX: %w", err) - } - if err := decompose(accY, uint(nbBits), outputs[i*6*nbLimbs+4*nbLimbs:i*6*nbLimbs+5*nbLimbs]); err != nil { - return fmt.Errorf("decompose accY: %w", err) - } - if err := decompose(accZ, uint(nbBits), outputs[i*6*nbLimbs+5*nbLimbs:(i+1)*6*nbLimbs]); err != nil { - return fmt.Errorf("decompose accZ: %w", err) + outPtr := 0 + for i := 0; i < nbInputs; i++ { + scalar := scalars[i] + x := xs[i] + y := ys[i] + accX := new(big.Int).Set(x) + accY := new(big.Int).Set(y) + accZ := big.NewInt(1) + resultX := big.NewInt(0) + resultY := big.NewInt(1) + resultZ := big.NewInt(0) + for j := 0; j < scalarLength; j++ { + selector.And(scalar, big.NewInt(1)) + scalar.Rsh(scalar, 1) + tmpX, tmpY, tmpZ := projAdd(api, accX, accY, accZ, resultX, resultY, resultZ) + resultX, resultY, resultZ = projSelect(api, selector, tmpX, tmpY, tmpZ, resultX, resultY, resultZ) + accX, accY, accZ = projDbl(api, accX, accY, accZ) + if err := decompose(resultX, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose resultX: %w", err) + } + outPtr += nbLimbs + if err := decompose(resultY, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose resultY: %w", err) + } + outPtr += nbLimbs + if err := decompose(resultZ, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose resultZ: %w", err) + } + outPtr += nbLimbs + if err := decompose(accX, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose accX: %w", err) + } + outPtr += nbLimbs + if err := decompose(accY, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose accY: %w", err) + } + outPtr += nbLimbs + if err := decompose(accZ, uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose accZ: %w", err) + } + outPtr += nbLimbs } } + // now, we construct the sumcheck proof + h, err := recursion.NewShort(mod, fp) + if err != nil { + return fmt.Errorf("new short hash: %w", err) + } + fs := cryptofs.NewTranscript(h, "alpha", "beta") + _ = fs + return nil } From 4e6ca989c5ed8f336f18dc9c45238045292c9166 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Fri, 15 Mar 2024 15:36:29 +0000 Subject: [PATCH 5/9] XXX: sumcheck scalarmul --- std/recursion/sumcheck/scalarmul_test.go | 38 ++++++++++++++++++++---- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/std/recursion/sumcheck/scalarmul_test.go b/std/recursion/sumcheck/scalarmul_test.go index 2a33689e43..e37f16d3f5 100644 --- a/std/recursion/sumcheck/scalarmul_test.go +++ b/std/recursion/sumcheck/scalarmul_test.go @@ -63,12 +63,12 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { // compute the all double-and-add steps for each scalar multiplication // var results, accs []ProjectivePoint[B] for i := range c.Points { - if err := fs.Bind("alpha", curve.MarshalScalar(c.Scalars[i])); err != nil { - return fmt.Errorf("bind scalar %d alpha: %w", i, err) - } if err := fs.Bind("alpha", curve.MarshalG1(c.Points[i])); err != nil { return fmt.Errorf("bind point %d alpha: %w", i, err) } + if err := fs.Bind("alpha", curve.MarshalScalar(c.Scalars[i])); err != nil { + return fmt.Errorf("bind scalar %d alpha: %w", i, err) + } } result, acc, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points, c.Scalars) if err != nil { @@ -213,12 +213,19 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err } } + // first, we need to provide the steps of the scalar multiplication to the + // verifier. As the output of one step is an input of the next step, we need + // to provide the results and the accumulators. By checking the consistency + // of the inputs related to the outputs (inputs using multilinear evaluation + // in the final round of the sumcheck and outputs by requiring the verifier + // to construct the claim itself), we can ensure that the final step is the + // actual scalar multiplication result. scalarLength := len(outputs) / (6 * nbLimbs * nbInputs) api := newBigIntEngine(fp) selector := new(big.Int) outPtr := 0 for i := 0; i < nbInputs; i++ { - scalar := scalars[i] + scalar := new(big.Int).Set(scalars[i]) x := xs[i] y := ys[i] accX := new(big.Int).Set(x) @@ -260,13 +267,32 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err } } - // now, we construct the sumcheck proof + // now, we construct the sumcheck proof. For that we first need to compute + // the challenges for computing the random linear combination of the + // double-and-add outputs and for the claim polynomial evaluation. h, err := recursion.NewShort(mod, fp) if err != nil { return fmt.Errorf("new short hash: %w", err) } fs := cryptofs.NewTranscript(h, "alpha", "beta") - _ = fs + for i := range xs { + var P secp256k1.G1Affine + var s fr_secp256k1.Element + P.X.SetBigInt(xs[i]) + P.Y.SetBigInt(ys[i]) + raw := P.RawBytes() + if err := fs.Bind("alpha", raw[:]); err != nil { + return fmt.Errorf("bind alpha point: %w", err) + } + s.SetBigInt(scalars[i]) + if err := fs.Bind("alpha", s.Marshal()); err != nil { + return fmt.Errorf("bind alpha scalar: %w", err) + } + } + alpha, err := fs.ComputeChallenge("alpha") + if err != nil { + return fmt.Errorf("compute challenge alpha: %w", err) + } return nil } From fdf9bcb42e4ccee7bb72028f807d179270e02f23 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Sun, 17 Mar 2024 23:44:09 +0000 Subject: [PATCH 6/9] XXX scalarmul end to end not working --- std/recursion/sumcheck/scalarmul_test.go | 122 ++++++++++++++++++++--- 1 file changed, 106 insertions(+), 16 deletions(-) diff --git a/std/recursion/sumcheck/scalarmul_test.go b/std/recursion/sumcheck/scalarmul_test.go index e37f16d3f5..08837cdfbc 100644 --- a/std/recursion/sumcheck/scalarmul_test.go +++ b/std/recursion/sumcheck/scalarmul_test.go @@ -70,7 +70,7 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { return fmt.Errorf("bind scalar %d alpha: %w", i, err) } } - result, acc, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points, c.Scalars) + result, acc, proof, err := callHintScalarMulSteps[B, S](api, baseApi, scalarApi, c.nbScalarBits, c.Points, c.Scalars) if err != nil { return fmt.Errorf("hint scalar mul steps: %w", err) } @@ -81,22 +81,23 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { return fmt.Errorf("compute challenge alpha: %w", err) } alphaBts := bits.ToBinary(api, alphaNative, bits.WithNbDigits(fp.Modulus().BitLen())) - alpha1 := baseApi.FromBits(alphaBts...) - alpha2 := baseApi.Mul(alpha1, alpha1) - alpha3 := baseApi.Mul(alpha1, alpha2) - alpha4 := baseApi.Mul(alpha1, alpha3) - alpha5 := baseApi.Mul(alpha1, alpha4) + alphas := make([]*emulated.Element[B], 6) + alphas[0] = baseApi.One() + alphas[1] = baseApi.FromBits(alphaBts...) + for i := 2; i < len(alphas); i++ { + alphas[i] = baseApi.Mul(alphas[i-1], alphas[1]) + } claimed := make([]*emulated.Element[B], nbInputs*c.nbScalarBits) // compute the random linear combinations of the intermediate results provided by the hint for i := 0; i < nbInputs; i++ { for j := 0; j < c.nbScalarBits; j++ { claimed[i*c.nbScalarBits+j] = baseApi.Sum( &acc[i][j].X, - baseApi.MulNoReduce(alpha1, &acc[i][j].Y), - baseApi.MulNoReduce(alpha2, &acc[i][j].Z), - baseApi.MulNoReduce(alpha3, &result[i][j].X), - baseApi.MulNoReduce(alpha4, &result[i][j].Y), - baseApi.MulNoReduce(alpha5, &result[i][j].Z), + baseApi.MulNoReduce(alphas[1], &acc[i][j].Y), + baseApi.MulNoReduce(alphas[2], &acc[i][j].Z), + baseApi.MulNoReduce(alphas[3], &result[i][j].X), + baseApi.MulNoReduce(alphas[4], &result[i][j].Y), + baseApi.MulNoReduce(alphas[5], &result[i][j].Z), ) } } @@ -113,12 +114,38 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { } // compute the polynomial evaluation claimedPoly := polynomial.FromSliceReferences(claimed) - claim, err := poly.EvalMultilinear(evalPoints, claimedPoly) + evaluation, err := poly.EvalMultilinear(evalPoints, claimedPoly) if err != nil { return fmt.Errorf("eval multilinear: %w", err) } + fmt.Printf("claim: %s\n", baseApi.String(evaluation)) - _ = claim + inputs := make([][]*emulated.Element[B], 7) + for i := range inputs { + inputs[i] = make([]*emulated.Element[B], nbInputs*c.nbScalarBits) + } + for i := 0; i < nbInputs; i++ { + scalarBts := scalarApi.ToBits(&c.Scalars[i]) + for j := 0; j < c.nbScalarBits; j++ { + inputs[0][i*c.nbScalarBits+j] = &acc[i][j].X + inputs[1][i*c.nbScalarBits+j] = &acc[i][j].Y + inputs[2][i*c.nbScalarBits+j] = &acc[i][j].Z + inputs[3][i*c.nbScalarBits+j] = &result[i][j].X + inputs[4][i*c.nbScalarBits+j] = &result[i][j].Y + inputs[5][i*c.nbScalarBits+j] = &result[i][j].Z + inputs[6][i*c.nbScalarBits+j] = baseApi.NewElement(scalarBts[j]) + } + } + gate := dblAddSelectGate[*emuEngine[B], *emulated.Element[B]]{folding: alphas} + claim, err := newGate[B](api, gate, inputs, [][]*emulated.Element[B]{evalPoints}, []*emulated.Element[B]{evaluation}) + v, err := NewVerifier[B](api) + if err != nil { + return fmt.Errorf("new sumcheck verifier: %w", err) + } + if err = v.Verify(claim, proof); err != nil { + return fmt.Errorf("verify sumcheck: %w", err) + } + _ = evaluation return nil } @@ -126,7 +153,7 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, baseApi *emulated.Field[B], scalarApi *emulated.Field[S], nbScalarBits int, - points []sw_emulated.AffinePoint[B], scalars []emulated.Element[S]) (results [][]ProjectivePoint[B], accumulators [][]ProjectivePoint[B], err error) { + points []sw_emulated.AffinePoint[B], scalars []emulated.Element[S]) (results [][]ProjectivePoint[B], accumulators [][]ProjectivePoint[B], proof Proof[B], err error) { var fp B var fr S nbInputs := len(points) @@ -138,10 +165,13 @@ func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, inputs = append(inputs, points[i].Y.Limbs...) inputs = append(inputs, scalars[i].Limbs...) } + // steps part nbRes := nbScalarBits * int(fp.NbLimbs()) * 6 * nbInputs + // proof part + nbRes += int(fp.NbLimbs()) * (stdbits.Len(uint(nbInputs*nbScalarBits)) - 1) * (dblAddSelectGate[*noopEngine, element]{}.Degree() + 1) hintRes, err := api.Compiler().NewHint(hintScalarMulSteps, nbRes, inputs...) if err != nil { - return nil, nil, fmt.Errorf("new hint: %w", err) + return nil, nil, proof, fmt.Errorf("new hint: %w", err) } res := make([][]ProjectivePoint[B], nbInputs) acc := make([][]ProjectivePoint[B], nbInputs) @@ -169,7 +199,18 @@ func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, } } } - return res, acc, nil + proof.RoundPolyEvaluations = make([]polynomial.Univariate[B], stdbits.Len(uint(nbInputs*nbScalarBits))-1) + ptr := nbInputs * 6 * int(fp.NbLimbs()) * nbScalarBits + for i := range proof.RoundPolyEvaluations { + proof.RoundPolyEvaluations[i] = make(polynomial.Univariate[B], dblAddSelectGate[*noopEngine, element]{}.Degree()+1) + for j := range proof.RoundPolyEvaluations[i] { + limbs := hintRes[ptr : ptr+int(fp.NbLimbs())] + el := baseApi.NewElement(limbs) + proof.RoundPolyEvaluations[i][j] = *el + ptr += int(fp.NbLimbs()) + } + } + return res, acc, proof, nil } func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { @@ -224,6 +265,10 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err api := newBigIntEngine(fp) selector := new(big.Int) outPtr := 0 + proofInput := make([][]*big.Int, 7) + for i := range proofInput { + proofInput[i] = make([]*big.Int, nbInputs*scalarLength) + } for i := 0; i < nbInputs; i++ { scalar := new(big.Int).Set(scalars[i]) x := xs[i] @@ -237,6 +282,13 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err for j := 0; j < scalarLength; j++ { selector.And(scalar, big.NewInt(1)) scalar.Rsh(scalar, 1) + proofInput[0][i*scalarLength+j] = new(big.Int).Set(accX) + proofInput[1][i*scalarLength+j] = new(big.Int).Set(accY) + proofInput[2][i*scalarLength+j] = new(big.Int).Set(accZ) + proofInput[3][i*scalarLength+j] = new(big.Int).Set(resultX) + proofInput[4][i*scalarLength+j] = new(big.Int).Set(resultY) + proofInput[5][i*scalarLength+j] = new(big.Int).Set(resultZ) + proofInput[6][i*scalarLength+j] = new(big.Int).Set(selector) tmpX, tmpY, tmpZ := projAdd(api, accX, accY, accZ, resultX, resultY, resultZ) resultX, resultY, resultZ = projSelect(api, selector, tmpX, tmpY, tmpZ, resultX, resultY, resultZ) accX, accY, accZ = projDbl(api, accX, accY, accZ) @@ -289,11 +341,49 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err return fmt.Errorf("bind alpha scalar: %w", err) } } + // challenges. + // alpha is used for the random linear combination of the double-and-add alpha, err := fs.ComputeChallenge("alpha") if err != nil { return fmt.Errorf("compute challenge alpha: %w", err) } + alphas := make([]*big.Int, 6) + alphas[0] = big.NewInt(1) + alphas[1] = new(big.Int).SetBytes(alpha) + for i := 2; i < len(alphas); i++ { + alphas[i] = new(big.Int).Mul(alphas[i-1], alphas[1]) + } + // beta is used for the claim polynomial evaluation + beta, err := fs.ComputeChallenge("beta") + if err != nil { + return fmt.Errorf("compute challenge beta: %w", err) + } + betas := make([]*big.Int, stdbits.Len(uint(nbInputs*scalarLength))-1) + betas[0] = new(big.Int).SetBytes(beta) + for i := 1; i < len(betas); i++ { + betas[i] = new(big.Int).Mul(betas[i-1], betas[0]) + } + + nativeGate := dblAddSelectGate[*bigIntEngine, *big.Int]{folding: alphas} + claim, evals, err := newNativeGate(fp, nativeGate, proofInput, [][]*big.Int{betas}) + if err != nil { + return fmt.Errorf("new native gate: %w", err) + } + proof, err := prove(mod, fp, claim) + if err != nil { + return fmt.Errorf("prove: %w", err) + } + for _, pl := range proof.RoundPolyEvaluations { + for j := range pl { + if err := decompose(pl[j], uint(nbBits), outputs[outPtr:outPtr+nbLimbs]); err != nil { + return fmt.Errorf("decompose claim: %w", err) + } + outPtr += nbLimbs + } + } + // verifier computes the evaluation itself for consistency + _ = evals return nil } From e9a845c86d500176b3f8f52bf963ffd2f6f7d5fc Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 18 Mar 2024 15:01:28 +0000 Subject: [PATCH 7/9] XXX sumcheck scalarmul working --- std/recursion/sumcheck/scalarmul_test.go | 45 +++++++++++++----------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/std/recursion/sumcheck/scalarmul_test.go b/std/recursion/sumcheck/scalarmul_test.go index 08837cdfbc..d326436921 100644 --- a/std/recursion/sumcheck/scalarmul_test.go +++ b/std/recursion/sumcheck/scalarmul_test.go @@ -118,7 +118,6 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { if err != nil { return fmt.Errorf("eval multilinear: %w", err) } - fmt.Printf("claim: %s\n", baseApi.String(evaluation)) inputs := make([][]*emulated.Element[B], 7) for i := range inputs { @@ -126,13 +125,20 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { } for i := 0; i < nbInputs; i++ { scalarBts := scalarApi.ToBits(&c.Scalars[i]) - for j := 0; j < c.nbScalarBits; j++ { - inputs[0][i*c.nbScalarBits+j] = &acc[i][j].X - inputs[1][i*c.nbScalarBits+j] = &acc[i][j].Y - inputs[2][i*c.nbScalarBits+j] = &acc[i][j].Z - inputs[3][i*c.nbScalarBits+j] = &result[i][j].X - inputs[4][i*c.nbScalarBits+j] = &result[i][j].Y - inputs[5][i*c.nbScalarBits+j] = &result[i][j].Z + inputs[0][i*c.nbScalarBits] = &c.Points[i].X + inputs[1][i*c.nbScalarBits] = &c.Points[i].Y + inputs[2][i*c.nbScalarBits] = baseApi.One() + inputs[3][i*c.nbScalarBits] = baseApi.Zero() + inputs[4][i*c.nbScalarBits] = baseApi.One() + inputs[5][i*c.nbScalarBits] = baseApi.Zero() + inputs[6][i*c.nbScalarBits] = baseApi.NewElement(scalarBts[0]) + for j := 1; j < c.nbScalarBits; j++ { + inputs[0][i*c.nbScalarBits+j] = &acc[i][j-1].X + inputs[1][i*c.nbScalarBits+j] = &acc[i][j-1].Y + inputs[2][i*c.nbScalarBits+j] = &acc[i][j-1].Z + inputs[3][i*c.nbScalarBits+j] = &result[i][j-1].X + inputs[4][i*c.nbScalarBits+j] = &result[i][j-1].Y + inputs[5][i*c.nbScalarBits+j] = &result[i][j-1].Z inputs[6][i*c.nbScalarBits+j] = baseApi.NewElement(scalarBts[j]) } } @@ -145,7 +151,6 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { if err = v.Verify(claim, proof); err != nil { return fmt.Errorf("verify sumcheck: %w", err) } - _ = evaluation return nil } @@ -157,7 +162,7 @@ func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, var fp B var fr S nbInputs := len(points) - inputs := []frontend.Variable{nbInputs, fp.BitsPerLimb(), fp.NbLimbs(), fr.BitsPerLimb(), fr.NbLimbs()} + inputs := []frontend.Variable{nbInputs, nbScalarBits, fp.BitsPerLimb(), fp.NbLimbs(), fr.BitsPerLimb(), fr.NbLimbs()} inputs = append(inputs, baseApi.Modulus().Limbs...) inputs = append(inputs, scalarApi.Modulus().Limbs...) for i := range points { @@ -215,12 +220,13 @@ func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { nbInputs := int(inputs[0].Int64()) - nbBits := int(inputs[1].Int64()) - nbLimbs := int(inputs[2].Int64()) - nbScalarBits := int(inputs[3].Int64()) - nbScalarLimbs := int(inputs[4].Int64()) - fpLimbs := inputs[5 : 5+nbLimbs] - frLimbs := inputs[5+nbLimbs : 5+nbLimbs+nbScalarLimbs] + scalarLength := int(inputs[1].Int64()) + nbBits := int(inputs[2].Int64()) + nbLimbs := int(inputs[3].Int64()) + nbScalarBits := int(inputs[4].Int64()) + nbScalarLimbs := int(inputs[5].Int64()) + fpLimbs := inputs[6 : 6+nbLimbs] + frLimbs := inputs[6+nbLimbs : 6+nbLimbs+nbScalarLimbs] fp := new(big.Int) fr := new(big.Int) if err := recompose(fpLimbs, uint(nbBits), fp); err != nil { @@ -229,7 +235,7 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err if err := recompose(frLimbs, uint(nbScalarBits), fr); err != nil { return fmt.Errorf("recompose fr: %w", err) } - ptr := 5 + nbLimbs + nbScalarLimbs + ptr := 6 + nbLimbs + nbScalarLimbs xs := make([]*big.Int, nbInputs) ys := make([]*big.Int, nbInputs) scalars := make([]*big.Int, nbInputs) @@ -261,7 +267,6 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err // in the final round of the sumcheck and outputs by requiring the verifier // to construct the claim itself), we can ensure that the final step is the // actual scalar multiplication result. - scalarLength := len(outputs) / (6 * nbLimbs * nbInputs) api := newBigIntEngine(fp) selector := new(big.Int) outPtr := 0 @@ -382,7 +387,8 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err outPtr += nbLimbs } } - // verifier computes the evaluation itself for consistency + // verifier computes the evaluation itself for consistency. We do not pass + // it through the hint. Explicitly ignore. _ = evals return nil } @@ -438,7 +444,6 @@ func TestScalarMul(t *testing.T) { s.SetRandom() P.ScalarMultiplicationBase(s.BigInt(new(big.Int))) sc, _ := rand.Int(rand.Reader, scalarBound) - // t.Log(P.X.String(), P.Y.String(), sc.String()) points[i] = sw_emulated.AffinePoint[B]{ X: emulated.ValueOf[B](P.X), Y: emulated.ValueOf[B](P.Y), From 3af14c5cf116847fd91ea64b1ebf12c81d01360a Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 18 Mar 2024 23:26:40 +0000 Subject: [PATCH 8/9] feat: add noop-engine --- std/recursion/sumcheck/arithengine.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/std/recursion/sumcheck/arithengine.go b/std/recursion/sumcheck/arithengine.go index e4de69ba0a..2df009e9f9 100644 --- a/std/recursion/sumcheck/arithengine.go +++ b/std/recursion/sumcheck/arithengine.go @@ -95,3 +95,12 @@ func newEmulatedEngine[FR emulated.FieldParams](api frontend.API) (*emuEngine[FR } return &emuEngine[FR]{f: f}, nil } + +// noopEngine is a no-operation arithmetic engine. Can be used to access methods of the gates without performing any computation. +type noopEngine struct{} + +func (ne *noopEngine) Add(a, b element) element { panic("noop engine: Add called") } +func (ne *noopEngine) Mul(a, b element) element { panic("noop engine: Mul called") } +func (ne *noopEngine) Sub(a, b element) element { panic("noop engine: Sub called") } +func (ne *noopEngine) One() element { panic("noop engine: One called") } +func (ne *noopEngine) Const(i *big.Int) element { panic("noop engine: Const called") } From 06d03635eacb06f78ff97885f56f22fc6994f601 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Thu, 4 Jul 2024 15:23:45 +0000 Subject: [PATCH 9/9] import scs --- std/recursion/sumcheck/scalarmul_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/std/recursion/sumcheck/scalarmul_test.go b/std/recursion/sumcheck/scalarmul_test.go index d326436921..b9677287c9 100644 --- a/std/recursion/sumcheck/scalarmul_test.go +++ b/std/recursion/sumcheck/scalarmul_test.go @@ -12,6 +12,7 @@ import ( fr_secp256k1 "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" cryptofs "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/algebra" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/math/bits"