Skip to content

Commit 951add8

Browse files
[SYCL] Add one more test for virtual functions (#14705)
Test plan can be found in #10540. This PR introduces a basic test case for a scenario where object of a polymorphic class is constructed in one kernel, but used in another kernel.
1 parent 74a7907 commit 951add8

File tree

2 files changed

+119
-7
lines changed

2 files changed

+119
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// UNSUPPORTED: cuda, hip, acc
2+
// FIXME: replace unsupported with an aspect check once we have it
3+
//
4+
// RUN: %{build} -o %t.out %helper-includes
5+
// RUN: %{run} %t.out
6+
7+
#include <sycl/detail/core.hpp>
8+
9+
#include "helpers.hpp"
10+
11+
#include <iostream>
12+
13+
namespace oneapi = sycl::ext::oneapi::experimental;
14+
15+
class BaseIncrement {
16+
public:
17+
BaseIncrement(int Mod, int /* unused */ = 42) : Mod(Mod) {}
18+
19+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(oneapi::indirectly_callable<>)
20+
virtual void increment(int *Data) { *Data += 1 + Mod; }
21+
22+
protected:
23+
int Mod = 0;
24+
};
25+
26+
class IncrementBy2 : public BaseIncrement {
27+
public:
28+
IncrementBy2(int Mod, int /* unused */) : BaseIncrement(Mod) {}
29+
30+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(oneapi::indirectly_callable<>)
31+
void increment(int *Data) override { *Data += 2 + Mod; }
32+
};
33+
34+
class IncrementBy4 : public BaseIncrement {
35+
public:
36+
IncrementBy4(int Mod, int ExtraMod)
37+
: BaseIncrement(Mod), ExtraMod(ExtraMod) {}
38+
39+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(oneapi::indirectly_callable<>)
40+
void increment(int *Data) override { *Data += 4 + Mod + ExtraMod; }
41+
42+
private:
43+
int ExtraMod = 0;
44+
};
45+
46+
class IncrementBy8 : public BaseIncrement {
47+
public:
48+
IncrementBy8(int Mod, int /* unused */) : BaseIncrement(Mod) {}
49+
50+
SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(oneapi::indirectly_callable<>)
51+
void increment(int *Data) override { *Data += 8 + Mod; }
52+
};
53+
54+
int main() try {
55+
using storage_t =
56+
obj_storage_t<BaseIncrement, IncrementBy2, IncrementBy4, IncrementBy8>;
57+
58+
storage_t HostStorage;
59+
sycl::buffer<storage_t> DeviceStorage(sycl::range{1});
60+
61+
auto asyncHandler = [](sycl::exception_list list) {
62+
for (auto &e : list)
63+
std::rethrow_exception(e);
64+
};
65+
66+
sycl::queue q(asyncHandler);
67+
68+
// TODO: cover uses case when objects are passed through USM
69+
constexpr oneapi::properties props{oneapi::calls_indirectly<>};
70+
for (unsigned TestCase = 0; TestCase < 4; ++TestCase) {
71+
int HostData = 42;
72+
int Data = HostData;
73+
sycl::buffer<int> DataStorage(&Data, sycl::range{1});
74+
75+
q.submit([&](sycl::handler &CGH) {
76+
sycl::accessor StorageAcc(DeviceStorage, CGH, sycl::write_only);
77+
CGH.single_task([=]() {
78+
StorageAcc[0].construct</* ret type = */ BaseIncrement>(TestCase, 19,
79+
23);
80+
});
81+
});
82+
83+
q.submit([&](sycl::handler &CGH) {
84+
sycl::accessor StorageAcc(DeviceStorage, CGH, sycl::read_write);
85+
sycl::accessor DataAcc(DataStorage, CGH, sycl::write_only);
86+
CGH.single_task(props, [=]() {
87+
auto *Ptr = StorageAcc[0].getAs<BaseIncrement>();
88+
Ptr->increment(
89+
DataAcc.get_multi_ptr<sycl::access::decorated::no>().get());
90+
});
91+
});
92+
93+
auto *Ptr =
94+
HostStorage.construct</* ret type = */ BaseIncrement>(TestCase, 19, 23);
95+
Ptr->increment(&HostData);
96+
97+
sycl::host_accessor HostAcc(DataStorage);
98+
assert(HostAcc[0] == HostData);
99+
}
100+
101+
return 0;
102+
} catch (sycl::exception &e) {
103+
std::cout << "Unexpected exception was thrown: " << e.what() << std::endl;
104+
return 1;
105+
}

sycl/test-e2e/VirtualFunctions/helpers.hpp

+14-7
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,36 @@ template <typename... T> struct obj_storage_t {
2525

2626
type storage;
2727

28-
template <typename RetT> RetT *construct(const unsigned int TypeIndex) {
28+
template <typename RetT, typename... Args>
29+
RetT *construct(const unsigned int TypeIndex, Args... args) {
2930
if (TypeIndex >= sizeof...(T)) {
3031
#ifndef __SYCL_DEVICE_ONLY__
3132
assert(false && "Type index is invalid");
3233
#endif
3334
return nullptr;
3435
}
3536

36-
return constructHelper<RetT, T...>(TypeIndex, 0);
37+
return constructHelper<RetT, T...>(TypeIndex, 0, args...);
38+
}
39+
40+
template <typename RetT> RetT *getAs() {
41+
return reinterpret_cast<RetT *>(&storage);
3742
}
3843

3944
private:
40-
template <typename RetT> RetT *constructHelper(const int, const int) {
45+
template <typename RetT, typename... Args>
46+
RetT *constructHelper(const int, const int, Args...) {
4147
// Won't be ever called, but required to compile
4248
return nullptr;
4349
}
4450

45-
template <typename RetT, typename Type, typename... Rest>
46-
RetT *constructHelper(const int TargetIndex, const int CurIndex) {
51+
template <typename RetT, typename Type, typename... Rest, typename... Args>
52+
RetT *constructHelper(const int TargetIndex, const int CurIndex,
53+
Args... args) {
4754
if (TargetIndex != CurIndex)
48-
return constructHelper<RetT, Rest...>(TargetIndex, CurIndex + 1);
55+
return constructHelper<RetT, Rest...>(TargetIndex, CurIndex + 1, args...);
4956

50-
RetT *Ptr = new (reinterpret_cast<Type *>(&storage)) Type;
57+
RetT *Ptr = new (reinterpret_cast<Type *>(&storage)) Type(args...);
5158
return Ptr;
5259
}
5360
};

0 commit comments

Comments
 (0)