Skip to content

Commit 4b59675

Browse files
authored
Use wrapping arithmetic for Frame of Reference (#178)
fixes #54
1 parent 4cb4fc1 commit 4b59675

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

vortex-array/src/scalar/primitive.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ macro_rules! pscalar {
211211
..
212212
}) => match pscalar {
213213
PScalar::$ptype(v) => Ok(*v),
214-
_ => Err(vortex_err!(MismatchedTypes: "$T", pscalar.ptype())),
214+
_ => Err(vortex_err!(MismatchedTypes: any::type_name::<Self>(), pscalar.ptype())),
215215
},
216-
_ => Err(vortex_err!("can't extract $T from scalar: {}", value)),
216+
_ => Err(vortex_err!("can't extract {} from scalar: {}", any::type_name::<Self>(), value)),
217217
}
218218
}
219219
}

vortex-fastlanes/src/for/compress.rs

+41-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use itertools::Itertools;
2-
use num_traits::PrimInt;
2+
use num_traits::{PrimInt, WrappingAdd, WrappingSub};
33

44
use vortex::array::constant::ConstantArray;
55
use vortex::array::downcast::DowncastArrayBuiltin;
@@ -70,7 +70,7 @@ impl EncodingCompression for FoREncoding {
7070
}
7171
}
7272

73-
fn compress_primitive<T: NativePType + PrimInt>(
73+
fn compress_primitive<T: NativePType + WrappingSub + PrimInt>(
7474
parray: &PrimitiveArray,
7575
shift: u8,
7676
) -> PrimitiveArray {
@@ -86,13 +86,13 @@ fn compress_primitive<T: NativePType + PrimInt>(
8686
.typed_data::<T>()
8787
.iter()
8888
.map(|&v| v >> shift as usize)
89-
.map(|v| v - shifted_min)
89+
.map(|v| v.wrapping_sub(&shifted_min))
9090
.collect_vec()
9191
} else {
9292
parray
9393
.typed_data::<T>()
9494
.iter()
95-
.map(|&v| v - min)
95+
.map(|&v| v.wrapping_sub(&min))
9696
.collect_vec()
9797
};
9898

@@ -112,16 +112,23 @@ pub fn decompress(array: &FoRArray) -> VortexResult<PrimitiveArray> {
112112
}))
113113
}
114114

115-
fn decompress_primitive<T: NativePType + PrimInt>(values: &[T], reference: T, shift: u8) -> Vec<T> {
115+
fn decompress_primitive<T: NativePType + WrappingAdd + PrimInt>(
116+
values: &[T],
117+
reference: T,
118+
shift: u8,
119+
) -> Vec<T> {
116120
if shift > 0 {
117121
let shifted_reference = reference << shift as usize;
118122
values
119123
.iter()
120124
.map(|&v| v << shift as usize)
121-
.map(|v| v + shifted_reference)
125+
.map(|v| v.wrapping_add(&shifted_reference))
122126
.collect_vec()
123127
} else {
124-
values.iter().map(|&v| v + reference).collect_vec()
128+
values
129+
.iter()
130+
.map(|&v| v.wrapping_add(&reference))
131+
.collect_vec()
125132
}
126133
}
127134

@@ -144,6 +151,7 @@ mod test {
144151
use std::sync::Arc;
145152

146153
use vortex::array::{Encoding, EncodingRef};
154+
use vortex::compute::scalar_at::ScalarAtFn;
147155

148156
use crate::BitPackedEncoding;
149157

@@ -183,4 +191,30 @@ mod test {
183191
let decompressed = flatten_primitive(compressed.as_ref()).unwrap();
184192
assert_eq!(decompressed.typed_data::<u32>(), array.typed_data::<u32>());
185193
}
194+
195+
#[test]
196+
fn test_overflow() {
197+
let ctx = compress_ctx();
198+
199+
// Create a range offset by a million
200+
let array = PrimitiveArray::from((i8::MIN..i8::MAX).collect_vec());
201+
let compressed = FoREncoding {}.compress(&array, None, ctx).unwrap();
202+
let compressed = compressed.as_for();
203+
assert_eq!(i8::MIN, compressed.reference().try_into().unwrap());
204+
205+
let encoded = flatten_primitive(compressed.encoded()).unwrap();
206+
let bitcast: &[u8] = unsafe { std::mem::transmute(encoded.typed_data::<i8>()) };
207+
let unsigned: Vec<u8> = (0..u8::MAX).collect_vec();
208+
assert_eq!(bitcast, unsigned.as_slice());
209+
210+
let decompressed = flatten_primitive(compressed).unwrap();
211+
assert_eq!(decompressed.typed_data::<i8>(), array.typed_data::<i8>());
212+
array
213+
.typed_data::<i8>()
214+
.iter()
215+
.enumerate()
216+
.for_each(|(i, v)| {
217+
assert_eq!(*v, compressed.scalar_at(i).unwrap().try_into().unwrap());
218+
});
219+
}
186220
}

vortex-fastlanes/src/for/compute.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ impl ScalarAtFn for FoRArray {
4848
(Scalar::Primitive(p), Scalar::Primitive(r)) => match p.value() {
4949
None => Ok(encoded_scalar),
5050
Some(pv) => match_each_integer_ptype!(pv.ptype(), |$P| {
51+
use num_traits::WrappingAdd;
5152
Ok(PrimitiveScalar::try_new::<$P>(
52-
Some((p.typed_value::<$P>().unwrap() << self.shift()) + r.typed_value::<$P>().unwrap()),
53+
Some((p.typed_value::<$P>().unwrap() << self.shift()).wrapping_add(r.typed_value::<$P>().unwrap())),
5354
p.dtype().nullability()
5455
).unwrap().into())
5556
}),

0 commit comments

Comments
 (0)