Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimise lexicographic gadget bitwidth #722

Merged
merged 4 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ package gadgets

import (
"fmt"
"math/big"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/air"
Expand Down Expand Up @@ -43,30 +44,33 @@ func ApplyBinaryGadget(col uint, schema *air.Schema) {
// number of bits. This is implemented using a *byte decomposition* which adds
// n columns and a vanishing constraint (where n*8 >= nbits).
func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
if nbits%8 != 0 {
panic("asymmetric bitwidth constraints not yet supported")
} else if nbits == 0 {
var (
// Determine ranges required for the give bitwidth
ranges = splitColumnRanges(nbits)
// Identify number of columns required.
n = uint(len(ranges))
)
// Sanity check
if nbits == 0 {
panic("zero bitwidth constraint encountered")
}
// Identify target column
column := schema.Columns().Nth(col)
// Calculate how many bytes required.
n := nbits / 8
es := make([]air.Expr, n)
fr256 := fr.NewElement(256)
name := column.Name
coefficient := fr.NewElement(1)
// Add decomposition assignment
index := schema.AddAssignment(
assignment.NewByteDecomposition(name, column.Context, col, n))
assignment.NewByteDecomposition(name, column.Context, col, nbits))
// Construct Columns
for i := uint(0); i < n; i++ {
// Create Column + Constraint
es[i] = air.NewColumnAccess(index+i, 0).Mul(air.NewConst(coefficient))

schema.AddRangeConstraint(index+i, 0, fr256)
schema.AddRangeConstraint(index+i, 0, ranges[i])
// Update coefficient
coefficient.Mul(&coefficient, &fr256)
coefficient.Mul(&coefficient, &ranges[i])
}
// Construct (X:0 * 1) + ... + (X:n * 2^n)
sum := air.Sum(es...)
Expand All @@ -76,3 +80,30 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), 0, column.Context, util.None[int](), eq)
}

