Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractBlockPermutation <: AbstractBlockTuple #11

Merged
merged 11 commits into from
Jan 14, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.2"
version = "0.1.3"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
2 changes: 1 addition & 1 deletion src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ module TensorAlgebra

export contract, contract!

include("blockedtuple.jl")
include("blockedpermutation.jl")
include("BaseExtensions/BaseExtensions.jl")
include("blockedtuple.jl")
include("fusedims.jl")
include("splitdims.jl")
include("contract/contract.jl")
Expand Down
154 changes: 66 additions & 88 deletions src/blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,82 +12,32 @@
_flatten_tuples() = ()
flatten_tuples(ts::Tuple) = _flatten_tuples(ts...)

_blocklength(blocklengths::Tuple{Vararg{Int}}) = length(blocklengths)
function _blockfirsts(blocklengths::Tuple{Vararg{Int}})
return ntuple(_blocklength(blocklengths)) do i
prev_blocklast =
isone(i) ? zero(eltype(blocklengths)) : _blocklasts(blocklengths)[i - 1]
return prev_blocklast + 1
end
end
_blocklasts(blocklengths::Tuple{Vararg{Int}}) = cumsum(blocklengths)

collect_tuple(x) = (x,)
collect_tuple(x::Ellipsis) = x
collect_tuple(t::Tuple) = t

const TupleOfTuples{N} = Tuple{Vararg{Tuple{Vararg{Int}},N}}

abstract type AbstractBlockedPermutation{BlockLength,Length} end

BlockArrays.blocks(blockedperm::AbstractBlockedPermutation) = error("Not implemented")

function Base.Tuple(blockedperm::AbstractBlockedPermutation)
return flatten_tuples(blocks(blockedperm))
end

function BlockArrays.blocklengths(blockedperm::AbstractBlockedPermutation)
return length.(blocks(blockedperm))
end

function BlockArrays.blockfirsts(blockedperm::AbstractBlockedPermutation)
return _blockfirsts(blocklengths(blockedperm))
end

function BlockArrays.blocklasts(blockedperm::AbstractBlockedPermutation)
return _blocklasts(blocklengths(blockedperm))
end
#
# =============================== AbstractBlockPermutation ===============================
#
abstract type AbstractBlockPermutation{BlockLength} <: AbstractBlockTuple{BlockLength} end

Base.iterate(permblocks::AbstractBlockedPermutation) = iterate(Tuple(permblocks))
function Base.iterate(permblocks::AbstractBlockedPermutation, state)
return iterate(Tuple(permblocks), state)
end
widened_constructorof(::Type{<:AbstractBlockPermutation}) = BlockedTuple

# Block a permutation based on the specified lengths.
# blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1))
# TODO: Optimize with StaticNumbers.jl or generated functions, see:
# https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567
function blockperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}})
starts = _blockfirsts(blocklengths)
stops = _blocklasts(blocklengths)
return blockedperm(ntuple(i -> perm[starts[i]:stops[i]], length(blocklengths))...)
end

function Base.invperm(blockedperm::AbstractBlockedPermutation)
return blockperm(invperm(Tuple(blockedperm)), blocklengths(blockedperm))
return blockedperm(BlockedTuple(perm, blocklengths))

Check warning on line 31 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L31

Added line #L31 was not covered by tests
end

Base.length(blockedperm::AbstractBlockedPermutation) = length(Tuple(blockedperm))
function BlockArrays.blocklength(blockedperm::AbstractBlockedPermutation)
return length(blocks(blockedperm))
function blockperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val)
return blockedperm(BlockedTuple(perm, BlockLengths))
end

function Base.getindex(blockedperm::AbstractBlockedPermutation, i::Int)
return Tuple(blockedperm)[i]
end

function Base.getindex(blockedperm::AbstractBlockedPermutation, I::AbstractUnitRange)
perm = Tuple(blockedperm)
return [perm[i] for i in I]
end

function Base.getindex(blockedperm::AbstractBlockedPermutation, b::Block)
return blocks(blockedperm)[Int(b)]
end

