Skip to content

Commit 063a107

Browse files
authored
Rename contract! to contractadd! when α, β are specified (#82)
1 parent 5873db1 commit 063a107

File tree

12 files changed

+70
-50
lines changed

12 files changed

+70
-50
lines changed

.github/workflows/IntegrationTest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jobs:
1818
matrix:
1919
pkg:
2020
- 'BlockSparseArrays'
21+
- 'FusionTensors'
2122
- 'NamedDimsArrays'
2223
uses: "ITensor/ITensorActions/.github/workflows/IntegrationTest.yml@main"
2324
with:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.16"
4+
version = "0.4.0"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
66
[compat]
77
Documenter = "1.8.1"
88
Literate = "2.20.1"
9-
TensorAlgebra = "0.3.0"
9+
TensorAlgebra = "0.4.0"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33

44
[compat]
5-
TensorAlgebra = "0.3.0"
5+
TensorAlgebra = "0.4.0"

ext/TensorAlgebraTensorOperationsExt.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
module TensorAlgebraTensorOperationsExt
22

3-
using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm
4-
using TupleTools
5-
using TensorOperations
6-
using TensorOperations: AbstractBackend, DefaultBackend
3+
using TensorAlgebra: TensorAlgebra, BlockedPermutation, Algorithm, blocklengths
4+
using TupleTools: TupleTools
5+
using TensorOperations: TensorOperations, AbstractBackend, DefaultBackend, Index2Tuple
76

87
"""
98
TensorOperationsAlgorithm(backend::AbstractBackend)
@@ -44,8 +43,9 @@ function TensorAlgebra.contract(
4443
pA = _index2tuple(bipermA)
4544
pB = _index2tuple(bipermB)
4645
pAB = _index2tuple(bipermAB)
47-
48-
return tensorcontract(A, pA, false, B, pB, false, pAB, α, algorithm.backend)
46+
return TensorOperations.tensorcontract(
47+
A, pA, false, B, pB, false, pAB, α, algorithm.backend
48+
)
4949
end
5050

5151
function TensorAlgebra.contract(
@@ -62,7 +62,7 @@ function TensorAlgebra.contract(
6262
end
6363

6464
# in-place
65-
function TensorAlgebra.contract!(
65+
function TensorAlgebra.contractadd!(
6666
algorithm::TensorOperationsAlgorithm,
6767
C::AbstractArray,
6868
bipermAB::BlockedPermutation,
@@ -76,10 +76,12 @@ function TensorAlgebra.contract!(
7676
pA = _index2tuple(bipermA)
7777
pB = _index2tuple(bipermB)
7878
pAB = _index2tuple(bipermAB)
79-
return tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend)
79+
return TensorOperations.tensorcontract!(
80+
C, A, pA, false, B, pB, false, pAB, α, β, algorithm.backend
81+
)
8082
end
8183

82-
function TensorAlgebra.contract!(
84+
function TensorAlgebra.contractadd!(
8385
algorithm::TensorOperationsAlgorithm,
8486
C::AbstractArray,
8587
labelsC,
@@ -117,7 +119,7 @@ function TensorOperations.tensorcontract!(
117119
bipermAB = _blockedpermutation(pAB)
118120
A′ = conjA ? conj(A) : A
119121
B′ = conjB ? conj(B) : B
120-
return TensorAlgebra.contract!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β)
122+
return TensorAlgebra.contractadd!(backend, C, bipermAB, A′, bipermA, B′, bipermB, α, β)
121123
end
122124

123125
# For now no trace/add is supported, so simply reselect default backend from TensorOperations

src/contract/allocate_output.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ function output_axes(
2222
biperm1::AbstractBlockPermutation{2},
2323
a2::AbstractArray,
2424
biperm2::AbstractBlockPermutation{2},
25-
α::Number=one(Bool),
2625
)
2726
axes_codomain, axes_contracted = blocks(axes(a1)[biperm1])
2827
axes_contracted2, axes_domain = blocks(axes(a2)[biperm2])
@@ -40,9 +39,8 @@ function allocate_output(
4039
biperm1::AbstractBlockPermutation,
4140
a2::AbstractArray,
4241
biperm2::AbstractBlockPermutation,
43-
α::Number=one(Bool),
4442
)
4543
check_input(contract, a1, biperm1, a2, biperm2)
46-
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
47-
return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
44+
axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2)
45+
return similar(a1, promote_type(eltype(a1), eltype(a2)), axes_dest)
4846
end

src/contract/contract.jl

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ default_contract_alg() = Matricize()
1010

1111
# Required interface if not using
1212
# matricized contraction.
13-
function contract!(
13+
function contractadd!(
1414
alg::Algorithm,
1515
a_dest::AbstractArray,
1616
biperm_dest::AbstractBlockPermutation,
@@ -28,53 +28,59 @@ function contract(
2828
a1::AbstractArray,
2929
labels1,
3030
a2::AbstractArray,
31-
labels2,
32-
α::Number=one(Bool);
31+
labels2;
3332
alg=default_contract_alg(),
3433
kwargs...,
3534
)
36-
return contract(Algorithm(alg), a1, labels1, a2, labels2, α; kwargs...)
35+
return contract(Algorithm(alg), a1, labels1, a2, labels2; kwargs...)
3736
end
3837

3938
function contract(
40-
alg::Algorithm,
39+
alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...
40+
)
41+
labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2; kwargs...)
42+
return contract(alg, labels_dest, a1, labels1, a2, labels2; kwargs...), labels_dest
43+
end
44+
45+
function contract(
46+
labels_dest,
4147
a1::AbstractArray,
4248
labels1,
4349
a2::AbstractArray,
44-
labels2,
45-
α::Number=one(Bool);
50+
labels2;
51+
alg=default_contract_alg(),
4652
kwargs...,
4753
)
48-
labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α; kwargs...)
49-
return contract(alg, labels_dest, a1, labels1, a2, labels2, α; kwargs...), labels_dest
54+
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2; kwargs...)
5055
end
5156

52-
function contract(
57+
function contract!(
58+
a_dest::AbstractArray,
5359
labels_dest,
5460
a1::AbstractArray,
5561
labels1,
5662
a2::AbstractArray,
57-
labels2,
58-
α::Number=one(Bool);
59-
alg=default_contract_alg(),
63+
labels2;
6064
kwargs...,
6165
)
62-
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2, α; kwargs...)
66+
return contractadd!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...)
6367
end
6468

65-
function contract!(
69+
function contractadd!(
6670
a_dest::AbstractArray,
6771
labels_dest,
6872
a1::AbstractArray,
6973
labels1,
7074
a2::AbstractArray,
7175
labels2,
72-
α::Number=one(Bool),
73-
β::Number=zero(Bool);
76+
α::Number,
77+
β::Number;
7478
alg=default_contract_alg(),
7579
kwargs...,
7680
)
77-
contract!(Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...)
81+
contractadd!(
82+
Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...
83+
)
7884
return a_dest
7985
end
8086

@@ -84,16 +90,30 @@ function contract(
8490
a1::AbstractArray,
8591
labels1,
8692
a2::AbstractArray,
87-
labels2,
88-
α::Number=one(Bool);
93+
labels2;
8994
kwargs...,
9095
)
9196
check_input(contract, a1, labels1, a2, labels2)
9297
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
93-
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2, α; kwargs...)
98+
return contract(alg, biperm_dest, a1, biperm1, a2, biperm2; kwargs...)
9499
end
95100

96101
function contract!(
102+
alg::Algorithm,
103+
a_dest::AbstractArray,
104+
labels_dest,
105+
a1::AbstractArray,
106+
labels1,
107+
a2::AbstractArray,
108+
labels2;
109+
kwargs...,
110+
)
111+
return contractadd!(
112+
alg, a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...
113+
)
114+
end
115+
116+
function contractadd!(
97117
alg::Algorithm,
98118
a_dest::AbstractArray,
99119
labels_dest,
@@ -107,7 +127,7 @@ function contract!(
107127
)
108128
check_input(contract, a_dest, labels_dest, a1, labels1, a2, labels2)
109129
biperm_dest, biperm1, biperm2 = blockedperms(contract, labels_dest, labels1, labels2)
110-
return contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
130+
return contractadd!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
111131
end
112132

113133
function contract(
@@ -116,12 +136,11 @@ function contract(
116136
a1::AbstractArray,
117137
biperm1::AbstractBlockPermutation,
118138
a2::AbstractArray,
119-
biperm2::AbstractBlockPermutation,
120-
α::Number;
139+
biperm2::AbstractBlockPermutation;
121140
kwargs...,
122141
)
123142
check_input(contract, a1, biperm1, a2, biperm2)
124-
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
125-
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, zero(Bool); kwargs...)
143+
a_dest = allocate_output(contract, biperm_dest, a1, biperm1, a2, biperm2)
144+
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2; kwargs...)
126145
return a_dest
127146
end

src/contract/contract_matricize/contract.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearAlgebra: mul!
22

3-
function contract!(
3+
function contractadd!(
44
::Matricize,
55
a_dest::AbstractArray,
66
biperm_dest::AbstractBlockPermutation{2},
@@ -17,6 +17,6 @@ function contract!(
1717
a1_mat = matricize(a1, biperm1)
1818
a2_mat = matricize(a2, biperm2)
1919
a_dest_mat = a1_mat * a2_mat
20-
unmatricize_add!(a_dest, a_dest_mat, invbiperm, α, β)
20+
unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β)
2121
return a_dest
2222
end

src/contract/output_labels.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ function output_labels(
55
labels1,
66
a2::AbstractArray,
77
labels2,
8-
α,
98
)
109
return output_labels(f, alg, labels1, labels2)
1110
end

src/matricize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function unmatricize!(a_dest, m::AbstractMatrix, invbiperm::AbstractBlockPermuta
122122
return permuteblockeddims!(a_dest, a_perm, biperm_dest)
123123
end
124124

125-
function unmatricize_add!(a_dest, a_dest_mat, invbiperm, α, β)
125+
function unmatricizeadd!(a_dest, a_dest_mat, invbiperm, α, β)
126126
a12 = unmatricize(a_dest_mat, axes(a_dest), invbiperm)
127127
a_dest .= α .* a12 .+ β .* a_dest
128128
return a_dest

0 commit comments

Comments
 (0)