Skip to content

Commit ab9d6e6

Browse files
authored
Merge pull request #3561 from autonomys/pot-aarch64-aes
Implement PoT proving and verification optimized for AES (aarch64)
2 parents adadb17 + 8ce21d4 commit ab9d6e6

File tree

5 files changed

+223
-22
lines changed

5 files changed

+223
-22
lines changed

crates/subspace-proof-of-time/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ aes.workspace = true
1919
subspace-core-primitives.workspace = true
2020
thiserror.workspace = true
2121

22-
[target.'cfg(target_arch = "x86_64")'.dependencies]
22+
[target.'cfg(any(target_arch = "aarch64", target_arch = "x86_64"))'.dependencies]
2323
cpufeatures = { workspace = true }
2424

2525
[dev-dependencies]

crates/subspace-proof-of-time/src/aes.rs

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
//! AES related functionality.
22
3+
#[cfg(target_arch = "aarch64")]
4+
mod aarch64;
35
#[cfg(target_arch = "x86_64")]
46
mod x86_64;
57

@@ -19,6 +21,14 @@ pub(crate) fn create(seed: PotSeed, key: PotKey, checkpoint_iterations: u32) ->
1921
return unsafe { x86_64::create(seed.as_ref(), key.as_ref(), checkpoint_iterations) };
2022
}
2123
}
24+
#[cfg(target_arch = "aarch64")]
25+
{
26+
cpufeatures::new!(has_aes, "aes");
27+
if has_aes::get() {
28+
// SAFETY: Checked `aes` feature
29+
return unsafe { aarch64::create(seed.as_ref(), key.as_ref(), checkpoint_iterations) };
30+
}
31+
}
2232

2333
create_generic(seed, key, checkpoint_iterations)
2434
}
@@ -83,6 +93,16 @@ pub(crate) fn verify_sequential(
8393
};
8494
}
8595
}
96+
#[cfg(target_arch = "aarch64")]
97+
{
98+
cpufeatures::new!(has_aes, "aes");
99+
if has_aes::get() {
100+
// SAFETY: Checked `aes` feature
101+
return unsafe {
102+
aarch64::verify_sequential_aes(&seed, &key, checkpoints, checkpoint_iterations)
103+
};
104+
}
105+
}
86106

87107
verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations)
88108
}
@@ -143,9 +163,8 @@ mod tests {
143163
checkpoint_iterations: u32,
144164
) -> bool {
145165
let sequential = verify_sequential(seed, key, checkpoints, checkpoint_iterations);
146-
let sequential_generic =
147-
verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations);
148-
assert_eq!(sequential, sequential_generic);
166+
let generic = verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations);
167+
assert_eq!(sequential, generic);
149168

