Skip to content

Commit 6d3bb27

Browse files
authored
Merge pull request #400 from EYBlockchain/julian@dsl-update
Zokrates DSL update
2 parents 777de64 + 278a3dc commit 6d3bb27

45 files changed

Lines changed: 969 additions & 825 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from "EMBED" import u32_to_bits
1+
from "EMBED" import u32_to_bits;
22

3-
def main<N>(u32[N] i) -> (field):
4-
field res = 0
5-
for u32 k in 0..N do
6-
for u32 j in 0..32 do
7-
bool[32] bits = u32_to_bits(i[k])
8-
u32 exponent = (N - k - 1) * 32 + (32 - j - 1)
9-
res = res + if bits[j] then 2 ** exponent else 0 fi
10-
endfor
11-
endfor
12-
return res
3+
def main<N>(u32[N] i) -> field {
4+
field mut res = 0;
5+
for u32 k in 0..N {
6+
for u32 j in 0..32 {
7+
bool[32] bits = u32_to_bits(i[k]);
8+
u32 exponent = (N - k - 1) * 32 + (32 - j - 1);
9+
res = res + if bits[j] { 2 ** exponent } else { 0 };
10+
}
11+
}
12+
return res;
13+
}
Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from "EMBED" import u8_to_bits
1+
from "EMBED" import u8_to_bits;
22

3-
def main<N>(u8[N] i) -> (field):
4-
field res = 0
5-
for u32 k in 0..N do
6-
for u32 j in 0..8 do
7-
bool[8] bits = u8_to_bits(i[k])
8-
u32 exponent = (N - k - 1) * 8 + (8 - j - 1)
9-
res = res + if bits[j] then 2 ** exponent else 0 fi
10-
endfor
11-
endfor
12-
return res
3+
def main<N>(u8[N] i) -> field {
4+
field res = 0;
5+
for u32 k in 0..N {
6+
for u32 j in 0..8 {
7+
bool[8] bits = u8_to_bits(i[k]);
8+
u32 exponent = (N - k - 1) * 8 + (8 - j - 1);
9+
res = res + if bits[j] { 2 ** exponent } else { 0 };
10+
}
11+
}
12+
return res;
13+
}

circuits/common/concatenate/order-left-right-1x1.zok

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
// If order = 0: input1 is on the left.
44
// If order = 1: input1 is on the right.
55

6-
def main(bool order, field input1, field input2) -> (field[2]):
7-
field left = if order == false then input1 else input2 fi
8-
field right = if order == false then input2 else input1 fi
9-
return [left, right]
6+
def main(bool order, field input1, field input2) -> field[2] {
7+
field left = if order == false { input1 } else { input2 };
8+
field right = if order == false { input2 } else { input1 };
9+
return [left, right];
10+
}
Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,42 @@
1-
from "ecc/babyjubjubParams" import BabyJubJubParams
2-
from "ecc/babyjubjubParams" import main as curveParams
3-
from "ecc/edwardsScalarMult" import main as scalarMult
4-
from "utils/casts/u32_to_field" import main as u32_to_field
5-
from "hashes/poseidon/poseidon.zok" import main as poseidon
1+
from "ecc/babyjubjubParams" import BabyJubJubParams;
2+
from "ecc/babyjubjubParams" import main as curveParams;
3+
from "ecc/edwardsScalarMult" import main as scalarMult;
4+
from "utils/casts/u32_to_field" import main as u32_to_field;
5+
from "hashes/poseidon/poseidon.zok" import main as poseidon;
66

77
struct EncryptedMsgs<N> {
8-
field[N] cipherText
9-
field[2] ephemeralPublicKey
8+
field[N] cipherText;
9+
field[2] ephemeralPublicKey;
1010
}
1111

12+
const field DOMAIN_KEM = 10f;
1213

13-
const field DOMAIN_KEM = 10f
14+
const field DOMAIN_DEM = 20f;
1415

15-
const field DOMAIN_DEM = 20f
16-
17-
def kem(bool[256] ephemeralKey, field[2] recipientPub) -> field:
18-
BabyJubJubParams context = curveParams()
19-
field[2] g = [context.Gu, context.Gv]
20-
field[2] sharedSecret = scalarMult(ephemeralKey, recipientPub, context)
21-
field encryptionKey = poseidon([sharedSecret[0], sharedSecret[1], DOMAIN_KEM])
22-
return encryptionKey
16+
def kem(bool[256] mut ephemeralKey, field[2] recipientPub) -> field {
17+
BabyJubJubParams mut context = curveParams();
18+
field[2] mut g = [context.Gu, context.Gv];
19+
field[2] mut sharedSecret = scalarMult(ephemeralKey, recipientPub, context);
20+
field mut encryptionKey = poseidon([sharedSecret[0], sharedSecret[1], DOMAIN_KEM]);
21+
return encryptionKey;
22+
}
2323

