@@ -274,11 +274,11 @@ rows for the row major layout, or between columns for the column major layout.
274
274
```c++
275
275
namespace sycl::ext::oneapi::experimental::matrix {
276
276
277
- template <typename Group, typename Ta, typename Tb, typename Tc,
278
- std::size_t M, std::size_t K, std::size_t N, layout LayoutA, layout
279
- LayoutB, typename Td = Tc >
280
- joint_matrix< Group, Td, use::accumulator, M, N, layout::dynamic>
281
- joint_matrix_mad( Group g ,
277
+ template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
278
+ std::size_t M, std::size_t K, std::size_t N,
279
+ layout LayoutA, layout LayoutB >
280
+ void joint_matrix_mad( Group g,
281
+ joint_matrix< Group, Td, use::accumulator, M, N, layout::dynamic> &D ,
282
282
const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
283
283
const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
284
284
const joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> &C);
@@ -287,7 +287,7 @@ joint_matrix_mad(Group g,
287
287
```
288
288
The matrix multiply and add function performs the multiply operation
289
289
on the matrices `A` and `B`, accumulates the result with `C` and returns
290
- the result.
290
+ the result into the matrix `D` .
291
291
292
292
Each device supports only certain combinations of types for the `A`,
293
293
`B`, and `C` matrices. The application must use the query operations
@@ -505,6 +505,12 @@ range<2> L = {1, SG_SIZE};
505
505
int8_t *memA = malloc_shared<int8_t>(M*K, q);
506
506
int8_t *memB = malloc_shared<int8_t>(K*N, q);
507
507
int32_t *memC = malloc_shared<int32_t>(M*N, q);
508
+ auto pA = address_space_cast<sycl::access::address_space::global_space,
509
+ sycl::access::decorated::no>(memA);
510
+ auto pB = address_space_cast<sycl::access::address_space::global_space,
511
+ sycl::access::decorated::no>(memB);
512
+ auto pC = address_space_cast<sycl::access::address_space::global_space,
513
+ sycl::access::decorated::no>(memC);
508
514
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
509
515
[[sycl::reqd_sub_group_size(SG_SIZE)]] {
510
516
const auto global_idx = item.get_global_id(0);
@@ -517,20 +523,15 @@ q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
517
523
joint_matrix<sub_group, int32_t, use::accumulator, tM, tN> tC;
518
524
joint_matrix_fill(sg, tC, 0);
519
525
for (int k = 0; k < K; k += tK) {
520
- joint_matrix_load(sg, tA,
521
- multi_ptr<int8_t, sycl::access::address_space::global_space>(memA) +
522
- sg_startx * tM * K + k, K);
523
- joint_matrix_load(sg, tB,
524
- multi_ptr<int8_t, sycl::access::address_space::global_space>(memB) +
525
- k * N + sg_starty/SG_SIZE*tN, N);
526
- tC = joint_matrix_mad(sg, tA, tB, tC);
526
+ joint_matrix_load(sg, tA, pA + sg_startx * tM * K + k, K);
527
+ joint_matrix_load(sg, tB, pB + k * N + sg_starty/SG_SIZE*tN, N);
528
+ joint_matrix_mad(sg, tC, tA, tB, tC);
527
529
}
528
530
joint_matrix_apply(sg, tC, [=](int8_t x) {
529
531
x *= alpha;
530
532
});
531
- joint_matrix_store(sg, tC,
532
- multi_ptr<int32_t, sycl::access::address_space::global_space>(memC) +
533
- sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
533
+ joint_matrix_store(sg, tC, pC + sg_startx * tM * N + sg_starty/SG_SIZE*tN,
534
+ N, layout::row_major);
534
535
}).wait();
535
536
```
536
537
0 commit comments