From ea7794143914794b73baa8c5d5b6618591d5656e Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Mon, 18 Mar 2024 15:01:28 +0000 Subject: [PATCH] XXX sumcheck scalarmul working --- std/recursion/sumcheck/scalarmul_test.go | 45 +++++++++++++----------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/std/recursion/sumcheck/scalarmul_test.go b/std/recursion/sumcheck/scalarmul_test.go index 08837cdfbc..d326436921 100644 --- a/std/recursion/sumcheck/scalarmul_test.go +++ b/std/recursion/sumcheck/scalarmul_test.go @@ -118,7 +118,6 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { if err != nil { return fmt.Errorf("eval multilinear: %w", err) } - fmt.Printf("claim: %s\n", baseApi.String(evaluation)) inputs := make([][]*emulated.Element[B], 7) for i := range inputs { @@ -126,13 +125,20 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { } for i := 0; i < nbInputs; i++ { scalarBts := scalarApi.ToBits(&c.Scalars[i]) - for j := 0; j < c.nbScalarBits; j++ { - inputs[0][i*c.nbScalarBits+j] = &acc[i][j].X - inputs[1][i*c.nbScalarBits+j] = &acc[i][j].Y - inputs[2][i*c.nbScalarBits+j] = &acc[i][j].Z - inputs[3][i*c.nbScalarBits+j] = &result[i][j].X - inputs[4][i*c.nbScalarBits+j] = &result[i][j].Y - inputs[5][i*c.nbScalarBits+j] = &result[i][j].Z + inputs[0][i*c.nbScalarBits] = &c.Points[i].X + inputs[1][i*c.nbScalarBits] = &c.Points[i].Y + inputs[2][i*c.nbScalarBits] = baseApi.One() + inputs[3][i*c.nbScalarBits] = baseApi.Zero() + inputs[4][i*c.nbScalarBits] = baseApi.One() + inputs[5][i*c.nbScalarBits] = baseApi.Zero() + inputs[6][i*c.nbScalarBits] = baseApi.NewElement(scalarBts[0]) + for j := 1; j < c.nbScalarBits; j++ { + inputs[0][i*c.nbScalarBits+j] = &acc[i][j-1].X + inputs[1][i*c.nbScalarBits+j] = &acc[i][j-1].Y + inputs[2][i*c.nbScalarBits+j] = &acc[i][j-1].Z + inputs[3][i*c.nbScalarBits+j] = &result[i][j-1].X + inputs[4][i*c.nbScalarBits+j] = &result[i][j-1].Y + inputs[5][i*c.nbScalarBits+j] = &result[i][j-1].Z inputs[6][i*c.nbScalarBits+j] = baseApi.NewElement(scalarBts[j]) } } @@ -145,7 +151,6 @@ func (c *ScalarMulCircuit[B, S]) Define(api frontend.API) error { if err = v.Verify(claim, proof); err != nil { return fmt.Errorf("verify sumcheck: %w", err) } - _ = evaluation return nil } @@ -157,7 +162,7 @@ func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, var fp B var fr S nbInputs := len(points) - inputs := []frontend.Variable{nbInputs, fp.BitsPerLimb(), fp.NbLimbs(), fr.BitsPerLimb(), fr.NbLimbs()} + inputs := []frontend.Variable{nbInputs, nbScalarBits, fp.BitsPerLimb(), fp.NbLimbs(), fr.BitsPerLimb(), fr.NbLimbs()} inputs = append(inputs, baseApi.Modulus().Limbs...) inputs = append(inputs, scalarApi.Modulus().Limbs...) for i := range points { @@ -215,12 +220,13 @@ func callHintScalarMulSteps[B, S emulated.FieldParams](api frontend.API, func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { nbInputs := int(inputs[0].Int64()) - nbBits := int(inputs[1].Int64()) - nbLimbs := int(inputs[2].Int64()) - nbScalarBits := int(inputs[3].Int64()) - nbScalarLimbs := int(inputs[4].Int64()) - fpLimbs := inputs[5 : 5+nbLimbs] - frLimbs := inputs[5+nbLimbs : 5+nbLimbs+nbScalarLimbs] + scalarLength := int(inputs[1].Int64()) + nbBits := int(inputs[2].Int64()) + nbLimbs := int(inputs[3].Int64()) + nbScalarBits := int(inputs[4].Int64()) + nbScalarLimbs := int(inputs[5].Int64()) + fpLimbs := inputs[6 : 6+nbLimbs] + frLimbs := inputs[6+nbLimbs : 6+nbLimbs+nbScalarLimbs] fp := new(big.Int) fr := new(big.Int) if err := recompose(fpLimbs, uint(nbBits), fp); err != nil { @@ -229,7 +235,7 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err if err := recompose(frLimbs, uint(nbScalarBits), fr); err != nil { return fmt.Errorf("recompose fr: %w", err) } - ptr := 5 + nbLimbs + nbScalarLimbs + ptr := 6 + nbLimbs + nbScalarLimbs xs := make([]*big.Int, nbInputs) ys := make([]*big.Int, nbInputs) scalars := make([]*big.Int, nbInputs) @@ -261,7 +267,6 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err // in the final round of the sumcheck and outputs by requiring the verifier // to construct the claim itself), we can ensure that the final step is the // actual scalar multiplication result. - scalarLength := len(outputs) / (6 * nbLimbs * nbInputs) api := newBigIntEngine(fp) selector := new(big.Int) outPtr := 0 @@ -382,7 +387,8 @@ func hintScalarMulSteps(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err outPtr += nbLimbs } } - // verifier computes the evaluation itself for consistency + // verifier computes the evaluation itself for consistency. We do not pass + // it through the hint. Explicitly ignore. _ = evals return nil } @@ -438,7 +444,6 @@ func TestScalarMul(t *testing.T) { s.SetRandom() P.ScalarMultiplicationBase(s.BigInt(new(big.Int))) sc, _ := rand.Int(rand.Reader, scalarBound) - // t.Log(P.X.String(), P.Y.String(), sc.String()) points[i] = sw_emulated.AffinePoint[B]{ X: emulated.ValueOf[B](P.X), Y: emulated.ValueOf[B](P.Y),