From 0d031c45fdff73ba7a6f8737c5537b1af90d0de0 Mon Sep 17 00:00:00 2001 From: mtfishman <mfishman@flatironinstitute.org> Date: Wed, 15 Jan 2025 10:37:30 -0500 Subject: [PATCH 1/3] Fix scalar contraction --- src/contract/allocate_output.jl | 55 +++++++++++-- src/contract/contract_matricize/contract.jl | 39 +++++++++ test/test_basics.jl | 91 +++++++++++++++++---- 3 files changed, 165 insertions(+), 20 deletions(-) diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl index 3d9efa3..610ff87 100644 --- a/src/contract/allocate_output.jl +++ b/src/contract/allocate_output.jl @@ -9,7 +9,7 @@ function output_axes( biperm1::BlockedPermutation{2}, a2::AbstractArray, biperm2::BlockedPermutation{2}, - α::Number=true, + α::Number=one(Bool), ) axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1) axes_contracted2, axes_domain = blockpermute(axes(a2), biperm2) @@ -27,7 +27,7 @@ function output_axes( perm1::BlockedPermutation{1}, a2::AbstractArray, perm2::BlockedPermutation{1}, - α::Number=true, + α::Number=one(Bool), ) axes_contracted = blockpermute(axes(a1), perm1) axes_contracted′ = blockpermute(axes(a2), perm2) @@ -43,7 +43,7 @@ function output_axes( perm1::BlockedPermutation{1}, a2::AbstractArray, biperm2::BlockedPermutation{2}, - α::Number=true, + α::Number=one(Bool), ) (axes_contracted,) = blockpermute(axes(a1), perm1) axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2) @@ -59,7 +59,7 @@ function output_axes( perm1::BlockedPermutation{2}, a2::AbstractArray, biperm2::BlockedPermutation{1}, - α::Number=true, + α::Number=one(Bool), ) axes_dest, axes_contracted = blockpermute(axes(a1), perm1) (axes_contracted′,) = blockpermute(axes(a2), biperm2) @@ -75,7 +75,7 @@ function output_axes( perm1::BlockedPermutation{1}, a2::AbstractArray, perm2::BlockedPermutation{1}, - α::Number=true, + α::Number=one(Bool), ) @assert istrivialperm(Tuple(perm1)) @assert istrivialperm(Tuple(perm2)) @@ -83,6 +83,49 @@ function output_axes( return genperm(axes_dest, invperm(Tuple(biperm_dest))) end +# Array-scalar contraction. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{1}, + a1::AbstractArray, + perm1::BlockedPermutation{1}, + a2::AbstractArray, + perm2::BlockedPermutation{0}, + α::Number=one(Bool), +) + @assert istrivialperm(Tuple(perm1)) + axes_dest = axes(a1) + return genperm(axes_dest, invperm(Tuple(perm_dest))) +end + +# Scalar-array contraction. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{1}, + a1::AbstractArray, + perm1::BlockedPermutation{0}, + a2::AbstractArray, + perm2::BlockedPermutation{1}, + α::Number=one(Bool), +) + @assert istrivialperm(Tuple(perm2)) + axes_dest = axes(a2) + return genperm(axes_dest, invperm(Tuple(perm_dest))) +end + +# Scalar-scalar contraction. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{0}, + a1::AbstractArray, + perm1::BlockedPermutation{0}, + a2::AbstractArray, + perm2::BlockedPermutation{0}, + α::Number=one(Bool), +) + return () +end + # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function allocate_output( @@ -92,7 +135,7 @@ function allocate_output( biperm1::BlockedPermutation, a2::AbstractArray, biperm2::BlockedPermutation, - α::Number=true, + α::Number=one(Bool), ) axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α) return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest) diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl index 5fffd2b..1750ae2 100644 --- a/src/contract/contract_matricize/contract.jl +++ b/src/contract/contract_matricize/contract.jl @@ -62,3 +62,42 @@ function _mul!( mul!(a_dest, a1, transpose(a2), α, β) return a_dest end + +# Array-scalar contraction. +function _mul!( + a_dest::AbstractVector, + a1::AbstractVector, + a2::AbstractArray{<:Any,0}, + α::Number, + β::Number, +) + α′ = a2[] * α + a_dest .= a1 .* α′ .+ a_dest .* β + return a_dest +end + +# Scalar-array contraction. +function _mul!( + a_dest::AbstractVector, + a1::AbstractArray{<:Any,0}, + a2::AbstractVector, + α::Number, + β::Number, +) + # Preserve the ordering in case of non-commutative algebra. + a_dest .= a1[] .* a2 .* α .+ a_dest .* β + return a_dest +end + +# Scalar-scalar contraction. +function _mul!( + a_dest::AbstractArray{<:Any,0}, + a1::AbstractArray{<:Any,0}, + a2::AbstractArray{<:Any,0}, + α::Number, + β::Number, +) + # Preserve the ordering in case of non-commutative algebra. + a_dest[] = a1[] * a2[] * α + a_dest[] * β + return a_dest +end diff --git a/test/test_basics.jl b/test/test_basics.jl index 09fb26a..dde9bcc 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,8 +1,11 @@ using EllipsisNotation: var".." using LinearAlgebra: norm, qr -using TensorAlgebra: TensorAlgebra, fusedims, splitdims -default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) +using StableRNGs: StableRNG +using TensorAlgebra: contract, contract!, fusedims, splitdims +using TensorOperations: TensorOperations using Test: @test, @test_broken, @testset + +default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin @@ -90,14 +93,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) labels_dest = map(i -> labels[i], d_dests) # Don't specify destination labels - a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2) + a_dest, labels_dest′ = contract(a1, labels1, a2, labels2) a_dest_tensoroperations = TensorOperations.tensorcontract( labels_dest′, a1, labels1, a2, labels2 ) @test a_dest ≈ a_dest_tensoroperations # Specify destination labels - a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2) + a_dest = contract(labels_dest, a1, labels1, a2, labels2) a_dest_tensoroperations = TensorOperations.tensorcontract( labels_dest, a1, labels1, a2, labels2 ) @@ -111,7 +114,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) β = elt_dest(2.4) # randn(elt_dest) a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests)) a_dest = copy(a_dest_init) - TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) + contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) a_dest_tensoroperations = TensorOperations.tensorcontract( labels_dest, a1, labels1, a2, labels2 ) @@ -124,28 +127,90 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts - a1 = randn(elt1, 2, 3) - a2 = randn(elt2, 4, 5) - elt_dest = promote_type(elt1, elt2) - a_dest, labels = TensorAlgebra.contract(a1, ("i", "j"), a2, ("k", "l")) + rng = StableRNG(123) + a1 = randn(rng, elt1, 2, 3) + a2 = randn(rng, elt2, 4, 5) + + a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l")) @test labels == ("i", "j", "k", "l") @test eltype(a_dest) === elt_dest @test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)) - a_dest = TensorAlgebra.contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l")) + a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l")) @test eltype(a_dest) === elt_dest @test a_dest ≈ permutedims( reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4) ) a_dest = zeros(elt_dest, 2, 5, 3, 4) - TensorAlgebra.contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l")) + contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l")) @test a_dest ≈ permutedims( reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3) ) end + @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, + elt2 in elts + + elt_dest = promote_type(elt1, elt2) + + rng = StableRNG(123) + a = randn(rng, elt1, (2, 3, 4, 5)) + s = randn(rng, elt2, ()) + t = randn(rng, elt2, ()) + + labels_a = ("i", "j", "k", "l") + + # Array-scalar contraction. + a_dest, labels_dest = contract(a, labels_a, s, ()) + @test labels_dest == labels_a + @test a_dest ≈ a * s[] + + # Scalar-array contraction. + a_dest, labels_dest = contract(s, (), a, labels_a) + @test labels_dest == labels_a + @test a_dest ≈ a * s[] + + # Scalar-scalar contraction. + a_dest, labels_dest = contract(s, (), t, ()) + @test labels_dest == () + @test a_dest[] ≈ s[] * t[] + + # Specify output labels. + labels_dest_example = ("j", "l", "i", "k") + size_dest_example = (3, 5, 2, 4) + + # Array-scalar contraction. + a_dest = contract(labels_dest_example, a, labels_a, s, ()) + @test size(a_dest) == size_dest_example + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-array contraction. + a_dest = contract(labels_dest_example, s, (), a, labels_a) + @test size(a_dest) == size_dest_example + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-scalar contraction. + a_dest = contract((), s, (), t, ()) + @test size(a_dest) == () + @test a_dest[] ≈ s[] * t[] + + # Array-scalar contraction. + a_dest = zeros(elt_dest, size_dest_example) + contract!(a_dest, labels_dest_example, a, labels_a, s, ()) + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-array contraction. + a_dest = zeros(elt_dest, size_dest_example) + contract!(a_dest, labels_dest_example, s, (), a, labels_a) + @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[] + + # Scalar-scalar contraction. + a_dest = zeros(elt_dest, ()) + contract!(a_dest, (), s, (), t, ()) + @test a_dest[] ≈ s[] * t[] + end end @testset "qr (eltype=$elt)" for elt in elts a = randn(elt, 5, 4, 3, 2) @@ -154,8 +219,6 @@ end labels_r = (:d, :c) q, r = qr(a, labels_a, labels_q, labels_r) label_qr = :qr - a′ = TensorAlgebra.contract( - labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...) - ) + a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)) @test a ≈ a′ end From b996497d2ba05083e3ed5bd04f390519616d1cf4 Mon Sep 17 00:00:00 2001 From: mtfishman <mfishman@flatironinstitute.org> Date: Wed, 15 Jan 2025 10:38:58 -0500 Subject: [PATCH 2/3] Add missing test dep --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index 60a147b..09a07c1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" From 0061180f7f5373c075573a8ef07c97c5f23d1701 Mon Sep 17 00:00:00 2001 From: mtfishman <mfishman@flatironinstitute.org> Date: Wed, 15 Jan 2025 11:48:10 -0500 Subject: [PATCH 3/3] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 19abea2..cf45b57 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers <support@itensor.org> and contributors"] -version = "0.1.4" +version = "0.1.5" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"