Skip to content

Define BlockedTuple #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 10, 2025
1 change: 1 addition & 0 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TensorAlgebra
include("blockedpermutation.jl")
include("BaseExtensions/BaseExtensions.jl")
include("blockedtuple.jl")
include("fusedims.jl")
include("splitdims.jl")
include("contract/contract.jl")
Expand Down
76 changes: 76 additions & 0 deletions src/blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# This file defines BlockedTuple, a Tuple of heterogeneous Tuple with a BlockArrays.jl
# like interface

using BlockArrays: Block, BlockArrays, BlockIndexRange, BlockRange, blockedrange

struct BlockedTuple{Divs,Flat}
flat::Flat

function BlockedTuple{Divs}(flat::Tuple) where {Divs}
length(flat) != sum(Divs) && throw(DimensionMismatch("Invalid total length"))
return new{Divs,typeof(flat)}(flat)

Check warning on line 11 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L9-L11

Added lines #L9 - L11 were not covered by tests
end
end

# TensorAlgebra Interface
BlockedTuple(tt::Vararg{Tuple}) = BlockedTuple{length.(tt)}(flatten_tuples(tt))
BlockedTuple(bt::BlockedTuple) = bt
flatten_tuples(bt::BlockedTuple) = Tuple(bt)

Check warning on line 18 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L16-L18

Added lines #L16 - L18 were not covered by tests

# Base interface
Base.Tuple(bt::BlockedTuple) = bt.flat

Check warning on line 21 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L21

Added line #L21 was not covered by tests

Base.axes(bt::BlockedTuple) = (blockedrange([blocklengths(bt)...]),)

Check warning on line 23 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L23

Added line #L23 was not covered by tests

Base.broadcastable(bt::BlockedTuple) = bt

Check warning on line 25 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L25

Added line #L25 was not covered by tests
struct BlockedTupleBroadcastStyle{Divs} <: Broadcast.BroadcastStyle end
function Base.BroadcastStyle(::Type{<:BlockedTuple{Divs}}) where {Divs}
return BlockedTupleBroadcastStyle{Divs}()

Check warning on line 28 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L27-L28

Added lines #L27 - L28 were not covered by tests
end
function Base.BroadcastStyle(::BlockedTupleBroadcastStyle, ::BlockedTupleBroadcastStyle)
throw(DimensionMismatch("Incompatible blocks"))

Check warning on line 31 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L30-L31

Added lines #L30 - L31 were not covered by tests
end
# BroadcastStyle is not called for two identical styles
function Base.copy(bc::Broadcast.Broadcasted{BlockedTupleBroadcastStyle{Divs}}) where {Divs}
return BlockedTuple{Divs}(bc.f.((Tuple.(bc.args))...))

Check warning on line 35 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
end

