Skip to content

[WIP] TensorOperations extension #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[weakdeps]
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

[extensions]
TensorAlgebraGradedUnitRangesExt = "GradedUnitRanges"
TensorAlgebraTensorOperationsExt = "TensorOperations"

[compat]
ArrayLayouts = "1.10.4"
BlockArrays = "1.2.0"
EllipsisNotation = "1.8.0"
GradedUnitRanges = "0.1.0"
LinearAlgebra = "1.10"
TensorOperations = "5"
TupleTools = "1.6.0"
TypeParameterAccessors = "0.2.1, 0.3"
julia = "1.10"
83 changes: 83 additions & 0 deletions ext/TensorAlgebraTensorOperationsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
module TensorAlgebraTensorOperationsExt

using TensorAlgebra: TensorAlgebra, BlockedPermutation
using TupleTools
using TensorOperations
using TensorOperations: AbstractBackend as TOAlgorithm

TensorAlgebra.Algorithm(backend::TOAlgorithm) = backend

Check warning on line 8 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L8

Added line #L8 was not covered by tests

trivtuple(n) = ntuple(identity, n)

Check warning on line 10 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L10

Added line #L10 was not covered by tests

function _index2tuple(p::BlockedPermutation{2})
N₁, N₂ = blocklengths(p)
return (

Check warning on line 14 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L12-L14

Added lines #L12 - L14 were not covered by tests
TupleTools.getindices(Tuple(p), trivtuple(N₁)),
TupleTools.getindices(Tuple(p), N₁ .+ trivtuple(N₂)),
)
end

# not in-place
# ------------
function TensorAlgebra.contract(

Check warning on line 22 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L22

Added line #L22 was not covered by tests
backend::TOAlgorithm,
pAB::BlockedPermutation,
A::AbstractArray,
pA::BlockedPermutation,
B::AbstractArray,
pB::BlockedPermutation,
α::Number,
)
pA′ = _index2tuple(pA)
pB′ = _index2tuple(pB)
pAB′ = _index2tuple(pAB)
return tensorcontract(A, pA′, false, B, pB′, false, pAB′, α, backend)

Check warning on line 34 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L31-L34

Added lines #L31 - L34 were not covered by tests
end

function TensorAlgebra.contract(

Check warning on line 37 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L37

Added line #L37 was not covered by tests
backend::TOAlgorithm,
labelsC,
A::AbstractArray,
labelsA,
B::AbstractArray,
labelsB,
α::Number,
)
return tensorcontract(labelsC, A, labelsA, B, labelsB, α; backend)

Check warning on line 46 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L46

Added line #L46 was not covered by tests
end

# in-place
# --------
function TensorAlgebra.contract!(

Check warning on line 51 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L51

Added line #L51 was not covered by tests
backend::TOAlgorithm,
C::AbstractArray,
pAB::BlockedPermutation,
A::AbstractArray,
pA::BlockedPermutation,
B::AbstractArray,
pB::BlockedPermutation,
α::Number,
β::Number,
)
pA′ = _index2tuple(pA)
pB′ = _index2tuple(pB)
pAB′ = _index2tuple(pAB)
return tensorcontract!(C, A, pA′, false, B, pB′, false, pAB′, α, β, backend)

Check warning on line 65 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L62-L65

Added lines #L62 - L65 were not covered by tests
end

function TensorAlgebra.contract!(

Check warning on line 68 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L68

Added line #L68 was not covered by tests
backend::TOAlgorithm,
C::AbstractArray,
labelsC,
A::AbstractArray,
labelsA,
B::AbstractArray,
labelsB,
α::Number,
β::Number,
)
pA, pB, pAB = TensorOperations.contract_indices(labelsA, labelsB, labelsC)
return TensorOperations.tensorcontract!(C, A, pA, false, B, pB, false, pAB, α, β, backend)

Check warning on line 80 in ext/TensorAlgebraTensorOperationsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraTensorOperationsExt.jl#L79-L80

Added lines #L79 - L80 were not covered by tests
end

end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ StableRNGs = "1.0.2"
Suppressor = "0.2"
SymmetrySectors = "0.1"
TensorAlgebra = "0.2.0"
TensorOperations = "5.1.3"
TensorOperations = "5"
Test = "1.10"
TestExtras = "0.3.1"
Loading