1
1
use itertools:: Itertools ;
2
- use num_traits:: PrimInt ;
2
+ use num_traits:: { PrimInt , WrappingAdd , WrappingSub } ;
3
3
4
4
use vortex:: array:: constant:: ConstantArray ;
5
5
use vortex:: array:: downcast:: DowncastArrayBuiltin ;
@@ -70,7 +70,7 @@ impl EncodingCompression for FoREncoding {
70
70
}
71
71
}
72
72
73
- fn compress_primitive < T : NativePType + PrimInt > (
73
+ fn compress_primitive < T : NativePType + WrappingSub + PrimInt > (
74
74
parray : & PrimitiveArray ,
75
75
shift : u8 ,
76
76
) -> PrimitiveArray {
@@ -86,13 +86,13 @@ fn compress_primitive<T: NativePType + PrimInt>(
86
86
. typed_data :: < T > ( )
87
87
. iter ( )
88
88
. map ( |& v| v >> shift as usize )
89
- . map ( |v| v - shifted_min)
89
+ . map ( |v| v. wrapping_sub ( & shifted_min) )
90
90
. collect_vec ( )
91
91
} else {
92
92
parray
93
93
. typed_data :: < T > ( )
94
94
. iter ( )
95
- . map ( |& v| v - min)
95
+ . map ( |& v| v. wrapping_sub ( & min) )
96
96
. collect_vec ( )
97
97
} ;
98
98
@@ -112,16 +112,23 @@ pub fn decompress(array: &FoRArray) -> VortexResult<PrimitiveArray> {
112
112
} ) )
113
113
}
114
114
115
- fn decompress_primitive < T : NativePType + PrimInt > ( values : & [ T ] , reference : T , shift : u8 ) -> Vec < T > {
115
+ fn decompress_primitive < T : NativePType + WrappingAdd + PrimInt > (
116
+ values : & [ T ] ,
117
+ reference : T ,
118
+ shift : u8 ,
119
+ ) -> Vec < T > {
116
120
if shift > 0 {
117
121
let shifted_reference = reference << shift as usize ;
118
122
values
119
123
. iter ( )
120
124
. map ( |& v| v << shift as usize )
121
- . map ( |v| v + shifted_reference)
125
+ . map ( |v| v. wrapping_add ( & shifted_reference) )
122
126
. collect_vec ( )
123
127
} else {
124
- values. iter ( ) . map ( |& v| v + reference) . collect_vec ( )
128
+ values
129
+ . iter ( )
130
+ . map ( |& v| v. wrapping_add ( & reference) )
131
+ . collect_vec ( )
125
132
}
126
133
}
127
134
@@ -144,6 +151,7 @@ mod test {
144
151
use std:: sync:: Arc ;
145
152
146
153
use vortex:: array:: { Encoding , EncodingRef } ;
154
+ use vortex:: compute:: scalar_at:: ScalarAtFn ;
147
155
148
156
use crate :: BitPackedEncoding ;
149
157
@@ -183,4 +191,30 @@ mod test {
183
191
let decompressed = flatten_primitive ( compressed. as_ref ( ) ) . unwrap ( ) ;
184
192
assert_eq ! ( decompressed. typed_data:: <u32 >( ) , array. typed_data:: <u32 >( ) ) ;
185
193
}
194
+
195
+ #[ test]
196
+ fn test_overflow ( ) {
197
+ let ctx = compress_ctx ( ) ;
198
+
199
+ // Create a range offset by a million
200
+ let array = PrimitiveArray :: from ( ( i8:: MIN ..i8:: MAX ) . collect_vec ( ) ) ;
201
+ let compressed = FoREncoding { } . compress ( & array, None , ctx) . unwrap ( ) ;
202
+ let compressed = compressed. as_for ( ) ;
203
+ assert_eq ! ( i8 :: MIN , compressed. reference( ) . try_into( ) . unwrap( ) ) ;
204
+
205
+ let encoded = flatten_primitive ( compressed. encoded ( ) ) . unwrap ( ) ;
206
+ let bitcast: & [ u8 ] = unsafe { std:: mem:: transmute ( encoded. typed_data :: < i8 > ( ) ) } ;
207
+ let unsigned: Vec < u8 > = ( 0 ..u8:: MAX ) . collect_vec ( ) ;
208
+ assert_eq ! ( bitcast, unsigned. as_slice( ) ) ;
209
+
210
+ let decompressed = flatten_primitive ( compressed) . unwrap ( ) ;
211
+ assert_eq ! ( decompressed. typed_data:: <i8 >( ) , array. typed_data:: <i8 >( ) ) ;
212
+ array
213
+ . typed_data :: < i8 > ( )
214
+ . iter ( )
215
+ . enumerate ( )
216
+ . for_each ( |( i, v) | {
217
+ assert_eq ! ( * v, compressed. scalar_at( i) . unwrap( ) . try_into( ) . unwrap( ) ) ;
218
+ } ) ;
219
+ }
186
220
}
0 commit comments