# Like `BlockRange`.
function blockeachindex(blockedperm::AbstractBlockedPermutation)
return ntuple(i -> Block(i), blocklength(blockedperm))
function Base.invperm(blockedperm::AbstractBlockPermutation)
# use Val to preserve compile time info
return blockperm(invperm(Tuple(blockedperm)), Val(blocklengths(blockedperm)))
end

#
Expand All @@ -97,7 +47,7 @@
# Bipartition a vector according to the
# bipartitioned permutation.
# Like `Base.permute!` block out-of-place and blocked.
function blockpermute(v, blockedperm::AbstractBlockedPermutation)
function blockpermute(v, blockedperm::AbstractBlockPermutation)
return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))
end

Expand All @@ -106,8 +56,8 @@
return blockedperm(length, permblocks...)
end

function blockedperm(length::Nothing, permblocks::Tuple{Vararg{Int}}...)
return blockedperm(Val(sum(Base.length, permblocks; init=zero(Bool))), permblocks...)
function blockedperm(::Nothing, permblocks::Tuple{Vararg{Int}}...)
return blockedperm(Val(sum(length, permblocks; init=zero(Bool))), permblocks...)
end

# blockedperm((3, 2), 1) == blockedperm((3, 2), (1,))
Expand All @@ -119,11 +69,15 @@
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
end

function blockedperm(bt::AbstractBlockTuple)
return blockedperm(Val(length(bt)), blocks(bt)...)
end

function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
return maximum(specified_perm)
end

function _blockedperm_length(vallength::Val, specified_perm::Tuple{Vararg{Int}})
function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}})
return value(vallength)
end

Expand All @@ -148,45 +102,69 @@
return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...)
end

struct BlockedPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockLength}} <:
AbstractBlockedPermutation{BlockLength,Length}
blocks::Blocks
global function _BlockedPermutation(blocks::TupleOfTuples)
len = sum(length, blocks; init=zero(Bool))
blocklength = length(blocks)
return new{blocklength,len,typeof(blocks)}(blocks)
#
# ================================== BlockedPermutation ==================================
#

# for dispatch reason, it is convenient to have BlockLength as the first parameter
struct BlockedPermutation{BlockLength,BlockLengths,Flat} <:
AbstractBlockPermutation{BlockLength}
flat::Flat

function BlockedPermutation{BlockLength,BlockLengths}(
flat::Tuple
) where {BlockLength,BlockLengths}
length(flat) != sum(BlockLengths; init=0) &&
throw(DimensionMismatch("Invalid total length"))
length(BlockLengths) != BlockLength &&
throw(DimensionMismatch("Invalid total blocklength"))
any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length"))
return new{BlockLength,BlockLengths,typeof(flat)}(flat)
end
end

BlockArrays.blocks(blockedperm::BlockedPermutation) = getfield(blockedperm, :blocks)
# Base interface
Base.Tuple(blockedperm::BlockedPermutation) = getfield(blockedperm, :flat)

function blockedperm(length::Val, permblocks::Tuple{Vararg{Int}}...)
@assert value(length) == sum(Base.length, permblocks; init=zero(Bool))
blockedperm = _BlockedPermutation(permblocks)
# BlockArrays interface
function BlockArrays.blocklengths(
::Type{<:BlockedPermutation{<:Any,BlockLengths}}
) where {BlockLengths}
return BlockLengths
end

function blockedperm(::Val, permblocks::Tuple{Vararg{Int}}...)
blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}(
flatten_tuples(permblocks)
)
@assert isperm(blockedperm)
return blockedperm
end

#
# ============================== BlockedTrivialPermutation ===============================
#
trivialperm(length::Union{Integer,Val}) = ntuple(identity, length)

struct BlockedTrivialPermutation{BlockLength,Length,Blocks<:TupleOfTuples{BlockLength}} <:
AbstractBlockedPermutation{BlockLength,Length}
blocks::Blocks
global function _BlockedTrivialPermutation(blocklengths::Tuple{Vararg{Int}})
len = sum(blocklengths; init=zero(Bool))
blocklength = length(blocklengths)
permblocks = blocks(blockperm(trivialperm(len), blocklengths))
return new{blocklength,len,typeof(permblocks)}(permblocks)
end
struct BlockedTrivialPermutation{BlockLength,BlockLengths} <:
AbstractBlockPermutation{BlockLength} end

