Skip to content

Commit

Permalink
Define TensorAlgebra.svd, define TensorAlgebra.qr rather than ove…
Browse files Browse the repository at this point in the history
…rloading `LinearAlgebra.qr` (#20)
  • Loading branch information
mtfishman authored Jan 17, 2025
1 parent 4453f64 commit 9c547b6
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.6"
version = "0.1.7"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
3 changes: 0 additions & 3 deletions src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl

This file was deleted.

69 changes: 0 additions & 69 deletions src/LinearAlgebraExtensions/qr.jl

This file was deleted.

3 changes: 1 addition & 2 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ include("contract/output_labels.jl")
include("contract/blockedperms.jl")
include("contract/allocate_output.jl")
include("contract/contract_matricize/contract.jl")
# TODO: Rename to `TensorAlgebraLinearAlgebraExt`.
include("LinearAlgebraExtensions/LinearAlgebraExtensions.jl")
include("factorizations.jl")

end
45 changes: 45 additions & 0 deletions src/factorizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using ArrayLayouts: LayoutMatrix
using LinearAlgebra: LinearAlgebra, Diagonal

function qr(a::AbstractArray, biperm::BlockedPermutation{2})
a_matricized = fusedims(a, biperm)
# TODO: Make this more generic, allow choosing thin or full,
# make sure this works on GPU.
q_fact, r_matricized = LinearAlgebra.qr(a_matricized)
q_matricized = typeof(a_matricized)(q_fact)
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
axes_q = (axes_codomain..., axes(q_matricized, 2))
axes_r = (axes(r_matricized, 1), axes_domain...)
q = splitdims(q_matricized, axes_q)
r = splitdims(r_matricized, axes_r)
return q, r
end

function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain)
# TODO: Generalize to conversion to `Tuple` isn't needed.
return qr(
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
)
end

function svd(a::AbstractArray, biperm::BlockedPermutation{2})
a_matricized = fusedims(a, biperm)
usv_matricized = LinearAlgebra.svd(a_matricized)
u_matricized = usv_matricized.U
s_diag = usv_matricized.S
v_matricized = usv_matricized.Vt
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
axes_u = (axes_codomain..., axes(u_matricized, 2))
axes_v = (axes(v_matricized, 1), axes_domain...)
u = splitdims(u_matricized, axes_u)
# TODO: Use `DiagonalArrays.diagonal` to make it more general.
s = Diagonal(s_diag)
v = splitdims(v_matricized, axes_v)
return u, s, v
end

function svd(a::AbstractArray, labels_a, labels_codomain, labels_domain)
return svd(
a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
)
end
17 changes: 15 additions & 2 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using EllipsisNotation: var".."
using LinearAlgebra: norm, qr
using LinearAlgebra: norm
using StableRNGs: StableRNG
using TensorAlgebra: contract, contract!, fusedims, splitdims
using TensorAlgebra: contract, contract!, fusedims, qr, splitdims, svd
using TensorOperations: TensorOperations
using Test: @test, @test_broken, @testset

Expand Down Expand Up @@ -222,3 +222,16 @@ end
a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...))
@test a a′
end
@testset "svd (eltype=$elt)" for elt in elts
a = randn(elt, 5, 4, 3, 2)
labels_a = (:a, :b, :c, :d)
labels_u = (:b, :a)
labels_v = (:d, :c)
u, s, v = svd(a, labels_a, labels_u, labels_v)
label_u = :u
label_v = :v
# TODO: Define multi-arg `contract`?
us, labels_us = contract(u, (labels_u..., label_u), s, (label_u, label_v))
a′ = contract(labels_a, us, labels_us, v, (label_v, labels_v...))
@test a a′
end

0 comments on commit 9c547b6

Please sign in to comment.