From 0d031c45fdff73ba7a6f8737c5537b1af90d0de0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 15 Jan 2025 10:37:30 -0500 Subject: [PATCH] Fix scalar contraction --- src/contract/allocate_output.jl | 55 +++++++++++-- src/contract/contract_matricize/contract.jl | 39 +++++++++ test/test_basics.jl | 91 +++++++++++++++++---- 3 files changed, 165 insertions(+), 20 deletions(-) diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index 3d9efa3..610ff87 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -9,7 +9,7 @@ function output_axes( biperm1::BlockedPermutation{2}, a2::AbstractArray, biperm2::BlockedPermutation{2}, - α::Number=true, + α::Number=one(Bool), ) axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1) axes_contracted2, axes_domain = blockpermute(axes(a2), biperm2) @@ -27,7 +27,7 @@ function output_axes( perm1::BlockedPermutation{1}, a2::AbstractArray, perm2::BlockedPermutation{1}, - α::Number=true, + α::Number=one(Bool), ) axes_contracted = blockpermute(axes(a1), perm1) axes_contracted′ = blockpermute(axes(a2), perm2) @@ -43,7 +43,7 @@ function output_axes( perm1::BlockedPermutation{1}, a2::AbstractArray, biperm2::BlockedPermutation{2}, - α::Number=true, + α::Number=one(Bool), ) (axes_contracted,) = blockpermute(axes(a1), perm1) axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2) @@ -59,7 +59,7 @@ function output_axes( perm1::BlockedPermutation{2}, a2::AbstractArray, biperm2::BlockedPermutation{1}, - α::Number=true, + α::Number=one(Bool), ) axes_dest, axes_contracted = blockpermute(axes(a1), perm1) (axes_contracted′,) = blockpermute(axes(a2), biperm2) @@ -75,7 +75,7 @@ function output_axes( perm1::BlockedPermutation{1}, a2::AbstractArray, perm2::BlockedPermutation{1}, - α::Number=true, + α::Number=one(Bool), ) @assert istrivialperm(Tuple(perm1)) @assert istrivialperm(Tuple(perm2)) @@ -83,6 +83,49 @@ function output_axes( return genperm(axes_dest, invperm(Tuple(biperm_dest))) end +# Array-scalar contraction. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{1}, + a1::AbstractArray, + perm1::BlockedPermutation{1}, + a2::AbstractArray, + perm2::BlockedPermutation{0}, + α::Number=one(Bool), +) + @assert istrivialperm(Tuple(perm1)) + axes_dest = axes(a1) + return genperm(axes_dest, invperm(Tuple(perm_dest))) +end + +# Scalar-array contraction. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{1}, + a1::AbstractArray, + perm1::BlockedPermutation{0}, + a2::AbstractArray, + perm2::BlockedPermutation{1}, + α::Number=one(Bool), +) + @assert istrivialperm(Tuple(perm2)) + axes_dest = axes(a2) + return genperm(axes_dest, invperm(Tuple(perm_dest))) +end + +# Scalar-scalar contraction. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{0}, + a1::AbstractArray, + perm1::BlockedPermutation{0}, + a2::AbstractArray, + perm2::BlockedPermutation{0}, + α::Number=one(Bool), +) + return () +end + # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function allocate_output( @@ -92,7 +135,7 @@ function allocate_output( biperm1::BlockedPermutation, a2::AbstractArray, biperm2::BlockedPermutation, - α::Number=true, + α::Number=one(Bool), ) axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α) return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest) diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 5fffd2b..1750ae2 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -62,3 +62,42 @@ function _mul!( mul!(a_dest, a1, transpose(a2), α, β) return a_dest end + +# Array-scalar contraction. +function _mul!( + a_dest::AbstractVector, + a1::AbstractVector, + a2::AbstractArray{<:Any,0}, + α::Number, + β::Number, +) + α′ = a2[] * α + a_dest .= a1 .* α′ .+ a_dest .* β + return a_dest +end + +# Scalar-array contraction. +function _mul!( + a_dest::AbstractVector, + a1::AbstractArray{<:Any,0}, + a2::AbstractVector, + α::Number, + β::Number, +) + # Preserve the ordering in case of non-commutative algebra. + a_dest .= a1[] .* a2 .* α .+ a_dest .* β + return a_dest +end + +# Scalar-scalar contraction. +function _mul!( + a_dest::AbstractArray{<:Any,0}, + a1::AbstractArray{<:Any,0}, + a2::AbstractArray{<:Any,0}, + α::Number, + β::Number, +) + # Preserve the ordering in case of non-commutative algebra. + a_dest[] = a1[] * a2[] * α + a_dest[] * β + return a_dest +end diff --git a/test/test_basics.jl b/test/test_basics.jl index 09fb26a..dde9bcc 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,8 +1,11 @@ using EllipsisNotation: var".." using LinearAlgebra: norm, qr -using TensorAlgebra: TensorAlgebra, fusedims, splitdims -default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) +using StableRNGs: StableRNG +using TensorAlgebra: contract, contract!, fusedims, splitdims +using TensorOperations: TensorOperations using Test: @test, @test_broken, @testset + +default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin @@ -90,14 +93,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) labels_dest = map(i -> labels[i], d_dests) # Don't specify destination labels - a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2) + a_dest, labels_dest′ = contract(a1, labels1, a2, labels2) a_dest_tensoroperations = TensorOperations.tensorcontract( labels_dest′, a1, labels1, a2, labels2 ) @test a_dest ≈ a_dest_tensoroperations # Specify destination labels - a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2) + a_dest = contract(labels_dest, a1, labels1, a2, labels2) a_dest_tensoroperations = TensorOperations.tensorcontract( labels_dest, a1, labels1, a2, labels2 ) @@ -111,7 +114,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) β = elt_dest(2.4) # randn(elt_dest) a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests)) a_dest = copy(a_dest_init) - TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) + contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) a_dest_tensoroperations = TensorOperations.tensorcontract( labels_dest, a1, labels1, a2, labels2 ) @@ -124,28 +127,90 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts - a1 = randn(elt1, 2, 3) - a2 = randn(elt2, 4, 5) - elt_dest = promote_type(elt1, elt2) - a_dest, labels = TensorAlgebra.contract(a1, ("i", "j"), a2, ("k", "l")) + rng = StableRNG(123) + a1 = randn(rng, elt1, 2, 3) + a2 = randn(rng, elt2, 4, 5) + + a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l")) @test labels == ("i", "j", "k", "l") @test eltype(a_dest) === elt_dest @test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)) - a_dest = TensorAlgebra.contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l")) + a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l")) @test eltype(a_dest) === elt_dest @test a_dest ≈ permutedims( reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4) ) a_dest = zeros(elt_dest, 2, 5, 3, 4) - TensorAlgebra.contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l")) + contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l")) @test a_dest ≈ permutedims( reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3) ) end + @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, + elt2 in elts + + elt_dest = promote_type(elt1, elt2) + + rng = StableRNG(123) + a = randn(rng, elt1, (2, 3, 4, 5)) + s = randn(rng, elt2, ()) + t = randn(rng, elt2, ()) + + labels_a = ("i", "j", "k", "l") + + # Array-scalar contraction. + a_dest, labels_dest = contract(a, labels_a, s, ()) + @test labels_dest == labels_a + @test a_dest ≈ a * s[] + + # Scalar-array contraction. + a_dest, labels_dest = contract(s, (), a, labels_a) + @test labels_dest == labels_a + @test a_dest ≈ a * s[] + + # Scalar-scalar contraction. + a_dest, labels_dest = contract(s, (), t, ()) + @test labels_dest == () + @test a_dest[] ≈ s[] * t[] + + # Specify output labels. + labels_dest_example = ("j", "l", "i", "k") + size_dest_example = (3, 5, 2, 4) + + # Array-scalar contraction. + a_dest = contract(labels_dest_example, a, labels_a, s, ()) + @test size(a_dest) == size_dest_example + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-array contraction. + a_dest = contract(labels_dest_example, s, (), a, labels_a) + @test size(a_dest) == size_dest_example + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-scalar contraction. + a_dest = contract((), s, (), t, ()) + @test size(a_dest) == () + @test a_dest[] ≈ s[] * t[] + + # Array-scalar contraction. + a_dest = zeros(elt_dest, size_dest_example) + contract!(a_dest, labels_dest_example, a, labels_a, s, ()) + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-array contraction. + a_dest = zeros(elt_dest, size_dest_example) + contract!(a_dest, labels_dest_example, s, (), a, labels_a) + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-scalar contraction. + a_dest = zeros(elt_dest, ()) + contract!(a_dest, (), s, (), t, ()) + @test a_dest[] ≈ s[] * t[] + end end @testset "qr (eltype=$elt)" for elt in elts a = randn(elt, 5, 4, 3, 2) @@ -154,8 +219,6 @@ end labels_r = (:d, :c) q, r = qr(a, labels_a, labels_q, labels_r) label_qr = :qr - a′ = TensorAlgebra.contract( - labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...) - ) + a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)) @test a ≈ a′ end