Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.4.0"
version = "0.4.1"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
1 change: 1 addition & 0 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ include("contract/blockedperms.jl")
include("contract/allocate_output.jl")
include("contract/contract_matricize/contract.jl")
include("factorizations.jl")
include("matrixfunctions.jl")

end
46 changes: 46 additions & 0 deletions src/matrixfunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# TensorAlgebra version of matrix functions.
const MATRIX_FUNCTIONS = [
:exp,
:cis,
:log,
:sqrt,
:cbrt,
:cos,
:sin,
:tan,
:csc,
:sec,
:cot,
:cosh,
:sinh,
:tanh,
:csch,
:sech,
:coth,
:acos,
:asin,
:atan,
:acsc,
:asec,
:acot,
:acosh,
:asinh,
:atanh,
:acsch,
:asech,
:acoth,
]

for f in MATRIX_FUNCTIONS
@eval begin
function $f(a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...)
return $f(a, biperm; kwargs...)
end
function $f(a::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...)
a_mat = matricize(a, biperm)
fa_mat = Base.$f(a_mat; kwargs...)
return unmatricize(fa_mat, axes(a)[biperm])
end
end
end
6 changes: 2 additions & 4 deletions test/test_factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
using LinearAlgebra: LinearAlgebra, norm, diag
using Test: @test, @testset

using TestExtras: @constinferred

using MatrixAlgebraKit: truncrank
using TensorAlgebra:
contract,
Expand All @@ -21,6 +17,8 @@ using TensorAlgebra:
right_polar,
svd,
svdvals
using Test: @test, @testset
using TestExtras: @constinferred

elts = (Float64, ComplexF64)

Expand Down
21 changes: 21 additions & 0 deletions test/test_matrixfunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using StableRNGs: StableRNG
using TensorAlgebra: TensorAlgebra, biperm
using Test: @test, @testset

@testset "Matrix functions (eltype=$elt)" for elt in (Float32, ComplexF64)
for f in TensorAlgebra.MATRIX_FUNCTIONS
f == :cbrt && elt <: Complex && continue
f == :cbrt && VERSION < v"1.11-" && continue
@eval begin
rng = StableRNG(123)
a = randn(rng, $elt, (2, 2, 2, 2))
for fa in (
TensorAlgebra.$f(a, (:a, :b, :c, :d), (:c, :b), (:d, :a)),
TensorAlgebra.$f(a, biperm((3, 2, 4, 1), Val(2))),
)
fa′ = reshape($f(reshape(permutedims(a, (3, 2, 4, 1)), (4, 4))), (2, 2, 2, 2))
@test fa ≈ fa′
end
end
end
end
Loading