Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions deps/src/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
return 0;
}

extern "C" void onemklSaxpy(syclQueue_t device_queue, int64_t n, float alpha, const float *x, std::int64_t incx, float *y, int64_t incy) {
oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x, incx, y, incy);
}

extern "C" void onemklDaxpy(syclQueue_t device_queue, int64_t n, double alpha, const double *x, std::int64_t incx, double *y, int64_t incy) {
oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x, incx, y, incy);
}

extern "C" void onemklCaxpy(syclQueue_t device_queue, int64_t n, float _Complex alpha, const float _Complex *x, std::int64_t incx, float _Complex *y, int64_t incy) {
oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, reinterpret_cast<const std::complex<float> *>(x), incx, reinterpret_cast<std::complex<float> *>(y), incy);
}

extern "C" void onemklZaxpy(syclQueue_t device_queue, int64_t n, double _Complex alpha, const double _Complex *x, std::int64_t incx, double _Complex *y, int64_t incy) {
oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, reinterpret_cast<const std::complex<double> *>(x), incx, reinterpret_cast<std::complex<double> *>(y), incy);
}

extern "C" void onemklDcopy(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, double *y, int64_t incy) {
oneapi::mkl::blas::column_major::copy(device_queue->val, n, x, incx, y, incy);
Expand Down
5 changes: 5 additions & 0 deletions deps/src/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
const double _Complex *B, int64_t ldb, double _Complex beta,
double _Complex *C, int64_t ldc);

void onemklSaxpy(syclQueue_t device_queue, int64_t n, float alpha, const float *x, int64_t incx, float *y, int64_t incy);
void onemklDaxpy(syclQueue_t device_queue, int64_t n, double alpha, const double *x, int64_t incx, double *y, int64_t incy);
void onemklCaxpy(syclQueue_t device_queue, int64_t n, float _Complex alpha, const float _Complex *x, int64_t incx, float _Complex *y, int64_t incy);
void onemklZaxpy(syclQueue_t device_queue, int64_t n, double _Complex alpha, const double _Complex *x, int64_t incx, double _Complex *y, int64_t incy);

void onemklDcopy(syclQueue_t device_queue, int64_t n, const double *x,
int64_t incx, double *y, int64_t incy);
void onemklScopy(syclQueue_t device_queue, int64_t n, const float *x,
Expand Down
17 changes: 16 additions & 1 deletion lib/mkl/libonemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld
C::ZePtr{ComplexF64}, ldc::Int64)::Cint
end

function onemklSaxpy(device_queue, n, alpha, x, incx, y, incy)
@ccall liboneapi_support.onemklSaxpy(device_queue::syclQueue_t, n::Int64, alpha::Cfloat, x::ZePtr{Cfloat}, incx::Int64, y::ZePtr{Cfloat}, incy::Int64)::Cvoid
end

function onemklDaxpy(device_queue, n, alpha, x, incx, y, incy)
@ccall liboneapi_support.onemklDaxpy(device_queue::syclQueue_t, n::Int64, alpha::Cdouble, x::ZePtr{Cdouble}, incx::Int64, y::ZePtr{Cdouble}, incy::Int64)::Cvoid
end

function onemklCaxpy(device_queue, n, alpha, x, incx, y, incy)
@ccall liboneapi_support.onemklCaxpy(device_queue::syclQueue_t, n::Int64, alpha::ComplexF32, x::ZePtr{ComplexF32}, incx::Int64, y::ZePtr{ComplexF32}, incy::Int64)::Cvoid
end

function onemklZaxpy(device_queue, n, alpha, x, incx, y, incy)
@ccall liboneapi_support.onemklZaxpy(device_queue::syclQueue_t, n::Int64, alpha::ComplexF64, x::ZePtr{ComplexF64}, incx::Int64, y::ZePtr{ComplexF64}, incy::Int64)::Cvoid
end

function onemklDcopy(device_queue, n, x, incx, y, incy)
@ccall liboneapi_support.onemklDcopy(device_queue::syclQueue_t, n::Int64,
x::ZePtr{Cdouble}, incx::Int64,
Expand All @@ -65,4 +81,3 @@ function onemklCcopy(device_queue, n, x, incx, y, incy)
x::ZePtr{ComplexF32}, incx::Int64,
y::ZePtr{ComplexF32}, incy::Int64)::Cvoid
end

5 changes: 5 additions & 0 deletions lib/mkl/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ function gemm_dispatch!(C::oneStridedVecOrMat, A, B, alpha::Number=true, beta::N
end
end

function LinearAlgebra.axpy!(alpha::Number, x::oneStridedVecOrMat{<:onemklFloat}, y::oneStridedVecOrMat{<:onemklFloat}) where T<:Union{onemklFloat}
length(x)==length(y) || throw(DimensionMismatch("axpy arguments have lengths $(length(x)) and $(length(y))"))
oneMKL.axpy!(length(x), alpha, x, y)
end

for NT in (Number, Real)
# NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods
@eval begin
Expand Down
22 changes: 20 additions & 2 deletions lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,26 @@ function Base.convert(::Type{onemklTranspose}, trans::Char)
end
end



# level 1
## axpy primitive
for (fname, elty) in
((:onemklDaxpy,:Float64),
(:onemklSaxpy,:Float32),
(:onemklZaxpy,:ComplexF64),
(:onemklCaxpy,:ComplexF32))
@eval begin
function axpy!(n::Integer,
alpha::Number,
x::oneStridedArray{$elty},
y::oneStridedArray{$elty}
)
queue = global_queue(context(x), device(x))
alpha = $elty(alpha)
$fname(sycl_queue(queue), n, alpha, x, stride(x,1), y, stride(y,1))
y
end
end
end
#
# BLAS
#
Expand Down
16 changes: 12 additions & 4 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@ k = 13
############################################################################################
@testset "level 1" begin
@testset for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64])
A = oneArray(rand(T, m))
B = oneArray{T}(undef, m)
oneMKL.copy!(m,A,B)
@test Array(A) == Array(B)
@testset "copy" begin
A = oneArray(rand(T, m))
B = oneArray{T}(undef, m)
oneMKL.copy!(m,A,B)
@test Array(A) == Array(B)
end

@testset "axpy" begin
# Test axpy primitive
alpha = rand(T,1)
@test testf(axpy!, alpha[1], rand(T,m), rand(T,m))
end
end # level 1 testset
end