Skip to content

Commit d714a56

Browse files
authored
oneMKL: Add support for iamax and iamin. (JuliaGPU#235)
1 parent e6b8c01 commit d714a56

File tree

5 files changed

+142
-2
lines changed

5 files changed

+142
-2
lines changed

deps/src/onemkl.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,52 @@ extern "C" void onemklCcopy(syclQueue_t device_queue, int64_t n, const float _Co
105105
reinterpret_cast<std::complex<float> *>(y), incy);
106106
}
107107

108+
extern "C" void onemklDamax(syclQueue_t device_queue, int64_t n, const double *x,
109+
int64_t incx, int64_t *result){
110+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result);
111+
status.wait();
112+
}
113+
extern "C" void onemklSamax(syclQueue_t device_queue, int64_t n, const float *x,
114+
int64_t incx, int64_t *result){
115+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n, x, incx, result);
116+
status.wait();
117+
}
118+
extern "C" void onemklZamax(syclQueue_t device_queue, int64_t n, const double _Complex *x,
119+
int64_t incx, int64_t *result){
120+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n,
121+
reinterpret_cast<const std::complex<double> *>(x), incx, result);
122+
status.wait();
123+
}
124+
extern "C" void onemklCamax(syclQueue_t device_queue, int64_t n, const float _Complex *x,
125+
int64_t incx, int64_t *result){
126+
auto status = oneapi::mkl::blas::column_major::iamax(device_queue->val, n,
127+
reinterpret_cast<const std::complex<float> *>(x), incx, result);
128+
status.wait();
129+
}
130+
131+
extern "C" void onemklDamin(syclQueue_t device_queue, int64_t n, const double *x,
132+
int64_t incx, int64_t *result){
133+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result);
134+
status.wait();
135+
}
136+
extern "C" void onemklSamin(syclQueue_t device_queue, int64_t n, const float *x,
137+
int64_t incx, int64_t *result){
138+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n, x, incx, result);
139+
status.wait();
140+
}
141+
extern "C" void onemklZamin(syclQueue_t device_queue, int64_t n, const double _Complex *x,
142+
int64_t incx, int64_t *result){
143+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n,
144+
reinterpret_cast<const std::complex<double> *>(x), incx, result);
145+
status.wait();
146+
}
147+
extern "C" void onemklCamin(syclQueue_t device_queue, int64_t n, const float _Complex *x,
148+
int64_t incx, int64_t *result){
149+
auto status = oneapi::mkl::blas::column_major::iamin(device_queue->val, n,
150+
reinterpret_cast<const std::complex<float> *>(x), incx, result);
151+
status.wait();
152+
}
153+
108154
// other
109155

110156
// oneMKL keeps a cache of SYCL queues and tries to destroy them when unloading the library.

deps/src/onemkl.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,24 @@ void onemklZcopy(syclQueue_t device_queue, int64_t n, const double _Complex *x,
4848
void onemklCcopy(syclQueue_t device_queue, int64_t n, const float _Complex *x,
4949
int64_t incx, float _Complex *y, int64_t incy);
5050

51+
void onemklDamax(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx,
52+
int64_t *result);
53+
void onemklSamax(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx,
54+
int64_t *result);
55+
void onemklZamax(syclQueue_t device_queue, int64_t n, const double _Complex *x, int64_t incx,
56+
int64_t *result);
57+
void onemklCamax(syclQueue_t device_queue, int64_t n, const float _Complex *x, int64_t incx,
58+
int64_t *result);
59+
60+
void onemklDamin(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx,
61+
int64_t *result);
62+
void onemklSamin(syclQueue_t device_queue, int64_t n, const float *x, int64_t incx,
63+
int64_t *result);
64+
void onemklZamin(syclQueue_t device_queue, int64_t n, const double _Complex *x, int64_t incx,
65+
int64_t *result);
66+
void onemklCamin(syclQueue_t device_queue, int64_t n, const float _Complex *x, int64_t incx,
67+
int64_t *result);
68+
5169
void onemklDestroy();
5270
#ifdef __cplusplus
5371
}

lib/mkl/libonemkl.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,42 @@ function onemklCcopy(device_queue, n, x, incx, y, incy)
6666
y::ZePtr{ComplexF32}, incy::Int64)::Cvoid
6767
end
6868

