Skip to content

Commit 11f8157

Browse files
authored
Fix scalar contraction (#18)
1 parent 840975a commit 11f8157

File tree

5 files changed

+167
-21
lines changed

5 files changed

+167
-21
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.4"
4+
version = "0.1.5"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/contract/allocate_output.jl

+49-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function output_axes(
99
biperm1::BlockedPermutation{2},
1010
a2::AbstractArray,
1111
biperm2::BlockedPermutation{2},
12-
α::Number=true,
12+
α::Number=one(Bool),
1313
)
1414
axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1)
1515
axes_contracted2, axes_domain = blockpermute(axes(a2), biperm2)
@@ -27,7 +27,7 @@ function output_axes(
2727
perm1::BlockedPermutation{1},
2828
a2::AbstractArray,
2929
perm2::BlockedPermutation{1},
30-
α::Number=true,
30+
α::Number=one(Bool),
3131
)
3232
axes_contracted = blockpermute(axes(a1), perm1)
3333
axes_contracted′ = blockpermute(axes(a2), perm2)
@@ -43,7 +43,7 @@ function output_axes(
4343
perm1::BlockedPermutation{1},
4444
a2::AbstractArray,
4545
biperm2::BlockedPermutation{2},
46-
α::Number=true,
46+
α::Number=one(Bool),
4747
)
4848
(axes_contracted,) = blockpermute(axes(a1), perm1)
4949
axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2)
@@ -59,7 +59,7 @@ function output_axes(
5959
perm1::BlockedPermutation{2},
6060
a2::AbstractArray,
6161
biperm2::BlockedPermutation{1},
62-
α::Number=true,
62+
α::Number=one(Bool),
6363
)
6464
axes_dest, axes_contracted = blockpermute(axes(a1), perm1)
6565
(axes_contracted′,) = blockpermute(axes(a2), biperm2)
@@ -75,14 +75,57 @@ function output_axes(
7575
perm1::BlockedPermutation{1},
7676
a2::AbstractArray,
7777
perm2::BlockedPermutation{1},
78-
α::Number=true,
78+
α::Number=one(Bool),
7979
)
8080
@assert istrivialperm(Tuple(perm1))
8181
@assert istrivialperm(Tuple(perm2))
8282
axes_dest = (axes(a1)..., axes(a2)...)
8383
return genperm(axes_dest, invperm(Tuple(biperm_dest)))
8484
end
8585

86+
# Array-scalar contraction.
87+
function output_axes(
88+
::typeof(contract),
89+
perm_dest::BlockedPermutation{1},
90+
a1::AbstractArray,
91+
perm1::BlockedPermutation{1},
92+
a2::AbstractArray,
93+
perm2::BlockedPermutation{0},
94+
α::Number=one(Bool),
95+
)
96+
@assert istrivialperm(Tuple(perm1))
97+
axes_dest = axes(a1)
98+
return genperm(axes_dest, invperm(Tuple(perm_dest)))
99+
end
100+
101+
# Scalar-array contraction.
102+
function output_axes(
103+
::typeof(contract),
104+
perm_dest::BlockedPermutation{1},
105+
a1::AbstractArray,
106+
perm1::BlockedPermutation{0},
107+
a2::AbstractArray,
108+
perm2::BlockedPermutation{1},
109+
α::Number=one(Bool),
110+
)
111+
@assert istrivialperm(Tuple(perm2))
112+
axes_dest = axes(a2)
113+
return genperm(axes_dest, invperm(Tuple(perm_dest)))
114+
end
115+
116+
# Scalar-scalar contraction.
117+
function output_axes(
118+
::typeof(contract),
119+
perm_dest::BlockedPermutation{0},
120+
a1::AbstractArray,
121+
perm1::BlockedPermutation{0},
122+
a2::AbstractArray,
123+
perm2::BlockedPermutation{0},
124+
α::Number=one(Bool),
125+
)
126+
return ()
127+
end
128+
86129
# TODO: Use `ArrayLayouts`-like `MulAdd` object,
87130
# i.e. `ContractAdd`?
88131
function allocate_output(
@@ -92,7 +135,7 @@ function allocate_output(
92135
biperm1::BlockedPermutation,
93136
a2::AbstractArray,
94137
biperm2::BlockedPermutation,
95-
α::Number=true,
138+
α::Number=one(Bool),
96139
)
97140
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
98141
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)

src/contract/contract_matricize/contract.jl

+39
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,42 @@ function _mul!(
6262
mul!(a_dest, a1, transpose(a2), α, β)
6363
return a_dest
6464
end
65+
66+
# Array-scalar contraction.
67+
function _mul!(
68+
a_dest::AbstractVector,
69+
a1::AbstractVector,
70+
a2::AbstractArray{<:Any,0},
71+
α::Number,
72+
β::Number,
73+
)
74+
α′ = a2[] * α
75+
a_dest .= a1 .* α′ .+ a_dest .* β
76+
return a_dest
77+
end
78+
79+
# Scalar-array contraction.
80+
function _mul!(
81+
a_dest::AbstractVector,
82+
a1::AbstractArray{<:Any,0},
83+
a2::AbstractVector,
84+
α::Number,
85+
β::Number,
86+
)
87+
# Preserve the ordering in case of non-commutative algebra.
88+
a_dest .= a1[] .* a2 .* α .+ a_dest .* β
89+
return a_dest
90+
end
91+
92+
# Scalar-scalar contraction.
93+
function _mul!(
94+
a_dest::AbstractArray{<:Any,0},
95+
a1::AbstractArray{<:Any,0},
96+
a2::AbstractArray{<:Any,0},
97+
α::Number,
98+
β::Number,
99+
)
100+
# Preserve the ordering in case of non-commutative algebra.
101+
a_dest[] = a1[] * a2[] * α + a_dest[] * β
102+
return a_dest
103+
end

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1010
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
11+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1112
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1213
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1314
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

test/test_basics.jl

+77-14
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
using EllipsisNotation: var".."
22
using LinearAlgebra: norm, qr
3-
using TensorAlgebra: TensorAlgebra, fusedims, splitdims
4-
default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
3+
using StableRNGs: StableRNG
4+
using TensorAlgebra: contract, contract!, fusedims, splitdims
5+
using TensorOperations: TensorOperations
56
using Test: @test, @test_broken, @testset
7+
8+
default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
69
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
710

811
@testset "TensorAlgebra" begin
@@ -90,14 +93,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
9093
labels_dest = map(i -> labels[i], d_dests)
9194

9295
# Don't specify destination labels
93-
a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2)
96+
a_dest, labels_dest′ = contract(a1, labels1, a2, labels2)
9497
a_dest_tensoroperations = TensorOperations.tensorcontract(
9598
labels_dest′, a1, labels1, a2, labels2
9699
)
97100
@test a_dest a_dest_tensoroperations
98101

