46
46
47
47
#include " integer_advanced_indexing.hpp"
48
48
49
- #define INDEXING_MODES 2
50
- #define WRAP_MODE 0
51
- #define CLIP_MODE 1
52
-
53
49
namespace dpctl
54
50
{
55
51
namespace tensor
@@ -62,11 +58,15 @@ namespace td_ns = dpctl::tensor::type_dispatch;
62
58
using dpctl::tensor::kernels::indexing::put_fn_ptr_t ;
63
59
using dpctl::tensor::kernels::indexing::take_fn_ptr_t ;
64
60
65
- static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][td_ns::num_types]
66
- [td_ns::num_types];
61
+ static take_fn_ptr_t take_wrap_dispatch_table[td_ns::num_types]
62
+ [td_ns::num_types];
63
+
64
+ static take_fn_ptr_t take_clip_dispatch_table[td_ns::num_types]
65
+ [td_ns::num_types];
66
+
67
+ static put_fn_ptr_t put_wrap_dispatch_table[td_ns::num_types][td_ns::num_types];
67
68
68
- static put_fn_ptr_t put_dispatch_table[INDEXING_MODES][td_ns::num_types]
69
- [td_ns::num_types];
69
+ static put_fn_ptr_t put_clip_dispatch_table[td_ns::num_types][td_ns::num_types];
70
70
71
71
namespace py = pybind11;
72
72
@@ -486,7 +486,8 @@ py_take(const dpctl::tensor::usm_ndarray &src,
486
486
std::end (pack_deps));
487
487
all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
488
488
489
- auto fn = take_dispatch_table[mode][src_type_id][ind_type_id];
489
+ auto fn = mode ? take_wrap_dispatch_table[src_type_id][ind_type_id]
490
+ : take_clip_dispatch_table[src_type_id][ind_type_id];
490
491
491
492
if (fn == nullptr ) {
492
493
sycl::event::wait (host_task_events);
@@ -755,7 +756,8 @@ py_put(const dpctl::tensor::usm_ndarray &dst,
755
756
std::end (pack_deps));
756
757
all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
757
758
758
- auto fn = put_dispatch_table[mode][dst_type_id][ind_type_id];
759
+ auto fn = mode ? put_wrap_dispatch_table[src_type_id][ind_type_id]
760
+ : put_clip_dispatch_table[src_type_id][ind_type_id];
759
761
760
762
if (fn == nullptr ) {
761
763
sycl::event::wait (host_task_events);
@@ -790,20 +792,20 @@ void init_advanced_indexing_dispatch_tables(void)
790
792
using dpctl::tensor::kernels::indexing::TakeClipFactory;
791
793
DispatchTableBuilder<take_fn_ptr_t , TakeClipFactory, num_types>
792
794
dtb_takeclip;
793
- dtb_takeclip.populate_dispatch_table (take_dispatch_table[CLIP_MODE] );
795
+ dtb_takeclip.populate_dispatch_table (take_clip_dispatch_table );
794
796
795
797
using dpctl::tensor::kernels::indexing::TakeWrapFactory;
796
798
DispatchTableBuilder<take_fn_ptr_t , TakeWrapFactory, num_types>
797
799
dtb_takewrap;
798
- dtb_takewrap.populate_dispatch_table (take_dispatch_table[WRAP_MODE] );
800
+ dtb_takewrap.populate_dispatch_table (take_wrap_dispatch_table );
799
801
800
802
using dpctl::tensor::kernels::indexing::PutClipFactory;
801
803
DispatchTableBuilder<put_fn_ptr_t , PutClipFactory, num_types> dtb_putclip;
802
- dtb_putclip.populate_dispatch_table (put_dispatch_table[CLIP_MODE] );
804
+ dtb_putclip.populate_dispatch_table (put_clip_dispatch_table );
803
805
804
806
using dpctl::tensor::kernels::indexing::PutWrapFactory;
805
807
DispatchTableBuilder<put_fn_ptr_t , PutWrapFactory, num_types> dtb_putwrap;
806
- dtb_putwrap.populate_dispatch_table (put_dispatch_table[WRAP_MODE] );
808
+ dtb_putwrap.populate_dispatch_table (put_wrap_dispatch_table );
807
809
}
808
810
809
811
} // namespace py_internal
0 commit comments