Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define BlockedTuple #9

Merged
merged 17 commits into from
Jan 10, 2025
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.0"
version = "0.1.1"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

[weakdeps]
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
Expand All @@ -23,4 +24,5 @@ EllipsisNotation = "1.8.0"
GradedUnitRanges = "0.1.0"
LinearAlgebra = "1.10"
TupleTools = "1.6.0"
TypeParameterAccessors = "0.2.1"
julia = "1.10"
1 change: 1 addition & 0 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TensorAlgebra
include("blockedpermutation.jl")
include("BaseExtensions/BaseExtensions.jl")
include("blockedtuple.jl")
include("fusedims.jl")
include("splitdims.jl")
include("contract/contract.jl")
Expand Down
107 changes: 107 additions & 0 deletions src/blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl
# like interface

using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange

using TypeParameterAccessors: unspecify_type_parameters

#
# ================================== AbstractBlockTuple ==================================
#
abstract type AbstractBlockTuple end

# Base interface
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)

Base.deepcopy(bt::AbstractBlockTuple) = deepcopy.(bt)

Base.firstindex(::AbstractBlockTuple) = 1

Base.getindex(bt::AbstractBlockTuple, i::Integer) = Tuple(bt)[i]
Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r]
Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)]
function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1})
r = Int.(br)
T = unspecify_type_parameters(typeof(bt))
flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]]
return T{blocklengths(bt)[r]}(flat)
end
function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1})
return bt[Block(bi)][only(bi.indices)]
end

Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt))
Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i)

Base.length(bt::AbstractBlockTuple) = length(Tuple(bt))

Base.lastindex(bt::AbstractBlockTuple) = length(bt)

function Base.map(f, bt::AbstractBlockTuple)
return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt)))
end

# Broadcast interface
Base.broadcastable(bt::AbstractBlockTuple) = bt
struct AbstractBlockTupleBroadcastStyle{BlockLengths,BT} <: Broadcast.BroadcastStyle end
function Base.BroadcastStyle(T::Type{<:AbstractBlockTuple})
return AbstractBlockTupleBroadcastStyle{blocklengths(T),unspecify_type_parameters(T)}()
end

# BroadcastStyle is not called for two identical styles
function Base.BroadcastStyle(
::AbstractBlockTupleBroadcastStyle, ::AbstractBlockTupleBroadcastStyle
)
throw(DimensionMismatch("Incompatible blocks"))
end
function Base.copy(
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
) where {BlockLengths,BT}
return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...))
end

# BlockArrays interface
function BlockArrays.blockfirsts(bt::AbstractBlockTuple)
return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1
end

function BlockArrays.blocklasts(bt::AbstractBlockTuple)
return cumsum(blocklengths(bt)[begin:end])
end

BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt))

BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt))

function BlockArrays.blocks(bt::AbstractBlockTuple)
bf = blockfirsts(bt)
bl = blocklasts(bt)
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
end

#
# ===================================== BlockedTuple =====================================
#
struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
flat::Flat

function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths}
length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length"))
return new{BlockLengths,typeof(flat)}(flat)
end
end

# TensorAlgebra Interface
tuplemortar(tt::Tuple{Vararg{Tuple}}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}})
return BlockedTuple{BlockLengths}(flat)
end
BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt))

# Base interface
Base.Tuple(bt::BlockedTuple) = bt.flat

# BlockArrays interface
function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths}
return BlockLengths
end
9 changes: 5 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
Expand All @@ -8,16 +9,16 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
TensorOperations = "4.1.1"
Aqua = "0.8.9"
SafeTestsets = "0.1"
Suppressor = "0.2"
TensorOperations = "5.1.3"
Test = "1.10"
55 changes: 55 additions & 0 deletions test/test_blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Test: @test, @test_throws

using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks
using TestExtras: @constinferred

using TensorAlgebra: BlockedTuple, tuplemortar

@testset "BlockedTuple" begin
flat = (true, 'a', 2, "b", 3.0)
divs = (1, 2, 2)

bt = BlockedTuple{divs}(flat)

@test (@constinferred Tuple(bt)) == flat
@test bt == tuplemortar(((true,), ('a', 2), ("b", 3.0)))
@test bt == BlockedTuple(flat, divs)
@test BlockedTuple(bt) == bt
@test blocklength(bt) == 3
@test blocklengths(bt) == (1, 2, 2)
@test (@constinferred blocks(bt)) == ((true,), ('a', 2), ("b", 3.0))

@test (@constinferred bt[1]) == true
@test (@constinferred bt[2]) == 'a'

# it is hard to make bt[Block(1)] type stable as compile-time knowledge of 1 is lost in Block
@test bt[Block(1)] == blocks(bt)[1]
@test bt[Block(2)] == blocks(bt)[2]
@test bt[Block(1):Block(2)] == tuplemortar(((true,), ('a', 2)))
@test bt[Block(2)[1:2]] == ('a', 2)
@test bt[2:4] == ('a', 2, "b")

@test firstindex(bt) == 1
@test lastindex(bt) == 5
@test length(bt) == 5

@test iterate(bt) == (1, 2)
@test iterate(bt, 2) == ('a', 3)
@test blockisequal(only(axes(bt)), blockedrange([1, 2, 2]))

@test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat)

bt = tuplemortar(((1,), (4, 2), (5, 3)))
@test Tuple(bt) == (1, 4, 2, 5, 3)
@test blocklengths(bt) == (1, 2, 2)
@test deepcopy(bt) == bt

@test (@constinferred map(n -> n + 1, bt)) ==
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
@test bt .+ tuplemortar(((1,), (1, 1), (1, 1))) ==
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
@test_throws DimensionMismatch bt .+ tuplemortar(((1, 1), (1, 1), (1,)))

bt = tuplemortar(((1:2, 1:2), (1:3,)))
@test length.(bt) == tuplemortar(((2, 2), (3,)))
end
Loading