Base.Tuple(blockedperm::BlockedTrivialPermutation) = trivialperm(length(blockedperm))

# BlockArrays interface
function BlockArrays.blocklengths(
::Type{<:BlockedTrivialPermutation{<:Any,BlockLengths}}
) where {BlockLengths}
return BlockLengths
end

BlockArrays.blocks(blockedperm::BlockedTrivialPermutation) = getfield(blockedperm, :blocks)
blockedperm(tp::BlockedTrivialPermutation) = tp

function blockedtrivialperm(blocklengths::Tuple{Vararg{Int}})
return _BlockedTrivialPermutation(blocklengths)
return BlockedTrivialPermutation{length(blocklengths),blocklengths}()
end

function trivialperm(blockedperm::AbstractBlockedPermutation)
function trivialperm(blockedperm::AbstractBlockTuple)
return blockedtrivialperm(blocklengths(blockedperm))
end
Base.invperm(blockedperm::BlockedTrivialPermutation) = blockedperm
70 changes: 50 additions & 20 deletions src/blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl
# like interface
# This file defines an abstract type AbstractBlockTuple and a concrete type BlockedTuple.
# These types allow to store a Tuple of heterogeneous Tuples with a BlockArrays.jl like
# interface.

using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange

Expand All @@ -8,7 +9,17 @@ using TypeParameterAccessors: unspecify_type_parameters
#
# ================================== AbstractBlockTuple ==================================
#
abstract type AbstractBlockTuple end
# AbstractBlockTuple imposes BlockLength as first type parameter for easy dispatch
# it makes no assumotion on storage type
abstract type AbstractBlockTuple{BlockLength} end

constructorof(type::Type{<:AbstractBlockTuple}) = unspecify_type_parameters(type)
widened_constructorof(type::Type{<:AbstractBlockTuple}) = constructorof(type)

# Like `BlockRange`.
function blockeachindex(bt::AbstractBlockTuple)
return ntuple(i -> Block(i), blocklength(bt))
end

# Base interface
Base.axes(bt::AbstractBlockTuple) = (blockedrange([blocklengths(bt)...]),)
Expand All @@ -22,9 +33,8 @@ Base.getindex(bt::AbstractBlockTuple, r::AbstractUnitRange) = Tuple(bt)[r]
Base.getindex(bt::AbstractBlockTuple, b::Block{1}) = blocks(bt)[Int(b)]
function Base.getindex(bt::AbstractBlockTuple, br::BlockRange{1})
r = Int.(br)
T = unspecify_type_parameters(typeof(bt))
flat = Tuple(bt)[blockfirsts(bt)[first(r)]:blocklasts(bt)[last(r)]]
return T{blocklengths(bt)[r]}(flat)
return widened_constructorof(typeof(bt))(flat, blocklengths(bt)[r])
end
function Base.getindex(bt::AbstractBlockTuple, bi::BlockIndexRange{1})
return bt[Block(bi)][only(bi.indices)]
Expand All @@ -33,12 +43,14 @@ end
Base.iterate(bt::AbstractBlockTuple) = iterate(Tuple(bt))
Base.iterate(bt::AbstractBlockTuple, i::Int) = iterate(Tuple(bt), i)

Base.length(bt::AbstractBlockTuple) = length(Tuple(bt))

Base.lastindex(bt::AbstractBlockTuple) = length(bt)

Base.length(bt::AbstractBlockTuple) = sum(blocklengths(bt); init=0)

function Base.map(f, bt::AbstractBlockTuple)
return unspecify_type_parameters(typeof(bt)){blocklengths(bt)}(map(f, Tuple(bt)))
BL = blocklengths(bt)
# use Val to preserve compile time knowledge of BL
return widened_constructorof(typeof(bt))(map(f, Tuple(bt)), Val(BL))
end

