From 9c547b6df8992b3116d9b529c57a8e1167ce402d Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Fri, 17 Jan 2025 15:33:58 -0500 Subject: [PATCH] Define `TensorAlgebra.svd`, define `TensorAlgebra.qr` rather than overloading `LinearAlgebra.qr` (#20) --- Project.toml | 2 +- .../LinearAlgebraExtensions.jl | 3 - src/LinearAlgebraExtensions/qr.jl | 69 ------------------- src/TensorAlgebra.jl | 3 +- src/factorizations.jl | 45 ++++++++++++ test/test_basics.jl | 17 ++++- 6 files changed, 62 insertions(+), 77 deletions(-) delete mode 100644 src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl delete mode 100644 src/LinearAlgebraExtensions/qr.jl create mode 100644 src/factorizations.jl diff --git a/Project.toml b/Project.toml index 7aad76b..f8700a8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.6" +version = "0.1.7" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl b/src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl deleted file mode 100644 index 471f2bd..0000000 --- a/src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl +++ /dev/null @@ -1,3 +0,0 @@ -module LinearAlgebraExtensions -include("qr.jl") -end diff --git a/src/LinearAlgebraExtensions/qr.jl b/src/LinearAlgebraExtensions/qr.jl deleted file mode 100644 index 903efa9..0000000 --- a/src/LinearAlgebraExtensions/qr.jl +++ /dev/null @@ -1,69 +0,0 @@ -using ArrayLayouts: LayoutMatrix -using LinearAlgebra: LinearAlgebra, qr -using ..TensorAlgebra: - TensorAlgebra, - BlockedPermutation, - blockedperm, - blockedperm_indexin, - blockpermute, - fusedims, - splitdims - -# TODO: Define as `tensor_qr`. -# TODO: This look generic but doesn't work for `BlockSparseArrays`. -function _qr(a::AbstractArray, biperm::BlockedPermutation{2}) - a_matricized = fusedims(a, biperm) - - # TODO: Make this more generic, allow choosing thin or full, - # make sure this works on GPU. - q_matricized, r_matricized = qr(a_matricized) - q_matricized_thin = typeof(a_matricized)(q_matricized) - - axes_codomain, axes_domain = blockpermute(axes(a), biperm) - axes_q = (axes_codomain..., axes(q_matricized_thin, 2)) - # TODO: Use `tuple_oneto(n) = ntuple(identity, n)`, currently in `BlockSparseArrays`. - biperm_q = blockedperm( - ntuple(identity, length(axes_codomain)), (length(axes_codomain) + 1,) - ) - axes_r = (axes(r_matricized, 1), axes_domain...) - biperm_r = blockedperm((1,), ntuple(identity, length(axes_domain)) .+ 1) - q = splitdims(q_matricized_thin, axes_q) - r = splitdims(r_matricized, axes_r) - return q, r -end - -function LinearAlgebra.qr(a::AbstractArray, biperm::BlockedPermutation{2}) - return _qr(a, biperm) -end - -# Fix ambiguity error with `LinearAlgebra`. -function LinearAlgebra.qr(a::AbstractMatrix, biperm::BlockedPermutation{2}) - return _qr(a, biperm) -end - -# Fix ambiguity error with `ArrayLayouts`. -function LinearAlgebra.qr(a::LayoutMatrix, biperm::BlockedPermutation{2}) - return _qr(a, biperm) -end - -# TODO: Define in terms of an inner function `_qr` or `tensor_qr`. -# TODO: this is type piracy -function LinearAlgebra.qr( - a::AbstractArray, labels_a::Tuple, labels_q::Tuple, labels_r::Tuple -) - return qr(a, blockedperm_indexin(labels_a, labels_q, labels_r)) -end - -# Fix ambiguity error with `LinearAlgebra`. -function LinearAlgebra.qr( - a::AbstractMatrix, labels_a::Tuple, labels_q::Tuple, labels_r::Tuple -) - return qr(a, blockedperm_indexin(labels_a, labels_q, labels_r)) -end - -# Fix ambiguity error with `ArrayLayouts`. -function LinearAlgebra.qr( - a::LayoutMatrix, labels_a::Tuple, labels_q::Tuple, labels_r::Tuple -) - return qr(a, blockedperm_indexin(labels_a, labels_q, labels_r)) -end diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl index bc5771a..caa2cc5 100644 --- a/src/TensorAlgebra.jl +++ b/src/TensorAlgebra.jl @@ -12,7 +12,6 @@ include("contract/output_labels.jl") include("contract/blockedperms.jl") include("contract/allocate_output.jl") include("contract/contract_matricize/contract.jl") -# TODO: Rename to `TensorAlgebraLinearAlgebraExt`. -include("LinearAlgebraExtensions/LinearAlgebraExtensions.jl") +include("factorizations.jl") end diff --git a/src/factorizations.jl b/src/factorizations.jl new file mode 100644 index 0000000..a017ca1 --- /dev/null +++ b/src/factorizations.jl @@ -0,0 +1,45 @@ +using ArrayLayouts: LayoutMatrix +using LinearAlgebra: LinearAlgebra, Diagonal + +function qr(a::AbstractArray, biperm::BlockedPermutation{2}) + a_matricized = fusedims(a, biperm) + # TODO: Make this more generic, allow choosing thin or full, + # make sure this works on GPU. + q_fact, r_matricized = LinearAlgebra.qr(a_matricized) + q_matricized = typeof(a_matricized)(q_fact) + axes_codomain, axes_domain = blockpermute(axes(a), biperm) + axes_q = (axes_codomain..., axes(q_matricized, 2)) + axes_r = (axes(r_matricized, 1), axes_domain...) + q = splitdims(q_matricized, axes_q) + r = splitdims(r_matricized, axes_r) + return q, r +end + +function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain) + # TODO: Generalize to conversion to `Tuple` isn't needed. + return qr( + a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)) + ) +end + +function svd(a::AbstractArray, biperm::BlockedPermutation{2}) + a_matricized = fusedims(a, biperm) + usv_matricized = LinearAlgebra.svd(a_matricized) + u_matricized = usv_matricized.U + s_diag = usv_matricized.S + v_matricized = usv_matricized.Vt + axes_codomain, axes_domain = blockpermute(axes(a), biperm) + axes_u = (axes_codomain..., axes(u_matricized, 2)) + axes_v = (axes(v_matricized, 1), axes_domain...) + u = splitdims(u_matricized, axes_u) + # TODO: Use `DiagonalArrays.diagonal` to make it more general. + s = Diagonal(s_diag) + v = splitdims(v_matricized, axes_v) + return u, s, v +end + +function svd(a::AbstractArray, labels_a, labels_codomain, labels_domain) + return svd( + a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)) + ) +end diff --git a/test/test_basics.jl b/test/test_basics.jl index dde9bcc..221274b 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,7 +1,7 @@ using EllipsisNotation: var".." -using LinearAlgebra: norm, qr +using LinearAlgebra: norm using StableRNGs: StableRNG -using TensorAlgebra: contract, contract!, fusedims, splitdims +using TensorAlgebra: contract, contract!, fusedims, qr, splitdims, svd using TensorOperations: TensorOperations using Test: @test, @test_broken, @testset @@ -222,3 +222,16 @@ end a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)) @test a ≈ a′ end +@testset "svd (eltype=$elt)" for elt in elts + a = randn(elt, 5, 4, 3, 2) + labels_a = (:a, :b, :c, :d) + labels_u = (:b, :a) + labels_v = (:d, :c) + u, s, v = svd(a, labels_a, labels_u, labels_v) + label_u = :u + label_v = :v + # TODO: Define multi-arg `contract`? + us, labels_us = contract(u, (labels_u..., label_u), s, (label_u, label_v)) + a′ = contract(labels_a, us, labels_us, v, (label_v, labels_v...)) + @test a ≈ a′ +end