Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: strict ModReduce in emulated fields #1224

Merged
merged 12 commits into from
Jul 25, 2024
20 changes: 20 additions & 0 deletions std/algebra/algopts/algopts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type algebraCfg struct {
NbScalarBits int
FoldMulti bool
CompleteArithmetic bool
ToBitsCanonical bool
}

// AlgebraOption allows modifying algebraic operation behaviour.
Expand Down Expand Up @@ -57,6 +58,25 @@ func WithCompleteArithmetic() AlgebraOption {
}
}

// WithCanonicalBitRepresentation enforces the marshalling methods to assert
// that the bit representation is in canonical form. For field elements this
// means that the bits represent a number less than the modulus.
//
// This option is useful when performing direct comparison between the bit form
// of two elements. It can be avoided when the bit representation is used in
// other cases, such as computing a challenge using a hash function, where
// non-canonical bit representation leads to incorrect challenge (which in turn
// makes the verification fail).
func WithCanonicalBitRepresentation() AlgebraOption {
return func(ac *algebraCfg) error {
if ac.ToBitsCanonical {
return fmt.Errorf("WithCanonicalBitRepresentation already set")
}
ac.ToBitsCanonical = true
return nil
}
}

// NewConfig applies all given options and returns a configuration to be used.
func NewConfig(opts ...AlgebraOption) (*algebraCfg, error) {
ret := new(algebraCfg)
Expand Down
37 changes: 26 additions & 11 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package sw_emulated
import (
"fmt"
"math/big"
"slices"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/algopts"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
"golang.org/x/exp/slices"
)

// New returns a new [Curve] instance over the base field Base and scalar field
Expand Down Expand Up @@ -101,26 +101,41 @@ type AffinePoint[Base emulated.FieldParams] struct {

// MarshalScalar marshals the scalar into bits. Compatible with scalar
// marshalling in gnark-crypto.
func (c *Curve[B, S]) MarshalScalar(s emulated.Element[S]) []frontend.Variable {
func (c *Curve[B, S]) MarshalScalar(s emulated.Element[S], opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
var fr S
nbBits := 8 * ((fr.Modulus().BitLen() + 7) / 8)
sReduced := c.scalarApi.Reduce(&s)
res := c.scalarApi.ToBits(sReduced)[:nbBits]
for i, j := 0, nbBits-1; i < j; {
res[i], res[j] = res[j], res[i]
i++
j--
var sReduced *emulated.Element[S]
if cfg.ToBitsCanonical {
sReduced = c.scalarApi.ReduceStrict(&s)
} else {
sReduced = c.scalarApi.Reduce(&s)
}
res := c.scalarApi.ToBits(sReduced)[:nbBits]
slices.Reverse(res)
return res
}

// MarshalG1 marshals the affine point into bits. The output is compatible with
// the point marshalling in gnark-crypto.
func (c *Curve[B, S]) MarshalG1(p AffinePoint[B]) []frontend.Variable {
func (c *Curve[B, S]) MarshalG1(p AffinePoint[B], opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
var fp B
nbBits := 8 * ((fp.Modulus().BitLen() + 7) / 8)
x := c.baseApi.Reduce(&p.X)
y := c.baseApi.Reduce(&p.Y)
var x, y *emulated.Element[B]
if cfg.ToBitsCanonical {
x = c.baseApi.ReduceStrict(&p.X)
y = c.baseApi.ReduceStrict(&p.Y)
} else {
x = c.baseApi.Reduce(&p.X)
y = c.baseApi.Reduce(&p.Y)
}
bx := c.baseApi.ToBits(x)[:nbBits]
by := c.baseApi.ToBits(y)[:nbBits]
slices.Reverse(bx)
Expand Down
4 changes: 2 additions & 2 deletions std/algebra/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ type Curve[FR emulated.FieldParams, G1El G1ElementT] interface {

// MarshalG1 returns the binary decomposition G1.X || G1.Y. It matches the
// output of gnark-crypto's Marshal method on G1 points.
MarshalG1(G1El) []frontend.Variable
MarshalG1(G1El, ...algopts.AlgebraOption) []frontend.Variable

// MarshalScalar returns the binary decomposition of the argument.
MarshalScalar(emulated.Element[FR]) []frontend.Variable
MarshalScalar(emulated.Element[FR], ...algopts.AlgebraOption) []frontend.Variable

// Select sets p1 if b=1, p2 if b=0, and returns it. b must be boolean constrained
Select(b frontend.Variable, p1 *G1El, p2 *G1El) *G1El
Expand Down
34 changes: 24 additions & 10 deletions std/algebra/native/sw_bls12377/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sw_bls12377
import (
"fmt"
"math/big"
"slices"

"github.com/consensys/gnark-crypto/ecc"
bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377"
Expand Down Expand Up @@ -36,25 +37,38 @@ func NewCurve(api frontend.API) (*Curve, error) {
}

// MarshalScalar returns
func (c *Curve) MarshalScalar(s Scalar) []frontend.Variable {
func (c *Curve) MarshalScalar(s Scalar, opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
nbBits := 8 * ((ScalarField{}.Modulus().BitLen() + 7) / 8)
ss := c.fr.Reduce(&s)
x := c.fr.ToBits(ss)
for i, j := 0, nbBits-1; i < j; {
x[i], x[j] = x[j], x[i]
i++
j--
var ss *emulated.Element[ScalarField]
if cfg.ToBitsCanonical {
ss = c.fr.ReduceStrict(&s)
} else {
ss = c.fr.Reduce(&s)
}
x := c.fr.ToBits(ss)[:nbBits]
slices.Reverse(x)
return x
}

// MarshalG1 returns [P.X || P.Y] in binary. Both P.X and P.Y are
// in little endian.
func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable {
func (c *Curve) MarshalG1(P G1Affine, opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
nbBits := 8 * ((ecc.BLS12_377.BaseField().BitLen() + 7) / 8)
bOpts := []bits.BaseConversionOption{bits.WithNbDigits(nbBits)}
if !cfg.ToBitsCanonical {
bOpts = append(bOpts, bits.OmitModulusCheck())
}
res := make([]frontend.Variable, 2*nbBits)
x := bits.ToBinary(c.api, P.X, bits.WithNbDigits(nbBits))
y := bits.ToBinary(c.api, P.Y, bits.WithNbDigits(nbBits))
x := bits.ToBinary(c.api, P.X, bOpts...)
y := bits.ToBinary(c.api, P.Y, bOpts...)
for i := 0; i < nbBits; i++ {
res[i] = x[nbBits-1-i]
res[i+nbBits] = y[nbBits-1-i]
Expand Down
34 changes: 24 additions & 10 deletions std/algebra/native/sw_bls24315/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sw_bls24315
import (
"fmt"
"math/big"
"slices"

"github.com/consensys/gnark-crypto/ecc"
bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315"
Expand Down Expand Up @@ -36,25 +37,38 @@ func NewCurve(api frontend.API) (*Curve, error) {
}

// MarshalScalar returns
func (c *Curve) MarshalScalar(s Scalar) []frontend.Variable {
func (c *Curve) MarshalScalar(s Scalar, opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
nbBits := 8 * ((ScalarField{}.Modulus().BitLen() + 7) / 8)
ss := c.fr.Reduce(&s)
x := c.fr.ToBits(ss)
for i, j := 0, nbBits-1; i < j; {
x[i], x[j] = x[j], x[i]
i++
j--
var ss *emulated.Element[ScalarField]
if cfg.ToBitsCanonical {
ss = c.fr.ReduceStrict(&s)
} else {
ss = c.fr.Reduce(&s)
}
x := c.fr.ToBits(ss)[:nbBits]
slices.Reverse(x)
return x
}

// MarshalG1 returns [P.X || P.Y] in binary. Both P.X and P.Y are
// in little endian.
func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable {
func (c *Curve) MarshalG1(P G1Affine, opts ...algopts.AlgebraOption) []frontend.Variable {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
nbBits := 8 * ((ecc.BLS24_315.BaseField().BitLen() + 7) / 8)
bOpts := []bits.BaseConversionOption{bits.WithNbDigits(nbBits)}
if !cfg.ToBitsCanonical {
bOpts = append(bOpts, bits.OmitModulusCheck())
}
res := make([]frontend.Variable, 2*nbBits)
x := bits.ToBinary(c.api, P.X, bits.WithNbDigits(nbBits))
y := bits.ToBinary(c.api, P.Y, bits.WithNbDigits(nbBits))
x := bits.ToBinary(c.api, P.X, bOpts...)
y := bits.ToBinary(c.api, P.Y, bOpts...)
for i := 0; i < nbBits; i++ {
res[i] = x[nbBits-1-i]
res[i+nbBits] = y[nbBits-1-i]
Expand Down
12 changes: 12 additions & 0 deletions std/math/emulated/element.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ type Element[T FieldParams] struct {
// enforcement info in the Element to prevent modifying the witness.
internal bool

// modReduced indicates that the element has been reduced modulo the modulus
// and we have asserted that the integer value of the element is strictly
// less than the modulus. This is required for some operations which depend
// on the bit-representation of the element (ToBits, exponentiation etc.).
modReduced bool

isEvaluated bool
evaluation frontend.Variable `gnark:"-"`
}
Expand Down Expand Up @@ -95,6 +101,11 @@ func (e *Element[T]) GnarkInitHook() {
*e = ValueOf[T](0)
e.internal = false // we need to constrain in later.
}
// set modReduced to false - in case the circuit is compiled we may change
// the value for an existing element. If we don't reset it here, then during
// second compilation we may take a shortPath where we assume that modReduce
// flag is set.
e.modReduced = false
}

// copy makes a deep copy of the element.
Expand All @@ -104,5 +115,6 @@ func (e *Element[T]) copy() *Element[T] {
copy(r.Limbs, e.Limbs)
r.overflow = e.overflow
r.internal = e.internal
r.modReduced = e.modReduced
return &r
}
Loading
Loading