Skip to content

Commit

Permalink
Implement outer product contractions (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Jan 14, 2025
1 parent 5323bc4 commit 840975a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ using BlockArrays:
using EllipsisNotation: Ellipsis, var".."
using TupleTools: TupleTools

trivialperm(len) = ntuple(identity, len)
function istrivialperm(t::Tuple)
return t == trivialperm(length(t))
end

value(::Val{N}) where {N} = N

_flatten_tuples(t::Tuple) = t
Expand Down
16 changes: 16 additions & 0 deletions src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ function output_axes(
return genperm((axes_dest...,), invperm(Tuple(perm_dest)))
end

# Outer product.
function output_axes(
::typeof(contract),
biperm_dest::BlockedPermutation{2},
a1::AbstractArray,
perm1::BlockedPermutation{1},
a2::AbstractArray,
perm2::BlockedPermutation{1},
α::Number=true,
)
@assert istrivialperm(Tuple(perm1))
@assert istrivialperm(Tuple(perm2))
axes_dest = (axes(a1)..., axes(a2)...)
return genperm(axes_dest, invperm(Tuple(biperm_dest)))
end

# TODO: Use `ArrayLayouts`-like `MulAdd` object,
# i.e. `ContractAdd`?
function allocate_output(
Expand Down
8 changes: 8 additions & 0 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,11 @@ function _mul!(
mul!(a_dest, a1, a2, α, β)
return a_dest
end

# Outer product.
function _mul!(
a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, transpose(a2), α, β)
return a_dest
end
25 changes: 25 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,31 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
50 * default_rtol(elt_dest)
end
end
@testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
elt2 in elts

a1 = randn(elt1, 2, 3)
a2 = randn(elt2, 4, 5)

elt_dest = promote_type(elt1, elt2)

a_dest, labels = TensorAlgebra.contract(a1, ("i", "j"), a2, ("k", "l"))
@test labels == ("i", "j", "k", "l")
@test eltype(a_dest) === elt_dest
@test a_dest reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))

a_dest = TensorAlgebra.contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
@test eltype(a_dest) === elt_dest
@test a_dest permutedims(
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4)
)

a_dest = zeros(elt_dest, 2, 5, 3, 4)
TensorAlgebra.contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
@test a_dest permutedims(
reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3)
)
end
end
@testset "qr (eltype=$elt)" for elt in elts
a = randn(elt, 5, 4, 3, 2)
Expand Down

0 comments on commit 840975a

Please sign in to comment.