Skip to content

Commit

Permalink
Fix scalar contraction
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jan 15, 2025
1 parent 840975a commit 0d031c4
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 20 deletions.
55 changes: 49 additions & 6 deletions src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function output_axes(
biperm1::BlockedPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation{2},
α::Number=true,
α::Number=one(Bool),
)
axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1)
axes_contracted2, axes_domain = blockpermute(axes(a2), biperm2)
Expand All @@ -27,7 +27,7 @@ function output_axes(
perm1::BlockedPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{1},
α::Number=true,
α::Number=one(Bool),
)
axes_contracted = blockpermute(axes(a1), perm1)
axes_contracted′ = blockpermute(axes(a2), perm2)
Expand All @@ -43,7 +43,7 @@ function output_axes(
perm1::BlockedPermutation{1},
a2::AbstractArray,
biperm2::BlockedPermutation{2},
α::Number=true,
α::Number=one(Bool),
)
(axes_contracted,) = blockpermute(axes(a1), perm1)
axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2)
Expand All @@ -59,7 +59,7 @@ function output_axes(
perm1::BlockedPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation{1},
α::Number=true,
α::Number=one(Bool),
)
axes_dest, axes_contracted = blockpermute(axes(a1), perm1)
(axes_contracted′,) = blockpermute(axes(a2), biperm2)
Expand All @@ -75,14 +75,57 @@ function output_axes(
perm1::BlockedPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{1},
α::Number=true,
α::Number=one(Bool),
)
@assert istrivialperm(Tuple(perm1))
@assert istrivialperm(Tuple(perm2))
axes_dest = (axes(a1)..., axes(a2)...)
return genperm(axes_dest, invperm(Tuple(biperm_dest)))
end

# Array-scalar contraction.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{0},
α::Number=one(Bool),
)
@assert istrivialperm(Tuple(perm1))
axes_dest = axes(a1)
return genperm(axes_dest, invperm(Tuple(perm_dest)))
end

# Scalar-array contraction.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{0},
a2::AbstractArray,
perm2::BlockedPermutation{1},
α::Number=one(Bool),
)
@assert istrivialperm(Tuple(perm2))
axes_dest = axes(a2)
return genperm(axes_dest, invperm(Tuple(perm_dest)))
end

# Scalar-scalar contraction.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{0},
a1::AbstractArray,
perm1::BlockedPermutation{0},
a2::AbstractArray,
perm2::BlockedPermutation{0},
α::Number=one(Bool),
)
return ()
end

# TODO: Use `ArrayLayouts`-like `MulAdd` object,
# i.e. `ContractAdd`?
function allocate_output(
Expand All @@ -92,7 +135,7 @@ function allocate_output(
biperm1::BlockedPermutation,
a2::AbstractArray,
biperm2::BlockedPermutation,
α::Number=true,
α::Number=one(Bool),
)
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
Expand Down
39 changes: 39 additions & 0 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,42 @@ function _mul!(
mul!(a_dest, a1, transpose(a2), α, β)
return a_dest
end

# Array-scalar contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractVector,
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
α′ = a2[] * α
a_dest .= a1 .* α′ .+ a_dest .* β
return a_dest
end

# Scalar-array contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractArray{<:Any,0},
a2::AbstractVector,
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest .= a1[] .* a2 .* α .+ a_dest .* β
return a_dest
end

# Scalar-scalar contraction.
function _mul!(
a_dest::AbstractArray{<:Any,0},
a1::AbstractArray{<:Any,0},
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest[] = a1[] * a2[] * α + a_dest[] * β
return a_dest
end
91 changes: 77 additions & 14 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using EllipsisNotation: var".."
using LinearAlgebra: norm, qr
using TensorAlgebra: TensorAlgebra, fusedims, splitdims
default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
using StableRNGs: StableRNG
using TensorAlgebra: contract, contract!, fusedims, splitdims
using TensorOperations: TensorOperations
using Test: @test, @test_broken, @testset

default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

@testset "TensorAlgebra" begin
Expand Down Expand Up @@ -90,14 +93,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
labels_dest = map(i -> labels[i], d_dests)

# Don't specify destination labels
a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2)
a_dest, labels_dest′ = contract(a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest′, a1, labels1, a2, labels2
)
@test a_dest a_dest_tensoroperations

# Specify destination labels
a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
a_dest = contract(labels_dest, a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
Expand All @@ -111,7 +114,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
β = elt_dest(2.4) # randn(elt_dest)
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
a_dest = copy(a_dest_init)
TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
Expand All @@ -124,28 +127,90 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
elt2 in elts

a1 = randn(elt1, 2, 3)
a2 = randn(elt2, 4, 5)

elt_dest = promote_type(elt1, elt2)

a_dest, labels = TensorAlgebra.contract(a1, ("i", "j"), a2, ("k", "l"))
rng = StableRNG(123)
a1 = randn(rng, elt1, 2, 3)
a2 = randn(rng, elt2, 4, 5)

a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l"))
@test labels == ("i", "j", "k", "l")
@test eltype(a_dest) === elt_dest
@test a_dest reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))

a_dest = TensorAlgebra.contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
@test eltype(a_dest) === elt_dest
@test a_dest permutedims(
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4)
)

a_dest = zeros(elt_dest, 2, 5, 3, 4)
TensorAlgebra.contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
@test a_dest permutedims(
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3)
)
end
@testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
elt2 in elts

elt_dest = promote_type(elt1, elt2)

rng = StableRNG(123)
a = randn(rng, elt1, (2, 3, 4, 5))
s = randn(rng, elt2, ())
t = randn(rng, elt2, ())

labels_a = ("i", "j", "k", "l")

# Array-scalar contraction.
a_dest, labels_dest = contract(a, labels_a, s, ())
@test labels_dest == labels_a
@test a_dest a * s[]

# Scalar-array contraction.
a_dest, labels_dest = contract(s, (), a, labels_a)
@test labels_dest == labels_a
@test a_dest a * s[]

# Scalar-scalar contraction.
a_dest, labels_dest = contract(s, (), t, ())
@test labels_dest == ()
@test a_dest[] s[] * t[]

# Specify output labels.
labels_dest_example = ("j", "l", "i", "k")
size_dest_example = (3, 5, 2, 4)

# Array-scalar contraction.
a_dest = contract(labels_dest_example, a, labels_a, s, ())
@test size(a_dest) == size_dest_example
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]

# Scalar-array contraction.
a_dest = contract(labels_dest_example, s, (), a, labels_a)
@test size(a_dest) == size_dest_example
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]

# Scalar-scalar contraction.
a_dest = contract((), s, (), t, ())
@test size(a_dest) == ()
@test a_dest[] s[] * t[]

# Array-scalar contraction.
a_dest = zeros(elt_dest, size_dest_example)
contract!(a_dest, labels_dest_example, a, labels_a, s, ())
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]

# Scalar-array contraction.
a_dest = zeros(elt_dest, size_dest_example)
contract!(a_dest, labels_dest_example, s, (), a, labels_a)
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]

# Scalar-scalar contraction.
a_dest = zeros(elt_dest, ())
contract!(a_dest, (), s, (), t, ())
@test a_dest[] s[] * t[]
end
end
@testset "qr (eltype=$elt)" for elt in elts
a = randn(elt, 5, 4, 3, 2)
Expand All @@ -154,8 +219,6 @@ end
labels_r = (:d, :c)
q, r = qr(a, labels_a, labels_q, labels_r)
label_qr = :qr
a′ = TensorAlgebra.contract(
labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)
)
a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...))
@test a a′
end

0 comments on commit 0d031c4

Please sign in to comment.