Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: settable hasher for MiMC #1345

Merged
merged 9 commits into from
Dec 17, 2024
15 changes: 15 additions & 0 deletions std/hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ type FieldHasher interface {
Reset()
}

// StateStorer allows to store and retrieve the state of a hash function.
type StateStorer interface {
FieldHasher
// State retrieves the current state of the hash function. Calling this
// method should not destroy the current state and allow continue the use of
// the current hasher. The number of returned Variable is implementation
// dependent.
State() []frontend.Variable
// SetState sets the state of the hash function from a previously stored
// state retrieved using [StateStorer.State] method. The implementation
// returns an error if the number of supplied Variable does not match the
// number of Variable expected.
SetState(state []frontend.Variable) error
}

var (
builderRegistry = make(map[string]func(api frontend.API) (FieldHasher, error))
lock sync.RWMutex
Expand Down
25 changes: 25 additions & 0 deletions std/hash/mimc/mimc.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ func (h *MiMC) Reset() {
h.h = 0
}

// SetState manually sets the state of the hasher to the provided value. In the
// case of MiMC only a single frontend variable is expected to represent the
// state.
func (h *MiMC) SetState(newState []frontend.Variable) error {

if len(h.data) > 0 {
return errors.New("the hasher is not in an initial state")
}

if len(newState) != 1 {
return errors.New("the MiMC hasher expects a single field element to represent the state")
}

h.h = newState[0]
h.data = nil
return nil
}

// State returns the inner-state of the hasher. In the context of MiMC only a
// single field element is returned.
func (h *MiMC) State() []frontend.Variable {
h.Sum() // this flushes the unsummed data
return []frontend.Variable{h.h}
}

// Sum hash using [Miyaguchi–Preneel] where the XOR operation is replaced by
// field addition.
//
Expand Down
128 changes: 128 additions & 0 deletions std/hash/mimc/mimc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
package mimc

import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"testing"

Expand Down Expand Up @@ -80,3 +83,128 @@ func TestMimcAll(t *testing.T) {
}

}

// stateStoreCircuit checks that SetState works as expected. The circuit, however
// does not check the correctness of the hashes returned by the MiMC function
// as there is another test already testing this property.
type stateStoreTestCircuit struct {
X frontend.Variable
}

func (s *stateStoreTestCircuit) Define(api frontend.API) error {

hsh1, err1 := NewMiMC(api)
hsh2, err2 := NewMiMC(api)

if err1 != nil || err2 != nil {
return fmt.Errorf("could not instantiate the MIMC hasher: %w", errors.Join(err1, err2))
}

// This pre-shuffle the hasher state so that the test does not start from
// a zero state.
hsh1.Write(s.X)

state := hsh1.State()
hsh2.SetState(state)

hsh1.Write(s.X)
hsh2.Write(s.X)

var (
dig1 = hsh1.Sum()
dig2 = hsh2.Sum()
newState1 = hsh1.State()
newState2 = hsh2.State()
)

api.AssertIsEqual(dig1, dig2)

for i := range newState1 {
api.AssertIsEqual(newState1[i], newState2[i])
}

return nil
}

func TestStateStoreMiMC(t *testing.T) {

assert := test.NewAssert(t)

curves := map[ecc.ID]hash.Hash{
ecc.BN254: hash.MIMC_BN254,
ecc.BLS12_381: hash.MIMC_BLS12_381,
ecc.BLS12_377: hash.MIMC_BLS12_377,
ecc.BW6_761: hash.MIMC_BW6_761,
ecc.BW6_633: hash.MIMC_BW6_633,
ecc.BLS24_315: hash.MIMC_BLS24_315,
ecc.BLS24_317: hash.MIMC_BLS24_317,
}

for curve := range curves {

// minimal cs res = hash(data)
var (
circuit = &stateStoreTestCircuit{}
assignment = &stateStoreTestCircuit{X: 2}
)

assert.CheckCircuit(circuit,
test.WithValidAssignment(assignment),
test.WithCurves(curve))
}
}

type recoveredStateTestCircuit struct {
State []frontend.Variable
Input frontend.Variable
Expected frontend.Variable `gnark:",public"`
}

func (c *recoveredStateTestCircuit) Define(api frontend.API) error {
h, err := NewMiMC(api)
if err != nil {
return fmt.Errorf("initialize hash: %w", err)
}
if err = h.SetState(c.State); err != nil {
return fmt.Errorf("set state: %w", err)
}
h.Write(c.Input)
res := h.Sum()
api.AssertIsEqual(res, c.Expected)
return nil
}

func TestHasherFromState(t *testing.T) {
assert := test.NewAssert(t)

hashes := map[ecc.ID]hash.Hash{
ecc.BN254: hash.MIMC_BN254,
ecc.BLS12_381: hash.MIMC_BLS12_381,
ecc.BLS12_377: hash.MIMC_BLS12_377,
ecc.BW6_761: hash.MIMC_BW6_761,
ecc.BW6_633: hash.MIMC_BW6_633,
ecc.BLS24_315: hash.MIMC_BLS24_315,
ecc.BLS24_317: hash.MIMC_BLS24_317,
}

for cc, hh := range hashes {
hasher := hh.New()
ss, ok := hasher.(hash.StateStorer)
assert.True(ok)
_, err := ss.Write([]byte("hello world"))
assert.NoError(err)
state := ss.State()
nbBytes := cc.ScalarField().BitLen() / 8
buf := make([]byte, nbBytes)
_, err = rand.Read(buf)
assert.NoError(err)
ss.Write(buf)
expected := ss.Sum(nil)
bstate := new(big.Int).SetBytes(state)
binput := new(big.Int).SetBytes(buf)
assert.CheckCircuit(
&recoveredStateTestCircuit{State: make([]frontend.Variable, 1)},
test.WithValidAssignment(&recoveredStateTestCircuit{State: []frontend.Variable{bstate}, Input: binput, Expected: expected}),
test.WithCurves(cc))
}
}
Loading