Skip to content

Commit 49d7c4f

Browse files
authored
Matrix functions (#86)
1 parent fa70da3 commit 49d7c4f

File tree

5 files changed

+71
-5
lines changed

5 files changed

+71
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorAlgebra"
22
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.0"
4+
version = "0.4.1"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/TensorAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ include("contract/blockedperms.jl")
2929
include("contract/allocate_output.jl")
3030
include("contract/contract_matricize/contract.jl")
3131
include("factorizations.jl")
32+
include("matrixfunctions.jl")
3233

3334
end

src/matrixfunctions.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# TensorAlgebra version of matrix functions.
2+
const MATRIX_FUNCTIONS = [
3+
:exp,
4+
:cis,
5+
:log,
6+
:sqrt,
7+
:cbrt,
8+
:cos,
9+
:sin,
10+
:tan,
11+
:csc,
12+
:sec,
13+
:cot,
14+
:cosh,
15+
:sinh,
16+
:tanh,
17+
:csch,
18+
:sech,
19+
:coth,
20+
:acos,
21+
:asin,
22+
:atan,
23+
:acsc,
24+
:asec,
25+
:acot,
26+
:acosh,
27+
:asinh,
28+
:atanh,
29+
:acsch,
30+
:asech,
31+
:acoth,
32+
]
33+
34+
for f in MATRIX_FUNCTIONS
35+
@eval begin
36+
function $f(a::AbstractArray, labels_a, labels_codomain, labels_domain; kwargs...)
37+
biperm = blockedperm_indexin(Tuple.((labels_a, labels_codomain, labels_domain))...)
38+
return $f(a, biperm; kwargs...)
39+
end
40+
function $f(a::AbstractArray, biperm::AbstractBlockPermutation{2}; kwargs...)
41+
a_mat = matricize(a, biperm)
42+
fa_mat = Base.$f(a_mat; kwargs...)
43+
return unmatricize(fa_mat, axes(a)[biperm])
44+
end
45+
end
46+
end

test/test_factorizations.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
using LinearAlgebra: LinearAlgebra, norm, diag
2-
using Test: @test, @testset
3-
4-
using TestExtras: @constinferred
5-
62
using MatrixAlgebraKit: truncrank
73
using TensorAlgebra:
84
contract,
@@ -21,6 +17,8 @@ using TensorAlgebra:
2117
right_polar,
2218
svd,
2319
svdvals
20+
using Test: @test, @testset
21+
using TestExtras: @constinferred
2422

2523
elts = (Float64, ComplexF64)
2624

test/test_matrixfunctions.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using StableRNGs: StableRNG
2+
using TensorAlgebra: TensorAlgebra, biperm
3+
using Test: @test, @testset
4+
5+
@testset "Matrix functions (eltype=$elt)" for elt in (Float32, ComplexF64)
6+
for f in TensorAlgebra.MATRIX_FUNCTIONS
7+
f == :cbrt && elt <: Complex && continue
8+
f == :cbrt && VERSION < v"1.11-" && continue
9+
@eval begin
10+
rng = StableRNG(123)
11+
a = randn(rng, $elt, (2, 2, 2, 2))
12+
for fa in (
13+
TensorAlgebra.$f(a, (:a, :b, :c, :d), (:c, :b), (:d, :a)),
14+
TensorAlgebra.$f(a, biperm((3, 2, 4, 1), Val(2))),
15+
)
16+
fa′ = reshape($f(reshape(permutedims(a, (3, 2, 4, 1)), (4, 4))), (2, 2, 2, 2))
17+
@test fa fa′
18+
end
19+
end
20+
end
21+
end

0 commit comments

Comments
 (0)