99102
# Specify destination labels
100-
a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
103+
a_dest = contract(labels_dest, a1, labels1, a2, labels2)
101104
a_dest_tensoroperations = TensorOperations.tensorcontract(
102105
labels_dest, a1, labels1, a2, labels2
103106
)
@@ -111,7 +114,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
111114
β = elt_dest(2.4) # randn(elt_dest)
112115
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
113116
a_dest = copy(a_dest_init)
114-
TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
117+
contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
115118
a_dest_tensoroperations = TensorOperations.tensorcontract(
116119
labels_dest, a1, labels1, a2, labels2
117120
)
@@ -124,28 +127,90 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
124127
@testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
125128
elt2 in elts
126129

127-
a1 = randn(elt1, 2, 3)
128-
a2 = randn(elt2, 4, 5)
129-
130130
elt_dest = promote_type(elt1, elt2)
131131

132-
a_dest, labels = TensorAlgebra.contract(a1, ("i", "j"), a2, ("k", "l"))
132+
rng = StableRNG(123)
133+
a1 = randn(rng, elt1, 2, 3)
134+
a2 = randn(rng, elt2, 4, 5)
135+
136+
a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l"))
133137
@test labels == ("i", "j", "k", "l")
134138
@test eltype(a_dest) === elt_dest
135139
@test a_dest reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))
136140

137-
a_dest = TensorAlgebra.contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
141+
a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
138142
@test eltype(a_dest) === elt_dest
139143
@test a_dest permutedims(
140144
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4)
141145
)
142146

143147
a_dest = zeros(elt_dest, 2, 5, 3, 4)
144-
TensorAlgebra.contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
148+
contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
145149
@test a_dest permutedims(
146150
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3)
147151
)
148152
end
153+
@testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
154+
elt2 in elts
155+
156+
elt_dest = promote_type(elt1, elt2)
157+
158+
rng = StableRNG(123)
159+
a = randn(rng, elt1, (2, 3, 4, 5))
160+
s = randn(rng, elt2, ())
161+
t = randn(rng, elt2, ())
162+
163+
labels_a = ("i", "j", "k", "l")
164+
165+
# Array-scalar contraction.
166+
a_dest, labels_dest = contract(a, labels_a, s, ())
167+
@test labels_dest == labels_a
168+
@test a_dest a * s[]
169+
170+
# Scalar-array contraction.
171+
a_dest, labels_dest = contract(s, (), a, labels_a)
172+
@test labels_dest == labels_a
173+
@test a_dest a * s[]
174+
175+
# Scalar-scalar contraction.
176+
a_dest, labels_dest = contract(s, (), t, ())
177+
@test labels_dest == ()
178+
@test a_dest[] s[] * t[]
179+
180+
# Specify output labels.
181+
labels_dest_example = ("j", "l", "i", "k")
182+
size_dest_example = (3, 5, 2, 4)
183+
184+
# Array-scalar contraction.
185+
a_dest = contract(labels_dest_example, a, labels_a, s, ())
186+
@test size(a_dest) == size_dest_example
187+
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]
188+
189+
# Scalar-array contraction.
190+
a_dest = contract(labels_dest_example, s, (), a, labels_a)
191+
@test size(a_dest) == size_dest_example
192+
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]
193+
194+
# Scalar-scalar contraction.
195+
a_dest = contract((), s, (), t, ())
196+
@test size(a_dest) == ()
197+
@test a_dest[] s[] * t[]
198+
199+
# Array-scalar contraction.
200+
a_dest = zeros(elt_dest, size_dest_example)
201+
contract!(a_dest, labels_dest_example, a, labels_a, s, ())
202+
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]
203+
204+
# Scalar-array contraction.
205+
a_dest = zeros(elt_dest, size_dest_example)
206+
contract!(a_dest, labels_dest_example, s, (), a, labels_a)
207+
@test a_dest permutedims(a, (2, 4, 1, 3)) * s[]
208+
209+
# Scalar-scalar contraction.
210+
a_dest = zeros(elt_dest, ())
211+
contract!(a_dest, (), s, (), t, ())
212+
@test a_dest[] s[] * t[]
213+
end
149214
end
150215
@testset "qr (eltype=$elt)" for elt in elts
151216
a = randn(elt, 5, 4, 3, 2)
@@ -154,8 +219,6 @@ end
154219
labels_r = (:d, :c)
155220
q, r = qr(a, labels_a, labels_q, labels_r)
156221
label_qr = :qr
157-
a′ = TensorAlgebra.contract(
158-
labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)
159-
)
222+
a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...))
160223
@test a a′
161224
end

0 commit comments

Comments
 (0)