Skip to content

Commit 7c81df9

Browse files
authored
fix modmul 256-bit perf (#156)
Not sure about other lengths, but this 100x's 256-bit `modmul` on the given trivial benchmark and fixes abysmally slow EVM performance ``` Modmul (stint): 856300 ms ``` ``` Modmul (stint): 8850 ms ```
1 parent 9a3348b commit 7c81df9

File tree

2 files changed

+42
-69
lines changed

2 files changed

+42
-69
lines changed

benchmarks/bench.nim

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ let a = [123'u64, 123'u64, 123'u64, 123'u64]
2323
let m = [456'u64, 456'u64, 456'u64, 45'u64]
2424

2525
proc add_stint(a, m: array[4, uint64]) =
26-
let aU256 = cast[Stuint[256]](a)
27-
let mU256 = cast[Stuint[256]](m)
26+
let aU256 = cast[StUint[256]](a)
27+
let mU256 = cast[StUint[256]](m)
2828

2929
bench "Add (stint)":
3030
var foo = aU256
@@ -33,26 +33,37 @@ proc add_stint(a, m: array[4, uint64]) =
3333
foo += aU256
3434

3535
proc mul_stint(a, m: array[4, uint64]) =
36-
let aU256 = cast[Stuint[256]](a)
37-
let mU256 = cast[Stuint[256]](m)
36+
let aU256 = cast[StUint[256]](a)
37+
let mU256 = cast[StUint[256]](m)
3838

3939
bench "Mul (stint)":
4040
var foo = aU256
4141
for i in 0 ..< 100_000_000:
4242
foo += (foo * foo)
4343

4444
proc mod_stint(a, m: array[4, uint64]) =
45-
let aU256 = cast[Stuint[256]](a)
46-
let mU256 = cast[Stuint[256]](m)
45+
let aU256 = cast[StUint[256]](a)
46+
let mU256 = cast[StUint[256]](m)
4747

4848
bench "Mod (stint)":
4949
var foo = aU256
5050
for i in 0 ..< 100_000_000:
5151
foo += (foo * foo) mod mU256
5252

53-
add_stint(a, m)
54-
mul_stint(a, m)
55-
mod_stint(a, m)
53+
proc mulmod_stint(a, m: array[4, uint64]) =
54+
let aU256 = cast[StUint[256]](a)
55+
let mU256 = cast[StUint[256]](m)
56+
57+
bench "Modmul (stint)":
58+
var foo = aU256
59+
for i in 0 ..< 100_000_000:
60+
foo += mulmod(aU256, aU256, mU256)
61+
62+
# add_stint(a, m)
63+
# mul_stint(a, m)
64+
# mod_stint(a, m)
65+
66+
mulmod_stint(a, m)
5667

5768
when defined(bench_ttmath):
5869
# need C++
@@ -88,4 +99,4 @@ when defined(bench_ttmath):
8899

89100
add_ttmath(a, m)
90101
mul_ttmath(a, m)
91-
mod_ttmath(a, m)
102+
mod_ttmath(a, m)

stint/modular_arithmetic.nim

Lines changed: 21 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ func addmod_internal(a, b, m: StUint): StUint {.inline.}=
2222
let b_from_m = m - b
2323

2424
if a >= b_from_m:
25-
return a - b_from_m
26-
return m - b_from_m + a
25+
a - b_from_m
26+
else:
27+
m - b_from_m + a
2728

2829
func submod_internal(a, b, m: StUint): StUint {.inline.}=
2930
## Modular substraction
@@ -34,53 +35,9 @@ func submod_internal(a, b, m: StUint): StUint {.inline.}=
3435

3536
# We don't do a_m - b_m directly to avoid underflows
3637
if a >= b:
37-
return a - b
38-
return m - b + a
39-
40-
41-
func doublemod_internal(a, m: StUint): StUint {.inline.}=
42-
## Double a modulo m. Assume a < m
43-
## Internal proc - used in mulmod
44-
45-
doAssert a < m
46-
47-
result = a
48-
if a >= m - a:
49-
result -= m
50-
result += a
51-
52-
func mulmod_internal(a, b, m: StUint): StUint {.inline.}=
53-
## Does (a * b) mod m. Assume a < m and b < m
54-
## Internal proc - used in powmod
55-
56-
doAssert a < m
57-
doAssert b < m
58-
59-
var (a, b) = (a, b)
60-
61-
if b > a:
62-
swap(a, b)
63-
64-
while not b.isZero:
65-
if b.isOdd:
66-
result = result.addmod_internal(a, m)
67-
a = doublemod_internal(a, m)
68-
b = b shr 1
69-
70-
func powmod_internal(a, b, m: StUint): StUint {.inline.}=
71-
## Compute ``(a ^ b) mod m``, assume a < m
72-
## Internal proc
73-
74-
doAssert a < m
75-
76-
var (a, b) = (a, b)
77-
result = one(type a)
78-
79-
while not b.isZero:
80-
if b.isOdd:
81-
result = result.mulmod_internal(a, m)
82-
b = b shr 1
83-
a = mulmod_internal(a, a, m)
38+
a - b
39+
else:
40+
m - b + a
8441

8542
func addmod*(a, b, m: StUint): StUint =
8643
## Modular addition
@@ -90,7 +47,7 @@ func addmod*(a, b, m: StUint): StUint =
9047
let b_m = if b < m: b
9148
else: b mod m
9249

93-
result = addmod_internal(a_m, b_m, m)
50+
addmod_internal(a_m, b_m, m)
9451

9552
func submod*(a, b, m: StUint): StUint =
9653
## Modular substraction
@@ -100,24 +57,29 @@ func submod*(a, b, m: StUint): StUint =
10057
let b_m = if b < m: b
10158
else: b mod m
10259

103-
result = submod_internal(a_m, b_m, m)
60+
submod_internal(a_m, b_m, m)
10461

10562
func mulmod*(a, b, m: StUint): StUint =
10663
## Modular multiplication
10764

108-
let a_m = if a < m: a
109-
else: a mod m
110-
let b_m = if b < m: b
111-
else: b mod m
65+
let
66+
ax = a.stuint(a.bits * 2)
67+
bx = b.stuint(b.bits * 2)
68+
mx = m.stuint(m.bits * 2)
69+
px = ax * bx
11270

113-
result = mulmod_internal(a_m, b_m, m)
71+
divmod(px, mx).rem.stuint(a.bits)
11472

11573
func powmod*(a, b, m: StUint): StUint =
11674
## Modular exponentiation
11775

118-
let a_m = if a < m: a
119-
else: a mod m
76+
var (a, b) = (a, b)
77+
result = one(type a)
12078

121-
result = powmod_internal(a_m, b, m)
79+
while not b.isZero:
80+
if b.isOdd:
81+
result = result.mulmod(a, m)
82+
b = b shr 1
83+
a = mulmod(a, a, m)
12284

12385
{.pop.}

0 commit comments

Comments
 (0)