Skip to content

Commit 40b1904

Browse files
committed
Faster PoT verification for CPUs that support AVX512F+VAES
1 parent 310ba30 commit 40b1904

File tree

3 files changed

+163
-29
lines changed

3 files changed

+163
-29
lines changed

crates/subspace-proof-of-time/src/aes.rs

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
#[cfg(target_arch = "x86_64")]
44
mod x86_64;
55

6+
use subspace_core_primitives::pot::{PotCheckpoints, PotKey, PotOutput, PotSeed};
7+
use aes::Aes128;
68
use aes::cipher::array::Array;
79
use aes::cipher::{BlockCipherDecrypt, BlockCipherEncrypt, KeyInit};
8-
use aes::Aes128;
9-
use subspace_core_primitives::pot::{PotCheckpoints, PotKey, PotOutput, PotSeed};
1010

1111
/// Creates the AES based proof.
1212
#[inline(always)]
@@ -51,6 +51,25 @@ pub(crate) fn verify_sequential(
5151
) -> bool {
5252
assert_eq!(checkpoint_iterations % 2, 0);
5353

54+
#[cfg(target_arch = "x86_64")]
55+
{
56+
cpufeatures::new!(has_aes, "avx512f", "vaes");
57+
if has_aes::get() {
58+
return unsafe {
59+
x86_64::verify_sequential_avx512f(&seed, &key, checkpoints, checkpoint_iterations)
60+
};
61+
}
62+
}
63+
64+
verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations)
65+
}
66+
67+
fn verify_sequential_generic(
68+
seed: PotSeed,
69+
key: PotKey,
70+
checkpoints: &PotCheckpoints,
71+
checkpoint_iterations: u32,
72+
) -> bool {
5473
let key = Array::from(*key);
5574
let cipher = Aes128::new(&key);
5675

@@ -113,6 +132,12 @@ mod tests {
113132
&checkpoints,
114133
checkpoint_iterations,
115134
));
135+
assert!(verify_sequential_generic(
136+
seed,
137+
key,
138+
&checkpoints,
139+
checkpoint_iterations,
140+
));
116141

117142
// Decryption of invalid cipher text fails.
118143
let mut checkpoints_1 = checkpoints;
@@ -123,6 +148,12 @@ mod tests {
123148
&checkpoints_1,
124149
checkpoint_iterations,
125150
));
151+
assert!(!verify_sequential_generic(
152+
seed,
153+
key,
154+
&checkpoints_1,
155+
checkpoint_iterations,
156+
));
126157

127158
// Decryption with wrong number of iterations fails.
128159
assert!(!verify_sequential(
@@ -131,12 +162,24 @@ mod tests {
131162
&checkpoints,
132163
checkpoint_iterations + 2,
133164
));
165+
assert!(!verify_sequential_generic(
166+
seed,
167+
key,
168+
&checkpoints,
169+
checkpoint_iterations + 2,
170+
));
134171
assert!(!verify_sequential(
135172
seed,
136173
key,
137174
&checkpoints,
138175
checkpoint_iterations - 2,
139176
));
177+
assert!(!verify_sequential_generic(
178+
seed,
179+
key,
180+
&checkpoints,
181+
checkpoint_iterations - 2,
182+
));
140183

141184
// Decryption with wrong seed fails.
142185
assert!(!verify_sequential(
@@ -145,6 +188,12 @@ mod tests {
145188
&checkpoints,
146189
checkpoint_iterations,
147190
));
191+
assert!(!verify_sequential_generic(
192+
PotSeed::from(SEED_1),
193+
key,
194+
&checkpoints,
195+
checkpoint_iterations,
196+
));
148197

149198
// Decryption with wrong key fails.
150199
assert!(!verify_sequential(
@@ -153,5 +202,11 @@ mod tests {
153202
&checkpoints,
154203
checkpoint_iterations,
155204
));
205+
assert!(!verify_sequential_generic(
206+
seed,
207+
PotKey::from(KEY_1),
208+
&checkpoints,
209+
checkpoint_iterations,
210+
));
156211
}
157212
}

crates/subspace-proof-of-time/src/aes/x86_64.rs

Lines changed: 105 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
use subspace_core_primitives::pot::{PotCheckpoints, PotOutput};
12
use core::arch::x86_64::*;
2-
use core::mem;
3-
use subspace_core_primitives::pot::PotCheckpoints;
3+
use core::{array, mem};
44

55
/// Create PoT proof with checkpoints
66
#[target_feature(enable = "aes")]
@@ -12,40 +12,117 @@ pub(super) unsafe fn create(
1212
) -> PotCheckpoints {
1313
let mut checkpoints = PotCheckpoints::default();
1414

15-
let keys_reg = expand_key(key);
16-
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
17-
let mut seed_reg = _mm_loadu_si128(seed.as_ptr() as *const __m128i);
18-
seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
19-
for checkpoint in checkpoints.iter_mut() {
20-
for _ in 0..checkpoint_iterations {
21-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]);
22-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]);
23-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]);
24-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]);
25-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]);
26-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]);
27-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]);
28-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]);
29-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]);
30-
seed_reg = _mm_aesenclast_si128(seed_reg, xor_key);
31-
}
15+
unsafe {
16+
let keys_reg = expand_key(key);
17+
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
18+
let mut seed_reg = _mm_loadu_si128(seed.as_ptr() as *const __m128i);
19+
seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
20+
for checkpoint in checkpoints.iter_mut() {
21+
for _ in 0..checkpoint_iterations {
22+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]);
23+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]);
24+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]);
25+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]);
26+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]);
27+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]);
28+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]);
29+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]);
30+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]);
31+
seed_reg = _mm_aesenclast_si128(seed_reg, xor_key);
32+
}
3233