150169
#[cfg(target_arch = "x86_64")]
151170
{
@@ -180,14 +199,25 @@ mod tests {
180199
cpufeatures::new!(has_aes_sse41, "aes", "sse4.1");
181200
if has_aes_sse41::get() {
182201
// SAFETY: Checked `aes` and `sse4.1` features
183-
let aes = unsafe {
202+
let aes_sse41 = unsafe {
184203
x86_64::verify_sequential_aes_sse41(
185204
&seed,
186205
&key,
187206
checkpoints,
188207
checkpoint_iterations,
189208
)
190209
};
210+
assert_eq!(sequential, aes_sse41);
211+
}
212+
}
213+
#[cfg(target_arch = "aarch64")]
214+
{
215+
cpufeatures::new!(has_aes, "aes");
216+
if has_aes::get() {
217+
// SAFETY: Checked `aes` feature
218+
let aes = unsafe {
219+
aarch64::verify_sequential_aes(&seed, &key, checkpoints, checkpoint_iterations)
220+
};
191221
assert_eq!(sequential, aes);
192222
}
193223
}
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
use core::arch::aarch64::*;
2+
use core::simd::u8x16;
3+
use core::slice;
4+
use subspace_core_primitives::pot::{PotCheckpoints, PotOutput};
5+
6+
const NUM_ROUND_KEYS: usize = 11;
7+
8+
/// Create PoT proof with checkpoints
9+
#[target_feature(enable = "aes")]
10+
#[inline]
11+
pub(super) fn create(
12+
seed: &[u8; 16],
13+
key: &[u8; 16],
14+
checkpoint_iterations: u32,
15+
) -> PotCheckpoints {
16+
let mut checkpoints = PotCheckpoints::default();
17+
18+
let keys = expand_key(key);
19+
let xor_key = veorq_u8(keys[10], keys[0]);
20+
let mut seed = uint8x16_t::from(u8x16::from(*seed));
21+
seed = veorq_u8(seed, keys[10]);
22+
for checkpoint in checkpoints.iter_mut() {
23+
for _ in 0..checkpoint_iterations {
24+
seed = vaesmcq_u8(vaeseq_u8(seed, xor_key));
25+
seed = vaesmcq_u8(vaeseq_u8(seed, keys[1]));
26+
seed = vaesmcq_u8(vaeseq_u8(seed, keys[2]));
27+
seed = vaesmcq_u8(vaeseq_u8(seed, keys[3]));
28+
seed = vaesmcq_u8(vaeseq_u8(seed, keys[4]));
29+
seed = vaesmcq_u8(vaeseq_u8(seed, keys[5]));
30+
seed = vaesmcq_u8(vaeseq_u8(seed, keys[6]));
31+
seed = vaesmcq_u8(vaeseq_u8(seed, keys[7]));
32+
seed = vaesmcq_u8(vaeseq_u8(seed, keys[8]));
33+
seed = vaeseq_u8(seed, keys[9]);
34+
}
35+
36+
let checkpoint_reg = veorq_u8(seed, keys[10]);
37+
**checkpoint = u8x16::from(checkpoint_reg).to_array();
38+
}
39+
40+
checkpoints
41+
}
42+
43+
/// Verification mimics `create` function, but also has decryption half for better performance
44+
#[target_feature(enable = "aes")]
45+
#[inline]
46+
pub(super) fn verify_sequential_aes(
47+
seed: &[u8; 16],
48+
key: &[u8; 16],
49+
checkpoints: &PotCheckpoints,
50+
checkpoint_iterations: u32,
51+
) -> bool {
52+
let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
53+
54+
let keys = expand_key(key);
55+
let xor_key = veorq_u8(keys[10], keys[0]);
56+
57+
// Invert keys for decryption, the first and last element is not used below, hence they are
58+
// copied as is from encryption keys (otherwise the first and last element would need to be
59+
// swapped)
60+
let mut inv_keys = keys;
61+
for i in 1..10 {
62+
inv_keys[i] = vaesimcq_u8(keys[10 - i]);
63+
}
64+
65+
let mut inputs: [uint8x16_t; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
66+
uint8x16_t::from(u8x16::from(*seed)),
67+
uint8x16_t::from(u8x16::from(checkpoints[0])),
68+
uint8x16_t::from(u8x16::from(checkpoints[1])),
69+
uint8x16_t::from(u8x16::from(checkpoints[2])),
70+
uint8x16_t::from(u8x16::from(checkpoints[3])),
71+
uint8x16_t::from(u8x16::from(checkpoints[4])),
72+
uint8x16_t::from(u8x16::from(checkpoints[5])),
73+
uint8x16_t::from(u8x16::from(checkpoints[6])),
74+
];
75+
76+
let mut outputs: [uint8x16_t; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
77+
uint8x16_t::from(u8x16::from(checkpoints[0])),
78+
uint8x16_t::from(u8x16::from(checkpoints[1])),
79+
uint8x16_t::from(u8x16::from(checkpoints[2])),
80+
uint8x16_t::from(u8x16::from(checkpoints[3])),
81+
uint8x16_t::from(u8x16::from(checkpoints[4])),
82+
uint8x16_t::from(u8x16::from(checkpoints[5])),
83+
uint8x16_t::from(u8x16::from(checkpoints[6])),
84+
uint8x16_t::from(u8x16::from(checkpoints[7])),
85+
];
86+
87+
inputs = inputs.map(|input| veorq_u8(input, keys[10]));
88+
outputs = outputs.map(|output| veorq_u8(output, keys[0]));
89+
90+
for _ in 0..checkpoint_iterations / 2 {
91+
inputs = inputs.map(|input| vaesmcq_u8(vaeseq_u8(input, xor_key)));
92+
outputs = outputs.map(|output| vaesimcq_u8(vaesdq_u8(output, xor_key)));
93+
94+
for i in 1..9 {
95+
inputs = inputs.map(|input| vaesmcq_u8(vaeseq_u8(input, keys[i])));
96+
outputs = outputs.map(|output| vaesimcq_u8(vaesdq_u8(output, inv_keys[i])));
97+
}
98+
99+
inputs = inputs.map(|input| vaeseq_u8(input, keys[9]));
100+
outputs = outputs.map(|output| vaesdq_u8(output, inv_keys[9]));
101+
}
102+
103+
inputs.into_iter().zip(outputs).all(|(input, output)| {
104+
let diff = veorq_u8(input, output);
105+
let cmp = vceqq_u8(diff, xor_key);
106+
vminvq_u8(cmp) == u8::MAX
107+
})
108+
}
109+
110+
// Below code copied with minor changes from the following place under MIT/Apache-2.0 license by
111+
// Artyom Pavlov:
112+
// https://github.com/RustCrypto/block-ciphers/blob/fbb68f40b122909d92e40ee8a50112b6e5d0af8f/aes/src/armv8/expand.rs
113+
114+
/// There are 4 AES words in a block.
115+
const BLOCK_WORDS: usize = 4;
116+
117+
/// The AES (nee Rijndael) notion of a word is always 32-bits, or 4-bytes.
118+
const WORD_SIZE: usize = 4;
119+
120+
/// AES round constants.
121+
const ROUND_CONSTS: [u32; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36];
122+
123+
/// AES key expansion.
124+
#[target_feature(enable = "aes")]
125+
fn expand_key(key: &[u8; 16]) -> [uint8x16_t; NUM_ROUND_KEYS] {
126+
let mut expanded_keys = [uint8x16_t::from(u8x16::default()); NUM_ROUND_KEYS];
127+
128+
// Sanity check, as this is required in order for the subsequent conversion to be sound.
129+
const _: () = assert!(align_of::<uint8x16_t>() >= align_of::<u32>());
130+
let columns = unsafe {
131+
slice::from_raw_parts_mut(
132+
expanded_keys.as_mut_ptr().cast::<u32>(),
133+
NUM_ROUND_KEYS * BLOCK_WORDS,
134+
)
135+
};
136+
137+
for (i, chunk) in key.array_chunks::<WORD_SIZE>().enumerate() {
138+
columns[i] = u32::from_ne_bytes(*chunk);
139+
}
140+
141+
// From "The Rijndael Block Cipher" Section 4.1:
142+
// > The number of columns of the Cipher Key is denoted by `Nk` and is
143+
// > equal to the key length divided by 32 [bits].
144+
let nk = 16 / WORD_SIZE;
145+
146+
for i in nk..NUM_ROUND_KEYS * BLOCK_WORDS {
147+
let mut word = columns[i - 1];
148+
149+
if i % nk == 0 {
150+
word = sub_word(word).rotate_right(8) ^ ROUND_CONSTS[i / nk - 1];
151+
} else if nk > 6 && i % nk == 4 {
152+
word = sub_word(word);
153+
}
154+
155+
columns[i] = columns[i - nk] ^ word;
156+
}
157+
158+
expanded_keys
159+
}
160+
161+
/// Sub bytes for a single AES word: used for key expansion
162+
#[target_feature(enable = "aes")]
163+
fn sub_word(input: u32) -> u32 {
164+
let input = vreinterpretq_u8_u32(vdupq_n_u32(input));
165+
166+
// AES single round encryption (with a "round" key of all zeros)
167+
let sub_input = vaeseq_u8(input, vdupq_n_u8(0));
168+
169+
vgetq_lane_u32::<0>(vreinterpretq_u32_u8(sub_input))
170+
}

crates/subspace-proof-of-time/src/aes/x86_64.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,25 @@ pub(super) fn create(
1515
) -> PotCheckpoints {
1616
let mut checkpoints = PotCheckpoints::default();
1717

18-
let keys_reg = expand_key(key);
19-
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
20-
let mut seed_reg = __m128i::from(u8x16::from_array(*seed));
21-
seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
18+
let keys = expand_key(key);
19+
let xor_key = _mm_xor_si128(keys[10], keys[0]);
20+
let mut seed = __m128i::from(u8x16::from_array(*seed));
21+
seed = _mm_xor_si128(seed, keys[0]);
2222
for checkpoint in checkpoints.iter_mut() {
2323
for _ in 0..checkpoint_iterations {
24-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]);
25-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]);
26-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]);
27-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]);
28-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]);
29-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]);
30-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]);
31-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]);
32-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]);
33-
seed_reg = _mm_aesenclast_si128(seed_reg, xor_key);
24+
seed = _mm_aesenc_si128(seed, keys[1]);
25+
seed = _mm_aesenc_si128(seed, keys[2]);
26+
seed = _mm_aesenc_si128(seed, keys[3]);
27+
seed = _mm_aesenc_si128(seed, keys[4]);
28+
seed = _mm_aesenc_si128(seed, keys[5]);
29+
seed = _mm_aesenc_si128(seed, keys[6]);
30+
seed = _mm_aesenc_si128(seed, keys[7]);
31+
seed = _mm_aesenc_si128(seed, keys[8]);
32+
seed = _mm_aesenc_si128(seed, keys[9]);
33+
seed = _mm_aesenclast_si128(seed, xor_key);
3434
}
3535