Base.copy(bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(copy.(Tuple(bt)))

Check warning on line 38 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L38

Added line #L38 was not covered by tests

Base.deepcopy(bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(deepcopy.(Tuple(bt)))

Check warning on line 40 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L40

Added line #L40 was not covered by tests

Base.firstindex(::BlockedTuple) = 1

Check warning on line 42 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L42

Added line #L42 was not covered by tests

Base.getindex(bt::BlockedTuple, i::Integer) = Tuple(bt)[i]
Base.getindex(bt::BlockedTuple, r::AbstractUnitRange) = Tuple(bt)[r]
Base.getindex(bt::BlockedTuple, b::Block{1}) = blocks(bt)[Int(b)]
Base.getindex(bt::BlockedTuple, br::BlockRange{1}) = blocks(bt)[Int.(br)]
Base.getindex(bt::BlockedTuple, bi::BlockIndexRange{1}) = bt[Block(bi)][only(bi.indices)]

Check warning on line 48 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L44-L48

Added lines #L44 - L48 were not covered by tests

Base.iterate(bt::BlockedTuple) = iterate(Tuple(bt))
Base.iterate(bt::BlockedTuple, i::Int) = iterate(Tuple(bt), i)

Check warning on line 51 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L50-L51

Added lines #L50 - L51 were not covered by tests

Base.lastindex(bt::BlockedTuple) = length(bt)

Check warning on line 53 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L53

Added line #L53 was not covered by tests

Base.length(bt::BlockedTuple) = length(Tuple(bt))

Check warning on line 55 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L55

Added line #L55 was not covered by tests

Base.map(f, bt::BlockedTuple) = BlockedTuple{blocklengths(bt)}(map(f, Tuple(bt)))

Check warning on line 57 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L57

Added line #L57 was not covered by tests

# BlockArrays interface
function BlockArrays.blockfirsts(bt::BlockedTuple)
return (0, cumsum(blocklengths(bt)[begin:(end - 1)])...) .+ 1

Check warning on line 61 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L60-L61

Added lines #L60 - L61 were not covered by tests
end

function BlockArrays.blocklasts(bt::BlockedTuple)
return cumsum(blocklengths(bt)[begin:end])

Check warning on line 65 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L64-L65

Added lines #L64 - L65 were not covered by tests
end

BlockArrays.blocklength(::BlockedTuple{Divs}) where {Divs} = length(Divs)

Check warning on line 68 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L68

Added line #L68 was not covered by tests

BlockArrays.blocklengths(::BlockedTuple{Divs}) where {Divs} = Divs

Check warning on line 70 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L70

Added line #L70 was not covered by tests

function BlockArrays.blocks(bt::BlockedTuple)
bf = blockfirsts(bt)
bl = blocklasts(bt)
return ntuple(i -> Tuple(bt)[bf[i]:bl[i]], blocklength(bt))

Check warning on line 75 in src/blockedtuple.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedtuple.jl#L72-L75

Added lines #L72 - L75 were not covered by tests
end
52 changes: 52 additions & 0 deletions test/test_blockedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using Test: @test, @test_throws

using BlockArrays: Block, blocklength, blocklengths, blockedrange, blockisequal, blocks

using TensorAlgebra: BlockedTuple, flatten_tuples

@testset "BlockedTuple" begin
flat = (1, 'a', 2, 'b', 3)
divs = (1, 2, 2)

bt = BlockedTuple{divs}(flat)

@test Tuple(bt) == flat
@test flatten_tuples(bt) == flat
@test bt == BlockedTuple((1,), ('a', 2), ('b', 3))
@test BlockedTuple(bt) == bt
@test blocklength(bt) == 3
@test blocklengths(bt) == (1, 2, 2)
@test blocks(bt) == ((1,), ('a', 2), ('b', 3))

@test bt[1] == 1
@test bt[2] == 'a'
@test bt[Block(1)] == blocks(bt)[1]
@test bt[Block(2)] == blocks(bt)[2]
@test bt[Block(1):Block(2)] == blocks(bt)[1:2]
@test bt[Block(2)[1:2]] == ('a', 2)
@test bt[2:4] == ('a', 2, 'b')

@test firstindex(bt) == 1
@test lastindex(bt) == 5
@test length(bt) == 5

@test iterate(bt) == (1, 2)
@test iterate(bt, 2) == ('a', 3)
@test blockisequal(only(axes(bt)), blockedrange([1, 2, 2]))

@test_throws DimensionMismatch BlockedTuple{(1, 2, 3)}(flat)

bt = BlockedTuple((1,), (4, 2), (5, 3))
@test Tuple(bt) == (1, 4, 2, 5, 3)
@test blocklengths(bt) == (1, 2, 2)
@test copy(bt) == bt
@test deepcopy(bt) == bt

@test map(n -> n + 1, bt) == BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
@test bt .+ BlockedTuple((1,), (1, 1), (1, 1)) ==
BlockedTuple{blocklengths(bt)}(Tuple(bt) .+ 1)
@test_throws DimensionMismatch bt .+ BlockedTuple((1, 1), (1, 1), (1,))

bt = BlockedTuple((1:2, 1:2), (1:3,))
@test length.(bt) == BlockedTuple((2, 2), (3,))
end
Loading