Skip to content

Commit

Permalink
feat: optimise lexicographic gadget bitwidth (#722)
Browse files Browse the repository at this point in the history
This optimises delta bitwidth for sorting constraints to reduce the
number of overall columns.
  • Loading branch information
DavePearce authored Feb 24, 2025
1 parent a899fbd commit 9b5cbbb
Show file tree
Hide file tree
Showing 24 changed files with 4,940 additions and 952 deletions.
47 changes: 39 additions & 8 deletions pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ package gadgets

import (
"fmt"
"math/big"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/air"
Expand Down Expand Up @@ -43,30 +44,33 @@ func ApplyBinaryGadget(col uint, schema *air.Schema) {
// number of bits. This is implemented using a *byte decomposition* which adds
// n columns and a vanishing constraint (where n*8 >= nbits).
func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
if nbits%8 != 0 {
panic("asymmetric bitwidth constraints not yet supported")
} else if nbits == 0 {
var (
// Determine ranges required for the give bitwidth
ranges = splitColumnRanges(nbits)
// Identify number of columns required.
n = uint(len(ranges))
)
// Sanity check
if nbits == 0 {
panic("zero bitwidth constraint encountered")
}
// Identify target column
column := schema.Columns().Nth(col)
// Calculate how many bytes required.
n := nbits / 8
es := make([]air.Expr, n)
fr256 := fr.NewElement(256)
name := column.Name
coefficient := fr.NewElement(1)
// Add decomposition assignment
index := schema.AddAssignment(
assignment.NewByteDecomposition(name, column.Context, col, n))
assignment.NewByteDecomposition(name, column.Context, col, nbits))
// Construct Columns
for i := uint(0); i < n; i++ {
// Create Column + Constraint
es[i] = air.NewColumnAccess(index+i, 0).Mul(air.NewConst(coefficient))

schema.AddRangeConstraint(index+i, 0, fr256)
schema.AddRangeConstraint(index+i, 0, ranges[i])
// Update coefficient
coefficient.Mul(&coefficient, &fr256)
coefficient.Mul(&coefficient, &ranges[i])
}
// Construct (X:0 * 1) + ... + (X:n * 2^n)
sum := air.Sum(es...)
Expand All @@ -76,3 +80,30 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), 0, column.Context, util.None[int](), eq)
}