36-
let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
36+
let checkpoint_reg = _mm_xor_si128(seed, keys[0]);
3737
**checkpoint = u8x16::from(checkpoint_reg).to_array();
3838
}
3939

@@ -62,7 +62,7 @@ pub(super) fn verify_sequential_aes_sse41(
6262
inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
6363
}
6464

65-
let mut inputs: [__m128i; 8] = [
65+
let mut inputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
6666
__m128i::from(u8x16::from(*seed)),
6767
__m128i::from(u8x16::from(checkpoints[0])),
6868
__m128i::from(u8x16::from(checkpoints[1])),
@@ -73,7 +73,7 @@ pub(super) fn verify_sequential_aes_sse41(
7373
__m128i::from(u8x16::from(checkpoints[6])),
7474
];
7575

76-
let mut outputs: [__m128i; 8] = [
76+
let mut outputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [
7777
__m128i::from(u8x16::from(checkpoints[0])),
7878
__m128i::from(u8x16::from(checkpoints[1])),
7979
__m128i::from(u8x16::from(checkpoints[2])),

crates/subspace-proof-of-time/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Proof of time implementation.
22
3+
#![cfg_attr(target_arch = "aarch64", feature(array_chunks))]
34
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
45
#![feature(portable_simd)]
56
#![no_std]

0 commit comments

Comments
 (0)