diff --git a/Project.toml b/Project.toml index 99aa433..6e36442 100644 --- a/Project.toml +++ b/Project.toml @@ -13,9 +13,11 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" [weakdeps] GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5" +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" [extensions] TensorAlgebraGradedUnitRangesExt = "GradedUnitRanges" +TensorAlgebraTensorOperationsExt = "TensorOperations" [compat] ArrayLayouts = "1.10.4" @@ -23,6 +25,7 @@ BlockArrays = "1.2.0" EllipsisNotation = "1.8.0" GradedUnitRanges = "0.1.0" LinearAlgebra = "1.10" +TensorOperations = "5" TupleTools = "1.6.0" TypeParameterAccessors = "0.2.1, 0.3" julia = "1.10" diff --git a/ext/TensorAlgebraTensorOperationsExt.jl b/ext/TensorAlgebraTensorOperationsExt.jl new file mode 100644 index 0000000..ba9f587 --- /dev/null +++ b/ext/TensorAlgebraTensorOperationsExt.jl @@ -0,0 +1,83 @@ +module TensorAlgebraTensorOperationsExt + +using TensorAlgebra: TensorAlgebra, BlockedPermutation +using TupleTools +using TensorOperations +using TensorOperations: AbstractBackend as TOAlgorithm + +TensorAlgebra.Algorithm(backend::TOAlgorithm) = backend + +trivtuple(n) = ntuple(identity, n) + +function _index2tuple(p::BlockedPermutation{2}) + N₁, N₂ = blocklengths(p) + return ( + TupleTools.getindices(Tuple(p), trivtuple(N₁)), + TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)), + ) +end + +# not in-place +# ------------ +function TensorAlgebra.contract( + backend::TOAlgorithm, + pAB::BlockedPermutation, + A::AbstractArray, + pA::BlockedPermutation, + B::AbstractArray, + pB::BlockedPermutation, + α::Number, +) + pA′ = _index2tuple(pA) + pB′ = _index2tuple(pB) + pAB′ = _index2tuple(pAB) + return tensorcontract(A, pA′, false, B, pB′, false, pAB′, α, backend) +end + +function TensorAlgebra.contract( + backend::TOAlgorithm, + labelsC, + A::AbstractArray, + labelsA, + B::AbstractArray, + labelsB, + α::Number, +) + return tensorcontract(labelsC, A, labelsA, B, labelsB, α; backend) +end + +# in-place +# -------- +function TensorAlgebra.contract!( + backend::TOAlgorithm, + C::AbstractArray, + pAB::BlockedPermutation, + A::AbstractArray, + pA::BlockedPermutation, + B::AbstractArray, + pB::BlockedPermutation, + α::Number, + β::Number, +) + pA′ = _index2tuple(pA) + pB′ = _index2tuple(pB) + pAB′ = _index2tuple(pAB) + return tensorcontract!(C, A, pA′, false, B, pB′, false, pAB′, α, β, backend) +end + +function TensorAlgebra.contract!( + backend::TOAlgorithm, + C::AbstractArray, + labelsC, + A::AbstractArray, + labelsA, + B::AbstractArray, + labelsB, + α::Number, + β::Number, +) + pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC) + return TensorOperations.tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend) +end + +end diff --git a/test/Project.toml b/test/Project.toml index 7258ca4..e00095e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -32,6 +32,6 @@ StableRNGs = "1.0.2" Suppressor = "0.2" SymmetrySectors = "0.1" TensorAlgebra = "0.2.0" -TensorOperations = "5.1.3" +TensorOperations = "5" Test = "1.10" TestExtras = "0.3.1"