1
1
using EllipsisNotation: var".."
2
2
using LinearAlgebra: norm, qr
3
- using TensorAlgebra: TensorAlgebra, fusedims, splitdims
4
- default_rtol (elt:: Type ) = 10 ^ (0.75 * log10 (eps (real (elt))))
3
+ using StableRNGs: StableRNG
4
+ using TensorAlgebra: contract, contract!, fusedims, splitdims
5
+ using TensorOperations: TensorOperations
5
6
using Test: @test , @test_broken , @testset
7
+
8
+ default_rtol (elt:: Type ) = 10 ^ (0.75 * log10 (eps (real (elt))))
6
9
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
7
10
8
11
@testset " TensorAlgebra" begin
@@ -90,14 +93,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
90
93
labels_dest = map (i -> labels[i], d_dests)
91
94
92
95
# Don't specify destination labels
93
- a_dest, labels_dest′ = TensorAlgebra . contract (a1, labels1, a2, labels2)
96
+ a_dest, labels_dest′ = contract (a1, labels1, a2, labels2)
94
97
a_dest_tensoroperations = TensorOperations. tensorcontract (
95
98
labels_dest′, a1, labels1, a2, labels2
96
99
)
97
100
@test a_dest ≈ a_dest_tensoroperations
98
101
99
102
# Specify destination labels
100
- a_dest = TensorAlgebra . contract (labels_dest, a1, labels1, a2, labels2)
103
+ a_dest = contract (labels_dest, a1, labels1, a2, labels2)
101
104
a_dest_tensoroperations = TensorOperations. tensorcontract (
102
105
labels_dest, a1, labels1, a2, labels2
103
106
)
@@ -111,7 +114,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
111
114
β = elt_dest (2.4 ) # randn(elt_dest)
112
115
a_dest_init = randn (elt_dest, map (i -> dims[i], d_dests))
113
116
a_dest = copy (a_dest_init)
114
- TensorAlgebra . contract! (a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
117
+ contract! (a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
115
118
a_dest_tensoroperations = TensorOperations. tensorcontract (
116
119
labels_dest, a1, labels1, a2, labels2
117
120
)
@@ -124,28 +127,90 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
124
127
@testset " outer product contraction (eltype1=$elt1 , eltype2=$elt2 )" for elt1 in elts,
125
128
elt2 in elts
126
129
127
- a1 = randn (elt1, 2 , 3 )
128
- a2 = randn (elt2, 4 , 5 )
129
-
130
130
elt_dest = promote_type (elt1, elt2)
131
131
132
- a_dest, labels = TensorAlgebra. contract (a1, (" i" , " j" ), a2, (" k" , " l" ))
132
+ rng = StableRNG (123 )
133
+ a1 = randn (rng, elt1, 2 , 3 )
134
+ a2 = randn (rng, elt2, 4 , 5 )
135
+
136
+ a_dest, labels = contract (a1, (" i" , " j" ), a2, (" k" , " l" ))
133
137
@test labels == (" i" , " j" , " k" , " l" )
134
138
@test eltype (a_dest) === elt_dest
135
139
@test a_dest ≈ reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... ))
136
140
137
- a_dest = TensorAlgebra . contract ((" i" , " k" , " j" , " l" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
141
+ a_dest = contract ((" i" , " k" , " j" , " l" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
138
142
@test eltype (a_dest) === elt_dest
139
143
@test a_dest ≈ permutedims (
140
144
reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... )), (1 , 3 , 2 , 4 )
141
145
)
142
146
143
147
a_dest = zeros (elt_dest, 2 , 5 , 3 , 4 )
144
- TensorAlgebra . contract! (a_dest, (" i" , " l" , " j" , " k" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
148
+ contract! (a_dest, (" i" , " l" , " j" , " k" ), a1, (" i" , " j" ), a2, (" k" , " l" ))
145
149
@test a_dest ≈ permutedims (
146
150
reshape (vec (a1) * transpose (vec (a2)), (size (a1)... , size (a2)... )), (1 , 4 , 2 , 3 )
147
151
)
148
152
end
153
+ @testset " scalar contraction (eltype1=$elt1 , eltype2=$elt2 )" for elt1 in elts,
154
+ elt2 in elts
155
+
156
+ elt_dest = promote_type (elt1, elt2)
157
+
158
+ rng = StableRNG (123 )
159
+ a = randn (rng, elt1, (2 , 3 , 4 , 5 ))
160
+ s = randn (rng, elt2, ())
161
+ t = randn (rng, elt2, ())
162
+
163
+ labels_a = (" i" , " j" , " k" , " l" )
164
+
165
+ # Array-scalar contraction.
166
+ a_dest, labels_dest = contract (a, labels_a, s, ())
167
+ @test labels_dest == labels_a
168
+ @test a_dest ≈ a * s[]
169
+
170
+ # Scalar-array contraction.
171
+ a_dest, labels_dest = contract (s, (), a, labels_a)
172
+ @test labels_dest == labels_a
173
+ @test a_dest ≈ a * s[]
174
+
175
+ # Scalar-scalar contraction.
176
+ a_dest, labels_dest = contract (s, (), t, ())
177
+ @test labels_dest == ()
178
+ @test a_dest[] ≈ s[] * t[]
179
+
180
+ # Specify output labels.
181
+ labels_dest_example = (" j" , " l" , " i" , " k" )
182
+ size_dest_example = (3 , 5 , 2 , 4 )
183
+
184
+ # Array-scalar contraction.
185
+ a_dest = contract (labels_dest_example, a, labels_a, s, ())
186
+ @test size (a_dest) == size_dest_example
187
+ @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
188
+
189
+ # Scalar-array contraction.
190
+ a_dest = contract (labels_dest_example, s, (), a, labels_a)
191
+ @test size (a_dest) == size_dest_example
192
+ @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
193
+
194
+ # Scalar-scalar contraction.
195
+ a_dest = contract ((), s, (), t, ())
196
+ @test size (a_dest) == ()
197
+ @test a_dest[] ≈ s[] * t[]
198
+
199
+ # Array-scalar contraction.
200
+ a_dest = zeros (elt_dest, size_dest_example)
201
+ contract! (a_dest, labels_dest_example, a, labels_a, s, ())
202
+ @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
203
+
204
+ # Scalar-array contraction.
205
+ a_dest = zeros (elt_dest, size_dest_example)
206
+ contract! (a_dest, labels_dest_example, s, (), a, labels_a)
207
+ @test a_dest ≈ permutedims (a, (2 , 4 , 1 , 3 )) * s[]
208
+
209
+ # Scalar-scalar contraction.
210
+ a_dest = zeros (elt_dest, ())
211
+ contract! (a_dest, (), s, (), t, ())
212
+ @test a_dest[] ≈ s[] * t[]
213
+ end
149
214
end
150
215
@testset " qr (eltype=$elt )" for elt in elts
151
216
a = randn (elt, 5 , 4 , 3 , 2 )
154
219
labels_r = (:d , :c )
155
220
q, r = qr (a, labels_a, labels_q, labels_r)
156
221
label_qr = :qr
157
- a′ = TensorAlgebra. contract (
158
- labels_a, q, (labels_q... , label_qr), r, (label_qr, labels_r... )
159
- )
222
+ a′ = contract (labels_a, q, (labels_q... , label_qr), r, (label_qr, labels_r... ))
160
223
@test a ≈ a′
161
224
end
0 commit comments