Skip to content

Commit f3ce3ac

Browse files
authored
Add Take for REEArray (#162)
1 parent c56ba98 commit f3ce3ac

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

vortex-ree/src/compute.rs

+45-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
use vortex::array::Array;
2-
use vortex::compute::flatten::{flatten, FlattenFn, FlattenedArray};
1+
use vortex::array::primitive::PrimitiveArray;
2+
use vortex::array::{Array, ArrayRef};
3+
use vortex::compute::flatten::{flatten, flatten_primitive, FlattenFn, FlattenedArray};
34
use vortex::compute::scalar_at::{scalar_at, ScalarAtFn};
5+
use vortex::compute::take::{take, TakeFn};
46
use vortex::compute::ArrayCompute;
7+
use vortex::match_each_integer_ptype;
58
use vortex::scalar::Scalar;
69
use vortex::validity::ArrayValidity;
710
use vortex_error::{VortexError, VortexResult};
@@ -17,6 +20,10 @@ impl ArrayCompute for REEArray {
1720
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
1821
Some(self)
1922
}
23+
24+
fn take(&self) -> Option<&dyn TakeFn> {
25+
Some(self)
26+
}
2027
}
2128

2229
impl FlattenFn for REEArray {
@@ -45,3 +52,39 @@ impl ScalarAtFn for REEArray {
4552
scalar_at(self.values(), self.find_physical_index(index)?)
4653
}
4754
}
55+
56+
impl TakeFn for REEArray {
57+
fn take(&self, indices: &dyn Array) -> VortexResult<ArrayRef> {
58+
let primitive_indices = flatten_primitive(indices)?;
59+
let physical_indices = match_each_integer_ptype!(primitive_indices.ptype(), |$P| {
60+
primitive_indices
61+
.typed_data::<$P>()
62+
.iter()
63+
.map(|idx| {
64+
self.find_physical_index(*idx as usize)
65+
.map(|loc| loc as u64)
66+
})
67+
.collect::<VortexResult<Vec<_>>>()?
68+
});
69+
take(self.values(), &PrimitiveArray::from(physical_indices))
70+
}
71+
}
72+
73+
#[cfg(test)]
74+
mod test {
75+
use vortex::array::downcast::DowncastArrayBuiltin;
76+
use vortex::array::primitive::PrimitiveArray;
77+
use vortex::compute::take::take;
78+
79+
use crate::REEArray;
80+
81+
#[test]
82+
fn ree_take() {
83+
let ree = REEArray::encode(&PrimitiveArray::from(vec![
84+
1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5,
85+
]))
86+
.unwrap();
87+
let taken = take(&ree, &PrimitiveArray::from(vec![9, 8, 1, 3])).unwrap();
88+
assert_eq!(taken.as_primitive().typed_data::<i32>(), &[5, 5, 1, 4]);
89+
}
90+
}

0 commit comments

Comments
 (0)