Skip to content

Commit

Permalink
add MtlMatrixBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Sep 6, 2023
1 parent 68d1c9f commit 151ad82
Showing 1 changed file with 4 additions and 31 deletions.
35 changes: 4 additions & 31 deletions lib/mps/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ function encode!(cmdbuf::MTLCommandBuffer, matmul::MPSMatrixMultiplication, left
resultMatrix:result::id{MPSMatrix}]::Nothing
end

const MtlMatrixBatch{T,S} = MtlArray{T,3,S}

"""
matMulMPS(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
transpose_left=false, transpose_right=false)
Expand All @@ -182,9 +184,9 @@ A `MPSMatrixMultiplication` kernel thay computes:
This function should not typically be used. Rather, use the normal `LinearAlgebra` interface
with any `MtlArray` and it should be accelerated using Metal Performance Shaders.
"""
function matmul!(c::MtlMatrix, a::MtlMatrix, b::MtlMatrix,
function matmul!(c::T, a::T, b::T,
alpha::Number=true, beta::Number=true,
transpose_a=false, transpose_b=false)
transpose_a=false, transpose_b=false) where {T <: Union{MtlMatrix, MtlMatrixBatch}}
# NOTE: MPS uses row major, while Julia is col-major. Instead of transposing
# the inputs (by passing !transpose_[ab]) and afterwards transposing
# the output, we use the property that (AB)ᵀ = BᵀAᵀ
Expand Down Expand Up @@ -301,32 +303,3 @@ function topk(A::MtlMatrix{T,S}, k) where {T<:MtlFloat,S}

return _topk!(A, I, V, k)
end


function matmul!(c::MtlArray{T,3}, a::MtlArray{T,3}, b::MtlArray{T,3},
alpha::Number=true, beta::Number=true,
transpose_a=false, transpose_b=false) where {T}
# NOTE: MPS uses row major, while Julia is col-major. Instead of transposing
# the inputs (by passing !transpose_[ab]) and afterwards transposing
# the output, we use the property that (AB)ᵀ = BᵀAᵀ
cols_a = size(a)[2]
cols_c, rows_c = size(c)

# Create MPS-compatible matrix from the MtlArrays
mps_a = MPSMatrix(a)
mps_b = MPSMatrix(b)
mps_c = MPSMatrix(c)

mat_mul_kernel = MPSMatrixMultiplication(current_device(),
transpose_b, transpose_a,
rows_c, cols_c, cols_a,
alpha, beta)


# Encode and commit matmul kernel
cmdbuf = MTLCommandBuffer(global_queue(current_device()))
encode!(cmdbuf, mat_mul_kernel, mps_b, mps_a, mps_c)
commit!(cmdbuf)

c
end

0 comments on commit 151ad82

Please sign in to comment.