69+
function onemklSamax(device_queue, n, x, incx, result)
70+
@ccall liboneapi_support.onemklSamax(device_queue::syclQueue_t, n::Int64,
71+
x::ZePtr{Cfloat}, incx::Int64, result::ZePtr{Int64})::Cvoid
72+
end
73+
74+
function onemklDamax(device_queue, n, x, incx, result)
75+
@ccall liboneapi_support.onemklDamax(device_queue::syclQueue_t, n::Int64,
76+
x::ZePtr{Cdouble}, incx::Int64, result::ZePtr{Int64})::Cvoid
77+
end
78+
79+
function onemklCamax(device_queue, n, x, incx, result)
80+
@ccall liboneapi_support.onemklCamax(device_queue::syclQueue_t, n::Int64,
81+
x::ZePtr{ComplexF32}, incx::Int64,result::ZePtr{Int64})::Cvoid
82+
end
83+
84+
function onemklZamax(device_queue, n, x, incx, result)
85+
@ccall liboneapi_support.onemklZamax(device_queue::syclQueue_t, n::Int64,
86+
x::ZePtr{ComplexF64}, incx::Int64, result::ZePtr{Int64})::Cvoid
87+
end
88+
89+
function onemklSamin(device_queue, n, x, incx, result)
90+
@ccall liboneapi_support.onemklSamin(device_queue::syclQueue_t, n::Int64,
91+
x::ZePtr{Cfloat}, incx::Int64, result::ZePtr{Int64})::Cvoid
92+
end
93+
94+
function onemklDamin(device_queue, n, x, incx, result)
95+
@ccall liboneapi_support.onemklDamin(device_queue::syclQueue_t, n::Int64,
96+
x::ZePtr{Cdouble}, incx::Int64, result::ZePtr{Int64})::Cvoid
97+
end
98+
99+
function onemklCamin(device_queue, n, x, incx, result)
100+
@ccall liboneapi_support.onemklCamin(device_queue::syclQueue_t, n::Int64,
101+
x::ZePtr{ComplexF32}, incx::Int64,result::ZePtr{Int64})::Cvoid
102+
end
103+
104+
function onemklZamin(device_queue, n, x, incx, result)
105+
@ccall liboneapi_support.onemklZamin(device_queue::syclQueue_t, n::Int64,
106+
x::ZePtr{ComplexF64}, incx::Int64, result::ZePtr{Int64})::Cvoid
107+
end

lib/mkl/wrappers.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,39 @@ for (fname, elty) in
3838
end
3939
end
4040

41+
## iamax
42+
for (fname, elty) in
43+
((:onemklDamax,:Float64),
44+
(:onemklSamax,:Float32),
45+
(:onemklZamax,:ComplexF64),
46+
(:onemklCamax,:ComplexF32))
47+
@eval begin
48+
function iamax(x::oneStridedArray{$elty})
49+
n = length(x)
50+
queue = global_queue(context(x), device(x))
51+
result = oneArray{Int64}([0]);
52+
$fname(sycl_queue(queue), n, x, stride(x, 1), result)
53+
return Array(result)[1]+1
54+
end
55+
end
56+
end
57+
58+
## iamin
59+
for (fname, elty) in
60+
((:onemklDamin,:Float64),
61+
(:onemklSamin,:Float32),
62+
(:onemklZamin,:ComplexF64),
63+
(:onemklCamin,:ComplexF32))
64+
@eval begin
65+
function iamin(x::StridedArray{$elty})
66+
n = length(x)
67+
result = oneArray{Int64}([0]);
68+
queue = global_queue(context(x), device(x))
69+
$fname(sycl_queue(queue),n, x, stride(x, 1), result)
70+
return Array(result)[1]+1
71+
end
72+
end
73+
end
4174

4275
# level 3
4376

test/onemkl.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ using oneAPI.oneMKL
44
using LinearAlgebra
55

66
m = 20
7-
n = 35
8-
k = 13
97

108
############################################################################################
119
@testset "level 1" begin
@@ -14,5 +12,11 @@ k = 13
1412
B = oneArray{T}(undef, m)
1513
oneMKL.copy!(m,A,B)
1614
@test Array(A) == Array(B)
15+
16+
# testing oneMKL max and min
17+
a = convert.(T, [1.0, 2.0, -0.8, 5.0, 3.0])
18+
ca = oneArray(a)
19+
@test BLAS.iamax(a) == oneMKL.iamax(ca)
20+
@test oneMKL.iamin(ca) == 3
1721
end # level 1 testset
1822
end

0 commit comments

Comments
 (0)