# Broadcast interface
Expand All @@ -57,19 +69,20 @@ end
function Base.copy(
bc::Broadcast.Broadcasted{AbstractBlockTupleBroadcastStyle{BlockLengths,BT}}
) where {BlockLengths,BT}
return BT{BlockLengths}(bc.f.((Tuple.(bc.args))...))
return widened_constructorof(BT)(bc.f.((Tuple.(bc.args))...), Val(BlockLengths))
end

# BlockArrays interface
BlockArrays.blockfirsts(::AbstractBlockTuple{0}) = ()
function BlockArrays.blockfirsts(bt::AbstractBlockTuple)
return (0, cumsum(Base.front(blocklengths(bt)))...) .+ 1
end

function BlockArrays.blocklasts(bt::AbstractBlockTuple)
return cumsum(blocklengths(bt)[begin:end])
return cumsum(blocklengths(bt))
end

BlockArrays.blocklength(bt::AbstractBlockTuple) = length(blocklengths(bt))
BlockArrays.blocklength(::AbstractBlockTuple{BlockLength}) where {BlockLength} = BlockLength

BlockArrays.blocklengths(bt::AbstractBlockTuple) = blocklengths(typeof(bt))

Expand All @@ -79,29 +92,46 @@ function BlockArrays.blocks(bt::AbstractBlockTuple)
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))
end

#
# length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength"))

# ===================================== BlockedTuple =====================================
#
struct BlockedTuple{BlockLengths,Flat} <: AbstractBlockTuple
struct BlockedTuple{BlockLength,BlockLengths,Flat} <: AbstractBlockTuple{BlockLength}
flat::Flat

function BlockedTuple{BlockLengths}(flat::Tuple) where {BlockLengths}
length(flat) != sum(BlockLengths) && throw(DimensionMismatch("Invalid total length"))
return new{BlockLengths,typeof(flat)}(flat)
function BlockedTuple{BlockLength,BlockLengths}(
flat::Tuple
) where {BlockLength,BlockLengths}
length(BlockLengths) != BlockLength && throw(DimensionMismatch("Invalid blocklength"))
length(flat) != sum(BlockLengths; init=0) &&
throw(DimensionMismatch("Invalid total length"))
any(BlockLengths .< 0) && throw(DimensionMismatch("Invalid block length"))
return new{BlockLength,BlockLengths,typeof(flat)}(flat)
end
end

# TensorAlgebra Interface
tuplemortar(tt::Tuple{Vararg{Tuple}}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
function tuplemortar(tt::Tuple{Vararg{Tuple}})
return BlockedTuple{length(tt),length.(tt)}(flatten_tuples(tt))
end
function BlockedTuple(flat::Tuple, BlockLengths::Tuple{Vararg{Int}})
return BlockedTuple{BlockLengths}(flat)
return BlockedTuple{length(BlockLengths),BlockLengths}(flat)
end
function BlockedTuple(flat::Tuple, ::Val{BlockLengths}) where {BlockLengths}
# use Val to preserve compile time knowledge of BL
return BlockedTuple{length(BlockLengths),BlockLengths}(flat)
end
function BlockedTuple(bt::AbstractBlockTuple)
bl = blocklengths(bt)
return BlockedTuple{length(bl),bl}(Tuple(bt))
end
BlockedTuple(bt::AbstractBlockTuple) = BlockedTuple{blocklengths(bt)}(Tuple(bt))

# Base interface
Base.Tuple(bt::BlockedTuple) = bt.flat

# BlockArrays interface
function BlockArrays.blocklengths(::Type{<:BlockedTuple{BlockLengths}}) where {BlockLengths}
function BlockArrays.blocklengths(
::Type{<:BlockedTuple{<:Any,BlockLengths}}
) where {BlockLengths}
return BlockLengths
end
4 changes: 2 additions & 2 deletions src/fusedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ function fusedims(a::AbstractArray, permblocks...)
end

function fuseaxes(
axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockedPermutation
axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation
)
axesblocks = blockpermute(axes, blockedperm)
return map(block -> ⊗(block...), axesblocks)
end

function fuseaxes(a::AbstractArray, blockedperm::AbstractBlockedPermutation)
function fuseaxes(a::AbstractArray, blockedperm::AbstractBlockPermutation)
return fuseaxes(axes(a), blockedperm)
end

Expand Down
Loading
Loading