1
- use vortex:: array:: Array ;
2
- use vortex:: compute:: flatten:: { flatten, FlattenFn , FlattenedArray } ;
1
+ use vortex:: array:: primitive:: PrimitiveArray ;
2
+ use vortex:: array:: { Array , ArrayRef } ;
3
+ use vortex:: compute:: flatten:: { flatten, flatten_primitive, FlattenFn , FlattenedArray } ;
3
4
use vortex:: compute:: scalar_at:: { scalar_at, ScalarAtFn } ;
5
+ use vortex:: compute:: take:: { take, TakeFn } ;
4
6
use vortex:: compute:: ArrayCompute ;
7
+ use vortex:: match_each_integer_ptype;
5
8
use vortex:: scalar:: Scalar ;
6
9
use vortex:: validity:: ArrayValidity ;
7
10
use vortex_error:: { VortexError , VortexResult } ;
@@ -17,6 +20,10 @@ impl ArrayCompute for REEArray {
17
20
fn scalar_at ( & self ) -> Option < & dyn ScalarAtFn > {
18
21
Some ( self )
19
22
}
23
+
24
+ fn take ( & self ) -> Option < & dyn TakeFn > {
25
+ Some ( self )
26
+ }
20
27
}
21
28
22
29
impl FlattenFn for REEArray {
@@ -45,3 +52,39 @@ impl ScalarAtFn for REEArray {
45
52
scalar_at ( self . values ( ) , self . find_physical_index ( index) ?)
46
53
}
47
54
}
55
+
56
+ impl TakeFn for REEArray {
57
+ fn take ( & self , indices : & dyn Array ) -> VortexResult < ArrayRef > {
58
+ let primitive_indices = flatten_primitive ( indices) ?;
59
+ let physical_indices = match_each_integer_ptype ! ( primitive_indices. ptype( ) , |$P | {
60
+ primitive_indices
61
+ . typed_data:: <$P >( )
62
+ . iter( )
63
+ . map( |idx| {
64
+ self . find_physical_index( * idx as usize )
65
+ . map( |loc| loc as u64 )
66
+ } )
67
+ . collect:: <VortexResult <Vec <_>>>( ) ?
68
+ } ) ;
69
+ take ( self . values ( ) , & PrimitiveArray :: from ( physical_indices) )
70
+ }
71
+ }
72
+
73
+ #[ cfg( test) ]
74
+ mod test {
75
+ use vortex:: array:: downcast:: DowncastArrayBuiltin ;
76
+ use vortex:: array:: primitive:: PrimitiveArray ;
77
+ use vortex:: compute:: take:: take;
78
+
79
+ use crate :: REEArray ;
80
+
81
+ #[ test]
82
+ fn ree_take ( ) {
83
+ let ree = REEArray :: encode ( & PrimitiveArray :: from ( vec ! [
84
+ 1 , 1 , 1 , 4 , 4 , 4 , 2 , 2 , 5 , 5 , 5 , 5 ,
85
+ ] ) )
86
+ . unwrap ( ) ;
87
+ let taken = take ( & ree, & PrimitiveArray :: from ( vec ! [ 9 , 8 , 1 , 3 ] ) ) . unwrap ( ) ;
88
+ assert_eq ! ( taken. as_primitive( ) . typed_data:: <i32 >( ) , & [ 5 , 5 , 1 , 4 ] ) ;
89
+ }
90
+ }
0 commit comments