24-
def dem<N>(field encryptionKey, field[N] plainText) -> field[N]:
25-
field[N] output = [0; N]
26-
for u32 i in 0..N do
27-
output[i] = poseidon([encryptionKey, DOMAIN_DEM, u32_to_field(i)]) + plainText[i]
28-
endfor
29-
return output
24+
def dem<N>(field mut encryptionKey, field[N] plainText) -> field[N] {
25+
field[N] mut output = [0; N];
26+
for u32 i in 0..N {
27+
output[i] = poseidon([encryptionKey, DOMAIN_DEM, u32_to_field(i)]) + plainText[i];
28+
}
29+
return output;
30+
}
3031

31-
def main<N>(bool[256] ephemeralKey,field[2] recipientPub,field[N] plainText) -> EncryptedMsgs<N>:
32-
BabyJubJubParams context = curveParams()
33-
field[2] g = [context.Gu, context.Gv]
34-
field[2] ephemeralPub = scalarMult(ephemeralKey, g, context)
32+
def main<N>(bool[256] mut ephemeralKey, field[2] recipientPub, field[N] plainText) -> EncryptedMsgs<N> {
33+
BabyJubJubParams mut context = curveParams();
34+
field[2] mut g = [context.Gu, context.Gv];
35+
field[2] mut ephemeralPub = scalarMult(ephemeralKey, g, context);
3536

36-
field encryptionKey = kem(ephemeralKey, recipientPub)
37+
field mut encryptionKey = kem(ephemeralKey, recipientPub);
3738

38-
field[N] cipherText = dem(encryptionKey, plainText)
39-
EncryptedMsgs<N> e = EncryptedMsgs { cipherText: cipherText, ephemeralPublicKey: ephemeralPub}
40-
return e
39+
field[N] mut cipherText = dem(encryptionKey, plainText);
40+
EncryptedMsgs<N> mut e = EncryptedMsgs { cipherText: cipherText, ephemeralPublicKey: ephemeralPub };
41+
return e;
42+
}

circuits/common/hashes/mimc/altbn254/mimc-constants.zok

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
def main()->(field[91]):
1+
def main() -> field[91] {
22
return [
33
20888961410941983456478427210666206549300505294776164667214940546594746570981,
44
15265126113435022738560151911929040668591755459209400716467504685752745317193,
@@ -91,4 +91,5 @@ def main()->(field[91]):
9191
18979889247746272055963929241596362599320706910852082477600815822482192194401,
9292
13602139229813231349386885113156901793661719180900395818909719758150455500533,
9393
13952667105157556595308191233585255581771936717523666104281454907150877850313
94-
]
94+
];
95+
}

circuits/common/hashes/mimc/altbn254/mimc.zok

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
// we use an exponent of 7 and 91 rounds
44

5-
from "./mimc-constants.zok" import main as constants
5+
from "./mimc-constants.zok" import main as constants;
66

7-
def main(field x, field k)->(field):
8-
field[91] c = constants()
9-
for u32 i in 0..91 do
10-
field t = x + c[i] + k
11-
x = t**7 // t^7 because 7th power is bijective in this field
12-
endfor
13-
return x + k
7+
def main(field mut x, field mut k) -> field {
8+
field[91] c = constants();
9+
for u32 mut i in 0..91 {
10+
field mut t = x + c[i] + k;
11+
x = t**7; // t^7 because 7th power is bijective in this field
12+
}
13+
return x + k;
14+
}
Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
// MiMC hashing function for five input fields
22

3-
from "./mimc.zok" import main as mimcpe7
3+
from "./mimc.zok" import main as mimcpe7;
44

5-
def main(field[2] a)->(field):
6-
field r = 0
7-
for u32 i in 0..2 do
8-
r = r + a[i] + mimcpe7(a[i], r)
9-
endfor
10-
return r
5+
def main(field[2] a) -> field {
6+
field mut r = 0;
7+
for u32 i in 0..2 {
8+
r = r + a[i] + mimcpe7(a[i], r);
9+
}
10+
return r;
11+
}

circuits/common/hashes/poseidon/constants.zok

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10904,7 +10904,7 @@ const field[16][1292] POSEIDON_C = [
1090410904
15306300298273142257702357120212730128497075786589008381550108606914393296015,
1090510905
19116371381269652319147699604019975103087973589614811479290794650138683901396
1090610906
]
10907-
]
10907+
];
1090810908

1090910909
const field[16][17][17] POSEIDON_M = [
1091010910
[
@@ -13842,4 +13842,4 @@ const field[16][17][17] POSEIDON_M = [
1384213842
13228220894074693515947418568115512670466893414535562052872530653586084906533
1384313843
]
1384413844
]
13845-
]
13845+
];
Lines changed: 53 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,62 @@
11
// circuit code (with more inputs) with reference from zokrates core library
22
// https://github.com/Zokrates/ZoKrates/blob/develop/zokrates_stdlib/stdlib/hashes/poseidon/constants.zok
33

