diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 150a01f66e..c09a456000 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -416,15 +416,16 @@ for (fname, fname_64, eltyin, eltyout) in ( if length(A) != length(x) || length(A) != length(y) throw(DimensionMismatch("Lengths of inputs must be the same")) end + m = size(A[1], 1) + n = size(A[1], 2) for (i, (As,xs,ys)) in enumerate(zip(A,x,y)) - m,n = size(As) + if size(As) != (m, n) + throw(DimensionMismatch("A[$i] has different dimension from A[1]. Dimensions between A's should be identical.")) + end if length(xs) != (trans == 'N' ? n : m) || length(ys) != (trans == 'N' ? m : n) throw(DimensionMismatch("Input $i: A has dimension $(size(As)), x has dimension $(size(xs)), y has dimension $(size(ys))")) end end - - m = size(A[1], trans == 'N' ? 1 : 2) - n = size(A[1], trans == 'N' ? 2 : 1) lda = max(1,stride(A[1],2)) incx = stride(x[1],1) incy = stride(y[1],1) @@ -462,9 +463,9 @@ for (fname, fname_64, eltyin, eltyout) in ( if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2) throw(DimensionMismatch("Batch sizes must be equal for all inputs")) end - m = size(A, trans == 'N' ? 1 : 2) - n = size(A, trans == 'N' ? 2 : 1) - if m != size(y, 1) || n != size(x, 1) + m = size(A, 1) + n = size(A, 2) + if size(y, 1) != (trans == 'N' ? m : n) || size(x, 1) != (trans == 'N' ? n : m) throw(DimensionMismatch("A has dimension $(size(A)), x has dimension $(size(x)), y has dimension $(size(y))")) end diff --git a/test/libraries/cublas.jl b/test/libraries/cublas.jl index cab9271dc7..1a8bfb0934 100644 --- a/test/libraries/cublas.jl +++ b/test/libraries/cublas.jl @@ -105,7 +105,7 @@ end @test testf(*, rand(elty, m, n)', rand(elty, m)) x = rand(elty, m) A = rand(elty, m, m + 1 ) - y = rand(elty, m) + y = rand(elty, n) dx = CuArray(x) dA = CuArray(A) dy = CuArray(y) @@ -124,6 +124,10 @@ end dy = CUBLAS.gemv('N', dA, dx) hy = collect(dy) @test hy ≈ A * x + dy = CuArray(y) + dx = CUBLAS.gemv(elty <: Real ? 'T' : 'C', alpha, dA, dy) + hx = collect(dx) + @test hx ≈ alpha * A' * y end if CUBLAS.version() >= v"11.9" @@ -150,6 +154,16 @@ end y[i] = alpha * A[i] * x[i] + beta * y[i] @test y[i] ≈ hy end + dy = CuArray{elty, 1}[] + for i=1:length(A) + push!(dy, CuArray(y[i])) + end + CUBLAS.gemv_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx) + for i=1:size(A, 3) + hx = collect(dx[i]) + x[i] = alpha * A[i]' * y[i] + beta * x[i] + @test x[i] ≈ hx + end end end @@ -173,6 +187,13 @@ end y[:, i] = alpha * A[:, :, i] * x[:, i] + beta * y[:, i] @test y[:, i] ≈ hy end + dy = CuArray(y) + CUBLAS.gemv_strided_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx) + for i=1:size(A, 3) + hx = collect(dx[:, i]) + x[:, i] = alpha * A[:, :, i]' * y[:, i] + beta * x[:, i] + @test x[:, i] ≈ hx + end end end