diff --git a/benchmark/bench_mat_mul.jl b/benchmark/bench_mat_mul.jl new file mode 100644 index 00000000..97e00523 --- /dev/null +++ b/benchmark/bench_mat_mul.jl @@ -0,0 +1,212 @@ +module BenchmarkMatMul + +using StaticArrays +using BenchmarkTools +using LinearAlgebra +using Printf + +suite = BenchmarkGroup() + +mul_wrappers = [ + (m -> m, "ident "), + (m -> Symmetric(m, :U), "sym-u "), + (m -> Hermitian(m, :U), "herm-u "), + (m -> UpperTriangular(m), "up-tri "), + (m -> LowerTriangular(m), "lo-tri "), + (m -> UnitUpperTriangular(m), "uup-tri"), + (m -> UnitLowerTriangular(m), "ulo-tri"), + (m -> Adjoint(m), "adjoint"), + (m -> Transpose(m), "transpo"), + (m -> Diagonal(m), "diag ")] + +mul_wrappers_reduced = [ + (m -> m, "ident "), + (m -> Symmetric(m, :U), "sym-u "), + (m -> UpperTriangular(m), "up-tri "), + (m -> Transpose(m), "transpo"), + (m -> Diagonal(m), "diag ")] + +for N in [2, 4, 8, 10, 16] + + matvecstr = @sprintf("mat-vec %2d", N) + matmatstr = @sprintf("mat-mat %2d", N) + matvec_mut_str = @sprintf("mat-vec! %2d", N) + matmat_mut_str = @sprintf("mat-mat! %2d", N) + + suite[matvecstr] = BenchmarkGroup() + suite[matmatstr] = BenchmarkGroup() + suite[matvec_mut_str] = BenchmarkGroup() + suite[matmat_mut_str] = BenchmarkGroup() + + + A = randn(SMatrix{N,N,Float64}) + B = randn(SMatrix{N,N,Float64}) + bv = randn(SVector{N,Float64}) + for (wrapper_a, wrapper_name) in mul_wrappers_reduced + thrown = false + try + wrapper_a(A) * bv + catch e + thrown = true + end + if !thrown + suite[matvecstr][wrapper_name] = @benchmarkable $(Ref(wrapper_a(A)))[] * $(Ref(bv))[] + end + end + + for (wrapper_a, wrapper_a_name) in mul_wrappers, (wrapper_b, wrapper_b_name) in mul_wrappers + thrown = false + try + wrapper_a(A) * wrapper_b(B) + catch e + thrown = true + end + if !thrown + suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $(Ref(wrapper_a(A)))[] * $(Ref(wrapper_b(B)))[] + end + end + + C = randn(MMatrix{N,N,Float64}) + cv = randn(MVector{N,Float64}) + + for (wrapper_a, wrapper_name) in mul_wrappers + thrown = false + try + mul!(cv, wrapper_a(A), bv) + catch e + thrown = true + end + if !thrown + suite[matvec_mut_str][wrapper_name] = @benchmarkable mul!($cv, $(Ref(wrapper_a(A)))[], $(Ref(bv))[]) + end + end + + for (wrapper_a, wrapper_a_name) in mul_wrappers, (wrapper_b, wrapper_b_name) in mul_wrappers + thrown = false + try + mul!(C, wrapper_a(A), wrapper_b(B)) + catch e + thrown = true + end + if !thrown + suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul!($C, $(Ref(wrapper_a(A)))[], $(Ref(wrapper_b(B)))[]) + end + end +end + +function run_and_save(fname, make_params = true) + if make_params + tune!(suite) + BenchmarkTools.save("params.json", params(suite)) + else + loadparams!(suite, BenchmarkTools.load("params.json")[1], :evals, :samples) + end + results = run(suite, verbose = true) + BenchmarkTools.save(fname, results) +end + +function judge_results(m1, m2) + results = Any[] + for key1 in keys(m1) + if !haskey(m2, key1) + continue + end + for key2 in keys(m1[key1]) + if !haskey(m2[key1], key2) + continue + end + push!(results, (key1, key2, judge(median(m1[key1][key2]), median(m2[key1][key2])))) + end + end + return results +end + +function full_benchmark(mul_wrappers, size_iter = 1:4, T = Float64) + suite_full = BenchmarkGroup() + for N in size_iter + for M in size_iter + a = randn(SMatrix{N,M,T}) + wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1]] + sa = Size(a) + for K in size_iter + b = randn(SMatrix{M,K,T}) + wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1]] + sb = Size(b) + for (w_a, w_a_name) in wrappers_a + for (w_b, w_b_name) in wrappers_b + cur_str = @sprintf("mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + suite_full[cur_str] = @benchmarkable StaticArrays.mul_generic($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) + cur_str = @sprintf("mat-mat %s %s default (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + suite_full[cur_str] = @benchmarkable StaticArrays._mul($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) + cur_str = @sprintf("mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + suite_full[cur_str] = @benchmarkable StaticArrays.mul_unrolled($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) + if w_a_name != "diag " && w_b_name != "diag " + cur_str = @sprintf("mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + suite_full[cur_str] = @benchmarkable StaticArrays.mul_unrolled_chunks($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) + end + if w_a_name == "ident " && w_b_name == "ident " + cur_str = @sprintf("mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + suite_full[cur_str] = @benchmarkable StaticArrays.mul_loop($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) + end + end + end + end + end + end + results = run(suite_full, verbose = true) + results_median = map(collect(results)) do res + return (res[1], median(res[2]).time) + end + return results_median +end + +function judge_this(new_time, old_time, tol, w_a_name, w_b_name, N, M, K, which) + if new_time*tol < old_time + msg = @sprintf("better for %s %s (%2d, %2d) x (%2d, %2d): %s", w_a_name, w_b_name, N, M, M, K, which) + println(msg) + println(">> ", new_time, " | ", old_time) + end +end + +function pick_best(results, mul_wrappers, size_iter; tol = 1.2) + for N in size_iter + for M in size_iter + wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1]] + for K in size_iter + wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1]] + for (w_a, w_a_name) in wrappers_a + for (w_b, w_b_name) in wrappers_b + cur_default = @sprintf("mat-mat %s %s default (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + default_time = results[cur_default] + + cur_generic = @sprintf("mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + generic_time = results[cur_generic] + judge_this(generic_time, default_time, tol, w_a_name, w_b_name, N, M, K, "generic") + + cur_unrolled = @sprintf("mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + unrolled_time = results[cur_unrolled] + judge_this(unrolled_time, default_time, tol, w_a_name, w_b_name, N, M, K, "unrolled") + + if w_a_name != "diag " && w_b_name != "diag " + cur_chunks = @sprintf("mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + chunk_time = results[cur_chunks] + judge_this(chunk_time, default_time, tol, w_a_name, w_b_name, N, M, K, "chunks") + end + if w_a_name == "ident " && w_b_name == "ident " + cur_loop = @sprintf("mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) + loop_time = results[cur_loop] + judge_this(loop_time, default_time, tol, w_a_name, w_b_name, N, M, K, "loop") + end + end + end + end + end + end +end + +function run_1() + return full_benchmark(mul_wrappers_reduced, [2, 3, 4, 5, 8, 9, 14, 16]) +end + +end #module +BenchmarkMatMul.suite diff --git a/src/SDiagonal.jl b/src/SDiagonal.jl index 37fef9a2..472d8351 100644 --- a/src/SDiagonal.jl +++ b/src/SDiagonal.jl @@ -18,8 +18,6 @@ size(::Type{<:SDiagonal{N}}) where {N} = (N,N) size(::Type{<:SDiagonal{N}}, d::Int) where {N} = d > 2 ? 1 : N # define specific methods to avoid allocating mutable arrays -*(A::StaticMatrix, D::SDiagonal) = A .* transpose(D.diag) -*(D::SDiagonal, A::StaticMatrix) = D.diag .* A \(D::SDiagonal, b::AbstractVector) = D.diag .\ b \(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity diff --git a/src/matrix_multiply.jl b/src/matrix_multiply.jl index b9bbb284..767e5eca 100644 --- a/src/matrix_multiply.jl +++ b/src/matrix_multiply.jl @@ -4,23 +4,312 @@ import LinearAlgebra: BlasFloat, matprod, mul! # Manage dispatch of * and mul! # TODO Adjoint? (Inner product?) +# *(A::StaticMatMulLike, B::AbstractVector) causes an ambiguity with SparseArrays @inline *(A::StaticMatrix, B::AbstractVector) = _mul(Size(A), A, B) +@inline *(A::StaticMatMulLike, B::StaticVector) = _mul(Size(A), Size(B), A, B) @inline *(A::StaticMatrix, B::StaticVector) = _mul(Size(A), Size(B), A, B) -@inline *(A::StaticMatrix, B::StaticMatrix) = _mul(Size(A), Size(B), A, B) -@inline *(A::StaticVector, B::StaticMatrix) = *(reshape(A, Size(Size(A)[1], 1)), B) +@inline *(A::StaticMatMulLike, B::StaticMatMulLike) = _mul(Size(A), Size(B), A, B) +@inline *(A::StaticVector, B::StaticMatMulLike) = *(reshape(A, Size(Size(A)[1], 1)), B) @inline *(A::StaticVector, B::Transpose{<:Any, <:StaticVector}) = _mul(Size(A), Size(B), A, B) @inline *(A::StaticVector, B::Adjoint{<:Any, <:StaticVector}) = _mul(Size(A), Size(B), A, B) @inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B @inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B +""" + gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :wrapped_a) + +Statically generate outer code for fully unrolled multiplication loops. +Returned code does wrapper-specific tests (for example if a symmetric matrix view is +`U` or `L`) and the body of the if expression is then generated by function `expr_gen`. +The function `expr_gen` receives access pattern description symbol as its argument +and this symbol is then consumed by uplo_access to generate the right code for matrix +element access. + +The name of the matrix to test is indicated by `asym`. +""" +function gen_by_access(expr_gen, a::Type{<:StaticVecOrMat}, asym = :wrapped_a) + return expr_gen(:any) +end +function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return quote + if $(asym).uplo == 'U' + $(expr_gen(:up)) + else + $(expr_gen(:lo)) + end + end +end +function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return quote + if $(asym).uplo == 'U' + $(expr_gen(:up_herm)) + else + $(expr_gen(:lo_herm)) + end + end +end +function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return expr_gen(:upper_triangular) +end +function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return expr_gen(:lower_triangular) +end +function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return expr_gen(:unit_upper_triangular) +end +function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a) + return expr_gen(:unit_lower_triangular) +end +function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a) + return expr_gen(:transpose) +end +function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a) + return expr_gen(:adjoint) +end +function gen_by_access(expr_gen, a::Type{<:SDiagonal}, asym = :wrapped_a) + return expr_gen(:diagonal) +end +""" + gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray}) + +Simiar to gen_by_access with only one type argument. The difference is that tests for both +arrays of type `a` and `b` are generated and `expr_gen` receives two access arguments, +first for matrix `a` and the second for matrix `b`. +""" +function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:any, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, b::Type) + return quote + if wrapped_a.uplo == 'U' + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:up, access_b) + end) + else + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:lo, access_b) + end) + end + end +end +function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, b::Type) + return quote + if wrapped_a.uplo == 'U' + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:up_herm, access_b) + end) + else + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:lo_herm, access_b) + end) + end + end +end +function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:upper_triangular, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:lower_triangular, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:unit_upper_triangular, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:unit_lower_triangular, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:transpose, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:adjoint, access_b) + end) + end +end +function gen_by_access(expr_gen, a::Type{<:SDiagonal}, b::Type) + return quote + return $(gen_by_access(b, :wrapped_b) do access_b + expr_gen(:diagonal, access_b) + end) + end +end + +""" + mul_result_structure(a::Type, b::Type) + +Get a structure wrapper that should be applied to the result of multiplication of matrices +of given types (a*b). +""" +function mul_result_structure(a, b) + return identity +end +function mul_result_structure(::UpperTriangular{<:Any, <:StaticMatrix}, ::UpperTriangular{<:Any, <:StaticMatrix}) + return UpperTriangular +end +function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::LowerTriangular{<:Any, <:StaticMatrix}) + return LowerTriangular +end +function mul_result_structure(::UpperTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) + return UpperTriangular +end +function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) + return LowerTriangular +end +function mul_result_structure(::SDiagonal, ::UpperTriangular{<:Any, <:StaticMatrix}) + return UpperTriangular +end +function mul_result_structure(::SDiagonal, ::LowerTriangular{<:Any, <:StaticMatrix}) + return LowerTriangular +end +function mul_result_structure(::UnitUpperTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) + return UpperTriangular +end +function mul_result_structure(::UnitLowerTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) + return LowerTriangular +end +function mul_result_structure(::SDiagonal, ::UnitUpperTriangular{<:Any, <:StaticMatrix}) + return UpperTriangular +end +function mul_result_structure(::SDiagonal, ::UnitLowerTriangular{<:Any, <:StaticMatrix}) + return LowerTriangular +end +function mul_result_structure(::SDiagonal, ::SDiagonal) + return Diagonal +end + +""" + uplo_access(sa, asym, k, j, uplo) + +Generate code for matrix element access, for a matrix of size `sa` locally referred to +as `asym` in the context where the result will be used. Both indices `k` and `j` need to be +statically known for this function to work. `uplo` is the access pattern mode generated +by the `gen_by_access` function. +""" +function uplo_access(sa, asym, k, j, uplo) + TAsym = Symbol("T"*string(asym)) + if uplo == :any + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif uplo == :up + if k < j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :U)) + else + return :(transpose($asym[$(LinearIndices(sa)[j, k])])) + end + elseif uplo == :lo + if k > j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :L)) + else + return :(transpose($asym[$(LinearIndices(sa)[j, k])])) + end + elseif uplo == :up_herm + if k < j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :U)) + else + return :(adjoint($asym[$(LinearIndices(sa)[j, k])])) + end + elseif uplo == :lo_herm + if k > j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :L)) + else + return :(adjoint($asym[$(LinearIndices(sa)[j, k])])) + end + elseif uplo == :upper_triangular + if k <= j + return :($asym[$(LinearIndices(sa)[k, j])]) + else + return :(zero($TAsym)) + end + elseif uplo == :lower_triangular + if k >= j + return :($asym[$(LinearIndices(sa)[k, j])]) + else + return :(zero($TAsym)) + end + elseif uplo == :unit_upper_triangular + if k < j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(oneunit($TAsym)) + else + return :(zero($TAsym)) + end + elseif uplo == :unit_lower_triangular + if k > j + return :($asym[$(LinearIndices(sa)[k, j])]) + elseif k == j + return :(oneunit($TAsym)) + else + return :(zero($TAsym)) + end + elseif uplo == :upper_hessenberg + if k <= j+1 + return :($asym[$(LinearIndices(sa)[k, j])]) + else + return :(zero($TAsym)) + end + elseif uplo == :transpose + return :(transpose($asym[$(LinearIndices(reverse(sa))[j, k])])) + elseif uplo == :adjoint + return :(adjoint($asym[$(LinearIndices(reverse(sa))[j, k])])) + elseif uplo == :diagonal + if k == j + return :($asym[$k]) + else + return :(zero($TAsym)) + end + else + error("Unknown uplo: $uplo") + end +end # Implementations -@generated function _mul(::Size{sa}, a::StaticMatrix{<:Any, <:Any, Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb} +function mul_smat_vec_exprs(sa, access_a) + return [combine_products([:($(uplo_access(sa, :a, k, j, access_a))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]] +end + +@generated function _mul(::Size{sa}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb} if sa[2] != 0 - exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]] + retexpr = gen_by_access(wrapped_a) do access_a + exprs = mul_smat_vec_exprs(sa, access_a) + return :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))) + end else exprs = [:(zero(T)) for k = 1:sa[1]] + retexpr = :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))) end return quote @@ -29,25 +318,31 @@ import LinearAlgebra: BlasFloat, matprod, mul! throw(DimensionMismatch("Tried to multiply arrays of size $sa and $(size(b))")) end T = promote_op(matprod,Ta,Tb) - @inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))) + a = mul_parent(wrapped_a) + $retexpr end end -@generated function _mul(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb} +@generated function _mul(::Size{sa}, ::Size{sb}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb} if sb[1] != sa[2] throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb")) end if sa[2] != 0 - exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]] + retexpr = gen_by_access(wrapped_a) do access_a + exprs = mul_smat_vec_exprs(sa, access_a) + return :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))) + end else exprs = [:(zero(T)) for k = 1:sa[1]] + retexpr = :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))) end return quote @_inline_meta T = promote_op(matprod,Ta,Tb) - @inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))) + a = mul_parent(wrapped_a) + $retexpr end end @@ -55,8 +350,13 @@ end @generated function _mul(::Size{sa}, ::Size{sb}, a::StaticVector{<: Any, Ta}, b::Union{Transpose{Tb, <:StaticVector}, Adjoint{Tb, <:StaticVector}}) where {sa, sb, Ta, Tb} newsize = (sa[1], sb[2]) - exprs = [:(a[$i]*b[$j]) for i = 1:sa[1], j = 1:sb[2]] - + conjugate_b = b <: Adjoint + if conjugate_b + exprs = [:(a[$i] * adjoint(b[$j])) for i = 1:sa[1], j = 1:sb[2]] + else + exprs = [:(a[$i] * transpose(b[$j])) for i = 1:sa[1], j = 1:sb[2]] + end + return quote @_inline_meta T = promote_op(*, Ta, Tb) @@ -64,60 +364,69 @@ end end end -@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} - # Heuristic choice for amount of codegen - if sa[1]*sa[2]*sb[2] <= 8*8*8 - return quote - @_inline_meta - return mul_unrolled(Sa, Sb, a, b) +_unstatic_array(::Type{TSA}) where {S, T, N, TSA<:StaticArray{S,T,N}} = AbstractArray{T,N} +for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular, UnitUpperTriangular, UnitLowerTriangular, Diagonal] + @eval _unstatic_array(::Type{$TWR{T,TSA}}) where {S, T, N, TSA<:StaticArray{S,T,N}} = $TWR{T,<:AbstractArray{T,N}} +end + +function combine_products(expr_list) + filtered = filter(expr_list) do expr + if expr.head != :call || expr.args[1] != :* + error("expected call to *") end - elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14 - return quote - @_inline_meta - return mul_unrolled_chunks(Sa, Sb, a, b) + for arg in expr.args[2:end] + if isa(arg, Expr) && arg.head == :call && arg.args[1] == :zero + return false + end end + return true + end + if isempty(filtered) + return :(zero(T)) else - return quote - @_inline_meta - return mul_loop(Sa, Sb, a, b) + return reduce(filtered) do ex1, ex2 + if ex2.head != :call || ex2.args[1] != :* + error("expected call to *") + end + + return :(muladd($(ex2.args[2]), $(ex2.args[3]), $ex1)) end end end -@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::Union{SizedMatrix{T}, MMatrix{T}, MArray{T}}, b::Union{SizedMatrix{T}, MMatrix{T}, MArray{T}}) where {sa, sb, T <: BlasFloat} +@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} S = Size(sa[1], sb[2]) - - # Heuristic choice between BLAS and explicit unrolling (or chunk-based unrolling) - if sa[1]*sa[2]*sb[2] >= 14*14*14 - Sa = TSize{size(S),false}() - Sb = TSize{sa,false}() - Sc = TSize{sb,false}() - _add = MulAddMul(true,false) + # Heuristic choice for amount of codegen + a_tri_mul = a <: LinearAlgebra.AbstractTriangular ? 4 : 1 + b_tri_mul = b <: LinearAlgebra.AbstractTriangular ? 4 : 1 + ab_tri_mul = (a == 4 && b == 4) ? 2 : 1 + if a <: StaticMatrix && b <: StaticMatrix + # Julia unrolls these loops pretty well return quote @_inline_meta - C = similar(a, T, $S) - mul_blas!($Sa, C, $Sa, $Sb, a, b, $_add) - return C + return mul_loop(Sa, Sb, a, b) end - elseif sa[1]*sa[2]*sb[2] < 8*8*8 + elseif sa[1]*sa[2]*sb[2] <= 4*8*8*8*a_tri_mul*b_tri_mul*ab_tri_mul || a <: Diagonal || b <: Diagonal return quote @_inline_meta return mul_unrolled(Sa, Sb, a, b) end - elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14 + elseif (sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14) || !(a <: StaticMatrix) || !(b <: StaticMatrix) return quote @_inline_meta - return similar_type(a, T, $S)(mul_unrolled_chunks(Sa, Sb, a, b)) + return mul_unrolled_chunks(Sa, Sb, a, b) end else + # we don't have any special code for handling this case so let's fall back to + # the generic implementation of matrix multiplication return quote @_inline_meta - return mul_loop(Sa, Sb, a, b) + return mul_generic(Sa, Sb, a, b) end end end -@generated function mul_unrolled(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} +@generated function mul_unrolled(::Size{sa}, ::Size{sb}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, wrapped_b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} if sb[1] != sa[2] throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb")) end @@ -125,19 +434,25 @@ end S = Size(sa[1], sb[2]) if sa[2] != 0 - exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k1, j])]*b[$(LinearIndices(sb)[j, k2])]) for j = 1:sa[2]]) for k1 = 1:sa[1], k2 = 1:sb[2]] + retexpr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b + exprs = [combine_products([:($(uplo_access(sa, :a, k1, j, access_a))*$(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]] + ) for k1 = 1:sa[1], k2 = 1:sb[2]] + return :((mul_result_structure(wrapped_a, wrapped_b))(similar_type(a, T, $S)(tuple($(exprs...))))) + end else exprs = [:(zero(T)) for k1 = 1:sa[1], k2 = 1:sb[2]] + retexpr = :(return (mul_result_structure(wrapped_a, wrapped_b))(similar_type(a, T, $S)(tuple($(exprs...))))) end return quote @_inline_meta T = promote_op(matprod,Ta,Tb) - @inbounds return similar_type(a, T, $S)(tuple($(exprs...))) + a = mul_parent(wrapped_a) + b = mul_parent(wrapped_b) + @inbounds $retexpr end end - @generated function mul_loop(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} if sb[1] != sa[2] throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb")) @@ -145,25 +460,67 @@ end S = Size(sa[1], sb[2]) - tmps = [Symbol("tmp_$(k1)_$(k2)") for k1 = 1:sa[1], k2 = 1:sb[2]] - exprs_init = [:($(tmps[k1,k2]) = a[$k1] * b[1 + $((k2-1) * sb[1])]) for k1 = 1:sa[1], k2 = 1:sb[2]] - exprs_loop = [:($(tmps[k1,k2]) += a[$(k1-sa[1]) + $(sa[1])*j] * b[j + $((k2-1) * sb[1])]) for k1 = 1:sa[1], k2 = 1:sb[2]] - + # optimal for AVX2 with `Float64 + # AVX512 would want something more like 16x14 or 24x9 with `Float64` + M_r, N_r = 8, 6 + n = 0 + M, K = sa + N = sb[2] + q = Expr(:block) + atemps = [Symbol(:a_, k1) for k1 = 1:M] + tmps = [Symbol("tmp_$(k1)_$(k2)") for k1 = 1:M, k2 = 1:N] + while n < N + nu = min(N, n + N_r) + nrange = n+1:nu + m = 0 + while m < M + mu = min(M, m + M_r) + mrange = m+1:mu + + atemps_init = [:($(atemps[k1]) = a[$k1]) for k1 = mrange] + exprs_init = [:($(tmps[k1,k2]) = $(atemps[k1]) * b[$(1 + (k2-1) * sb[1])]) for k1 = mrange, k2 = nrange] + atemps_loop_init = [:($(atemps[k1]) = a[$(k1-sa[1]) + $(sa[1])*j]) for k1 = mrange] + exprs_loop = [:($(tmps[k1,k2]) = muladd($(atemps[k1]), b[j + $((k2-1) * sb[1])], $(tmps[k1,k2]))) for k1 = mrange, k2 = nrange] + qblock = quote + @inbounds $(Expr(:block, atemps_init...)) + @inbounds $(Expr(:block, exprs_init...)) + for j = 2:$(sa[2]) + @inbounds $(Expr(:block, atemps_loop_init...)) + @inbounds $(Expr(:block, exprs_loop...)) + end + end + push!(q.args, qblock) + m = mu + end + n = nu + end return quote @_inline_meta T = promote_op(matprod,Ta,Tb) - - @inbounds $(Expr(:block, exprs_init...)) - for j = 2:$(sa[2]) - @inbounds $(Expr(:block, exprs_loop...)) - end + $q @inbounds return similar_type(a, T, $S)(tuple($(tmps...))) end end +@generated function mul_generic(::Size{sa}, ::Size{sb}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, wrapped_b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} + if sb[1] != sa[2] + throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb")) + end + + S = Size(sa[1], sb[2]) + + return quote + @_inline_meta + T = promote_op(matprod, Ta, Tb) + a = mul_parent(wrapped_a) + b = mul_parent(wrapped_b) + return (mul_result_structure(wrapped_a, wrapped_b))(similar_type(a, T, $S)(invoke(*, Tuple{$(_unstatic_array(a)),$(_unstatic_array(b))}, a, b))) + end +end + # Concatenate a series of matrix-vector multiplications # Each function is N^2 not N^3 - aids in compile time. -@generated function mul_unrolled_chunks(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} +@generated function mul_unrolled_chunks(::Size{sa}, ::Size{sb}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, wrapped_b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} if sb[1] != sa[2] throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb")) end @@ -173,19 +530,74 @@ end # Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than (possibly) a mutable type. Avoids allocation == faster tmp_type_in = :(SVector{$(sb[1]), T}) tmp_type_out = :(SVector{$(sa[1]), T}) - vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(TSize(a), TSize($(sb[1])), a, - $(Expr(:call, tmp_type_in, [Expr(:ref, :b, LinearIndices(sb)[i, k2]) for i = 1:sb[1]]...)))::$tmp_type_out) - for k2 = 1:sb[2]] - exprs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]] + retexpr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b + vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply($(Size{sa}()), $(Size{(sb[1],)}()), + a, $(Expr(:call, tmp_type_in, [uplo_access(sb, :b, i, k2, access_b) for i = 1:sb[1]]...)), $(Val(access_a)))::$tmp_type_out) for k2 = 1:sb[2]] + + exprs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]] + + return quote + @inbounds $(Expr(:block, vect_exprs...)) + $(Expr(:block, + :(@inbounds return (mul_result_structure(wrapped_a, wrapped_b))(similar_type(a, T, $S)(tuple($(exprs...))))) + )) + end + end + return quote + @_inline_meta + T = promote_op(matprod, Ta, Tb) + a = mul_parent(wrapped_a) + b = mul_parent(wrapped_b) + $retexpr + end +end + +# a special version for plain matrices +@generated function mul_unrolled_chunks(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} + if sb[1] != sa[2] + throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb")) + end + + S = Size(sa[1], sb[2]) + + # optimal for AVX2 with `Float64 + # AVX512 would want something more like 16x14 or 24x9 with `Float64` + M_r, N_r = 8, 6 + n = 0 + M, K = sa + N = sb[2] + q = Expr(:block) + atemps = [Symbol(:a_, k1) for k1 = 1:M] + tmps = [Symbol("tmp_$(k1)_$(k2)") for k1 = 1:M, k2 = 1:N] + while n < N + nu = min(N, n + N_r) + nrange = n+1:nu + m = 0 + while m < M + mu = min(M, m + M_r) + mrange = m+1:mu + atemps_init = [:($(atemps[k1]) = a[$k1]) for k1 = mrange] + exprs_init = [:($(tmps[k1,k2]) = $(atemps[k1]) * b[$(1 + (k2-1) * sb[1])]) for k1 = mrange, k2 = nrange] + push!(q.args, :(@inbounds $(Expr(:block, atemps_init...)))) + push!(q.args, :(@inbounds $(Expr(:block, exprs_init...)))) + + for j in 2:K + atemps_loop_init = [:($(atemps[k1]) = a[$(LinearIndices(sa)[k1,j])]) for k1 = mrange] + exprs_loop = [:($(tmps[k1,k2]) = muladd($(atemps[k1]), b[$(LinearIndices(sb)[j,k2])], $(tmps[k1,k2]))) for k1 = mrange, k2 = nrange] + push!(q.args, :(@inbounds $(Expr(:block, atemps_loop_init...)))) + push!(q.args, :(@inbounds $(Expr(:block, exprs_loop...)))) + end + m = mu + end + n = nu + end return quote @_inline_meta T = promote_op(matprod,Ta,Tb) - $(Expr(:block, - vect_exprs..., - :(@inbounds return similar_type(a, T, $S)(tuple($(exprs...)))) - )) + $q + @inbounds return similar_type(a, T, $S)(tuple($(tmps...))) end end diff --git a/src/matrix_multiply_add.jl b/src/matrix_multiply_add.jl index dc25ae58..e3eac95e 100644 --- a/src/matrix_multiply_add.jl +++ b/src/matrix_multiply_add.jl @@ -1,21 +1,35 @@ # import LinearAlgebra.MulAddMul -abstract type MulAddMul{T} end +abstract type MulAddMul{TA,TB} end -struct AlphaBeta{T} <: MulAddMul{T} - α::T - β::T - function AlphaBeta{T}(α,β) where T <: Real - new{T}(α,β) - end +struct AlphaBeta{TA,TB} <: MulAddMul{TA,TB} + α::TA + β::TB end -@inline AlphaBeta(α::A,β::B) where {A,B} = AlphaBeta{promote_type(A,B)}(α,β) @inline alpha(ab::AlphaBeta) = ab.α @inline beta(ab::AlphaBeta) = ab.β -struct NoMulAdd{T} <: MulAddMul{T} end -@inline alpha(ma::NoMulAdd{T}) where T = one(T) -@inline beta(ma::NoMulAdd{T}) where T = zero(T) +struct NoMulAdd{TA,TB} <: MulAddMul{TA,TB} end +@inline alpha(ma::NoMulAdd{TA,TB}) where {TA,TB} = one(TA) +@inline beta(ma::NoMulAdd{TA,TB}) where {TA,TB} = zero(TB) + +""" + StaticMatMulLike + +Static wrappers used for multiplication dispatch. +""" +const StaticMatMulLike{s1, s2, T} = Union{ + StaticMatrix{s1, s2, T}, + Symmetric{T, <:StaticMatrix{s1, s2, T}}, + Hermitian{T, <:StaticMatrix{s1, s2, T}}, + LowerTriangular{T, <:StaticMatrix{s1, s2, T}}, + UpperTriangular{T, <:StaticMatrix{s1, s2, T}}, + UnitLowerTriangular{T, <:StaticMatrix{s1, s2, T}}, + UnitUpperTriangular{T, <:StaticMatrix{s1, s2, T}}, + Adjoint{T, <:StaticMatrix{s1, s2, T}}, + Transpose{T, <:StaticMatrix{s1, s2, T}}, + SDiagonal{s1, T}} + """ Size that stores whether a Matrix is a Transpose Useful when selecting multiplication methods, and avoiding allocations when dealing with @@ -24,34 +38,39 @@ Should pair with `parent`. """ struct TSize{S,T} function TSize{S,T}() where {S,T} - new{S::Tuple{Vararg{StaticDimension}},T::Bool}() + new{S::Tuple{Vararg{StaticDimension}},T::Symbol}() end end -TSize(A::Type{<:Transpose{<:Any,<:StaticArray}}) = TSize{size(A),true}() -TSize(A::Type{<:Adjoint{<:Real,<:StaticArray}}) = TSize{size(A),true}() # can't handle complex adjoints yet -TSize(A::Type{<:StaticArray}) = TSize{size(A),false}() +TSize(A::Type{<:StaticArrayLike}) = TSize{size(A), gen_by_access(identity, A)}() TSize(A::StaticArrayLike) = TSize(typeof(A)) -TSize(S::Size{s}, T=false) where s = TSize{s,T}() +TSize(S::Size{s}, T=:any) where s = TSize{s,T}() TSize(s::Number) = TSize(Size(s)) -istranpose(::TSize{<:Any,T}) where T = T +istranspose(::TSize{<:Any,T}) where T = (T === :transpose) size(::TSize{S}) where S = S Size(::TSize{S}) where S = Size{S}() -Base.transpose(::TSize{S,T}) where {S,T} = TSize{reverse(S),!T}() +access_type(::TSize{<:Any,T}) where T = T +Base.transpose(::TSize{S,:transpose}) where {S,T} = TSize{reverse(S),:any}() +Base.transpose(::TSize{S,:any}) where {S,T} = TSize{reverse(S),:transpose}() # Get the parent of transposed arrays, or the array itself if it has no parent -# QUESTION: maybe call this something else? -mul_parent(A) = parent(A) -mul_parent(A::StaticArray) = A +# Different from Base.parent because we only want to get rid of Transpose and Adjoint +# The two last methods can't be combined into one for StaticVecOrMat because then dispatch +# goes wrong for SizedArray +@inline mul_parent(A::Union{StaticMatMulLike, Adjoint{<:Any,<:StaticVector}, Transpose{<:Any,<:StaticVector}}) = Base.parent(A) +@inline mul_parent(A::StaticMatrix) = A +@inline mul_parent(A::StaticVector) = A # 5-argument matrix multiplication # To avoid allocations, strip away Transpose type and store tranpose info in Size @inline LinearAlgebra.mul!(dest::StaticVecOrMatLike, A::StaticVecOrMatLike, B::StaticVecOrMatLike, - α::Real, β::Real) = _mul!(TSize(dest), mul_parent(dest), TSize(A), TSize(B), mul_parent(A), mul_parent(B), + α::Real, β::Real) = _mul!(TSize(dest), mul_parent(dest), Size(A), Size(B), A, B, AlphaBeta(α,β)) -@inline LinearAlgebra.mul!(dest::StaticVecOrMatLike, A::StaticVecOrMatLike{T}, - B::StaticVecOrMatLike{T}) where T = - _mul!(TSize(dest), mul_parent(dest), TSize(A), TSize(B), mul_parent(A), mul_parent(B), NoMulAdd{T}()) +@inline function LinearAlgebra.mul!(dest::StaticVecOrMatLike{TDest}, A::StaticVecOrMatLike{TA}, + B::StaticVecOrMatLike{TB}) where {TDest,TA,TB} + TMul = promote_op(matprod, TA, TB) + return _mul!(TSize(dest), mul_parent(dest), Size(A), Size(B), A, B, NoMulAdd{TMul, TDest}()) +end "Calculate the product of the dimensions being multiplied. Useful as a heuristic for unrolling." @@ -95,52 +114,64 @@ end end "Obtain an expression for the linear index of var[k,j], taking transposes into account" -@inline _lind(A::Type{<:TSize}, k::Int, j::Int) = _lind(:a, A, k, j) function _lind(var::Symbol, A::Type{TSize{sa,tA}}, k::Int, j::Int) where {sa,tA} - if tA - return :($var[$(LinearIndices(reverse(sa))[j, k])]) - else - return :($var[$(LinearIndices(sa)[k, j])]) + ula = uplo_access(sa, var, k, j, tA) + if ula.head == :call && ula.args[1] == :transpose + # TODO: can this be properly fixed at all? + return ula.args[2] end + return ula end + + # Matrix-vector multiplication -@generated function _mul!(Sc::TSize{sc}, c::StaticVecOrMatLike, Sa::TSize{sa}, Sb::TSize{sb}, - a::StaticMatrix, b::StaticVector, _add::MulAddMul, - ::Val{col}=Val(1)) where {sa, sb, sc, col} +@generated function _mul!(Sc::TSize{sc}, c::StaticVecOrMatLike, Sa::Size{sa}, Sb::Size{sb}, + wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}, _add::MulAddMul, + ::Val{col}=Val(1)) where {sa, sb, sc, col, Ta, Tb} if sa[2] != sb[1] || sc[1] != sa[1] throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc")) end if sa[2] != 0 - lhs = [:($(_lind(:c,Sc,k,col))) for k = 1:sa[1]] - ab = [:($(reduce((ex1,ex2) -> :(+($ex1,$ex2)), - [:($(_lind(Sa,k,j))*b[$j]) for j = 1:sa[2]]))) for k = 1:sa[1]] - exprs = _muladd_expr(lhs, ab, _add) + assign_expr = gen_by_access(wrapped_a) do access_a + lhs = [_lind(:c,Sc,k,col) for k = 1:sa[1]] + ab = [combine_products([:($(uplo_access(sa, :a, k, j, access_a)) * b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]] + exprs = _muladd_expr(lhs, ab, _add) + + return :(@inbounds $(Expr(:block, exprs...))) + end else exprs = [:(c[$k] = zero(eltype(c))) for k = 1:sa[1]] + assign_expr = :(@inbounds $(Expr(:block, exprs...))) end return quote # @_inline_meta - # α = _add.alpha - # β = _add.beta α = alpha(_add) β = beta(_add) - @inbounds $(Expr(:block, exprs...)) + a = mul_parent(wrapped_a) + $assign_expr return c end end # Outer product -@generated function _mul!(::TSize{sc}, c::StaticMatrix, ::TSize{sa,false}, ::TSize{sb,true}, - a::StaticVector, b::StaticVector, _add::MulAddMul) where {sa, sb, sc} +@generated function _mul!(::TSize{sc}, c::StaticMatrix, tsa::Size{sa}, tsb::Size{sb}, + a::StaticVector, b::Union{Transpose{<:Any, <:StaticVector}, Adjoint{<:Any, <:StaticVector}}, _add::MulAddMul) where {sa, sb, sc} if sc[1] != sa[1] || sc[2] != sb[2] throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc")) end + conjugate_b = b <: Adjoint + lhs = [:(c[$(LinearIndices(sc)[i,j])]) for i = 1:sa[1], j = 1:sb[2]] - ab = [:(a[$i] * b[$j]) for i = 1:sa[1], j = 1:sb[2]] + if conjugate_b + ab = [:(a[$i] * adjoint(b[$j])) for i = 1:sa[1], j = 1:sb[2]] + else + ab = [:(a[$i] * transpose(b[$j])) for i = 1:sa[1], j = 1:sb[2]] + end + exprs = _muladd_expr(lhs, ab, _add) return quote @@ -153,15 +184,18 @@ end end # Matrix-matrix multiplication -@generated function _mul!(Sc::TSize{sc}, c::StaticMatrixLike, - Sa::TSize{sa}, Sb::TSize{sb}, - a::StaticMatrixLike, b::StaticMatrixLike, +@generated function _mul!(Sc::TSize{sc}, c::StaticMatMulLike, + Sa::Size{sa}, Sb::Size{sb}, + a::StaticMatMulLike, b::StaticMatMulLike, _add::MulAddMul) where {sa, sb, sc} Ta,Tb,Tc = eltype(a), eltype(b), eltype(c) - can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat + can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat && a <: Union{StaticMatrix,Transpose} && b <: Union{StaticMatrix,Transpose} mult_dim = multiplied_dimension(a,b) - if mult_dim < 4*4*4 + a_tri_mul = a <: LinearAlgebra.AbstractTriangular ? 2 : 1 + b_tri_mul = b <: LinearAlgebra.AbstractTriangular ? 2 : 1 + ab_tri_mul = (a == 2 && b == 2) ? 2 : 1 + if mult_dim < 4*4*4*a_tri_mul*b_tri_mul*ab_tri_mul || a <: Diagonal || b <: Diagonal return quote @_inline_meta muladd_unrolled_all!(Sc, c, Sa, Sb, a, b, _add) @@ -177,7 +211,7 @@ end if can_blas return quote @_inline_meta - mul_blas!(Sc, c, Sa, Sb, a, b, _add) + mul_blas!(Sc, c, TSize(a), TSize(b), mul_parent(a), mul_parent(b), _add) return c end else @@ -191,18 +225,26 @@ end end -@generated function muladd_unrolled_all!(Sc::TSize{sc}, c::StaticMatrixLike, Sa::TSize{sa}, Sb::TSize{sb}, - a::StaticMatrixLike, b::StaticMatrixLike, _add::MulAddMul) where {sa, sb, sc} +@generated function muladd_unrolled_all!(Sc::TSize{sc}, wrapped_c::StaticMatMulLike, Sa::Size{sa}, Sb::Size{sb}, + wrapped_a::StaticMatMulLike{<:Any,<:Any,Ta}, wrapped_b::StaticMatMulLike{<:Any,<:Any,Tb}, _add::MulAddMul) where {sa, sb, sc, Ta, Tb} if !check_dims(Size(sc),Size(sa),Size(sb)) throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc")) end if sa[2] != 0 - lhs = [:($(_lind(:c, Sc, k1, k2))) for k1 = 1:sa[1], k2 = 1:sb[2]] - ab = [:($(reduce((ex1,ex2) -> :(+($ex1,$ex2)), - [:($(_lind(:a, Sa, k1, j)) * $(_lind(:b, Sb, j, k2))) for j = 1:sa[2]] - ))) for k1 = 1:sa[1], k2 = 1:sb[2]] - exprs = _muladd_expr(lhs, ab, _add) + lhs = [_lind(:c, Sc, k1, k2) for k1 = 1:sa[1], k2 = 1:sb[2]] + + assign_expr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b + + ab = [combine_products([:($(uplo_access(sa, :a, k1, j, access_a)) * $(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]] + ) for k1 = 1:sa[1], k2 = 1:sb[2]] + + exprs = _muladd_expr(lhs, ab, _add) + return :(@inbounds $(Expr(:block, exprs...))) + end + else + exprs = [:(c[$k] = zero(eltype(c))) for k = 1:sc[1]*sc[2]] + assign_expr = :(@inbounds $(Expr(:block, exprs...))) end return quote @@ -211,49 +253,64 @@ end # β = _add.beta α = alpha(_add) β = beta(_add) - @inbounds $(Expr(:block, exprs...)) + c = mul_parent(wrapped_c) + a = mul_parent(wrapped_a) + b = mul_parent(wrapped_b) + T = promote_op(matprod,Ta,Tb) + $assign_expr + return c end end -@generated function muladd_unrolled_chunks!(Sc::TSize{sc}, c::StaticMatrix, ::TSize{sa,tA}, Sb::TSize{sb,tB}, - a::StaticMatrix, b::StaticMatrix, _add::MulAddMul) where {sa, sb, sc, tA, tB} +@generated function muladd_unrolled_chunks!(Sc::TSize{sc}, wrapped_c::StaticMatMulLike, ::Size{sa}, Sb::Size{sb}, + wrapped_a::StaticMatMulLike{<:Any,<:Any,Ta}, wrapped_b::StaticMatMulLike{<:Any,<:Any,Tb}, _add::MulAddMul) where {sa, sb, sc, Ta, Tb} if sb[1] != sa[2] || sa[1] != sc[1] || sb[2] != sc[2] throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb and assign to array of size $sc")) end + # This will not work for Symmetric and Hermitian wrappers of c + lhs = [_lind(:c, Sc, k1, k2) for k1 = 1:sa[1], k2 = 1:sb[2]] + #vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply(A, B[:, $k2])) for k2 = 1:sB[2]] # Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than a mutable type. Avoids allocation == faster - tmp_type = SVector{sb[1], eltype(c)} - vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply($(TSize{sa,tA}()), $(TSize{(sb[1],),tB}()), - a, $(Expr(:call, tmp_type, [:($(_lind(:b, Sb, i, k2))) for i = 1:sb[1]]...)))) for k2 = 1:sb[2]] + tmp_type = SVector{sb[1], eltype(wrapped_c)} + + assign_expr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b + vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply($(Size{sa}()), $(Size{(sb[1],)}()), + a, $(Expr(:call, tmp_type, [uplo_access(sb, :b, i, k2, access_b) for i = 1:sb[1]]...)), $(Val(access_a)))) for k2 = 1:sb[2]] - lhs = [:($(_lind(:c, Sc, k1, k2))) for k1 = 1:sa[1], k2 = 1:sb[2]] - # exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]] - rhs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]] - exprs = _muladd_expr(lhs, rhs, _add) + # exprs = [:(c[$(LinearIndices(sc)[k1, k2])] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]] + rhs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]] + exprs = _muladd_expr(lhs, rhs, _add) + return quote + @inbounds $(Expr(:block, vect_exprs...)) + @inbounds $(Expr(:block, exprs...)) + end + end + return quote @_inline_meta - # α = _add.alpha - # β = _add.beta α = alpha(_add) β = beta(_add) - @inbounds $(Expr(:block, vect_exprs...)) - @inbounds $(Expr(:block, exprs...)) + c = mul_parent(wrapped_c) + a = mul_parent(wrapped_a) + b = mul_parent(wrapped_b) + $assign_expr end end # @inline partly_unrolled_multiply(Sa::Size, Sb::Size, a::StaticMatrix, b::StaticArray) where {sa, sb, Ta, Tb} = # partly_unrolled_multiply(TSize(Sa), TSize(Sb), a, b) -@generated function partly_unrolled_multiply(Sa::TSize{sa}, ::TSize{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticArray{<:Tuple, Tb}) where {sa, sb, Ta, Tb} +@generated function partly_unrolled_multiply(Sa::Size{sa}, ::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticArray{<:Tuple, Tb}, ::Val{access_a}) where {sa, sb, Ta, Tb, access_a} if sa[2] != sb[1] throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb")) end if sa[2] != 0 - exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:($(_lind(:a,Sa,k,j))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]] + exprs = [combine_products([:($(uplo_access(sa, :a, k, j, access_a))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]] else exprs = [:(zero(promote_op(matprod,Ta,Tb))) for k = 1:sa[1]] end @@ -266,18 +323,21 @@ end @inline _get_raw_data(A::SizedArray) = A.data @inline _get_raw_data(A::StaticArray) = A +# we need something heap-allocated to make sure BLAS calls are safe +@inline _get_raw_data(A::SArray) = MArray(A) -function mul_blas!(::TSize{<:Any,false}, c::StaticMatrix, ::TSize{<:Any,tA}, ::TSize{<:Any,tB}, - a::StaticMatrix, b::StaticMatrix, _add::MulAddMul) where {tA,tB} - mat_char(tA) = tA ? 'T' : 'N' +function mul_blas!(::TSize{<:Any,:any}, c::StaticMatrix, + Sa::Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}}, Sb::Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}}, + a::StaticMatrix, b::StaticMatrix, _add::MulAddMul) + mat_char(s) = istranspose(s) ? 'T' : 'N' T = eltype(a) A = _get_raw_data(a) B = _get_raw_data(b) C = _get_raw_data(c) - BLAS.gemm!(mat_char(tA), mat_char(tB), T(alpha(_add)), A, B, T(beta(_add)), C) + BLAS.gemm!(mat_char(Sa), mat_char(Sb), T(alpha(_add)), A, B, T(beta(_add)), C) end # if C is transposed, transpose the entire expression -@inline mul_blas!(Sc::TSize{<:Any,true}, c::StaticMatrix, Sa::TSize, Sb::TSize, +@inline mul_blas!(Sc::TSize{<:Any,:transpose}, c::StaticMatrix, Sa::TSize, Sb::TSize, a::StaticMatrix, b::StaticMatrix, _add::MulAddMul) = mul_blas!(transpose(Sc), c, transpose(Sb), transpose(Sa), b, a, _add) diff --git a/src/triangular.jl b/src/triangular.jl index 499ca66b..49470ae4 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -6,373 +6,19 @@ LinearAlgebra.LowerTriangular(transpose(A.data)) @inline adjoint(A::LinearAlgebra.UpperTriangular{<:Any,<:StaticMatrix}) = LinearAlgebra.LowerTriangular(adjoint(A.data)) -@inline Base.:*(A::Adjoint{<:Any,<:StaticVecOrMat}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = +@inline Base.:*(A::Adjoint{<:Any,<:StaticVector}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = adjoint(adjoint(B) * adjoint(A)) -@inline Base.:*(A::Transpose{<:Any,<:StaticVecOrMat}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = +@inline Base.:*(A::Transpose{<:Any,<:StaticVector}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = transpose(transpose(B) * transpose(A)) -@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Adjoint{<:Any,<:StaticVecOrMat}) = +@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Adjoint{<:Any,<:StaticVector}) = adjoint(adjoint(B) * adjoint(A)) -@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Transpose{<:Any,<:StaticVecOrMat}) = +@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Transpose{<:Any,<:StaticVector}) = transpose(transpose(B) * transpose(A)) const StaticULT = Union{UpperTriangular{<:Any,<:StaticMatrix},LowerTriangular{<:Any,<:StaticMatrix}} -@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::StaticVecOrMat) = _A_mul_B(Size(A), Size(B), A, B) -@inline Base.:*(A::StaticVecOrMat, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) = _A_mul_B(Size(A), Size(B), A, B) -@inline Base.:*(A::StaticULT, B::StaticULT) = _A_mul_B(Size(A), Size(B), A, B) @inline Base.:\(A::StaticULT, B::StaticVecOrMat) = _A_ldiv_B(Size(A), Size(B), A, B) - -@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB} - m = sb[1] - n = length(sb) > 1 ? sb[2] : 1 - if m != sa[1] - throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for j = 1:n - for i = 1:m - ex = :(A.data[$(LinearIndices(sa)[i, i])]*B[$(LinearIndices(sb)[i, j])]) - for k = i+1:m - ex = :($ex + A.data[$(LinearIndices(sa)[i, k])]*B[$(LinearIndices(sb)[k, j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(B, TAB)(tuple($(X...))) - end -end - -@generated function _Ac_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB} - m = sb[1] - n = length(sb) > 1 ? sb[2] : 1 - if m != sa[1] - throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for j = 1:n - for i = m:-1:1 - ex = :(A.data[$(LinearIndices(sa)[i, i])]'*B[$(LinearIndices(sb)[i, j])]) - for k = 1:i-1 - ex = :($ex + A.data[$(LinearIndices(sa)[k, i])]'*B[$(LinearIndices(sb)[k, j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(B, TAB)(tuple($(X...))) - end -end - -@generated function _At_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB} - m = sb[1] - n = length(sb) > 1 ? sb[2] : 1 - if m != sa[1] - throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for j = 1:n - for i = m:-1:1 - ex = :(transpose(A.data[$(LinearIndices(sa)[i, i])])*B[$(LinearIndices(sb)[i, j])]) - for k = 1:i-1 - ex = :($ex + transpose(A.data[$(LinearIndices(sa)[k, i])])*B[$(LinearIndices(sb)[k, j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(B, TAB)(tuple($(X...))) - end -end - -@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB} - m = sb[1] - n = length(sb) > 1 ? sb[2] : 1 - if m != sa[1] - throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for j = 1:n - for i = m:-1:1 - ex = :(A.data[$(LinearIndices(sa)[i, i])]*B[$(LinearIndices(sb)[i, j])]) - for k = 1:i-1 - ex = :($ex + A.data[$(LinearIndices(sa)[i, k])]*B[$(LinearIndices(sb)[k, j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(B, TAB)(tuple($(X...))) - end -end - -@generated function _Ac_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB} - m = sb[1] - n = length(sb) > 1 ? sb[2] : 1 - if m != sa[1] - throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for j = 1:n - for i = 1:m - ex = :(A.data[$(LinearIndices(sa)[i, i])]'*B[$(LinearIndices(sb)[i, j])]) - for k = i+1:m - ex = :($ex + A.data[$(LinearIndices(sa)[k, i])]'*B[$(LinearIndices(sb)[k, j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(B, TAB)(tuple($(X...))) - end -end - -@generated function _At_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB} - m = sb[1] - n = length(sb) > 1 ? sb[2] : 1 - if m != sa[1] - throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for j = 1:n - for i = 1:m - ex = :(transpose(A.data[$(LinearIndices(sa)[i, i])])*B[$(LinearIndices(sb)[i, j])]) - for k = i+1:m - ex = :($ex + transpose(A.data[$(LinearIndices(sa)[k, i])])*B[$(LinearIndices(sb)[k, j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(B, TAB)(tuple($(X...))) - end -end - -@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::StaticArray{<:Tuple,TA}, B::UpperTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB} - m = sa[1] - if length(sa) == 1 - n = 1 - else - n = sa[2] - end - if sb[1] != n - throw(DimensionMismatch("right hand side B needs first dimension of size $n, has size $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for i = 1:m - for j = n:-1:1 - ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]) - for k = 1:j-1 - ex = :($ex + A[$(LinearIndices(sa)[i, k])]*B.data[$(LinearIndices(sb)[k, j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(A, TAB, Size($m,$n))(tuple($(X...))) - end -end - -@generated function _A_mul_Bc(::Size{sa}, ::Size{sb}, A::StaticArray{<:Tuple,TA}, B::UpperTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB} - m = sa[1] - if length(sa) == 1 - n = 1 - else - n = sa[2] - end - if sb[1] != n - throw(DimensionMismatch("right hand side B needs first dimension of size $n, has size $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for i = 1:m - for j = 1:n - ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]') - for k = j+1:n - ex = :($ex + A[$(LinearIndices(sa)[i, k])]*B.data[$(LinearIndices(sb)[j, k])]') - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(A, TAB, Size($m, $n))(tuple($(X...))) - end -end - -@generated function _A_mul_Bt(::Size{sa}, ::Size{sb}, A::StaticMatrix{<:Any,<:Any,TA}, B::UpperTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB} - m, n = sa[1], sa[2] - if sb[1] != n - throw(DimensionMismatch("right hand side B needs first dimension of size $n, has size $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for i = 1:m - for j = 1:n - ex = :(A[$(LinearIndices(sa)[i, j])]*transpose(B[$(LinearIndices(sb)[j, j])])) - for k = j+1:n - ex = :($ex + A[$(LinearIndices(sa)[i, k])]*transpose(B.data[$(LinearIndices(sb)[j, k])])) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(A, TAB)(tuple($(X...))) - end -end - -@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::StaticArray{<:Tuple,TA}, B::LowerTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB} - m = sa[1] - if length(sa) == 1 - n = 1 - else - n = sa[2] - end - if sb[1] != n - throw(DimensionMismatch("right hand side B needs first dimension of size $n, has size $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for i = 1:m - for j = 1:n - ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]) - for k = j+1:n - ex = :($ex + A[$(LinearIndices(sa)[i, k])]*B.data[$(LinearIndices(sb)[k, j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(A, TAB, Size($m,$n))(tuple($(X...))) - end -end - -@generated function _A_mul_Bc(::Size{sa}, ::Size{sb}, A::StaticArray{<:Tuple,TA}, B::LowerTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB} - m = sa[1] - if length(sa) == 1 - n = 1 - else - n = sa[2] - end - if sb[1] != n - throw(DimensionMismatch("right hand side B needs first dimension of size $n, has size $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for i = 1:m - for j = n:-1:1 - ex = :(A[$(LinearIndices(sa)[i, j])]*B[$(LinearIndices(sb)[j, j])]') - for k = 1:j-1 - ex = :($ex + A[$(LinearIndices(sa)[i, k])]*B.data[$(LinearIndices(sb)[j, k])]') - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(A, TAB, Size($m,$n))(tuple($(X...))) - end -end - -@generated function _A_mul_Bt(::Size{sa}, ::Size{sb}, A::StaticMatrix{<:Any,<:Any,TA}, B::LowerTriangular{TB,<:StaticMatrix}) where {sa,sb,TA,TB} - m, n = sa[1], sa[2] - if sb[1] != n - throw(DimensionMismatch("right hand side B needs first dimension of size $n, has size $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n] - - code = Expr(:block) - for i = 1:m - for j = n:-1:1 - ex = :(A[$(LinearIndices(sa)[i, j])]*transpose(B[$(LinearIndices(sb)[j, j])])) - for k = 1:j-1 - ex = :($ex + A[$(LinearIndices(sa)[i, k])]*transpose(B.data[$(LinearIndices(sb)[j, k])])) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(matprod, TA, TB) - return similar_type(A, TAB)(tuple($(X...))) - end -end - @generated function _A_ldiv_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB} m = sb[1] n = length(sb) > 1 ? sb[2] : 1 @@ -560,129 +206,3 @@ end @inbounds return similar_type(B, TAB)(tuple($(X...))) end end - -@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::UpperTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB} - n = sa[1] - if n != sb[1] - throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n] - - TAB = promote_op(*, eltype(TA), eltype(TB)) - z = zero(TAB) - - code = Expr(:block) - for j = 1:n - for i = 1:n - if i > j - push!(code.args, :($(X[i,j]) = $z)) - else - ex = :(A.data[$(LinearIndices(sa)[i,i])] * B.data[$(LinearIndices(sb)[i,j])]) - for k = i+1:j - ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - end - - return quote - @_inline_meta - @inbounds $code - return UpperTriangular(similar_type(B.data, $TAB)(tuple($(X...)))) - end - -end - -@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::LowerTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB} - n = sa[1] - if n != sb[1] - throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n] - - TAB = promote_op(*, eltype(TA), eltype(TB)) - z = zero(TAB) - - code = Expr(:block) - for j = 1:n - for i = 1:n - if i < j - push!(code.args, :($(X[i,j]) = $z)) - else - ex = :(A.data[$(LinearIndices(sa)[i,j])] * B.data[$(LinearIndices(sb)[j,j])]) - for k = j+1:i - ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - end - - return quote - @_inline_meta - @inbounds $code - return LowerTriangular(similar_type(B.data, $TAB)(tuple($(X...)))) - end - -end - - -@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::LowerTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB} - n = sa[1] - if n != sb[1] - throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n] - - code = Expr(:block) - for j = 1:n - for i = 1:n - k1 = max(i,j) - ex = :(A.data[$(LinearIndices(sa)[i,k1])] * B.data[$(LinearIndices(sb)[k1,j])]) - for k = k1+1:n - ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(*, eltype(TA), eltype(TB)) - return similar_type(B.data, TAB)(tuple($(X...))) - end - -end - -@generated function _A_mul_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::UpperTriangular{<:TB,<:StaticMatrix}) where {sa,sb,TA,TB} - n = sa[1] - if n != sb[1] - throw(DimensionMismatch("left and right-hand must have same sizes, got $(n) and $(sb[1])")) - end - - X = [Symbol("X_$(i)_$(j)") for i = 1:n, j = 1:n] - - code = Expr(:block) - for j = 1:n - for i = 1:n - ex = :(A.data[$(LinearIndices(sa)[i,1])] * B.data[$(LinearIndices(sb)[1,j])]) - for k = 2:min(i,j) - ex = :($ex + A.data[$(LinearIndices(sa)[i,k])] * B.data[$(LinearIndices(sb)[k,j])]) - end - push!(code.args, :($(X[i,j]) = $ex)) - end - end - - return quote - @_inline_meta - @inbounds $code - TAB = promote_op(*, eltype(TA), eltype(TB)) - return similar_type(B.data, TAB)(tuple($(X...))) - end - -end diff --git a/test/matrix_multiply.jl b/test/matrix_multiply.jl index f3b6f83d..82d5f051 100644 --- a/test/matrix_multiply.jl +++ b/test/matrix_multiply.jl @@ -1,11 +1,28 @@ using StaticArrays, Test, LinearAlgebra +mul_wrappers = [ + m -> m, + m -> Symmetric(m, :U), + m -> Symmetric(m, :L), + m -> Hermitian(m, :U), + m -> Hermitian(m, :L), + m -> UpperTriangular(m), + m -> LowerTriangular(m), + m -> UnitUpperTriangular(m), + m -> UnitLowerTriangular(m), + m -> Adjoint(m), + m -> Transpose(m), + m -> Diagonal(m)] + @testset "Matrix multiplication" begin @testset "Matrix-vector" begin m = @SMatrix [1 2; 3 4] v = @SVector [1, 2] v_bad = @SVector [1, 2, 3] @test m*v === @SVector [5, 11] + for wrapper in mul_wrappers + @test (@inferred wrapper(m)*v)::SVector{2} == wrapper(Array(m))*Array(v) + end @test_throws DimensionMismatch m*v_bad # More complicated eltype inference v2 = @SVector [CartesianIndex((1,3)), CartesianIndex((3,1))] @@ -17,6 +34,13 @@ using StaticArrays, Test, LinearAlgebra bm = @SMatrix [m m; m m] bv = @SVector [v,v] @test (bm*bv)::SVector{2,SVector{2,Int}} == @SVector [[10,22],[10,22]] + for wrapper in mul_wrappers + # there may be some problems with inferring the result type of symmetric block matrices + # and for some reason setindex! in Julia LinearAlgebra has == 0 and == 1 tests + if !any(x->isa(wrapper(bm), x), [UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular]) + @test wrapper(bm)*bv == wrapper(Array(bm))*Array(bv) + end + end # inner product @test @inferred(v'*v) === 5 @@ -128,6 +152,36 @@ using StaticArrays, Test, LinearAlgebra @test m*transpose(n) === @SMatrix [8 14; 18 32] @test transpose(m)*transpose(n) === @SMatrix [11 19; 16 28] + # check different sizes because there are multiple implementations for matrices of different sizes + for (mm, nn) in [ + (m, n), + (SMatrix{10, 10}(collect(1:100)), SMatrix{10, 10}(collect(1:100))), + (SMatrix{15, 15}(collect(1:225)), SMatrix{15, 15}(collect(1:225))) + ] + for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers + wm = wrapper_m(mm) + wn = wrapper_n(nn) + if length(mm) >= 255 && (!isa(wm, StaticArray) || !isa(wn, StaticArray)) + continue + end + res_structure = StaticArrays.mul_result_structure(wm, wn) + expected_type = if length(m) >= 100 + Matrix{Int} + elseif res_structure == identity || length(m) >= 100 + typeof(mm) + elseif res_structure == LowerTriangular + LowerTriangular{Int,typeof(mm)} + elseif res_structure == UpperTriangular + UpperTriangular{Int,typeof(mm)} + elseif res_structure == Diagonal + Diagonal{Int,<:SVector} + else + error("Unknown structure: ", res_structure) + end + @test (@inferred wm * wn)::expected_type == wrapper_m(Array(mm)) * wrapper_n(Array(nn)) + end + end + m = @MMatrix [1 2; 3 4] n = @MMatrix [2 3; 4 5] @test (m*n) == @MMatrix [10 13; 22 29] @@ -289,6 +343,10 @@ using StaticArrays, Test, LinearAlgebra @test a::MMatrix{2,2,Int,4} == @MMatrix [8 14; 18 32] mul!(a, transpose(m), transpose(n)) @test a::MMatrix{2,2,Int,4} == @MMatrix [11 19; 16 28] + for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers + mul!(a, wrapper_m(m), wrapper_n(n)) + @test a::MMatrix{2,2,Int,4} == wrapper_m(Array(m))*wrapper_n(Array(n)) + end a2 = MArray{Tuple{2,2},Int,2,4}(undef) mul!(a2, m, n) diff --git a/test/matrix_multiply_add.jl b/test/matrix_multiply_add.jl index c039a44c..413df9d4 100644 --- a/test/matrix_multiply_add.jl +++ b/test/matrix_multiply_add.jl @@ -5,11 +5,28 @@ using Test macro test_noalloc(ex) esc(quote - $ex - @test(@allocated($ex) == 0) + if VERSION < v"1.5" + $ex + @test(@allocated($ex) == 0) + end end) end +mul_add_wrappers = [ + m -> m, + m -> Symmetric(m, :U), + m -> Symmetric(m, :L), + m -> Hermitian(m, :U), + m -> Hermitian(m, :L), + m -> UpperTriangular(m), + m -> LowerTriangular(m), + m -> UnitUpperTriangular(m), + m -> UnitLowerTriangular(m), + m -> Adjoint(m), + m -> Transpose(m), + m -> Diagonal(m)] + + # check_dims @test StaticArrays.check_dims(Size(4,), Size(4,3), Size(3,)) @test !StaticArrays.check_dims(Size(4,), Size(4,3), Size(4,)) @@ -49,17 +66,17 @@ function test_multiply_add(N1,N2,ArrayType=MArray) # TSize ta = StaticArrays.TSize(A) - @test !StaticArrays.istranpose(ta) + @test !StaticArrays.istranspose(ta) @test size(ta) == (N1,N2) @test Size(ta) == Size(N1,N2) ta = StaticArrays.TSize(At) - @test StaticArrays.istranpose(ta) + @test StaticArrays.istranspose(ta) @test size(ta) == (N2,N1) @test Size(ta) == Size(N2,N1) tb = StaticArrays.TSize(b') - @test StaticArrays.istranpose(tb) + @test StaticArrays.access_type(tb) === :adjoint ta = transpose(ta) - @test !StaticArrays.istranpose(ta) + @test !StaticArrays.istranspose(ta) @test size(ta) == (N1,N2) @test Size(ta) == Size(N1,N2) @@ -89,14 +106,15 @@ function test_multiply_add(N1,N2,ArrayType=MArray) @test_broken(@allocated(mul!(c,A,b)) == 0) end end + expected_transpose_allocs = VERSION < v"1.5" ? 1 : 0 bmark = @benchmark mul!($c,$A,$b,$α,$β) samples=10 evals=10 @test minimum(bmark).allocs == 0 # @test_noalloc mul!(c, A, b, α, β) # records 32 bytes bmark = @benchmark mul!($b,Transpose($A),$c) samples=10 evals=10 - @test minimum(bmark).allocs == 0 + @test minimum(bmark).allocs <= expected_transpose_allocs # @test_noalloc mul!(b, Transpose(A), c) # records 16 bytes bmark = @benchmark mul!($b,Transpose($A),$c,$α,$β) samples=10 evals=10 - @test minimum(bmark).allocs == 0 + @test minimum(bmark).allocs <= expected_transpose_allocs # @test_noalloc mul!(b, Transpose(A), c, α, β) # records 48 bytes # outer product @@ -111,7 +129,7 @@ function test_multiply_add(N1,N2,ArrayType=MArray) @test C ≈ 3a*b' b = @benchmark mul!($C,$a,$b') samples=10 evals=10 - @test minimum(b).allocs == 0 + @test minimum(b).allocs <= expected_transpose_allocs # @test_noalloc mul!(C, a, b') # records 16 bytes # A × B @@ -139,7 +157,7 @@ function test_multiply_add(N1,N2,ArrayType=MArray) @test B ≈ 4A'C b = @benchmark mul!($B,Transpose($A),$C,$α,$β) samples=10 evals=10 - @test minimum(b).allocs == 0 + @test minimum(b).allocs <= expected_transpose_allocs # @test_noalloc mul!(B, Transpose(A), C, α, β) # records 48 bytes # A*B' @@ -152,7 +170,7 @@ function test_multiply_add(N1,N2,ArrayType=MArray) @test C ≈ 4A*B' b = @benchmark mul!($C,$A,Transpose($B),$α,$β) samples=10 evals=10 - @test minimum(b).allocs == 0 + @test minimum(b).allocs <= expected_transpose_allocs # @test_noalloc mul!(C, A, Transpose(B), α, β) # records 48 bytes # A'B' @@ -166,7 +184,7 @@ function test_multiply_add(N1,N2,ArrayType=MArray) @test C ≈ 4A'B' b = @benchmark mul!($C,Transpose($A),Transpose($B),$α,$β) samples=10 evals=10 - @test minimum(b).allocs == 0 + @test minimum(b).allocs <= 2*expected_transpose_allocs # @test_noalloc mul!(C, Transpose(A), Transpose(B), α, β) # records 64 bytes # Transpose Output @@ -174,7 +192,7 @@ function test_multiply_add(N1,N2,ArrayType=MArray) mul!(Transpose(C),Transpose(A),Transpose(B)) @test C' ≈ A'B' b = @benchmark mul!(Transpose($C),Transpose($A),Transpose($B),$α,$β) samples=10 evals=10 - @test minimum(b).allocs == 0 + @test minimum(b).allocs <= expected_transpose_allocs*3 # @test_noalloc mul!(Transpose(C), Transpose(A), Transpose(B), α, β) # records 80 bytes end @@ -187,3 +205,56 @@ end test_multiply_add(5,6,SizedArray) test_multiply_add(15,16,SizedArray) end + +function test_wrappers_for_size(N, test_block) + C = rand(MMatrix{N,N,Int}) + Cv = rand(MVector{N,Int}) + A = rand(SMatrix{N,N,Int}) + B = rand(SMatrix{N,N,Int}) + bv = rand(SVector{N,Int}) + # matrix-vector + for wrapper in mul_add_wrappers + mul!(Cv, wrapper(A), bv) + @test Cv == wrapper(Array(A))*Array(bv) + end + + # matrix-matrix + for wrapper_c in [identity, Transpose], wrapper_a in mul_add_wrappers, wrapper_b in mul_add_wrappers + mul!(wrapper_c(C), wrapper_a(A), wrapper_b(B)) + @test wrapper_c(C) == wrapper_a(Array(A))*wrapper_b(Array(B)) + end + + # block matrices + if test_block + + C_block = rand(MMatrix{N,N,SMatrix{2,2,Int,4}}) + Cv_block = rand(MVector{N,SMatrix{2,2,Int,4}}) + A_block = rand(SMatrix{N,N,SMatrix{2,2,Int,4}}) + B_block = rand(SMatrix{N,N,SMatrix{2,2,Int,4}}) + bv_block = rand(SVector{N,SMatrix{2,2,Int,4}}) + + # matrix-vector + for wrapper in mul_add_wrappers + # LinearAlgebra can't handle these + if all(T -> !isa(wrapper([1 2; 3 4]), T), [Symmetric, Hermitian, Diagonal]) + mul!(Cv_block, wrapper(A_block), bv_block) + @test Cv_block == wrapper(Array(A_block))*Array(bv_block) + end + end + + # matrix-matrix + for wrapper_a in mul_add_wrappers, wrapper_b in mul_add_wrappers + if all(T -> !isa(wrapper_a([1 2; 3 4]), T) && !isa(wrapper_b([1 2; 3 4]), T), [Symmetric, Hermitian, Diagonal]) + mul!(C_block, wrapper_a(A_block), wrapper_b(B_block)) + @test C_block == wrapper_a(Array(A_block))*wrapper_b(Array(B_block)) + end + end + end + +end + +@testset "Testing different wrappers" begin + test_wrappers_for_size(2, true) + test_wrappers_for_size(8, false) + test_wrappers_for_size(16, false) +end