4-
from "./constants.zok" import POSEIDON_C, POSEIDON_M
5-
6-
def ark<N>(field[N] state, field[1292] c, u32 it) -> field[N]:
7-
for u32 i in 0..N do
8-
state[i] = state[i] + c[it + i]
9-
endfor
10-
return state
11-
12-
def sbox<N>(field[N] state, u32 f, u32 p, u32 r) -> field[N]:
13-
state[0] = state[0]**5
14-
for u32 i in 1..N do
15-
state[i] = if ((r < f/2) || (r >= f/2 + p)) then state[i]**5 else state[i] fi
16-
endfor
17-
return state
18-
19-
def mix<N>(field[N] state, field[17][17] m) -> field[N]:
20-
field[N] out = [0; N]
21-
for u32 i in 0..N do
22-
field acc = 0
23-
for u32 j in 0..N do
24-
acc = acc + (state[j] * m[i][j])
25-
endfor
26-
out[i] = acc
27-
endfor
28-
return out
29-
30-
def main<N>(field[N] inputs) -> field:
31-
assert(N > 0 && N <= 16) // max 16 inputs
32-
33-
u32 t = N + 1
34-
u32[16] rounds_p = [56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68]
35-
36-
u32 f = 8
37-
u32 p = rounds_p[(t - 2)]
4+
from "./constants.zok" import POSEIDON_C, POSEIDON_M;
5+
6+
def ark<N>(field[N] mut state, field[1292] c, u32 it) -> field[N] {
7+
for u32 i in 0..N {
8+
state[i] = state[i] + c[it + i];
9+
}
10+
return state;
11+
}
12+
13+
def sbox<N>(field[N] mut state, u32 f, u32 p, u32 r) -> field[N] {
14+
state[0] = state[0]**5;
15+
for u32 i in 1..N {
16+
state[i] = if ((r < f/2) || (r >= f/2 + p)) { state[i]**5 } else { state[i] };
17+
}
18+
return state;
19+
}
20+
21+
def mix<N>(field[N] mut state, field[17][17] m) -> field[N] {
22+
field[N] mut out = [0; N];
23+
for u32 i in 0..N {
24+
field mut acc = 0;
25+
for u32 j in 0..N {
26+
acc = acc + (state[j] * m[i][j]);
27+
}
28+
out[i] = acc;
29+
}
30+
return out;
31+
}
32+
33+
def main<N>(field[N] inputs) -> field {
34+
assert(N > 0 && N <= 16); // max 16 inputs
35+
36+
u32 t = N + 1;
37+
u32[16] rounds_p = [56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68];
38+
39+
u32 f = 8;
40+
u32 p = rounds_p[(t - 2)];
3841

3942
// Constants are padded with zeroes to the maximum value calculated by
4043
// t * (f + p) = 497, where `t` (number of inputs + 1) is a max of 7.
4144
// This is done to keep the function generic, as resulting array size depends on `t`
4245
// and we do not want callers passing down constants.
4346
// This should be revisited once compiler limitations are gone.
44-
field[1292] c = POSEIDON_C[t - 2]
45-
field[17][17] m = POSEIDON_M[t - 2]
46-
47-
field[t] state = [0; t]
48-
for u32 i in 1..t do
49-
state[i] = inputs[i - 1]
50-
endfor
51-
52-
for u32 r in 0..f+p do
53-
state = ark(state, c, r * t)
54-
state = sbox(state, f, p, r)
55-
state = mix(state, m)
56-
endfor
57-
58-
return state[0]
59-
60-
47+
field[1292] c = POSEIDON_C[t - 2];
48+
field[17][17] m = POSEIDON_M[t - 2];
49+
50+
field[t] mut state = [0; t];
51+
for u32 i in 1..t {
52+
state[i] = inputs[i - 1];
53+
}
54+
55+
for u32 r in 0..f+p {
56+
state = ark(state, c, r * t);
57+
state = sbox(state, f, p, r);
58+
state = mix(state, m);
59+
}
60+
61+
return state[0];
62+
}
Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from "hashes/sha256/embed/1024bit.zok" import main as sha256
2-
from "utils/casts/u32_8_to_bool_256.zok" import main as u32_8_to_bool_256
3-
from "utils/casts/bool_256_to_u32_8.zok" import main as bool_256_to_u32_8
1+
from "hashes/sha256/embed/1024bit.zok" import main as sha256;
2+
from "utils/casts/u32_8_to_bool_256.zok" import main as u32_8_to_bool_256;
3+
from "utils/casts/bool_256_to_u32_8.zok" import main as bool_256_to_u32_8;
44

5-
def main(u32[32] a) -> (u32[8]):
6-
return bool_256_to_u32_8(sha256(\
7-
u32_8_to_bool_256(a[0..8]),\
8-
u32_8_to_bool_256(a[8..16]),\
9-
u32_8_to_bool_256(a[16..24]),\
10-
u32_8_to_bool_256(a[24..32])\
11-
))
5+
def main(u32[32] a) -> u32[8] {
6+
return bool_256_to_u32_8(sha256(
7+
u32_8_to_bool_256(a[0..8]),
8+
u32_8_to_bool_256(a[8..16]),
9+
u32_8_to_bool_256(a[16..24]),
10+
u32_8_to_bool_256(a[24..32])
11+
));
12+
}

0 commit comments

Comments
 (0)