1
+ use subspace_core_primitives:: pot:: { PotCheckpoints , PotOutput } ;
1
2
use core:: arch:: x86_64:: * ;
2
- use core:: mem;
3
- use subspace_core_primitives:: pot:: PotCheckpoints ;
3
+ use core:: { array, mem} ;
4
4
5
5
/// Create PoT proof with checkpoints
6
6
#[ target_feature( enable = "aes" ) ]
@@ -12,40 +12,117 @@ pub(super) unsafe fn create(
12
12
) -> PotCheckpoints {
13
13
let mut checkpoints = PotCheckpoints :: default ( ) ;
14
14
15
- let keys_reg = expand_key ( key) ;
16
- let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
17
- let mut seed_reg = _mm_loadu_si128 ( seed. as_ptr ( ) as * const __m128i ) ;
18
- seed_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
19
- for checkpoint in checkpoints. iter_mut ( ) {
20
- for _ in 0 ..checkpoint_iterations {
21
- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 1 ] ) ;
22
- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 2 ] ) ;
23
- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 3 ] ) ;
24
- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 4 ] ) ;
25
- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 5 ] ) ;
26
- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 6 ] ) ;
27
- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 7 ] ) ;
28
- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 8 ] ) ;
29
- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 9 ] ) ;
30
- seed_reg = _mm_aesenclast_si128 ( seed_reg, xor_key) ;
31
- }
15
+ unsafe {
16
+ let keys_reg = expand_key ( key) ;
17
+ let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
18
+ let mut seed_reg = _mm_loadu_si128 ( seed. as_ptr ( ) as * const __m128i ) ;
19
+ seed_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
20
+ for checkpoint in checkpoints. iter_mut ( ) {
21
+ for _ in 0 ..checkpoint_iterations {
22
+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 1 ] ) ;
23
+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 2 ] ) ;
24
+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 3 ] ) ;
25
+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 4 ] ) ;
26
+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 5 ] ) ;
27
+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 6 ] ) ;
28
+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 7 ] ) ;
29
+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 8 ] ) ;
30
+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 9 ] ) ;
31
+ seed_reg = _mm_aesenclast_si128 ( seed_reg, xor_key) ;
32
+ }
32
33
33
- let checkpoint_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
34
- _mm_storeu_si128 (
35
- checkpoint. as_mut ( ) . as_mut_ptr ( ) as * mut __m128i ,
36
- checkpoint_reg,
37
- ) ;
34
+ let checkpoint_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
35
+ _mm_storeu_si128 ( checkpoint. as_mut_ptr ( ) as * mut __m128i , checkpoint_reg) ;
36
+ }
38
37
}
39
38
40
39
checkpoints
41
40
}
42
41
42
+ /// Verification mimics `create` function, but also has decryption half for better performance
43
+ #[ target_feature( enable = "avx512f,vaes" ) ]
44
+ #[ inline]
45
+ pub ( super ) unsafe fn verify_sequential_avx512f (
46
+ seed : & [ u8 ; 16 ] ,
47
+ key : & [ u8 ; 16 ] ,
48
+ checkpoints : & PotCheckpoints ,
49
+ checkpoint_iterations : u32 ,
50
+ ) -> bool {
51
+ let checkpoints = PotOutput :: repr_from_slice ( checkpoints. as_slice ( ) ) ;
52
+
53
+ unsafe {
54
+ let keys_reg = expand_key ( key) ;
55
+ let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
56
+ let xor_key_512 = _mm512_broadcast_i32x4 ( xor_key) ;
57
+
58
+ // Invert keys for decryption
59
+ let mut inv_keys = keys_reg;
60
+ for i in 1 ..10 {
61
+ inv_keys[ i] = _mm_aesimc_si128 ( keys_reg[ 10 - i] ) ;
62
+ }
63
+
64
+ let keys_512 = array:: from_fn :: < _ , NUM_ROUNDS , _ > ( |i| _mm512_broadcast_i32x4 ( keys_reg[ i] ) ) ;
65
+ let inv_keys_512 =
66
+ array:: from_fn :: < _ , NUM_ROUNDS , _ > ( |i| _mm512_broadcast_i32x4 ( inv_keys[ i] ) ) ;
67
+
68
+ let mut input_0 = [ [ 0u8 ; 16 ] ; 4 ] ;
69
+ input_0[ 0 ] = * seed;
70
+ input_0[ 1 ..] . copy_from_slice ( & checkpoints[ ..3 ] ) ;
71
+ let mut input_0 = _mm512_loadu_si512 ( input_0. as_ptr ( ) as * const __m512i ) ;
72
+ let mut input_1 = _mm512_loadu_si512 ( checkpoints[ 3 ..7 ] . as_ptr ( ) as * const __m512i ) ;
73
+
74
+ let mut output_0 = _mm512_loadu_si512 ( checkpoints[ 0 ..4 ] . as_ptr ( ) as * const __m512i ) ;
75
+ let mut output_1 = _mm512_loadu_si512 ( checkpoints[ 4 ..8 ] . as_ptr ( ) as * const __m512i ) ;
76
+
77
+ input_0 = _mm512_xor_si512 ( input_0, keys_512[ 0 ] ) ;
78
+ input_1 = _mm512_xor_si512 ( input_1, keys_512[ 0 ] ) ;
79
+
80
+ output_0 = _mm512_xor_si512 ( output_0, keys_512[ 10 ] ) ;
81
+ output_1 = _mm512_xor_si512 ( output_1, keys_512[ 10 ] ) ;
82
+
83
+ for _ in 0 ..checkpoint_iterations / 2 {
84
+ for i in 1 ..10 {
85
+ input_0 = _mm512_aesenc_epi128 ( input_0, keys_512[ i] ) ;
86
+ input_1 = _mm512_aesenc_epi128 ( input_1, keys_512[ i] ) ;
87
+
88
+ output_0 = _mm512_aesdec_epi128 ( output_0, inv_keys_512[ i] ) ;
89
+ output_1 = _mm512_aesdec_epi128 ( output_1, inv_keys_512[ i] ) ;
90
+ }
91
+
92
+ input_0 = _mm512_aesenclast_epi128 ( input_0, xor_key_512) ;
93
+ input_1 = _mm512_aesenclast_epi128 ( input_1, xor_key_512) ;
94
+
95
+ output_0 = _mm512_aesdeclast_epi128 ( output_0, xor_key_512) ;
96
+ output_1 = _mm512_aesdeclast_epi128 ( output_1, xor_key_512) ;
97
+ }
98
+
99
+ // Code below is a more efficient version of this:
100
+ // input_0 = _mm512_xor_si512(input_0, keys_512[0]);
101
+ // input_1 = _mm512_xor_si512(input_1, keys_512[0]);
102
+ // output_0 = _mm512_xor_si512(output_0, keys_512[10]);
103
+ // output_1 = _mm512_xor_si512(output_1, keys_512[10]);
104
+ //
105
+ // let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
106
+ // let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
107
+
108
+ let diff_0 = _mm512_xor_si512 ( input_0, output_0) ;
109
+ let diff_1 = _mm512_xor_si512 ( input_1, output_1) ;
110
+
111
+ let mask0 = _mm512_cmpeq_epu64_mask ( diff_0, xor_key_512) ;
112
+ let mask1 = _mm512_cmpeq_epu64_mask ( diff_1, xor_key_512) ;
113
+
114
+ // All inputs match outputs
115
+ ( mask0 & mask1) == u8:: MAX
116
+ }
117
+ }
118
+
43
119
// Below code copied with minor changes from following place under MIT/Apache-2.0 license by Artyom
44
120
// Pavlov:
45
121
// https://github.com/RustCrypto/block-ciphers/blob/9413fcadd28d53854954498c0589b747d8e4ade2/aes/src/ni/aes128.rs
46
122
123
+ const NUM_ROUNDS : usize = 11 ;
47
124
/// AES-128 round keys
48
- type RoundKeys = [ __m128i ; 11 ] ;
125
+ type RoundKeys = [ __m128i ; NUM_ROUNDS ] ;
49
126
50
127
macro_rules! expand_round {
51
128
( $keys: expr, $pos: expr, $round: expr) => {
@@ -72,9 +149,10 @@ macro_rules! expand_round {
72
149
unsafe fn expand_key ( key : & [ u8 ; 16 ] ) -> RoundKeys {
73
150
// SAFETY: `RoundKeys` is a `[__m128i; 11]` which can be initialized
74
151
// with all zeroes.
75
- let mut keys: RoundKeys = mem:: zeroed ( ) ;
152
+ let mut keys: RoundKeys = unsafe { mem:: zeroed ( ) } ;
76
153
77
- let k = _mm_loadu_si128 ( key. as_ptr ( ) as * const __m128i ) ;
154
+ // SAFETY: No alignment requirement in `_mm_loadu_si128`
155
+ let k = unsafe { _mm_loadu_si128 ( key. as_ptr ( ) as * const __m128i ) } ;
78
156
keys[ 0 ] = k;
79
157
80
158
expand_round ! ( keys, 1 , 0x01 ) ;
0 commit comments