Skip to content

Commit a154553

Browse files
authored
[SYCL][Matrix spec] keep deletion of assign op and copy ctor but change signature of joint_matrix_mad (#11007)
@gmlueck, in #7964, we explicitly deleted the copy ctor and assign op because we added `joint_matrix_copy` that is used to actually copy matrix data. However, when deleting them in implementation, failures occur as `joint_matrix_mad` uses them. This PR proposes to change the signature of `joint_matrix_mad` so these ctors are not used.
1 parent fb368a1 commit a154553

File tree

2 files changed

+29
-23
lines changed

2 files changed

+29
-23
lines changed

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,20 @@ q.submit([&](sycl::handler& cgh) {
304304
joint_matrix<sub_group, int32_t, use::accumulator, tM, tN> tC;
305305
joint_matrix_fill(sg, tC, 0);
306306
for (int k = 0; k < K; k += tK) {
307-
joint_matrix_load(sg, tA, accA + sg_startx * tM * K + k, K);
308-
joint_matrix_load(sg, tB, accB + k * N*4 + sg_starty/SG_SIZE*tN*4, N*4);
309-
tC = joint_matrix_mad(sg, tA, tB, tC);
307+
joint_matrix_load(sg, tA,
308+
accA.template get_multi_ptr<sycl::access::decorated::no>() +
309+
sg_startx * tM * K + k, K);
310+
joint_matrix_load(sg, tB,
311+
accB.template get_multi_ptr<sycl::access::decorated::no>() +
312+
k * N*4 + sg_starty/SG_SIZE*tN*4, N*4);
313+
joint_matrix_mad(sg, tC, tA, tB, tC);
310314
}
311-
auto wi_data_c = ext::intel::experimental::matrix::get_wi_data(sg, tC);
312-
for (int i = 0; i < wi_data_c.length(); i++)
313-
wi_data_c[i] *= alpha;
315+
joint_matrix_apply(sg, tC, [=](int8_t x) {
316+
x *= alpha;
317+
});
314318
joint_matrix_store(sg, tC,
315-
accC + sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
319+
accC.template get_multi_ptr<sycl::access::decorated::no>()
320+
+ sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
316321
});
317322
});
318323
q.wait();

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,11 @@ rows for the row major layout, or between columns for the column major layout.
274274
```c++
275275
namespace sycl::ext::oneapi::experimental::matrix {
276276

277-
template <typename Group, typename Ta, typename Tb, typename Tc,
278-
std::size_t M, std::size_t K, std::size_t N, layout LayoutA, layout
279-
LayoutB, typename Td = Tc>
280-
joint_matrix<Group, Td, use::accumulator, M, N, layout::dynamic>
281-
joint_matrix_mad(Group g,
277+
template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
278+
std::size_t M, std::size_t K, std::size_t N,
279+
layout LayoutA, layout LayoutB>
280+
void joint_matrix_mad(Group g,
281+
joint_matrix<Group, Td, use::accumulator, M, N, layout::dynamic> &D,
282282
const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
283283
const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
284284
const joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> &C);
@@ -287,7 +287,7 @@ joint_matrix_mad(Group g,
287287
```
288288
The matrix multiply and add function performs the multiply operation
289289
on the matrices `A` and `B`, accumulates the result with `C` and returns
290-
the result.
290+
the result into the matrix `D`.
291291

292292
Each device supports only certain combinations of types for the `A`,
293293
`B`, and `C` matrices. The application must use the query operations
@@ -505,6 +505,12 @@ range<2> L = {1, SG_SIZE};
505505
int8_t *memA = malloc_shared<int8_t>(M*K, q);
506506
int8_t *memB = malloc_shared<int8_t>(K*N, q);
507507
int32_t *memC = malloc_shared<int32_t>(M*N, q);
508+
auto pA = address_space_cast<sycl::access::address_space::global_space,
509+
sycl::access::decorated::no>(memA);
510+
auto pB = address_space_cast<sycl::access::address_space::global_space,
511+
sycl::access::decorated::no>(memB);
512+
auto pC = address_space_cast<sycl::access::address_space::global_space,
513+
sycl::access::decorated::no>(memC);
508514
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
509515
[[sycl::reqd_sub_group_size(SG_SIZE)]] {
510516
const auto global_idx = item.get_global_id(0);
@@ -517,20 +523,15 @@ q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
517523
joint_matrix<sub_group, int32_t, use::accumulator, tM, tN> tC;
518524
joint_matrix_fill(sg, tC, 0);
519525
for (int k = 0; k < K; k += tK) {
520-
joint_matrix_load(sg, tA,
521-
multi_ptr<int8_t, sycl::access::address_space::global_space>(memA) +
522-
sg_startx * tM * K + k, K);
523-
joint_matrix_load(sg, tB,
524-
multi_ptr<int8_t, sycl::access::address_space::global_space>(memB) +
525-
k * N + sg_starty/SG_SIZE*tN, N);
526-
tC = joint_matrix_mad(sg, tA, tB, tC);
526+
joint_matrix_load(sg, tA, pA + sg_startx * tM * K + k, K);
527+
joint_matrix_load(sg, tB, pB + k * N + sg_starty/SG_SIZE*tN, N);
528+
joint_matrix_mad(sg, tC, tA, tB, tC);
527529
}
528530
joint_matrix_apply(sg, tC, [=](int8_t x) {
529531
x *= alpha;
530532
});
531-
joint_matrix_store(sg, tC,
532-
multi_ptr<int32_t, sycl::access::address_space::global_space>(memC) +
533-
sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
533+
joint_matrix_store(sg, tC, pC + sg_startx * tM * N + sg_starty/SG_SIZE*tN,
534+
N, layout::row_major);
534535
}).wait();
535536
```
536537

0 commit comments

Comments
 (0)