func splitColumnRanges(nbits uint) []fr.Element {
var (
n = nbits / 8
m = int64(nbits % 8)
ranges []fr.Element
fr256 = fr.NewElement(256)
)
//
if m == 0 {
ranges = make([]fr.Element, n)
} else {
var last fr.Element
// Most significant column has smaller range.
ranges = make([]fr.Element, n+1)
// Determine final range
last.Exp(fr.NewElement(2), big.NewInt(m))
//
ranges[n] = last
}
//
for i := uint(0); i < n; i++ {
ranges[i] = fr256
}
//
return ranges
}
6 changes: 6 additions & 0 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ var checkCmd = &cobra.Command{
cfg.mir = GetFlag(cmd, "mir")
cfg.hir = GetFlag(cmd, "hir")
cfg.defensive = GetFlag(cmd, "defensive")
cfg.validate = GetFlag(cmd, "validate")
cfg.expand = !GetFlag(cmd, "raw")
cfg.report = GetFlag(cmd, "report")
cfg.reportPadding = GetUint(cmd, "report-context")
Expand Down Expand Up @@ -148,6 +149,9 @@ type checkConfig struct {
// not required when a "raw" trace is given which already includes all
// implied columns.
expand bool
// Specifies whether or not to perform trace validation. That is, to check
// all input values are within expected bounds.
validate bool
// Specifies whether or not to include the standard library. The default is
// to include it.
stdlib bool
Expand Down Expand Up @@ -207,6 +211,7 @@ func checkTrace[K sc.Metric[K]](ir string, traces [][]tr.RawColumn, schema sc.Sc
//
coverage := sc.NewBranchCoverage()
builder := sc.NewTraceBuilder(schema).
Validate(cfg.validate).
Defensive(cfg.defensive).
Expand(cfg.expand).
Parallel(cfg.parallel).
Expand Down Expand Up @@ -341,6 +346,7 @@ func init() {
checkCmd.Flags().BoolP("quiet", "q", false, "suppress output (e.g. warnings)")
checkCmd.Flags().Bool("sequential", false, "perform sequential trace expansion")
checkCmd.Flags().Bool("defensive", true, "automatically apply defensive padding to every module")
checkCmd.Flags().Bool("validate", true, "apply trace validation")
checkCmd.Flags().String("coverage", "", "write JSON coverage data to file")
checkCmd.Flags().Uint("padding", 0, "specify amount of (front) padding to apply")
checkCmd.Flags().UintP("batch", "b", math.MaxUint, "specify batch size for constraint checking")
Expand Down
12 changes: 7 additions & 5 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ var traceCmd = &cobra.Command{
columns := GetFlag(cmd, "columns")
modules := GetFlag(cmd, "modules")
defensive := GetFlag(cmd, "defensive")
validate := GetFlag(cmd, "validate")
stats := GetFlag(cmd, "stats")
stdlib := !GetFlag(cmd, "no-stdlib")
includes := GetStringArray(cmd, "include")
Expand Down Expand Up @@ -85,7 +86,7 @@ var traceCmd = &cobra.Command{
} else if expand {
level := determineAbstractionLevel(air, mir, hir)
for i, cols := range traces {
traces[i] = expandWithConstraints(level, cols, stdlib, defensive, args[1:], optConfig)
traces[i] = expandWithConstraints(level, cols, stdlib, validate, defensive, args[1:], optConfig)
}
} else if defensive {
fmt.Println("cannot apply defensive padding without trace expansion")
Expand Down Expand Up @@ -131,6 +132,7 @@ func init() {
traceCmd.Flags().BoolP("print", "p", false, "print entire trace file")
traceCmd.Flags().BoolP("expand", "e", false, "perform trace expansion (schema required)")
traceCmd.Flags().Bool("defensive", false, "perform defensive padding (schema required)")
traceCmd.Flags().Bool("validate", true, "apply trace validation")
traceCmd.Flags().Uint("start", 0, "filter out rows below this")
traceCmd.Flags().Uint("end", math.MaxUint, "filter out this and all following rows")
traceCmd.Flags().Uint("max-width", 32, "specify maximum display width for a column")
Expand Down Expand Up @@ -165,7 +167,7 @@ func determineAbstractionLevel(air, mir, hir bool) int {
panic("unreachable")
}

func expandWithConstraints(level int, cols []trace.RawColumn, stdlib bool, defensive bool,
func expandWithConstraints(level int, cols []trace.RawColumn, stdlib bool, validate bool, defensive bool,
filenames []string, optConfig mir.OptimisationConfig) []trace.RawColumn {
//
var schema sc.Schema
Expand All @@ -183,11 +185,11 @@ func expandWithConstraints(level int, cols []trace.RawColumn, stdlib bool, defen
panic("unreachable")
}
// Done
return expandColumns(cols, schema, defensive)
return expandColumns(cols, schema, validate, defensive)
}

func expandColumns(cols []trace.RawColumn, schema sc.Schema, defensive bool) []trace.RawColumn {
builder := sc.NewTraceBuilder(schema).Expand(true).Defensive(defensive)
func expandColumns(cols []trace.RawColumn, schema sc.Schema, validate bool, defensive bool) []trace.RawColumn {
builder := sc.NewTraceBuilder(schema).Expand(true).Validate(validate).Defensive(defensive)
tr, errs := builder.Build(cols)
//
if len(errs) > 0 {
Expand Down
19 changes: 18 additions & 1 deletion pkg/corset/compiler/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,9 @@ func (t *translator) translateDefSorted(decl *ast.DefSorted, module util.Path) [
context := t.env.ContextOf(ast.ContextOfExpressions(decl.Sources))
// Clone the signs
signs := slices.Clone(decl.Signs)
bitwidth := determineMaxBitwidth(t.schema, sources[:len(signs)])
// Add translated constraint
t.schema.AddSortedConstraint(decl.Handle, context, sources, signs)
t.schema.AddSortedConstraint(decl.Handle, context, bitwidth, sources, signs)
}
// Done
return errors
Expand Down Expand Up @@ -685,3 +686,19 @@ func (t *translator) registerOfArrayAccess(expr *ast.ArrayAccess) (uint, []Synta
// Lookup underlying column info
return t.env.RegisterOf(path), errors
}

func determineMaxBitwidth(schema sc.Schema, sources []hir.UnitExpr) uint {
// Sanity check bitwidth
bitwidth := uint(0)

for _, e := range sources {
// Determine bitwidth of nth term
ith := e.BitWidth(schema)
//
if ith > bitwidth {
bitwidth = ith
}
}
//
return bitwidth
}
12 changes: 12 additions & 0 deletions pkg/hir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ func (e Expr) Branches() uint {
// direction (right).
func (e Expr) Bounds() util.Bounds { return e.Term.Bounds() }

// BitWidth determines bitwidth required to hold the result of evaluating this expression.
func (e Expr) BitWidth(schema sc.Schema) []uint {
switch e := e.Term.(type) {
case *ColumnAccess:
bitwidth := schema.Columns().Nth(e.Column).DataType.BitWidth()
return []uint{bitwidth}
default:
// For now, we only supports simple column accesses.
panic("bitwidth calculation only supported for column accesses")
}
}

// Lisp converts this schema element into a simple S-Termession, for example
// so it can be printed.
func (e Expr) Lisp(schema sc.Schema) sexp.SExp {
Expand Down
2 changes: 1 addition & 1 deletion pkg/hir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func lowerSortedConstraint(c SortedConstraint, schema *mir.Schema) {
sources[i] = lowerUnitTo(c.Sources[i], schema)
}
//
schema.AddSortedConstraint(c.Handle, c.Context, sources, c.Signs)
schema.AddSortedConstraint(c.Handle, c.Context, c.BitWidth, sources, c.Signs)
}

// Lower an expression which is expected to lower into a single expression.
Expand Down
5 changes: 3 additions & 2 deletions pkg/hir/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,11 @@ func (p *Schema) AddRangeConstraint(handle string, context trace.Context, expr E
}

// AddSortedConstraint appends a new sorted constraint.
func (p *Schema) AddSortedConstraint(handle string, context trace.Context, sources []UnitExpr, signs []bool) {
func (p *Schema) AddSortedConstraint(handle string, context trace.Context, bitwidth uint, sources []UnitExpr,
signs []bool) {
// Finally add constraint
p.constraints = append(p.constraints,
constraint.NewSortedConstraint(handle, context, sources, signs))
constraint.NewSortedConstraint(handle, context, bitwidth, sources, signs))
}

// AddPropertyAssertion appends a new property assertion.
Expand Down
5 changes: 5 additions & 0 deletions pkg/hir/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ func (e UnitExpr) Bounds() util.Bounds {
return e.Expr.Bounds()
}

// BitWidth determines bitwidth required to hold the result of evaluating this expression.
func (e UnitExpr) BitWidth(schema sc.Schema) uint {
return e.Expr.BitWidth(schema)[0]
}

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (e UnitExpr) Context(schema sc.Schema) tr.Context {
Expand Down
46 changes: 26 additions & 20 deletions pkg/mir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ func lowerLookupConstraintToAir(c LookupConstraint, mirSchema *Schema, airSchema
// is not concept of sorting constraints at the AIR level. Instead, we have to
// generate the necessary machinery to enforce the sorting constraint.
func lowerSortedConstraintToAir(c SortedConstraint, mirSchema *Schema, airSchema *air.Schema, cfg OptimisationConfig) {
ncols := len(c.Sources)
sources := make([]uint, ncols)
sources := make([]uint, len(c.Sources))
//
for i := 0; i < len(sources); i++ {
sourceBitwidth := c.Sources[i].IntRange(mirSchema).BitWidth()
Expand All @@ -214,30 +213,36 @@ func lowerSortedConstraintToAir(c SortedConstraint, mirSchema *Schema, airSchema
// Expand them
sources[i] = air_gadgets.Expand(c.Context, sourceBitwidth, source, airSchema)
}
// Determine number of ordered columns
numSignedCols := len(c.Signs)
// finally add the constraint
if ncols == 1 {
if numSignedCols == 1 {
// For a single column sort, its actually a bit easier because we don't
// need to implement a multiplexor (i.e. to determine which column is
// differs, etc). Instead, we just need a delta column which ensures
// there is a non-negative difference between consecutive rows. This
// also requires bitwidth constraints.
bitwidth := mirSchema.Columns().Nth(sources[0]).DataType.AsUint().BitWidth()
// Add column sorting constraints
air_gadgets.ApplyColumnSortGadget(c.Handle, sources[0], c.Signs[0], bitwidth, airSchema)
air_gadgets.ApplyColumnSortGadget(c.Handle, sources[0], c.Signs[0], c.BitWidth, airSchema)
} else {
// For a multi column sort, its a bit harder as we need additional
// logicl to ensure the target columns are lexicographally sorted.
bitwidth := uint(0)
air_gadgets.ApplyLexicographicSortingGadget(c.Handle, sources, c.Signs, c.BitWidth, airSchema)
}
// Sanity check bitwidth
bitwidth := uint(0)

for i := 0; i < ncols; i++ {
// Extract bitwidth of ith column
ith := mirSchema.Columns().Nth(sources[i]).DataType.AsUint().BitWidth()
if ith > bitwidth {
bitwidth = ith
}
for i := 0; i < numSignedCols; i++ {
// Extract bitwidth of ith column
ith := mirSchema.Columns().Nth(sources[i]).DataType.AsUint().BitWidth()
if ith > bitwidth {
bitwidth = ith
}
// Add lexicographically sorted constraints
air_gadgets.ApplyLexicographicSortingGadget(c.Handle, sources, c.Signs, bitwidth, airSchema)
}
//
if bitwidth != c.BitWidth {
// Should be unreachable.
msg := fmt.Sprintf("incompatible bitwidths (%d vs %d)", bitwidth, c.BitWidth)
panic(msg)
}
}

Expand All @@ -250,12 +255,11 @@ func lowerSortedConstraintToAir(c SortedConstraint, mirSchema *Schema, airSchema
func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Schema) {
builder := strings.Builder{}
c_targets := c.Targets
ncols := len(c_targets)
targets := make([]uint, ncols)
targets := make([]uint, len(c_targets))
//
builder.WriteString("permutation")
// Add individual permutation constraints
for i := 0; i < ncols; i++ {
for i := 0; i < len(c_targets); i++ {
var ok bool
// TODO: how best to avoid this lookup?
targets[i], ok = sc.ColumnIndexOf(airSchema, c.Module(), c_targets[i].Name)
Expand All @@ -268,8 +272,10 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche
}
//
airSchema.AddPermutationConstraint(builder.String(), c.Context(), targets, c.Sources)
// Determine number of ordered columns
numSignedCols := len(c.Signs)
// Add sorting constraints + computed columns as necessary.
if ncols == 1 {
if numSignedCols == 1 {
// For a single column sort, its actually a bit easier because we don't
// need to implement a multiplexor (i.e. to determine which column is
// differs, etc). Instead, we just need a delta column which ensures
Expand All @@ -285,7 +291,7 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche
// logicl to ensure the target columns are lexicographally sorted.
bitwidth := uint(0)

for i := 0; i < ncols; i++ {
for i := 0; i < numSignedCols; i++ {
// Extract bitwidth of ith column
ith := mirSchema.Columns().Nth(c.Sources[i]).DataType.AsUint().BitWidth()
if ith > bitwidth {
Expand Down
5 changes: 3 additions & 2 deletions pkg/mir/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ func (p *Schema) AddRangeConstraint(handle string, casenum uint, context trace.C
}

// AddSortedConstraint appends a new sorted constraint.
func (p *Schema) AddSortedConstraint(handle string, context trace.Context, sources []Expr, signs []bool) {
func (p *Schema) AddSortedConstraint(handle string, context trace.Context, bitwidth uint, sources []Expr,
signs []bool) {
// Finally add constraint
p.constraints = append(p.constraints,
constraint.NewSortedConstraint(handle, context, sources, signs))
constraint.NewSortedConstraint(handle, context, bitwidth, sources, signs))
}

// AddPropertyAssertion appends a new property assertion.
Expand Down
Loading

0 comments on commit 9b5cbbb

Please sign in to comment.