diff --git a/src/host/linalg.jl b/src/host/linalg.jl index bc599968..1558dc5a 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -683,6 +683,7 @@ function generic_rmul!(X::AbstractArray, s::Number) end LinearAlgebra.rmul!(A::AbstractGPUArray, b::Number) = generic_rmul!(A, b) +LinearAlgebra.rmul!(A::Diagonal{T, <:AbstractGPUArray}, b::Number) where {T} = generic_rmul!(A.diag, b) function generic_lmul!(s::Number, X::AbstractArray) @kernel function lmul_kernel!(X, s) @@ -694,6 +695,7 @@ function generic_lmul!(s::Number, X::AbstractArray) end LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B) +LinearAlgebra.lmul!(a::Number, B::Diagonal{T, <:AbstractGPUArray}) where {T} = generic_lmul!(a, B.diag) ## permutedims diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index fa701084..31637977 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -437,6 +437,21 @@ end A_empty = randn(Float32, 0, 0) @test compare(f, AT, A_empty, d) end + + @testset "rmul!/lmul! with diagonal and number" begin + n = 32 + h_d = rand(Float32, n) + h_D = Diagonal(h_d) + d = AT(h_d) + D = Diagonal(d) + a = rand(Float32) + rmul!(D, a) + rmul!(h_D, a) + @test collect(D) ≈ h_D + lmul!(a, D) + lmul!(a, h_D) + @test collect(D) ≈ h_D + end end @testsuite "linalg/mul!/vector-matrix" (AT, eltypes)->begin