diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 150a01f66e..c17f40efd7 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -462,9 +462,9 @@ for (fname, fname_64, eltyin, eltyout) in ( if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2) throw(DimensionMismatch("Batch sizes must be equal for all inputs")) end - m = size(A, trans == 'N' ? 1 : 2) - n = size(A, trans == 'N' ? 2 : 1) - if m != size(y, 1) || n != size(x, 1) + m = size(A, 1) + n = size(A, 2) + if size(y, 1) != (trans == 'N' ? m : n) || size(x, 1) != (trans == 'N' ? n : m) throw(DimensionMismatch("A has dimension $(size(A)), x has dimension $(size(x)), y has dimension $(size(y))")) end