Skip to content

Commit

Permalink
fix and add test for gemv_batched! when the matrix is transposed
Browse files Browse the repository at this point in the history
  • Loading branch information
kose-y committed Feb 3, 2025
1 parent 69f3a76 commit 0dd2249
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
15 changes: 8 additions & 7 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,16 @@ for (fname, fname_64, eltyin, eltyout) in (
if length(A) != length(x) || length(A) != length(y)
throw(DimensionMismatch("Lengths of inputs must be the same"))
end
m = size(A[1], 1)
n = size(A[1], 2)
for (i, (As,xs,ys)) in enumerate(zip(A,x,y))
m,n = size(As)
if size(As) != (m, n)
throw(DimensionMismatch("A[$i] has different dimension from A[1]. Dimensions between A's should be identical."))
end
if length(xs) != (trans == 'N' ? n : m) || length(ys) != (trans == 'N' ? m : n)
throw(DimensionMismatch("Input $i: A has dimension $(size(As)), x has dimension $(size(xs)), y has dimension $(size(ys))"))
end
end

m = size(A[1], trans == 'N' ? 1 : 2)
n = size(A[1], trans == 'N' ? 2 : 1)
lda = max(1,stride(A[1],2))
incx = stride(x[1],1)
incy = stride(y[1],1)
Expand Down Expand Up @@ -470,9 +471,9 @@ for (fname, fname_64, eltyin, eltyout) in (
if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2)
throw(DimensionMismatch("Batch sizes must be equal for all inputs"))
end
m = size(A, trans == 'N' ? 1 : 2)
n = size(A, trans == 'N' ? 2 : 1)
if m != size(y, 1) || n != size(x, 1)
m = size(A, 1)
n = size(A, 2)
if size(y, 1) != (trans == 'N' ? m : n) || size(x, 1) != (trans == 'N' ? n : m)
throw(DimensionMismatch("A has dimension $(size(A)), x has dimension $(size(x)), y has dimension $(size(y))"))
end

Expand Down
25 changes: 23 additions & 2 deletions test/libraries/cublas/level2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ k = 13
@test testf(*, transpose(rand(elty, m, n)), rand(elty, m))
@test testf(*, rand(elty, m, n)', rand(elty, m))
x = rand(elty, m)
A = rand(elty, m, m + 1 )
y = rand(elty, m)
A = rand(elty, m, m + 1)
y = rand(elty, n)
dx = CuArray(x)
dA = CuArray(A)
dy = CuArray(y)
Expand All @@ -44,6 +44,10 @@ k = 13
dy = CUBLAS.gemv('N', dA, dx)
hy = collect(dy)
@test hy A * x
dy = CuArray(y)
dx = CUBLAS.gemv(elty <: Real ? 'T' : 'C', alpha, dA, dy)
hx = collect(dx)
@test hx alpha * A' * y
end

if CUBLAS.version() >= v"11.9"
Expand Down Expand Up @@ -72,6 +76,16 @@ k = 13
y[i] = alpha * A[i] * x[i] + beta * y[i]
@test y[i] hy
end
dy = CuArray{elty, 1}[]
for i=1:length(A)
push!(dy, CuArray(y[i]))
end
CUBLAS.gemv_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx)
for i=1:length(A)
hx = collect(dx[i])
x[i] = alpha * A[i]' * y[i] + beta * x[i]
@test x[i] hx
end
end
end

Expand All @@ -97,6 +111,13 @@ k = 13
y[:, i] = alpha * A[:, :, i] * x[:, i] + beta * y[:, i]
@test y[:, i] hy
end
dy = CuArray(y)
CUBLAS.gemv_strided_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx)
for i=1:size(A, 3)
hx = collect(dx[:, i])
x[:, i] = alpha * A[:, :, i]' * y[:, i] + beta * x[:, i]
@test x[:, i] hx
end
end
end

Expand Down

0 comments on commit 0dd2249

Please sign in to comment.