From 0ff0ed3b769ffd1ab3e5d3963ab90087671bd330 Mon Sep 17 00:00:00 2001
From: mtfishman <mfishman@flatironinstitute.org>
Date: Fri, 17 Jan 2025 15:13:53 -0500
Subject: [PATCH 1/2] Define TensorAlgebra.svd, change qr namespace

---
 Project.toml                                  |  2 +-
 .../LinearAlgebraExtensions.jl                |  3 -
 src/LinearAlgebraExtensions/qr.jl             | 69 -------------------
 src/TensorAlgebra.jl                          |  3 +-
 src/factorizations.jl                         | 45 ++++++++++++
 test/test_basics.jl                           | 17 ++++-
 6 files changed, 62 insertions(+), 77 deletions(-)
 delete mode 100644 src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl
 delete mode 100644 src/LinearAlgebraExtensions/qr.jl
 create mode 100644 src/factorizations.jl

diff --git a/Project.toml b/Project.toml
index 7aad76b..f8700a8 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
 name = "TensorAlgebra"
 uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
 authors = ["ITensor developers <support@itensor.org> and contributors"]
-version = "0.1.6"
+version = "0.1.7"
 
 [deps]
 ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
diff --git a/src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl b/src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl
deleted file mode 100644
index 471f2bd..0000000
--- a/src/LinearAlgebraExtensions/LinearAlgebraExtensions.jl
+++ /dev/null
@@ -1,3 +0,0 @@
-module LinearAlgebraExtensions
-include("qr.jl")
-end
diff --git a/src/LinearAlgebraExtensions/qr.jl b/src/LinearAlgebraExtensions/qr.jl
deleted file mode 100644
index 903efa9..0000000
--- a/src/LinearAlgebraExtensions/qr.jl
+++ /dev/null
@@ -1,69 +0,0 @@
-using ArrayLayouts: LayoutMatrix
-using LinearAlgebra: LinearAlgebra, qr
-using ..TensorAlgebra:
-  TensorAlgebra,
-  BlockedPermutation,
-  blockedperm,
-  blockedperm_indexin,
-  blockpermute,
-  fusedims,
-  splitdims
-
-# TODO: Define as `tensor_qr`.
-# TODO: This look generic but doesn't work for `BlockSparseArrays`.
-function _qr(a::AbstractArray, biperm::BlockedPermutation{2})
-  a_matricized = fusedims(a, biperm)
-
-  # TODO: Make this more generic, allow choosing thin or full,
-  # make sure this works on GPU.
-  q_matricized, r_matricized = qr(a_matricized)
-  q_matricized_thin = typeof(a_matricized)(q_matricized)
-
-  axes_codomain, axes_domain = blockpermute(axes(a), biperm)
-  axes_q = (axes_codomain..., axes(q_matricized_thin, 2))
-  # TODO: Use `tuple_oneto(n) = ntuple(identity, n)`, currently in `BlockSparseArrays`.
-  biperm_q = blockedperm(
-    ntuple(identity, length(axes_codomain)), (length(axes_codomain) + 1,)
-  )
-  axes_r = (axes(r_matricized, 1), axes_domain...)
-  biperm_r = blockedperm((1,), ntuple(identity, length(axes_domain)) .+ 1)
-  q = splitdims(q_matricized_thin, axes_q)
-  r = splitdims(r_matricized, axes_r)
-  return q, r
-end
-
-function LinearAlgebra.qr(a::AbstractArray, biperm::BlockedPermutation{2})
-  return _qr(a, biperm)
-end
-
-# Fix ambiguity error with `LinearAlgebra`.
-function LinearAlgebra.qr(a::AbstractMatrix, biperm::BlockedPermutation{2})
-  return _qr(a, biperm)
-end
-
-# Fix ambiguity error with `ArrayLayouts`.
-function LinearAlgebra.qr(a::LayoutMatrix, biperm::BlockedPermutation{2})
-  return _qr(a, biperm)
-end
-
-# TODO: Define in terms of an inner function `_qr` or `tensor_qr`.
-# TODO: this is type piracy
-function LinearAlgebra.qr(
-  a::AbstractArray, labels_a::Tuple, labels_q::Tuple, labels_r::Tuple
-)
-  return qr(a, blockedperm_indexin(labels_a, labels_q, labels_r))
-end
-
-# Fix ambiguity error with `LinearAlgebra`.
-function LinearAlgebra.qr(
-  a::AbstractMatrix, labels_a::Tuple, labels_q::Tuple, labels_r::Tuple
-)
-  return qr(a, blockedperm_indexin(labels_a, labels_q, labels_r))
-end
-
-# Fix ambiguity error with `ArrayLayouts`.
-function LinearAlgebra.qr(
-  a::LayoutMatrix, labels_a::Tuple, labels_q::Tuple, labels_r::Tuple
-)
-  return qr(a, blockedperm_indexin(labels_a, labels_q, labels_r))
-end
diff --git a/src/TensorAlgebra.jl b/src/TensorAlgebra.jl
index bc5771a..caa2cc5 100644
--- a/src/TensorAlgebra.jl
+++ b/src/TensorAlgebra.jl
@@ -12,7 +12,6 @@ include("contract/output_labels.jl")
 include("contract/blockedperms.jl")
 include("contract/allocate_output.jl")
 include("contract/contract_matricize/contract.jl")