33-
let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
34-
_mm_storeu_si128(
35-
checkpoint.as_mut().as_mut_ptr() as *mut __m128i,
36-
checkpoint_reg,
37-
);
34+
let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
35+
_mm_storeu_si128(checkpoint.as_mut_ptr() as *mut __m128i, checkpoint_reg);
36+
}
3837
}
3938

4039
checkpoints
4140
}
4241

42+
/// Verification mimics `create` function, but also has decryption half for better performance
43+
#[target_feature(enable = "avx512f,vaes")]
44+
#[inline]
45+
pub(super) unsafe fn verify_sequential_avx512f(
46+
seed: &[u8; 16],
47+
key: &[u8; 16],
48+
checkpoints: &PotCheckpoints,
49+
checkpoint_iterations: u32,
50+
) -> bool {
51+
let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
52+
53+
unsafe {
54+
let keys_reg = expand_key(key);
55+
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
56+
let xor_key_512 = _mm512_broadcast_i32x4(xor_key);
57+
58+
// Invert keys for decryption
59+
let mut inv_keys = keys_reg;
60+
for i in 1..10 {
61+
inv_keys[i] = _mm_aesimc_si128(keys_reg[10 - i]);
62+
}
63+
64+
let keys_512 = array::from_fn::<_, NUM_ROUNDS, _>(|i| _mm512_broadcast_i32x4(keys_reg[i]));
65+
let inv_keys_512 =
66+
array::from_fn::<_, NUM_ROUNDS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i]));
67+
68+
let mut input_0 = [[0u8; 16]; 4];
69+
input_0[0] = *seed;
70+
input_0[1..].copy_from_slice(&checkpoints[..3]);
71+
let mut input_0 = _mm512_loadu_si512(input_0.as_ptr() as *const __m512i);
72+
let mut input_1 = _mm512_loadu_si512(checkpoints[3..7].as_ptr() as *const __m512i);
73+
74+
let mut output_0 = _mm512_loadu_si512(checkpoints[0..4].as_ptr() as *const __m512i);
75+
let mut output_1 = _mm512_loadu_si512(checkpoints[4..8].as_ptr() as *const __m512i);
76+
77+
input_0 = _mm512_xor_si512(input_0, keys_512[0]);
78+
input_1 = _mm512_xor_si512(input_1, keys_512[0]);
79+
80+
output_0 = _mm512_xor_si512(output_0, keys_512[10]);
81+
output_1 = _mm512_xor_si512(output_1, keys_512[10]);
82+
83+
for _ in 0..checkpoint_iterations / 2 {
84+
for i in 1..10 {
85+
input_0 = _mm512_aesenc_epi128(input_0, keys_512[i]);
86+
input_1 = _mm512_aesenc_epi128(input_1, keys_512[i]);
87+
88+
output_0 = _mm512_aesdec_epi128(output_0, inv_keys_512[i]);
89+
output_1 = _mm512_aesdec_epi128(output_1, inv_keys_512[i]);
90+
}
91+
92+
input_0 = _mm512_aesenclast_epi128(input_0, xor_key_512);
93+
input_1 = _mm512_aesenclast_epi128(input_1, xor_key_512);
94+
95+
output_0 = _mm512_aesdeclast_epi128(output_0, xor_key_512);
96+
output_1 = _mm512_aesdeclast_epi128(output_1, xor_key_512);
97+
}
98+
99+
// Code below is a more efficient version of this:
100+
// input_0 = _mm512_xor_si512(input_0, keys_512[0]);
101+
// input_1 = _mm512_xor_si512(input_1, keys_512[0]);
102+
// output_0 = _mm512_xor_si512(output_0, keys_512[10]);
103+
// output_1 = _mm512_xor_si512(output_1, keys_512[10]);
104+
//
105+
// let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
106+
// let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
107+
108+
let diff_0 = _mm512_xor_si512(input_0, output_0);
109+
let diff_1 = _mm512_xor_si512(input_1, output_1);
110+
111+
let mask0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512);
112+
let mask1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512);
113+
114+
// All inputs match outputs
115+
(mask0 & mask1) == u8::MAX
116+
}
117+
}
118+
43119
// Below code copied with minor changes from following place under MIT/Apache-2.0 license by Artyom
44120
// Pavlov:
45121
// https://github.com/RustCrypto/block-ciphers/blob/9413fcadd28d53854954498c0589b747d8e4ade2/aes/src/ni/aes128.rs
46122

123+
const NUM_ROUNDS: usize = 11;
47124
/// AES-128 round keys
48-
type RoundKeys = [__m128i; 11];
125+
type RoundKeys = [__m128i; NUM_ROUNDS];
49126

50127
macro_rules! expand_round {
51128
($keys:expr, $pos:expr, $round:expr) => {
@@ -72,9 +149,10 @@ macro_rules! expand_round {
72149
unsafe fn expand_key(key: &[u8; 16]) -> RoundKeys {
73150
// SAFETY: `RoundKeys` is a `[__m128i; 11]` which can be initialized
74151
// with all zeroes.
75-
let mut keys: RoundKeys = mem::zeroed();
152+
let mut keys: RoundKeys = unsafe { mem::zeroed() };
76153

77-
let k = _mm_loadu_si128(key.as_ptr() as *const __m128i);
154+
// SAFETY: No alignment requirement in `_mm_loadu_si128`
155+
let k = unsafe { _mm_loadu_si128(key.as_ptr() as *const __m128i) };
78156
keys[0] = k;
79157

80158
expand_round!(keys, 1, 0x01);

crates/subspace-proof-of-time/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Proof of time implementation.
22
3+
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
34
#![no_std]
45

56
mod aes;

0 commit comments

Comments
 (0)