Skip to content

Commit

Permalink
feat: add ripemd160 hash function with permutation (#1120)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivokub authored Dec 6, 2024
1 parent 96baf03 commit 5a545ea
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 4 deletions.
82 changes: 82 additions & 0 deletions std/hash/ripemd160/ripemd160.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Package ripemd160 implements in-circuit ripemd160 hash function.
//
// This package extends the permutation function [ripemd160.Permute] into a full
// hash function with padding computation and [hash.BinaryHasher] interface
// implementation.
package ripemd160

import (
"encoding/binary"
"fmt"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/std/permutation/ripemd160"
)

var _seed = uints.NewU32Array([]uint32{
0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476, 0xc3d2e1f0,
})

type digest struct {
uapi *uints.BinaryField[uints.U32]
in []uints.U8
}

// New returns a new ripemd160 hasher.
func New(api frontend.API) (hash.BinaryHasher, error) {
uapi, err := uints.New[uints.U32](api)
if err != nil {
return nil, fmt.Errorf("new uapi: %w", err)
}
return &digest{uapi: uapi}, nil
}

func (d *digest) Write(data []uints.U8) {
d.in = append(d.in, data...)
}

func (d *digest) padded(bytesLen int) []uints.U8 {
zeroPadLen := 55 - bytesLen%64
if zeroPadLen < 0 {
zeroPadLen += 64
}
if cap(d.in) < len(d.in)+9+zeroPadLen {
// in case this is the first time this method is called increase the
// capacity of the slice to fit the padding.
d.in = append(d.in, make([]uints.U8, 9+zeroPadLen)...)
d.in = d.in[:len(d.in)-9-zeroPadLen]
}
buf := d.in
buf = append(buf, uints.NewU8(0x80))
buf = append(buf, uints.NewU8Array(make([]uint8, zeroPadLen))...)
lenbuf := make([]uint8, 8)
binary.LittleEndian.PutUint64(lenbuf, uint64(8*bytesLen))
buf = append(buf, uints.NewU8Array(lenbuf)...)
return buf
}

func (d *digest) Sum() []uints.U8 {
var runningDigest [5]uints.U32
var buf [64]uints.U8
copy(runningDigest[:], _seed)
padded := d.padded(len(d.in))
for i := 0; i < len(padded)/64; i++ {
copy(buf[:], padded[i*64:(i+1)*64])
runningDigest = ripemd160.Permute(d.uapi, runningDigest, buf)
}
var ret []uints.U8
for i := range runningDigest {
ret = append(ret, d.uapi.UnpackLSB(runningDigest[i])...)
}
return ret
}

func (d *digest) Reset() {
d.in = nil
}

func (d *digest) Size() int {
return 20
}
52 changes: 52 additions & 0 deletions std/hash/ripemd160/ripemd160_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package ripemd160

import (
"fmt"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/uints"
"github.com/consensys/gnark/test"
"golang.org/x/crypto/ripemd160" //nolint staticcheck, backwards compatiblity
)

type ripemd160Circuit struct {
In []uints.U8
Expected [20]uints.U8
}

func (c *ripemd160Circuit) Define(api frontend.API) error {
h, err := New(api) //nolint G406, false positive, the current package name collides with "golang.org/x/crypto/ripemd160"
if err != nil {
return err
}
uapi, err := uints.New[uints.U32](api)
if err != nil {
return err
}
h.Write(c.In)
res := h.Sum()
if len(res) != len(c.Expected) {
return fmt.Errorf("not 20 bytes")
}
for i := range c.Expected {
uapi.ByteAssertEq(c.Expected[i], res[i])
}
return nil
}

func TestRipemd160(t *testing.T) {
bts := make([]byte, 310)
h := ripemd160.New() //nolint G406, false positive, we implement it for EVM compatibility
h.Write(bts)
dgst := h.Sum(nil)
witness := ripemd160Circuit{
In: uints.NewU8Array(bts),
}
copy(witness.Expected[:], uints.NewU8Array(dgst[:]))
err := test.IsSolved(&ripemd160Circuit{In: make([]uints.U8, len(bts))}, &witness, ecc.BN254.ScalarField())
if err != nil {
t.Fatal(err)
}
}
6 changes: 6 additions & 0 deletions std/math/uints/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func GetHints() []solver.Hint {
return []solver.Hint{
andHint,
xorHint,
orHint,
toBytes,
}
}
Expand All @@ -29,6 +30,11 @@ func andHint(_ *big.Int, inputs, outputs []*big.Int) error {
return nil
}

func orHint(_ *big.Int, inputs, outputs []*big.Int) error {
outputs[0].Or(inputs[0], inputs[1])
return nil
}

func toBytes(m *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 2 {
return fmt.Errorf("input must be 2 elements")
Expand Down
14 changes: 10 additions & 4 deletions std/math/uints/uint8.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ type U32 [4]U8
type Long interface{ U32 | U64 }

type BinaryField[T U32 | U64] struct {
api frontend.API
xorT, andT *logderivprecomp.Precomputed
rchecker frontend.Rangechecker
allOne U8
api frontend.API
xorT, andT, orT *logderivprecomp.Precomputed
rchecker frontend.Rangechecker
allOne U8
}

func New[T Long](api frontend.API) (*BinaryField[T], error) {
Expand All @@ -99,11 +99,16 @@ func New[T Long](api frontend.API) (*BinaryField[T], error) {
if err != nil {
return nil, fmt.Errorf("new and table: %w", err)
}
orT, err := logderivprecomp.New(api, orHint, []uint{8})
if err != nil {
return nil, fmt.Errorf("new or table: %w", err)
}
rchecker := rangecheck.New(api)
bf := &BinaryField[T]{
api: api,
xorT: xorT,
andT: andT,
orT: orT,
rchecker: rchecker,
}
// TODO: this is const. add way to init constants
Expand Down Expand Up @@ -244,6 +249,7 @@ func (bf *BinaryField[T]) twoArgWideFn(tbl *logderivprecomp.Precomputed, a ...T)

func (bf *BinaryField[T]) And(a ...T) T { return bf.twoArgWideFn(bf.andT, a...) }
func (bf *BinaryField[T]) Xor(a ...T) T { return bf.twoArgWideFn(bf.xorT, a...) }
func (bf *BinaryField[T]) Or(a ...T) T { return bf.twoArgWideFn(bf.orT, a...) }

func (bf *BinaryField[T]) not(a U8) U8 {
ret := bf.xorT.Query(a.Val, bf.allOne.Val)
Expand Down
137 changes: 137 additions & 0 deletions std/permutation/ripemd160/ripemd160block.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Package ripemd160 implements the permutation used in the ripemd160 hash function.
package ripemd160

import (
"github.com/consensys/gnark/std/math/uints"
)

var rLeft = [80]uint{
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
7, 4, 13, 1, 10, 6, 15, 3, 12, 0, 9, 5, 2, 14, 11, 8,
3, 10, 14, 4, 9, 15, 8, 1, 2, 7, 0, 6, 13, 11, 5, 12,
1, 9, 11, 10, 0, 8, 12, 4, 13, 3, 7, 15, 14, 5, 6, 2,
4, 0, 5, 9, 7, 12, 2, 10, 14, 1, 3, 8, 11, 6, 15, 13,
}

var rRight = [80]uint{
5, 14, 7, 0, 9, 2, 11, 4, 13, 6, 15, 8, 1, 10, 3, 12,
6, 11, 3, 7, 0, 13, 5, 10, 14, 15, 8, 12, 4, 9, 1, 2,
15, 5, 1, 3, 7, 14, 6, 9, 11, 8, 12, 2, 10, 0, 4, 13,
8, 6, 4, 1, 3, 11, 15, 0, 5, 12, 2, 13, 9, 7, 10, 14,
12, 15, 10, 4, 1, 5, 8, 7, 6, 2, 13, 14, 0, 3, 9, 11,
}

var sLeft = [80]uint{
11, 14, 15, 12, 5, 8, 7, 9, 11, 13, 14, 15, 6, 7, 9, 8,
7, 6, 8, 13, 11, 9, 7, 15, 7, 12, 15, 9, 11, 7, 13, 12,
11, 13, 6, 7, 14, 9, 13, 15, 14, 8, 13, 6, 5, 12, 7, 5,
11, 12, 14, 15, 14, 15, 9, 8, 9, 14, 5, 6, 8, 6, 5, 12,
9, 15, 5, 11, 6, 8, 13, 12, 5, 12, 13, 14, 11, 8, 5, 6,
}

var sRight = [80]uint{
8, 9, 9, 11, 13, 15, 15, 5, 7, 7, 8, 11, 14, 14, 12, 6,
9, 13, 15, 7, 12, 8, 9, 11, 7, 7, 12, 7, 6, 15, 13, 11,
9, 7, 15, 11, 8, 6, 6, 14, 12, 13, 5, 14, 13, 13, 7, 5,
15, 5, 8, 11, 14, 14, 6, 14, 6, 9, 12, 9, 12, 5, 15, 8,
8, 5, 12, 9, 12, 5, 14, 6, 8, 13, 6, 5, 15, 13, 11, 11,
}

var kLeft = [4]uints.U32(uints.NewU32Array([]uint32{
0x5a827999,
0x6ed9eba1,
0x8f1bbcdc,
0xa953fd4e,
}))

var kRight = [4]uints.U32(uints.NewU32Array([]uint32{
0x50a28be6,
0x5c4dd124,
0x6d703ef3,
0x7a6d76e9,
}))

func Permute(uapi *uints.BinaryField[uints.U32], currentHash [5]uints.U32, p [64]uints.U8) (newHash [5]uints.U32) {
var x [16]uints.U32
a, b, c, d, e := currentHash[0], currentHash[1], currentHash[2], currentHash[3], currentHash[4]
aa, bb, cc, dd, ee := a, b, c, d, e
for i := 0; i < 16; i++ {
x[i] = uapi.PackLSB(p[4*i], p[4*i+1], p[4*i+2], p[4*i+3])
}
for j := 0; j < 80; j++ {
a, b, c, d, e = round(uapi, j, true, a, b, c, d, e, x, rLeft, sLeft, kLeft)
aa, bb, cc, dd, ee = round(uapi, j, false, aa, bb, cc, dd, ee, x, rRight, sRight, kRight)
}
newHash[0] = uapi.Add(currentHash[1], c, dd)
newHash[1] = uapi.Add(currentHash[2], d, ee)
newHash[2] = uapi.Add(currentHash[3], e, aa)
newHash[3] = uapi.Add(currentHash[4], a, bb)
newHash[4] = uapi.Add(currentHash[0], b, cc)
return
}

func f(uapi *uints.BinaryField[uints.U32], j int, x, y, z uints.U32) uints.U32 {
if j < 16 {
// x ^ y ^ z
return uapi.Xor(x, y, z)
}
if j < 32 {
// (x & y) | (~x & z)
return uapi.Or(uapi.And(x, y), uapi.And(uapi.Not(x), z))
}
if j < 48 {
// (x | ~y) ^ z
return uapi.Xor(
uapi.Or(x, uapi.Not(y)),
z,
)
}
if j < 64 {
// (x & z) | (y & ~z)
return uapi.Or(
uapi.And(x, z),
uapi.And(y, uapi.Not(z)),
)
}
// x ^ (y | ~z)
return uapi.Xor(
x,
uapi.Or(y, uapi.Not(z)),
)
}

func round(uapi *uints.BinaryField[uints.U32], j int, isLeft bool, A, B, C, D, E uints.U32, X_i [16]uints.U32, r, s [80]uint, K [4]uints.U32) (AA, BB, CC, DD, EE uints.U32) {
var tmp1 uints.U32
jj := j
jjj := j / 16
if !isLeft {
jj = 79 - j
} else {
jjj = jjj - 1
}
ff := f(uapi, jj, B, C, D)
if (isLeft && j < 16) || (!isLeft && j >= 64) {
tmp1 = uapi.Add(
A,
ff,
X_i[r[j]],
)
} else {
tmp1 = uapi.Add(
A,
ff,
X_i[r[j]],
K[jjj],
)
}
T := uapi.Add(
uapi.Lrot(tmp1, int(s[j])),
E,
)
AA = E
BB = T
CC = B
DD = uapi.Lrot(C, 10)
EE = D
return
}
Loading

0 comments on commit 5a545ea

Please sign in to comment.