Skip to content

Commit 6eb616f

Browse files
authored
Define tensor_product (#2)
1 parent 9fe1fe9 commit 6eb616f

File tree

11 files changed

+130
-10
lines changed

11 files changed

+130
-10
lines changed

Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,12 @@ uuid = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
33
authors = ["ITensor developers <[email protected]> and contributors"]
44
version = "0.1.0"
55

6+
[weakdeps]
7+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
8+
9+
[extensions]
10+
TensorProductsBlockArraysExt = "BlockArrays"
11+
612
[compat]
13+
BlockArrays = "1.2.0"
714
julia = "1.10"

docs/make.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
using TensorProducts: TensorProducts
22
using Documenter: Documenter, DocMeta, deploydocs, makedocs
33

4-
DocMeta.setdocmeta!(
5-
TensorProducts, :DocTestSetup, :(using TensorProducts); recursive=true
6-
)
4+
DocMeta.setdocmeta!(TensorProducts, :DocTestSetup, :(using TensorProducts); recursive=true)
75

86
include("make_index.jl")
97

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module TensorProductsBlockArraysExt
2+
3+
using BlockArrays:
4+
AbstractBlockedUnitRange,
5+
Block,
6+
BlockArrays,
7+
blockaxes,
8+
blockedrange,
9+
blocklengths,
10+
blocks
11+
12+
using TensorProducts: OneToOne, TensorProducts
13+
14+
function TensorProducts.tensor_product(
15+
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
16+
)
17+
new_blocklengths = mapreduce(vcat, Iterators.product(blocks(a1), blocks(a2))) do (x, y)
18+
return length(x) * length(y)
19+
end
20+
return blockedrange(new_blocklengths)
21+
end
22+
23+
end

src/TensorProducts.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module TensorProducts
22

3-
# Write your package code here.
3+
export , OneToOne, tensor_product
4+
5+
include("onetoone.jl")
6+
include("tensor_product.jl")
47

58
end

src/onetoone.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# This files defines the struct OneToOne
2+
# OneToOne represents the range `1:1` or `Base.OneTo(1)`.
3+
4+
struct OneToOne{T} <: AbstractUnitRange{T} end
5+
OneToOne() = OneToOne{Int}()
6+
Base.first(a::OneToOne) = one(eltype(a))
7+
Base.last(a::OneToOne) = one(eltype(a))

src/tensor_product.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# This files defines an interface for the tensor product of two axes
2+
# https://en.wikipedia.org/wiki/Tensor_product
3+
4+
# ================================== misc ================================================
5+
is_offset_axis(a::AbstractUnitRange) = !isone(first(a))
6+
7+
function require_one_based_axis(a::AbstractUnitRange)
8+
return is_offset_axis(a) && throw(ArgumentError("Range must be one-based"))
9+
end
10+
11+
# ============================== tensor product ==========================================
12+
() = tensor_product()
13+
(a) = tensor_product(a)
14+
15+
# default. No type restriction to allow sectors as input
16+
(a1, a2) = tensor_product(a1, a2)
17+
18+
# allow to specialize ⊗(a1, a2) to fusion_product
19+
(a1, a2, as...) = ((a1, a2), as...)
20+
21+
tensor_product() = OneToOne()
22+
tensor_product(a) = a
23+
tensor_product(a1, a2, as...) = tensor_product(tensor_product(a1, a2), as...)
24+
25+
# default
26+
function tensor_product(a1::AbstractUnitRange, a2::AbstractUnitRange)
27+
require_one_based_axis(a1) || require_one_based_axis(a2)
28+
return Base.OneTo(length(a1) * length(a2))
29+
end
30+
31+
tensor_product(::OneToOne, ::OneToOne) = OneToOne()

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
34
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
45
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
56
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
67

78
[compat]
89
Aqua = "0.8.9"
10+
BlockArrays = "1.2.0"
911
SafeTestsets = "0.1"
1012
Suppressor = "0.2"
1113
Test = "1.10"

test/basics/test_basics.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

test/test_basics.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Test: @test, @testset
2+
3+
using BlockArrays: BlockRange, blockaxes
4+
5+
using TensorProducts: OneToOne
6+
7+
@testset "OneToOne" begin
8+
a0 = OneToOne()
9+
@test a0 isa OneToOne{Int}
10+
@test a0 isa AbstractUnitRange{Int}
11+
@test eltype(a0) == Int
12+
@test length(a0) == 1
13+
14+
@test blockaxes(OneToOne()) == (BlockRange(OneToOne()),)
15+
end

test/test_exports.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
using Test: @test, @testset
2+
3+
using TensorProducts: TensorProducts
4+
5+
@testset "Test exports" begin
6+
exports = [:, :TensorProducts, :OneToOne, :tensor_product]
7+
@test issetequal(names(TensorProducts), exports)
8+
end

test/test_tensor_product.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using Test: @test, @test_throws, @testset
2+
3+
using TensorProducts: , OneToOne, tensor_product
4+
5+
using BlockArrays: blockedrange, blockisequal
6+
7+
r0 = OneToOne()
8+
b1 = blockedrange([1, 2])
9+
10+
@testset "" begin
11+
@test () isa OneToOne
12+
@test (1:2) == 1:2
13+
@test (1:2, 1:3) == 1:6
14+
@test (1:2, 1:3, 1:4) == 1:24
15+
16+
@test (r0, r0) isa OneToOne
17+
@test blockisequal((b1, b1), blockedrange([1, 2, 2, 4]))
18+
end
19+
20+
@testset "tensor_product" begin
21+
@test tensor_product() isa OneToOne
22+
@test tensor_product(1:2) == 1:2
23+
@test tensor_product(1:2, 1:3) == 1:6
24+
@test tensor_product(1:2, 1:3, 1:4) == 1:24
25+
26+
@test_throws ArgumentError tensor_product(2:3, 1:2)
27+
@test_throws ArgumentError tensor_product(1:3, 2:2)
28+
@test_throws ArgumentError tensor_product(2:3, 2:2)
29+
30+
@test tensor_product(r0, r0) isa OneToOne
31+
@test blockisequal(tensor_product(b1, b1), blockedrange([1, 2, 2, 4]))
32+
end

0 commit comments

Comments
 (0)