@@ -8,7 +8,7 @@ unconstrained fn __sort_field_as_u32(lhs: Field, rhs: Field) -> bool {
8
8
9
9
fn assert_sorted (lhs : Field , rhs : Field ) {
10
10
let result = (rhs - lhs - 1 );
11
- result .assert_max_bit_size ( 32 );
11
+ result .assert_max_bit_size ::< 32 >( );
12
12
}
13
13
14
14
/**
@@ -22,21 +22,18 @@ fn assert_sorted(lhs: Field, rhs: Field) {
22
22
* 1. keys[i] maps to values[i+1]
23
23
* 2. values[0] is an empty object. when calling `get(idx)`, if `idx` is not in `keys` we will return `values[0]`
24
24
**/
25
- struct MutSparseArrayBase <let N : u32 , T , ComparisonFuncs >
26
- {
25
+ struct MutSparseArrayBase <let N : u32 , T , ComparisonFuncs > {
27
26
values : [T ; N + 3 ],
28
27
keys : [Field ; N + 2 ],
29
28
linked_keys : [Field ; N + 2 ],
30
29
tail_ptr : Field ,
31
- maximum : Field
30
+ maximum : Field ,
32
31
}
33
32
34
- struct U32RangeTraits {
35
- }
33
+ struct U32RangeTraits {}
36
34
37
- struct MutSparseArray <let N : u32 , T >
38
- {
39
- inner : MutSparseArrayBase <N , T , U32RangeTraits >
35
+ pub struct MutSparseArray <let N : u32 , T > {
36
+ inner : MutSparseArrayBase <N , T , U32RangeTraits >,
40
37
}
41
38
/**
42
39
* @brief SparseArray, stores a sparse array of up to size 2^32 with `N` nonzero entries
@@ -49,19 +46,23 @@ struct MutSparseArray<let N: u32, T>
49
46
* 1. keys[i] maps to values[i+1]
50
47
* 2. values[0] is an empty object. when calling `get(idx)`, if `idx` is not in `keys` we will return `values[0]`
51
48
**/
52
- struct SparseArray <let N : u32 , T > {
49
+ pub struct SparseArray <let N : u32 , T > {
53
50
keys : [Field ; N + 2 ],
54
51
values : [T ; N + 3 ],
55
- maximum : Field // can be up to 2^32
52
+ maximum : Field , // can be up to 2^32
56
53
}
57
- impl <let N : u32 , T > SparseArray <N , T > where T : std::default::Default {
54
+ impl <let N : u32 , T > SparseArray <N , T >
55
+ where
56
+ T : std::default::Default ,
57
+ {
58
58
59
59
/**
60
60
* @brief construct a SparseArray
61
61
**/
62
62
fn create (_keys : [Field ; N ], _values : [T ; N ], size : Field ) -> Self {
63
63
let _maximum = size - 1 ;
64
- let mut r : Self = SparseArray { keys : [0 ; N + 2 ], values : [T ::default (); N + 3 ], maximum : _maximum };
64
+ let mut r : Self =
65
+ SparseArray { keys : [0 ; N + 2 ], values : [T ::default (); N + 3 ], maximum : _maximum };
65
66
66
67
// for any valid index, we want to ensure the following is satified:
67
68
// self.keys[X] <= index <= self.keys[X+1]
@@ -71,16 +72,16 @@ impl<let N: u32, T> SparseArray<N, T> where T : std::default::Default {
71
72
// insert start and endpoints
72
73
r .keys [0 ] = 0 ;
73
74
for i in 0 ..N {
74
- r .keys [i + 1 ] = sorted_keys .sorted [i ];
75
+ r .keys [i + 1 ] = sorted_keys .sorted [i ];
75
76
}
76
- r .keys [N + 1 ] = _maximum ;
77
+ r .keys [N + 1 ] = _maximum ;
77
78
78
79
// populate values based on the sorted keys
79
80
// note: self.keys[i] maps to self.values[i+1]
80
81
// self.values[0] does not map to any key. we use it to store the default empty value,
81
82
// which is returned when `get(idx)` is called and `idx` does not exist in `self.keys`
82
83
for i in 0 ..N {
83
- r .values [i + 2 ] = _values [sorted_keys .sort_indices [i ]];
84
+ r .values [i + 2 ] = _values [sorted_keys .sort_indices [i ]];
84
85
}
85
86
// insert values that map to our key start and endpoints
86
87
// if _keys[0] = 0 then values[0] must equal _values[0], so some conditional logic is required
@@ -91,20 +92,20 @@ impl<let N: u32, T> SparseArray<N, T> where T : std::default::Default {
91
92
}
92
93
let mut final_value = T ::default ();
93
94
if (_keys [N - 1 ] == _maximum ) {
94
- final_value = _values [N - 1 ];
95
+ final_value = _values [N - 1 ];
95
96
}
96
97
r .values [1 ] = initial_value ;
97
- r .values [N + 2 ] = final_value ;
98
+ r .values [N + 2 ] = final_value ;
98
99
99
100
// perform boundary checks!
100
101
// the maximum size of the sparse array is 2^32
101
102
// we need to check that every element in `self.keys` is less than 2^32
102
103
// because `self.keys` is sorted, we can simply validate that
103
104
// sorted_keys.sorted[0] < 2^32
104
105
// sorted_keys.sorted[N-1] < maximum
105
- sorted_keys .sorted [0 ].assert_max_bit_size ( 32 );
106
- _maximum .assert_max_bit_size ( 32 );
107
- (_maximum - sorted_keys .sorted [N - 1 ]).assert_max_bit_size ( 32 );
106
+ sorted_keys .sorted [0 ].assert_max_bit_size ::< 32 >( );
107
+ _maximum .assert_max_bit_size ::< 32 >( );
108
+ (_maximum - sorted_keys .sorted [N - 1 ]).assert_max_bit_size ::< 32 >( );
108
109
r
109
110
}
110
111
@@ -138,9 +139,7 @@ impl<let N: u32, T> SparseArray<N, T> where T : std::default::Default {
138
139
* @details cost is 14.5 gates per lookup
139
140
**/
140
141
fn get (self , idx : Field ) -> T {
141
- let (found , found_index ) = unsafe {
142
- self .search_for_key (idx )
143
- };
142
+ let (found , found_index ) = unsafe { self .search_for_key (idx ) };
144
143
// bool check. 0.25 gates cheaper than a raw `bool` type. need to fix at some point
145
144
assert (found * found == found );
146
145
@@ -156,8 +155,8 @@ impl<let N: u32, T> SparseArray<N, T> where T : std::default::Default {
156
155
let rhs = self .keys [found_index + 1 - found ];
157
156
let lhs_condition = idx - lhs - 1 + found ;
158
157
let rhs_condition = rhs - 1 + found - idx ;
159
- lhs_condition .assert_max_bit_size ( 32 );
160
- rhs_condition .assert_max_bit_size ( 32 );
158
+ lhs_condition .assert_max_bit_size ::< 32 >( );
159
+ rhs_condition .assert_max_bit_size ::< 32 >( );
161
160
162
161
// self.keys[i] maps to self.values[i+1]
163
162
// however...if we did not find a non-sparse entry, we want to return self.values[0] (the default value)
@@ -170,7 +169,7 @@ mod test {
170
169
171
170
use crate::SparseArray ;
172
171
#[test]
173
- fn test_sparse_lookup () {
172
+ fn test_sparse_lookup () {
174
173
let example = SparseArray ::create ([1 , 99 , 7 , 5 ], [123 , 101112 , 789 , 456 ], 100 );
175
174
176
175
assert (example .get (1 ) == 123 );
@@ -186,12 +185,12 @@ fn test_sparse_lookup() {
186
185
}
187
186
188
187
#[test]
189
- fn test_sparse_lookup_boundary_cases () {
188
+ fn test_sparse_lookup_boundary_cases () {
190
189
// what about when keys[0] = 0 and keys[N-1] = 2^32 - 1?
191
190
let example = SparseArray ::create (
192
191
[0 , 99999 , 7 , 0xffffffff ],
193
192
[123 , 101112 , 789 , 456 ],
194
- 0x100000000
193
+ 0x100000000 ,
195
194
);
196
195
197
196
assert (example .get (0 ) == 123 );
@@ -202,30 +201,32 @@ fn test_sparse_lookup_boundary_cases() {
202
201
}
203
202
204
203
#[test(should_fail_with = "call to assert_max_bit_size")]
205
- fn test_sparse_lookup_overflow () {
204
+ fn test_sparse_lookup_overflow () {
206
205
let example = SparseArray ::create ([1 , 5 , 7 , 99999 ], [123 , 456 , 789 , 101112 ], 100000 );
207
206
208
207
assert (example .get (100000 ) == 0 );
209
208
}
210
209
211
210
#[test(should_fail_with = "call to assert_max_bit_size")]
212
- fn test_sparse_lookup_boundary_case_overflow () {
213
- let example = SparseArray ::create ([0 , 5 , 7 , 0xffffffff ], [123 , 456 , 789 , 101112 ], 0x100000000 );
211
+ fn test_sparse_lookup_boundary_case_overflow () {
212
+ let example =
213
+ SparseArray ::create ([0 , 5 , 7 , 0xffffffff ], [123 , 456 , 789 , 101112 ], 0x100000000 );
214
214
215
215
assert (example .get (0x100000000 ) == 0 );
216
216
}
217
217
218
218
#[test(should_fail_with = "call to assert_max_bit_size")]
219
- fn test_sparse_lookup_key_exceeds_maximum () {
220
- let example = SparseArray ::create ([0 , 5 , 7 , 0xffffffff ], [123 , 456 , 789 , 101112 ], 0xffffffff );
219
+ fn test_sparse_lookup_key_exceeds_maximum () {
220
+ let example =
221
+ SparseArray ::create ([0 , 5 , 7 , 0xffffffff ], [123 , 456 , 789 , 101112 ], 0xffffffff );
221
222
assert (example .maximum == 0xffffffff );
222
223
}
223
224
#[test]
224
- fn test_sparse_lookup_u32 () {
225
+ fn test_sparse_lookup_u32 () {
225
226
let example = SparseArray ::create (
226
227
[1 , 99 , 7 , 5 ],
227
228
[123 as u32 , 101112 as u32 , 789 as u32 , 456 as u32 ],
228
- 100
229
+ 100 ,
229
230
);
230
231
231
232
assert (example .get (1 ) == 123 );
@@ -241,8 +242,8 @@ fn test_sparse_lookup_u32() {
241
242
}
242
243
243
244
struct F {
244
- foo: [Field; 3]
245
- }
245
+ foo: [Field; 3],
246
+ }
246
247
impl std::cmp::Eq for F {
247
248
fn eq (self , other : Self ) -> bool {
248
249
self .foo == other .foo
@@ -256,8 +257,13 @@ fn test_sparse_lookup_u32() {
256
257
}
257
258
258
259
#[test]
259
- fn test_sparse_lookup_struct () {
260
- let values = [F { foo : [1 , 2 , 3 ] }, F { foo : [4 , 5 , 6 ] }, F { foo : [7 , 8 , 9 ] }, F { foo : [10 , 11 , 12 ] }];
260
+ fn test_sparse_lookup_struct () {
261
+ let values = [
262
+ F { foo : [1 , 2 , 3 ] },
263
+ F { foo : [4 , 5 , 6 ] },
264
+ F { foo : [7 , 8 , 9 ] },
265
+ F { foo : [10 , 11 , 12 ] },
266
+ ];
261
267
let example = SparseArray ::create ([1 , 99 , 7 , 5 ], values , 100000 );
262
268
263
269
assert (example .get (1 ) == values [0 ]);
0 commit comments