-# TODO: Rename to `TensorAlgebraLinearAlgebraExt`.
-include("LinearAlgebraExtensions/LinearAlgebraExtensions.jl")
+include("factorizations.jl")
 
 end
diff --git a/src/factorizations.jl b/src/factorizations.jl
new file mode 100644
index 0000000..ae7dd37
--- /dev/null
+++ b/src/factorizations.jl
@@ -0,0 +1,45 @@
+using ArrayLayouts: LayoutMatrix
+using LinearAlgebra: LinearAlgebra, Diagonal
+
+function qr(a::AbstractArray, biperm::BlockedPermutation{2})
+  a_matricized = fusedims(a, biperm)
+  # TODO: Make this more generic, allow choosing thin or full,
+  # make sure this works on GPU.
+  q_fact, r_matricized = LinearAlgebra.qr(a_matricized)
+  q_matricized = typeof(a_matricized)(q_fact)
+  axes_codomain, axes_domain = blockpermute(axes(a), biperm)
+  axes_q = (axes_codomain..., axes(q_matricized, 2))
+  axes_r = (axes(r_matricized, 1), axes_domain...)
+  q = splitdims(q_matricized, axes_q)
+  r = splitdims(r_matricized, axes_r)
+  return q, r
+end
+
+function qr(
+  a::AbstractArray, labels_a, labels_codomain, labels_domain
+)
+  # TODO: Generalize to conversion to `Tuple` isn't needed.
+  return qr(a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)))
+end
+
+function svd(a::AbstractArray, biperm::BlockedPermutation{2})
+  a_matricized = fusedims(a, biperm)
+  usv_matricized = LinearAlgebra.svd(a_matricized)
+  u_matricized = usv_matricized.U
+  s_diag = usv_matricized.S
+  v_matricized = usv_matricized.Vt
+  axes_codomain, axes_domain = blockpermute(axes(a), biperm)
+  axes_u = (axes_codomain..., axes(u_matricized, 2))
+  axes_v = (axes(v_matricized, 1), axes_domain...)
+  u = splitdims(u_matricized, axes_u)
+  # TODO: Use `DiagonalArrays.diagonal` to make it more general.
+  s = Diagonal(s_diag)
+  v = splitdims(v_matricized, axes_v)
+  return u, s, v
+end
+
+function svd(
+  a::AbstractArray, labels_a, labels_codomain, labels_domain
+)
+  return svd(a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)))
+end
diff --git a/test/test_basics.jl b/test/test_basics.jl
index dde9bcc..221274b 100644
--- a/test/test_basics.jl
+++ b/test/test_basics.jl
@@ -1,7 +1,7 @@
 using EllipsisNotation: var".."
-using LinearAlgebra: norm, qr
+using LinearAlgebra: norm
 using StableRNGs: StableRNG
-using TensorAlgebra: contract, contract!, fusedims, splitdims
+using TensorAlgebra: contract, contract!, fusedims, qr, splitdims, svd
 using TensorOperations: TensorOperations
 using Test: @test, @test_broken, @testset
 
@@ -222,3 +222,16 @@ end
   a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...))
   @test a ≈ a′
 end
+@testset "svd (eltype=$elt)" for elt in elts
+  a = randn(elt, 5, 4, 3, 2)
+  labels_a = (:a, :b, :c, :d)
+  labels_u = (:b, :a)
+  labels_v = (:d, :c)
+  u, s, v = svd(a, labels_a, labels_u, labels_v)
+  label_u = :u
+  label_v = :v
+  # TODO: Define multi-arg `contract`?
+  us, labels_us = contract(u, (labels_u..., label_u), s, (label_u, label_v))
+  a′ = contract(labels_a, us, labels_us, v, (label_v, labels_v...))
+  @test a ≈ a′
+end

From 6ddc169d52eb44c366af2ab738f84db8bb2031cc Mon Sep 17 00:00:00 2001
From: mtfishman <mfishman@flatironinstitute.org>
Date: Fri, 17 Jan 2025 15:18:35 -0500
Subject: [PATCH 2/2] Format

---
 src/factorizations.jl | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/src/factorizations.jl b/src/factorizations.jl
index ae7dd37..a017ca1 100644
--- a/src/factorizations.jl
+++ b/src/factorizations.jl
@@ -15,11 +15,11 @@ function qr(a::AbstractArray, biperm::BlockedPermutation{2})
   return q, r
 end
 
-function qr(
-  a::AbstractArray, labels_a, labels_codomain, labels_domain
-)
+function qr(a::AbstractArray, labels_a, labels_codomain, labels_domain)
   # TODO: Generalize to conversion to `Tuple` isn't needed.
-  return qr(a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)))
+  return qr(
+    a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
+  )
 end
 
 function svd(a::AbstractArray, biperm::BlockedPermutation{2})
@@ -38,8 +38,8 @@ function svd(a::AbstractArray, biperm::BlockedPermutation{2})
   return u, s, v
 end
 
-function svd(
-  a::AbstractArray, labels_a, labels_codomain, labels_domain
-)
-  return svd(a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain)))
+function svd(a::AbstractArray, labels_a, labels_codomain, labels_domain)
+  return svd(
+    a, blockedperm_indexin(Tuple(labels_a), Tuple(labels_codomain), Tuple(labels_domain))
+  )
 end