Skip to content

Commit 4802fbd

Browse files
committed
Change integer indexing mode dispatching
1 parent e40112d commit 4802fbd

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

+16-14
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@
4646

4747
#include "integer_advanced_indexing.hpp"
4848

49-
#define INDEXING_MODES 2
50-
#define WRAP_MODE 0
51-
#define CLIP_MODE 1
52-
5349
namespace dpctl
5450
{
5551
namespace tensor
@@ -62,11 +58,15 @@ namespace td_ns = dpctl::tensor::type_dispatch;
6258
using dpctl::tensor::kernels::indexing::put_fn_ptr_t;
6359
using dpctl::tensor::kernels::indexing::take_fn_ptr_t;
6460

65-
static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][td_ns::num_types]
66-
[td_ns::num_types];
61+
static take_fn_ptr_t take_wrap_dispatch_table[td_ns::num_types]
62+
[td_ns::num_types];
63+
64+
static take_fn_ptr_t take_clip_dispatch_table[td_ns::num_types]
65+
[td_ns::num_types];
66+
67+
static put_fn_ptr_t put_wrap_dispatch_table[td_ns::num_types][td_ns::num_types];
6768

68-
static put_fn_ptr_t put_dispatch_table[INDEXING_MODES][td_ns::num_types]
69-
[td_ns::num_types];
69+
static put_fn_ptr_t put_clip_dispatch_table[td_ns::num_types][td_ns::num_types];
7070

7171
namespace py = pybind11;
7272

@@ -486,7 +486,8 @@ py_take(const dpctl::tensor::usm_ndarray &src,
486486
std::end(pack_deps));
487487
all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));
488488

489-
auto fn = take_dispatch_table[mode][src_type_id][ind_type_id];
489+
auto fn = mode ? take_wrap_dispatch_table[src_type_id][ind_type_id]
490+
: take_clip_dispatch_table[src_type_id][ind_type_id];
490491

491492
if (fn == nullptr) {
492493
sycl::event::wait(host_task_events);
@@ -755,7 +756,8 @@ py_put(const dpctl::tensor::usm_ndarray &dst,
755756
std::end(pack_deps));
756757
all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));
757758

758-
auto fn = put_dispatch_table[mode][dst_type_id][ind_type_id];
759+
auto fn = mode ? put_wrap_dispatch_table[src_type_id][ind_type_id]
760+
: put_clip_dispatch_table[src_type_id][ind_type_id];
759761

760762
if (fn == nullptr) {
761763
sycl::event::wait(host_task_events);
@@ -790,20 +792,20 @@ void init_advanced_indexing_dispatch_tables(void)
790792
using dpctl::tensor::kernels::indexing::TakeClipFactory;
791793
DispatchTableBuilder<take_fn_ptr_t, TakeClipFactory, num_types>
792794
dtb_takeclip;
793-
dtb_takeclip.populate_dispatch_table(take_dispatch_table[CLIP_MODE]);
795+
dtb_takeclip.populate_dispatch_table(take_clip_dispatch_table);
794796

795797
using dpctl::tensor::kernels::indexing::TakeWrapFactory;
796798
DispatchTableBuilder<take_fn_ptr_t, TakeWrapFactory, num_types>
797799
dtb_takewrap;
798-
dtb_takewrap.populate_dispatch_table(take_dispatch_table[WRAP_MODE]);
800+
dtb_takewrap.populate_dispatch_table(take_wrap_dispatch_table);
799801

800802
using dpctl::tensor::kernels::indexing::PutClipFactory;
801803
DispatchTableBuilder<put_fn_ptr_t, PutClipFactory, num_types> dtb_putclip;
802-
dtb_putclip.populate_dispatch_table(put_dispatch_table[CLIP_MODE]);
804+
dtb_putclip.populate_dispatch_table(put_clip_dispatch_table);
803805

804806
using dpctl::tensor::kernels::indexing::PutWrapFactory;
805807
DispatchTableBuilder<put_fn_ptr_t, PutWrapFactory, num_types> dtb_putwrap;
806-
dtb_putwrap.populate_dispatch_table(put_dispatch_table[WRAP_MODE]);
808+
dtb_putwrap.populate_dispatch_table(put_wrap_dispatch_table);
807809
}
808810

809811
} // namespace py_internal

0 commit comments

Comments
 (0)