diff --git a/src/array_interface.jl b/src/array_interface.jl index ffb7e93f..b1ffd5ec 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -59,7 +59,6 @@ function Base.vcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat) return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)), getaxes(x)[2:end]...) end end -Base.vcat(x::CV...) where {CV<:AdjOrTransComponentArray} = ComponentArray(reduce(vcat, map(y->getdata(y.parent)', x)), getaxes(x[1])) Base.vcat(x::ComponentVector, args...) = vcat(getdata(x), getdata.(args)...) Base.vcat(x::ComponentVector, args::Union{Number, UniformScaling, AbstractVecOrMat}...) = vcat(getdata(x), getdata.(args)...) Base.vcat(x::ComponentVector, args::Vararg{AbstractVector{T}, N}) where {T,N} = vcat(getdata(x), getdata.(args)...) diff --git a/src/broadcasting.jl b/src/broadcasting.jl index 7a59f3e6..1c21045b 100644 --- a/src/broadcasting.jl +++ b/src/broadcasting.jl @@ -1,9 +1,5 @@ Base.BroadcastStyle(::Type{<:ComponentArray{T, N, A, Axes}}) where {T, N, A, Axes} = Broadcast.BroadcastStyle(A) -# Need special case here for adjoint vectors in order to avoid type instability in axistype -Broadcast.combine_axes(a::ComponentArray, b::AdjOrTransComponentVector) = (axes(a)[1], axes(b)[2]) -Broadcast.combine_axes(a::AdjOrTransComponentVector, b::ComponentArray) = (axes(b)[2], axes(a)[1]) - Broadcast.axistype(a::CombinedAxis, b::AbstractUnitRange) = a Broadcast.axistype(a::AbstractUnitRange, b::CombinedAxis) = b Broadcast.axistype(a::CombinedAxis, b::CombinedAxis) = CombinedAxis(FlatAxis(), Base.Broadcast.axistype(_array_axis(a), _array_axis(b))) diff --git a/src/compat/gpuarrays.jl b/src/compat/gpuarrays.jl index ace99e6a..8fe201d8 100644 --- a/src/compat/gpuarrays.jl +++ b/src/compat/gpuarrays.jl @@ -1,6 +1,7 @@ -const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax} -const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVector,Ax} -const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax} +const AbstractGPUArrayOrAdj = Union{<:GPUArrays.AbstractGPUArray{T, N}, Adjoint{T, <:GPUArrays.AbstractGPUArray{T, N}}, Transpose{T, <:GPUArrays.AbstractGPUArray{T, N}}} where {T, N} +const GPUComponentArray = ComponentArray{T,N,<:AbstractGPUArrayOrAdj{T, N},Ax} where {T,N,Ax} +const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:AbstractGPUArrayOrAdj{T, 1},Ax} +const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:AbstractGPUArrayOrAdj{T, 2},Ax} const GPUComponentVecorMat{T,Ax} = Union{GPUComponentVector{T,Ax},GPUComponentMatrix{T,Ax}} GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x)) @@ -25,7 +26,10 @@ end LinearAlgebra.dot(x::GPUComponentArray, y::GPUComponentArray) = dot(getdata(x), getdata(y)) LinearAlgebra.norm(ca::GPUComponentArray, p::Real) = norm(getdata(ca), p) -LinearAlgebra.rmul!(ca::GPUComponentArray, b::Number) = GPUArrays.generic_rmul!(ca, b) +function LinearAlgebra.rmul!(ca::GPUComponentArray, b::Number) + GPUArrays.generic_rmul!(getdata(ca), b) + return ca +end function Base.map(f, x::GPUComponentArray, args...) data = map(f, getdata(x), getdata.(args)...) @@ -78,196 +82,23 @@ end function LinearAlgebra.mul!(C::GPUComponentVecorMat, A::GPUComponentVecorMat, B::GPUComponentVecorMat, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) + return GPUArrays.generic_matmatmul!(C, getdata(A), getdata(B), a, b) end function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::GPUComponentVecorMat, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, + A::AbstractGPUArrayOrAdj, B::GPUComponentVecorMat, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, B::GPUComponentVecorMat, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Number, b::Number) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) + return GPUArrays.generic_matmatmul!(C, A, getdata(B), a, b) end function LinearAlgebra.mul!(C::GPUComponentVecorMat, A::GPUComponentVecorMat, - B::GPUComponentVecorMat, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real, - b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::GPUComponentVecorMat, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::GPUComponentVecorMat, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::GPUComponentVecorMat, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, B::GPUComponentVecorMat, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real, - b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, a::Real, - b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - B::LinearAlgebra.Adjoint{<:Any,<:GPUComponentVecorMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end -function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - B::LinearAlgebra.Transpose{<:Any,<:GPUArrays.AbstractGPUVecOrMat}, - a::Real, b::Real) - return GPUArrays.generic_matmatmul!(C, A, B, a, b) + B::AbstractGPUArrayOrAdj, a::Number, b::Number) + return GPUArrays.generic_matmatmul!(C, getdata(A), B, a, b) end + function LinearAlgebra.mul!(C::GPUComponentVecorMat, - A::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, - B::LinearAlgebra.Transpose{<:Any,<:GPUComponentVecorMat - }, a::Real, b::Real) + A::AbstractGPUArrayOrAdj, + B::AbstractGPUArrayOrAdj, a::Number, b::Number) return GPUArrays.generic_matmatmul!(C, A, B, a, b) -end +end \ No newline at end of file diff --git a/src/componentarray.jl b/src/componentarray.jl index 453dc7ee..1fc581ec 100644 --- a/src/componentarray.jl +++ b/src/componentarray.jl @@ -118,17 +118,11 @@ const CArray = ComponentArray const CVector = ComponentVector const CMatrix = ComponentMatrix -const AdjOrTrans{T, A} = Union{Adjoint{T, A}, Transpose{T, A}} -const AdjOrTransComponentArray{T, A} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentArray -const AdjOrTransComponentVector{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentVector -const AdjOrTransComponentMatrix{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentMatrix - const ComponentVecOrMat = Union{ComponentVector, ComponentMatrix} -const AdjOrTransComponentVecOrMat = AdjOrTrans{T, <:ComponentVecOrMat} where T -const AbstractComponentArray = Union{ComponentArray, AdjOrTransComponentArray} -const AbstractComponentVecOrMat = Union{ComponentVecOrMat, AdjOrTransComponentVecOrMat} -const AbstractComponentVector = Union{ComponentVector, AdjOrTransComponentVector} -const AbstractComponentMatrix = Union{ComponentMatrix, AdjOrTransComponentMatrix} +const AbstractComponentArray = ComponentArray +const AbstractComponentVecOrMat = ComponentVecOrMat +const AbstractComponentVector = ComponentVector +const AbstractComponentMatrix = ComponentMatrix ## Constructor helpers @@ -288,12 +282,8 @@ julia> getaxes(ca) ``` """ @inline getaxes(x::ComponentArray) = getfield(x, :axes) -@inline getaxes(x::AdjOrTrans{T, <:ComponentVector}) where T = (FlatAxis(), getaxes(x.parent)[1]) -@inline getaxes(x::AdjOrTrans{T, <:ComponentMatrix}) where T = reverse(getaxes(x.parent)) @inline getaxes(::Type{<:ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = map(x->x(), (Axes.types...,)) -@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentVector} = (FlatAxis(), getaxes(CA)[1]) |> typeof -@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentMatrix} = reverse(getaxes(CA)) |> typeof ## Field access through these functions to reserve dot-getting for keys @inline getaxes(x::VarAxes) = getaxes(typeof(x)) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 5d888e7f..0fa527db 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -13,7 +13,6 @@ _first_axis(x::AbstractComponentVecOrMat) = getaxes(x)[1] _second_axis(x::AbstractMatrix) = FlatAxis() _second_axis(x::ComponentMatrix) = getaxes(x)[2] -_second_axis(x::AdjOrTransComponentVecOrMat) = getaxes(x)[2] _out_axes(::typeof(*), a, b::AbstractVector) = (_first_axis(a), ) _out_axes(::typeof(*), a, b::AbstractMatrix) = (_first_axis(a), _second_axis(b)) @@ -27,19 +26,21 @@ for op in [:*, :\, :/] function Base.$op(A::AbstractComponentVecOrMat, B::AbstractComponentVecOrMat) C = $op(getdata(A), getdata(B)) ax = _out_axes($op, A, B) - return ComponentArray(C, ax) + return ComponentArray(C, ax...) end end - for (adj, Adj) in zip([:adjoint, :transpose], [:Adjoint, :Transpose]) - @eval begin - function Base.$op(aᵀ::$Adj{T,<:ComponentVector}, B::AbstractComponentMatrix) where {T} - cᵀ = $op(getdata(aᵀ), getdata(B)) - ax2 = _out_axes($op, aᵀ, B)[2] - return $adj(ComponentArray(cᵀ', ax2)) - end - function Base.$op(A::$Adj{T,<:CV}, B::CV) where {T<:Real, CV<:ComponentVector{T}} - return $op(getdata(A), getdata(B)) - end +end + + +for op in [:adjoint, :transpose] + @eval begin + function LinearAlgebra.$op(M::ComponentMatrix{T,A,Tuple{Ax1,Ax2}}) where {T,A,Ax1,Ax2} + data = $op(getdata(M)) + return ComponentArray(data, (Ax2(), Ax1())[1:ndims(data)]...) + end + + function LinearAlgebra.$op(M::ComponentVector{T,A,Tuple{Ax1}}) where {T,A,Ax1} + return ComponentMatrix($op(getdata(M)), FlatAxis(), Ax1()) end end end \ No newline at end of file diff --git a/src/show.jl b/src/show.jl index 2a08a47f..dc5d1613 100644 --- a/src/show.jl +++ b/src/show.jl @@ -79,4 +79,4 @@ function Base.show(io::IO, ::MIME"text/plain", x::ComponentMatrix{T,A,Axes}) whe println(io, " with axes $(axs[1]) × $(axs[2])") Base.print_matrix(io, getdata(x)) return nothing -end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index c1b58c49..b13dfe00 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,8 +28,8 @@ ca_composed = ComponentArray(a = 1, b = ca) ca2 = ComponentArray(nt2) -cmat = ComponentArray(a .* a', ax, ax) -cmat2 = ca2 .* ca2' +cmat = ComponentArray(a * a', ax, ax) +cmat2 = ca2 * ca2' caa = ComponentArray(a = ca, b = sq_mat) @@ -142,13 +142,13 @@ end @test hash(ca) != hash(getdata(ca)) @test hash(ca, zero(UInt)) != hash(getdata(ca), zero(UInt)) - ab = ComponentArray(a = 1, b = 2) - xy = ComponentArray(x = 1, y = 2) + ab = ComponentArray(a=1, b=2) + xy = ComponentArray(x=1, y=2) @test ab != xy @test hash(ab) != hash(xy) @test hash(ab, zero(UInt)) != hash(xy, zero(UInt)) - @test ab == LVector(a = 1, b = 2) + @test ab == LVector(a=1, b=2) # Issue #117 kw_fun(; a, b) = a // b @@ -369,11 +369,11 @@ end @testset "Broadcasting" begin temp = deepcopy(ca) @test eltype(Float32.(ca)) == Float32 - @test ca .* ca' == cmat + @test ca * ca' == cmat @test 1 .* (ca .+ ca) == ComponentArray(a .+ a, getaxes(ca)) @test typeof(ca .+ cmat) == typeof(cmat) - @test getaxes(false .* ca .* ca') == (ax, ax) - @test getaxes(false .* ca' .* ca) == (ax, ax) + @test getaxes(false .* ca * ca') == (ax, ax) + @test isa(ca' * ca, Float64) @test (vec(temp) .= vec(ca_Float32)) isa ComponentArray @test_broken getdata(ca_MVector .* ca_MVector) isa MArray @@ -393,8 +393,8 @@ end x1 = ComponentArray(a = [1.1, 2.1], b = [0.1]) x2 = ComponentArray(a = [1.1, 2.1], b = 0.1) x3 = ComponentArray(a = [1.1, 2.1], c = [0.1]) - xmat = x1 .* x2' - x1mat = x1 .* x1' + xmat = x1 * x2' + x1mat = x1 * x1' @test x1 + x2 isa Vector @test x1 + x3 isa Vector @test x2 + x3 isa Vector @@ -459,7 +459,7 @@ end @test ca * transpose(ca) == collect(cmat) @test ca * transpose(ca) == a * transpose(a) @test transpose(ca) * ca == transpose(a) * a - @test ca' * cmat == ComponentArray(a' * getdata(cmat), getaxes(ca)) + @test ca' * cmat == ComponentArray(a' * getdata(cmat), FlatAxis(), getaxes(ca)...) @test transpose(transpose(cmat)) == cmat @test transpose(transpose(ca)) == ca @test transpose(ca.c) * cmat[:c, :c] * ca.c isa Number