Skip to content

Commit

Permalink
Test batched matmul (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Nov 25, 2023
1 parent 0fe2294 commit 9f1d53e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
5 changes: 3 additions & 2 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ function encode!(cmdbuf::MTLCommandBuffer, matmul::MPSMatrixMultiplication, left
resultMatrix:result::id{MPSMatrix}]::Nothing
end


"""
matMulMPS(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
transpose_left=false, transpose_right=false)
Expand All @@ -182,9 +183,9 @@ A `MPSMatrixMultiplication` kernel thay computes:
This function should not typically be used. Rather, use the normal `LinearAlgebra` interface
with any `MtlArray` and it should be accelerated using Metal Performance Shaders.
"""
function matmul!(c::MtlMatrix, a::MtlMatrix, b::MtlMatrix,
function matmul!(c::MtlArray{T1,N}, a::MtlArray{T2,N}, b::MtlArray{T3,N},
alpha::Number=true, beta::Number=true,
transpose_a=false, transpose_b=false)
transpose_a=false, transpose_b=false) where {T1, T2, T3, N}
# NOTE: MPS uses row major, while Julia is col-major. Instead of transposing
# the inputs (by passing !transpose_[ab]) and afterwards transposing
# the output, we use the property that (AB)ᵀ = BᵀAᵀ
Expand Down
38 changes: 38 additions & 0 deletions test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,44 @@ if MPS.is_supported(current_device())
end
end

@testset "batched matrix matrix multiplication" begin
N = 10
batch_size = 3

rows_a = N
cols_a = N

rows_b = N
cols_b = N

rows_c = rows_a
cols_c = cols_b

alpha = Float64(1)
beta = Float64(1)

for (input_jl_type, accum_jl_type) in MPS.MPS_VALID_MATMUL_TYPES
@testset "$(input_jl_type) => $accum_jl_type" begin
arr_a = rand(input_jl_type, (rows_a, cols_a, batch_size))
arr_b = rand(input_jl_type, (rows_b, cols_b, batch_size))
arr_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size))

buf_a = MtlArray{input_jl_type}(arr_a)
buf_b = MtlArray{input_jl_type}(arr_b)
buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))

truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
for i in 1:batch_size
@views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
end

MPS.matmul!(buf_c, buf_a, buf_b, alpha, beta)

@test all(Array(buf_c) .≈ truth_c)
end
end
end

@testset "test matrix vector multiplication of views" begin
N = 20
a = rand(Float32, N,N)
Expand Down

0 comments on commit 9f1d53e

Please sign in to comment.