func splitColumnRanges(nbits uint) []fr.Element {
var (
n = nbits / 8
m = int64(nbits % 8)
ranges []fr.Element
fr256 = fr.NewElement(256)
)
//
if m == 0 {
ranges = make([]fr.Element, n)
} else {
var last fr.Element
// Most significant column has smaller range.
ranges = make([]fr.Element, n+1)
// Determine final range
last.Exp(fr.NewElement(2), big.NewInt(m))
//
ranges[n] = last
}
//
for i := uint(0); i < n; i++ {
ranges[i] = fr256
}
//
return ranges
}
6 changes: 6 additions & 0 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ var checkCmd = &cobra.Command{
cfg.mir = GetFlag(cmd, "mir")
cfg.hir = GetFlag(cmd, "hir")
cfg.defensive = GetFlag(cmd, "defensive")
cfg.validate = GetFlag(cmd, "validate")
cfg.expand = !GetFlag(cmd, "raw")
cfg.report = GetFlag(cmd, "report")
cfg.reportPadding = GetUint(cmd, "report-context")
Expand Down Expand Up @@ -148,6 +149,9 @@ type checkConfig struct {
// not required when a "raw" trace is given which already includes all
// implied columns.
expand bool
// Specifies whether or not to perform trace validation. That is, to check
// all input values are within expected bounds.
validate bool
// Specifies whether or not to include the standard library. The default is
// to include it.
stdlib bool
Expand Down Expand Up @@ -207,6 +211,7 @@ func checkTrace[K sc.Metric[K]](ir string, traces [][]tr.RawColumn, schema sc.Sc
//
coverage := sc.NewBranchCoverage()
builder := sc.NewTraceBuilder(schema).
Validate(cfg.validate).
Defensive(cfg.defensive).
Expand(cfg.expand).
Parallel(cfg.parallel).
Expand Down Expand Up @@ -341,6 +346,7 @@ func init() {
checkCmd.Flags().BoolP("quiet", "q", false, "suppress output (e.g. warnings)")
checkCmd.Flags().Bool("sequential", false, "perform sequential trace expansion")
checkCmd.Flags().Bool("defensive", true, "automatically apply defensive padding to every module")
checkCmd.Flags().Bool("validate", true, "apply trace validation")
checkCmd.Flags().String("coverage", "", "write JSON coverage data to file")
checkCmd.Flags().Uint("padding", 0, "specify amount of (front) padding to apply")
checkCmd.Flags().UintP("batch", "b", math.MaxUint, "specify batch size for constraint checking")
Expand Down
12 changes: 7 additions & 5 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ var traceCmd = &cobra.Command{
columns := GetFlag(cmd, "columns")
modules := GetFlag(cmd, "modules")
defensive := GetFlag(cmd, "defensive")
validate := GetFlag(cmd, "validate")
stats := GetFlag(cmd, "stats")
stdlib := !GetFlag(cmd, "no-stdlib")
includes := GetStringArray(cmd, "include")
Expand Down Expand Up @@ -85,7 +86,7 @@ var traceCmd = &cobra.Command{
} else if expand {
level := determineAbstractionLevel(air, mir, hir)
for i, cols := range traces {
traces[i] = expandWithConstraints(level, cols, stdlib, defensive, args[1:], optConfig)
traces[i] = expandWithConstraints(level, cols, stdlib, validate, defensive, args[1:], optConfig)
}
} else if defensive {
fmt.Println("cannot apply defensive padding without trace expansion")
Expand Down Expand Up @@ -131,6 +132,7 @@ func init() {
traceCmd.Flags().BoolP("print", "p", false, "print entire trace file")
traceCmd.Flags().BoolP("expand", "e", false, "perform trace expansion (schema required)")
traceCmd.Flags().Bool("defensive", false, "perform defensive padding (schema required)")
traceCmd.Flags().Bool("validate", true, "apply trace validation")
traceCmd.Flags().Uint("start", 0, "filter out rows below this")
traceCmd.Flags().Uint("end", math.MaxUint, "filter out this and all following rows")
traceCmd.Flags().Uint("max-width", 32, "specify maximum display width for a column")
Expand Down Expand Up @@ -165,7 +167,7 @@ func determineAbstractionLevel(air, mir, hir bool) int {
panic("unreachable")
}

func expandWithConstraints(level int, cols []trace.RawColumn, stdlib bool, defensive bool,
func expandWithConstraints(level int, cols []trace.RawColumn, stdlib bool, validate bool, defensive bool,
filenames []string, optConfig mir.OptimisationConfig) []trace.RawColumn {
//
var schema sc.Schema
Expand All @@ -183,11 +185,11 @@ func expandWithConstraints(level int, cols []trace.RawColumn, stdlib bool, defen
panic("unreachable")
}
// Done
return expandColumns(cols, schema, defensive)
return expandColumns(cols, schema, validate, defensive)
}

func expandColumns(cols []trace.RawColumn, schema sc.Schema, defensive bool) []trace.RawColumn {
builder := sc.NewTraceBuilder(schema).Expand(true).Defensive(defensive)
func expandColumns(cols []trace.RawColumn, schema sc.Schema, validate bool, defensive bool) []trace.RawColumn {
builder := sc.NewTraceBuilder(schema).Expand(true).Validate(validate).Defensive(defensive)
tr, errs := builder.Build(cols)
//
if len(errs) > 0 {
Expand Down
19 changes: 18 additions & 1 deletion pkg/corset/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,9 @@ func (t *translator) translateDefSorted(decl *ast.DefSorted, module util.Path) [
context := t.env.ContextOf(ast.ContextOfExpressions(decl.Sources))
// Clone the signs
signs := slices.Clone(decl.Signs)
bitwidth := determineMaxBitwidth(t.schema, sources[:len(signs)])
// Add translated constraint
t.schema.AddSortedConstraint(decl.Handle, context, sources, signs)
t.schema.AddSortedConstraint(decl.Handle, context, bitwidth, sources, signs)
}
// Done
return errors
Expand Down Expand Up @@ -685,3 +686,19 @@ func (t *translator) registerOfArrayAccess(expr *ast.ArrayAccess) (uint, []Synta
// Lookup underlying column info
return t.env.RegisterOf(path), errors
}

func determineMaxBitwidth(schema sc.Schema, sources []hir.UnitExpr) uint {
// Sanity check bitwidth
bitwidth := uint(0)

for _, e := range sources {
// Determine bitwidth of nth term
ith := e.BitWidth(schema)
//
if ith > bitwidth {
bitwidth = ith
}
}
//
return bitwidth
}
12 changes: 12 additions & 0 deletions pkg/hir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ func (e Expr) Branches() uint {
// direction (right).
func (e Expr) Bounds() util.Bounds { return e.Term.Bounds() }

// BitWidth determines bitwidth required to hold the result of evaluating this expression.
func (e Expr) BitWidth(schema sc.Schema) []uint {
switch e := e.Term.(type) {
case *ColumnAccess:
bitwidth := schema.Columns().Nth(e.Column).DataType.BitWidth()
return []uint{bitwidth}
default:
// For now, we only supports simple column accesses.
panic("bitwidth calculation only supported for column accesses")
}
}

// Lisp converts this schema element into a simple S-Termession, for example
// so it can be printed.
func (e Expr) Lisp(schema sc.Schema) sexp.SExp {
Expand Down
2 changes: 1 addition & 1 deletion pkg/hir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func lowerSortedConstraint(c SortedConstraint, schema *mir.Schema) {
sources[i] = lowerUnitTo(c.Sources[i], schema)
}
//
schema.AddSortedConstraint(c.Handle, c.Context, sources, c.Signs)
schema.AddSortedConstraint(c.Handle, c.Context, c.BitWidth, sources, c.Signs)
}

// Lower an expression which is expected to lower into a single expression.
Expand Down
5 changes: 3 additions & 2 deletions pkg/hir/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,11 @@ func (p *Schema) AddRangeConstraint(handle string, context trace.Context, expr E
}

// AddSortedConstraint appends a new sorted constraint.
func (p *Schema) AddSortedConstraint(handle string, context trace.Context, sources []UnitExpr, signs []bool) {
func (p *Schema) AddSortedConstraint(handle string, context trace.Context, bitwidth uint, sources []UnitExpr,
signs []bool) {
// Finally add constraint
p.constraints = append(p.constraints,
constraint.NewSortedConstraint(handle, context, sources, signs))
constraint.NewSortedConstraint(handle, context, bitwidth, sources, signs))
}

// AddPropertyAssertion appends a new property assertion.
Expand Down
5 changes: 5 additions & 0 deletions pkg/hir/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func (e UnitExpr) Bounds() util.Bounds {
return e.Expr.Bounds()
}

// BitWidth determines bitwidth required to hold the result of evaluating this expression.
func (e UnitExpr) BitWidth(schema sc.Schema) uint {
return e.Expr.BitWidth(schema)[0]
}

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (e UnitExpr) Context(schema sc.Schema) tr.Context {
Expand Down
46 changes: 26 additions & 20 deletions pkg/mir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ func lowerLookupConstraintToAir(c LookupConstraint, mirSchema *Schema, airSchema
// is not concept of sorting constraints at the AIR level. Instead, we have to
// generate the necessary machinery to enforce the sorting constraint.
func lowerSortedConstraintToAir(c SortedConstraint, mirSchema *Schema, airSchema *air.Schema, cfg OptimisationConfig) {
ncols := len(c.Sources)
sources := make([]uint, ncols)
sources := make([]uint, len(c.Sources))
//
for i := 0; i < len(sources); i++ {
sourceBitwidth := c.Sources[i].IntRange(mirSchema).BitWidth()
Expand All @@ -214,30 +213,36 @@ func lowerSortedConstraintToAir(c SortedConstraint, mirSchema *Schema, airSchema
// Expand them
sources[i] = air_gadgets.Expand(c.Context, sourceBitwidth, source, airSchema)
}
// Determine number of ordered columns
numSignedCols := len(c.Signs)
// finally add the constraint
if ncols == 1 {
if numSignedCols == 1 {
// For a single column sort, its actually a bit easier because we don't
// need to implement a multiplexor (i.e. to determine which column is
// differs, etc). Instead, we just need a delta column which ensures
// there is a non-negative difference between consecutive rows. This
// also requires bitwidth constraints.
bitwidth := mirSchema.Columns().Nth(sources[0]).DataType.AsUint().BitWidth()
// Add column sorting constraints
air_gadgets.ApplyColumnSortGadget(c.Handle, sources[0], c.Signs[0], bitwidth, airSchema)
air_gadgets.ApplyColumnSortGadget(c.Handle, sources[0], c.Signs[0], c.BitWidth, airSchema)
} else {
// For a multi column sort, its a bit harder as we need additional
// logicl to ensure the target columns are lexicographally sorted.
bitwidth := uint(0)
air_gadgets.ApplyLexicographicSortingGadget(c.Handle, sources, c.Signs, c.BitWidth, airSchema)
}
// Sanity check bitwidth
bitwidth := uint(0)

for i := 0; i < ncols; i++ {
// Extract bitwidth of ith column
ith := mirSchema.Columns().Nth(sources[i]).DataType.AsUint().BitWidth()
if ith > bitwidth {
bitwidth = ith
}
for i := 0; i < numSignedCols; i++ {
// Extract bitwidth of ith column
ith := mirSchema.Columns().Nth(sources[i]).DataType.AsUint().BitWidth()
if ith > bitwidth {
bitwidth = ith
}
// Add lexicographically sorted constraints
air_gadgets.ApplyLexicographicSortingGadget(c.Handle, sources, c.Signs, bitwidth, airSchema)
}
//
if bitwidth != c.BitWidth {
// Should be unreachable.
msg := fmt.Sprintf("incompatible bitwidths (%d vs %d)", bitwidth, c.BitWidth)
panic(msg)
}
}

Expand All @@ -250,12 +255,11 @@ func lowerSortedConstraintToAir(c SortedConstraint, mirSchema *Schema, airSchema
func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Schema) {
builder := strings.Builder{}
c_targets := c.Targets
ncols := len(c_targets)
targets := make([]uint, ncols)
targets := make([]uint, len(c_targets))
//
builder.WriteString("permutation")
// Add individual permutation constraints
for i := 0; i < ncols; i++ {
for i := 0; i < len(c_targets); i++ {
var ok bool
// TODO: how best to avoid this lookup?
targets[i], ok = sc.ColumnIndexOf(airSchema, c.Module(), c_targets[i].Name)
Expand All @@ -268,8 +272,10 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche
}
//
airSchema.AddPermutationConstraint(builder.String(), c.Context(), targets, c.Sources)
// Determine number of ordered columns
numSignedCols := len(c.Signs)
// Add sorting constraints + computed columns as necessary.
if ncols == 1 {
if numSignedCols == 1 {
// For a single column sort, its actually a bit easier because we don't
// need to implement a multiplexor (i.e. to determine which column is
// differs, etc). Instead, we just need a delta column which ensures
Expand All @@ -285,7 +291,7 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche
// logicl to ensure the target columns are lexicographally sorted.
bitwidth := uint(0)

for i := 0; i < ncols; i++ {
for i := 0; i < numSignedCols; i++ {
// Extract bitwidth of ith column
ith := mirSchema.Columns().Nth(c.Sources[i]).DataType.AsUint().BitWidth()
if ith > bitwidth {
Expand Down
5 changes: 3 additions & 2 deletions pkg/mir/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ func (p *Schema) AddRangeConstraint(handle string, casenum uint, context trace.C
}

// AddSortedConstraint appends a new sorted constraint.
func (p *Schema) AddSortedConstraint(handle string, context trace.Context, sources []Expr, signs []bool) {
func (p *Schema) AddSortedConstraint(handle string, context trace.Context, bitwidth uint, sources []Expr,
signs []bool) {
// Finally add constraint
p.constraints = append(p.constraints,
constraint.NewSortedConstraint(handle, context, sources, signs))
constraint.NewSortedConstraint(handle, context, bitwidth, sources, signs))
}

// AddPropertyAssertion appends a new property assertion.
Expand Down
Loading