Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix scalar contraction #18

Merged
merged 3 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.4"
version = "0.1.5"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
55 changes: 49 additions & 6 deletions src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function output_axes(
biperm1::BlockedPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation{2},
α::Number=true,
α::Number=one(Bool),
)
axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1)
axes_contracted2, axes_domain = blockpermute(axes(a2), biperm2)
Expand All @@ -27,7 +27,7 @@ function output_axes(
perm1::BlockedPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{1},
α::Number=true,
α::Number=one(Bool),
)
axes_contracted = blockpermute(axes(a1), perm1)
axes_contracted′ = blockpermute(axes(a2), perm2)
Expand All @@ -43,7 +43,7 @@ function output_axes(
perm1::BlockedPermutation{1},
a2::AbstractArray,
biperm2::BlockedPermutation{2},
α::Number=true,
α::Number=one(Bool),
)
(axes_contracted,) = blockpermute(axes(a1), perm1)
axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2)
Expand All @@ -59,7 +59,7 @@ function output_axes(
perm1::BlockedPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation{1},
α::Number=true,
α::Number=one(Bool),
)
axes_dest, axes_contracted = blockpermute(axes(a1), perm1)
(axes_contracted′,) = blockpermute(axes(a2), biperm2)
Expand All @@ -75,14 +75,57 @@ function output_axes(
perm1::BlockedPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{1},
α::Number=true,
α::Number=one(Bool),
)
@assert istrivialperm(Tuple(perm1))
@assert istrivialperm(Tuple(perm2))
axes_dest = (axes(a1)..., axes(a2)...)
return genperm(axes_dest, invperm(Tuple(biperm_dest)))
end

# Array-scalar contraction.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{0},
α::Number=one(Bool),
)
@assert istrivialperm(Tuple(perm1))
axes_dest = axes(a1)
return genperm(axes_dest, invperm(Tuple(perm_dest)))
end

# Scalar-array contraction.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{0},
a2::AbstractArray,
perm2::BlockedPermutation{1},
α::Number=one(Bool),
)
@assert istrivialperm(Tuple(perm2))
axes_dest = axes(a2)
return genperm(axes_dest, invperm(Tuple(perm_dest)))
end

# Scalar-scalar contraction.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{0},
a1::AbstractArray,
perm1::BlockedPermutation{0},
a2::AbstractArray,
perm2::BlockedPermutation{0},
α::Number=one(Bool),
)
return ()
end

# TODO: Use `ArrayLayouts`-like `MulAdd` object,
# i.e. `ContractAdd`?
function allocate_output(
Expand All @@ -92,7 +135,7 @@ function allocate_output(
biperm1::BlockedPermutation,
a2::AbstractArray,
biperm2::BlockedPermutation,
α::Number=true,
α::Number=one(Bool),
)
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
Expand Down
39 changes: 39 additions & 0 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,42 @@ function _mul!(
mul!(a_dest, a1, transpose(a2), α, β)
return a_dest
end

# Array-scalar contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractVector,
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
α′ = a2[] * α
a_dest .= a1 .* α′ .+ a_dest .* β
return a_dest
end

# Scalar-array contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractArray{<:Any,0},
a2::AbstractVector,
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest .= a1[] .* a2 .* α .+ a_dest .* β
return a_dest
end

# Scalar-scalar contraction.
function _mul!(
a_dest::AbstractArray{<:Any,0},
a1::AbstractArray{<:Any,0},
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest[] = a1[] * a2[] * α + a_dest[] * β
return a_dest
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Expand Down
91 changes: 77 additions & 14 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using EllipsisNotation: var".."
using LinearAlgebra: norm, qr
using TensorAlgebra: TensorAlgebra, fusedims, splitdims
default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
using StableRNGs: StableRNG
using TensorAlgebra: contract, contract!, fusedims, splitdims
using TensorOperations: TensorOperations
using Test: @test, @test_broken, @testset

default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

@testset "TensorAlgebra" begin
Expand Down Expand Up @@ -90,14 +93,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
labels_dest = map(i -> labels[i], d_dests)

# Don't specify destination labels
a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2)
a_dest, labels_dest′ = contract(a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest′, a1, labels1, a2, labels2
)
@test a_dest ≈ a_dest_tensoroperations

# Specify destination labels
a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
a_dest = contract(labels_dest, a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
Expand All @@ -111,7 +114,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
β = elt_dest(2.4) # randn(elt_dest)
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
a_dest = copy(a_dest_init)
TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
Expand All @@ -124,28 +127,90 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
elt2 in elts

a1 = randn(elt1, 2, 3)
a2 = randn(elt2, 4, 5)

elt_dest = promote_type(elt1, elt2)

a_dest, labels = TensorAlgebra.contract(a1, ("i", "j"), a2, ("k", "l"))
rng = StableRNG(123)
a1 = randn(rng, elt1, 2, 3)
a2 = randn(rng, elt2, 4, 5)

a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l"))
@test labels == ("i", "j", "k", "l")
@test eltype(a_dest) === elt_dest
@test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))

a_dest = TensorAlgebra.contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
@test eltype(a_dest) === elt_dest
@test a_dest ≈ permutedims(
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4)
)

a_dest = zeros(elt_dest, 2, 5, 3, 4)
TensorAlgebra.contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
@test a_dest ≈ permutedims(
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3)
)
end
@testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
elt2 in elts

elt_dest = promote_type(elt1, elt2)

rng = StableRNG(123)
a = randn(rng, elt1, (2, 3, 4, 5))
s = randn(rng, elt2, ())
t = randn(rng, elt2, ())

labels_a = ("i", "j", "k", "l")

# Array-scalar contraction.
a_dest, labels_dest = contract(a, labels_a, s, ())
@test labels_dest == labels_a
@test a_dest ≈ a * s[]

# Scalar-array contraction.
a_dest, labels_dest = contract(s, (), a, labels_a)
@test labels_dest == labels_a
@test a_dest ≈ a * s[]

# Scalar-scalar contraction.
a_dest, labels_dest = contract(s, (), t, ())
@test labels_dest == ()
@test a_dest[] ≈ s[] * t[]

# Specify output labels.
labels_dest_example = ("j", "l", "i", "k")
size_dest_example = (3, 5, 2, 4)

# Array-scalar contraction.
a_dest = contract(labels_dest_example, a, labels_a, s, ())
@test size(a_dest) == size_dest_example
@test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[]

# Scalar-array contraction.
a_dest = contract(labels_dest_example, s, (), a, labels_a)
@test size(a_dest) == size_dest_example
@test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[]

# Scalar-scalar contraction.
a_dest = contract((), s, (), t, ())
@test size(a_dest) == ()
@test a_dest[] ≈ s[] * t[]

# Array-scalar contraction.
a_dest = zeros(elt_dest, size_dest_example)
contract!(a_dest, labels_dest_example, a, labels_a, s, ())
@test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[]

# Scalar-array contraction.
a_dest = zeros(elt_dest, size_dest_example)
contract!(a_dest, labels_dest_example, s, (), a, labels_a)
@test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[]

# Scalar-scalar contraction.
a_dest = zeros(elt_dest, ())
contract!(a_dest, (), s, (), t, ())
@test a_dest[] ≈ s[] * t[]
end
end
@testset "qr (eltype=$elt)" for elt in elts
a = randn(elt, 5, 4, 3, 2)
Expand All @@ -154,8 +219,6 @@ end
labels_r = (:d, :c)
q, r = qr(a, labels_a, labels_q, labels_r)
label_qr = :qr
a′ = TensorAlgebra.contract(
labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)
)
a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...))
@test a ≈ a′
end
Loading