From 0d031c45fdff73ba7a6f8737c5537b1af90d0de0 Mon Sep 17 00:00:00 2001
From: mtfishman <mfishman@flatironinstitute.org>
Date: Wed, 15 Jan 2025 10:37:30 -0500
Subject: [PATCH 1/3] Fix scalar contraction

---
 src/contract/allocate_output.jl             | 55 +++++++++++--
 src/contract/contract_matricize/contract.jl | 39 +++++++++
 test/test_basics.jl                         | 91 +++++++++++++++++----
 3 files changed, 165 insertions(+), 20 deletions(-)

diff --git a/src/contract/allocate_output.jl b/src/contract/allocate_output.jl
index 3d9efa3..610ff87 100644
--- a/src/contract/allocate_output.jl
+++ b/src/contract/allocate_output.jl
@@ -9,7 +9,7 @@ function output_axes(
   biperm1::BlockedPermutation{2},
   a2::AbstractArray,
   biperm2::BlockedPermutation{2},
-  α::Number=true,
+  α::Number=one(Bool),
 )
   axes_codomain, axes_contracted = blockpermute(axes(a1), biperm1)
   axes_contracted2, axes_domain = blockpermute(axes(a2), biperm2)
@@ -27,7 +27,7 @@ function output_axes(
   perm1::BlockedPermutation{1},
   a2::AbstractArray,
   perm2::BlockedPermutation{1},
-  α::Number=true,
+  α::Number=one(Bool),
 )
   axes_contracted = blockpermute(axes(a1), perm1)
   axes_contracted′ = blockpermute(axes(a2), perm2)
@@ -43,7 +43,7 @@ function output_axes(
   perm1::BlockedPermutation{1},
   a2::AbstractArray,
   biperm2::BlockedPermutation{2},
-  α::Number=true,
+  α::Number=one(Bool),
 )
   (axes_contracted,) = blockpermute(axes(a1), perm1)
   axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2)
@@ -59,7 +59,7 @@ function output_axes(
   perm1::BlockedPermutation{2},
   a2::AbstractArray,
   biperm2::BlockedPermutation{1},
-  α::Number=true,
+  α::Number=one(Bool),
 )
   axes_dest, axes_contracted = blockpermute(axes(a1), perm1)
   (axes_contracted′,) = blockpermute(axes(a2), biperm2)
@@ -75,7 +75,7 @@ function output_axes(
   perm1::BlockedPermutation{1},
   a2::AbstractArray,
   perm2::BlockedPermutation{1},
-  α::Number=true,
+  α::Number=one(Bool),
 )
   @assert istrivialperm(Tuple(perm1))
   @assert istrivialperm(Tuple(perm2))
@@ -83,6 +83,49 @@ function output_axes(
   return genperm(axes_dest, invperm(Tuple(biperm_dest)))
 end
 
+# Array-scalar contraction.
+function output_axes(
+  ::typeof(contract),
+  perm_dest::BlockedPermutation{1},
+  a1::AbstractArray,
+  perm1::BlockedPermutation{1},
+  a2::AbstractArray,
+  perm2::BlockedPermutation{0},
+  α::Number=one(Bool),
+)
+  @assert istrivialperm(Tuple(perm1))
+  axes_dest = axes(a1)
+  return genperm(axes_dest, invperm(Tuple(perm_dest)))
+end
+
+# Scalar-array contraction.
+function output_axes(
+  ::typeof(contract),
+  perm_dest::BlockedPermutation{1},
+  a1::AbstractArray,
+  perm1::BlockedPermutation{0},
+  a2::AbstractArray,
+  perm2::BlockedPermutation{1},
+  α::Number=one(Bool),
+)
+  @assert istrivialperm(Tuple(perm2))
+  axes_dest = axes(a2)
+  return genperm(axes_dest, invperm(Tuple(perm_dest)))
+end
+
+# Scalar-scalar contraction.
+function output_axes(
+  ::typeof(contract),
+  perm_dest::BlockedPermutation{0},
+  a1::AbstractArray,
+  perm1::BlockedPermutation{0},
+  a2::AbstractArray,
+  perm2::BlockedPermutation{0},
+  α::Number=one(Bool),
+)
+  return ()
+end
+
 # TODO: Use `ArrayLayouts`-like `MulAdd` object,
 # i.e. `ContractAdd`?
 function allocate_output(
@@ -92,7 +135,7 @@ function allocate_output(
   biperm1::BlockedPermutation,
   a2::AbstractArray,
   biperm2::BlockedPermutation,
-  α::Number=true,
+  α::Number=one(Bool),
 )
   axes_dest = output_axes(contract, biperm_dest, a1, biperm1, a2, biperm2, α)
   return similar(a1, promote_type(eltype(a1), eltype(a2), typeof(α)), axes_dest)
diff --git a/src/contract/contract_matricize/contract.jl b/src/contract/contract_matricize/contract.jl
index 5fffd2b..1750ae2 100644
--- a/src/contract/contract_matricize/contract.jl
+++ b/src/contract/contract_matricize/contract.jl
@@ -62,3 +62,42 @@ function _mul!(
   mul!(a_dest, a1, transpose(a2), α, β)
   return a_dest
 end
+
+# Array-scalar contraction.
+function _mul!(
+  a_dest::AbstractVector,
+  a1::AbstractVector,
+  a2::AbstractArray{<:Any,0},
+  α::Number,
+  β::Number,
+)
+  α′ = a2[] * α
+  a_dest .= a1 .* α′ .+ a_dest .* β
+  return a_dest
+end
+
+# Scalar-array contraction.
+function _mul!(
+  a_dest::AbstractVector,
+  a1::AbstractArray{<:Any,0},
+  a2::AbstractVector,
+  α::Number,
+  β::Number,
+)
+  # Preserve the ordering in case of non-commutative algebra.
+  a_dest .= a1[] .* a2 .* α .+ a_dest .* β
+  return a_dest
+end
+
+# Scalar-scalar contraction.
+function _mul!(
+  a_dest::AbstractArray{<:Any,0},
+  a1::AbstractArray{<:Any,0},
+  a2::AbstractArray{<:Any,0},
+  α::Number,
+  β::Number,
+)
+  # Preserve the ordering in case of non-commutative algebra.
+  a_dest[] = a1[] * a2[] * α + a_dest[] * β
+  return a_dest
+end
diff --git a/test/test_basics.jl b/test/test_basics.jl
index 09fb26a..dde9bcc 100644
--- a/test/test_basics.jl
+++ b/test/test_basics.jl
@@ -1,8 +1,11 @@
 using EllipsisNotation: var".."
 using LinearAlgebra: norm, qr
-using TensorAlgebra: TensorAlgebra, fusedims, splitdims
-default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
+using StableRNGs: StableRNG
+using TensorAlgebra: contract, contract!, fusedims, splitdims
+using TensorOperations: TensorOperations
 using Test: @test, @test_broken, @testset
+
+default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
 const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
 
 @testset "TensorAlgebra" begin
@@ -90,14 +93,14 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
       labels_dest = map(i -> labels[i], d_dests)
 
       # Don't specify destination labels
-      a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2)
+      a_dest, labels_dest′ = contract(a1, labels1, a2, labels2)
       a_dest_tensoroperations = TensorOperations.tensorcontract(
         labels_dest′, a1, labels1, a2, labels2
       )
       @test a_dest ≈ a_dest_tensoroperations
 
       # Specify destination labels
-      a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
+      a_dest = contract(labels_dest, a1, labels1, a2, labels2)
       a_dest_tensoroperations = TensorOperations.tensorcontract(
         labels_dest, a1, labels1, a2, labels2
       )
@@ -111,7 +114,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
       β = elt_dest(2.4) # randn(elt_dest)
       a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
       a_dest = copy(a_dest_init)
-      TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
+      contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
       a_dest_tensoroperations = TensorOperations.tensorcontract(
         labels_dest, a1, labels1, a2, labels2
       )
@@ -124,28 +127,90 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
   @testset "outer product contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
     elt2 in elts
 
-    a1 = randn(elt1, 2, 3)
-    a2 = randn(elt2, 4, 5)
-
     elt_dest = promote_type(elt1, elt2)
 
-    a_dest, labels = TensorAlgebra.contract(a1, ("i", "j"), a2, ("k", "l"))
+    rng = StableRNG(123)
+    a1 = randn(rng, elt1, 2, 3)
+    a2 = randn(rng, elt2, 4, 5)
+
+    a_dest, labels = contract(a1, ("i", "j"), a2, ("k", "l"))
     @test labels == ("i", "j", "k", "l")
     @test eltype(a_dest) === elt_dest
     @test a_dest ≈ reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...))
 
-    a_dest = TensorAlgebra.contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
+    a_dest = contract(("i", "k", "j", "l"), a1, ("i", "j"), a2, ("k", "l"))
     @test eltype(a_dest) === elt_dest
     @test a_dest ≈ permutedims(
       reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 3, 2, 4)
     )
 
     a_dest = zeros(elt_dest, 2, 5, 3, 4)
-    TensorAlgebra.contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
+    contract!(a_dest, ("i", "l", "j", "k"), a1, ("i", "j"), a2, ("k", "l"))
     @test a_dest ≈ permutedims(
       reshape(vec(a1) * transpose(vec(a2)), (size(a1)..., size(a2)...)), (1, 4, 2, 3)
     )
   end
+  @testset "scalar contraction (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts,
+    elt2 in elts
+
+    elt_dest = promote_type(elt1, elt2)
+
+    rng = StableRNG(123)
+    a = randn(rng, elt1, (2, 3, 4, 5))
+    s = randn(rng, elt2, ())
+    t = randn(rng, elt2, ())
+
+    labels_a = ("i", "j", "k", "l")
+
+    # Array-scalar contraction.
+    a_dest, labels_dest = contract(a, labels_a, s, ())
+    @test labels_dest == labels_a
+    @test a_dest ≈ a * s[]
+
+    # Scalar-array contraction.
+    a_dest, labels_dest = contract(s, (), a, labels_a)
+    @test labels_dest == labels_a
+    @test a_dest ≈ a * s[]
+
+    # Scalar-scalar contraction.
+    a_dest, labels_dest = contract(s, (), t, ())
+    @test labels_dest == ()
+    @test a_dest[] ≈ s[] * t[]
+
+    # Specify output labels.
+    labels_dest_example = ("j", "l", "i", "k")
+    size_dest_example = (3, 5, 2, 4)
+
+    # Array-scalar contraction.
+    a_dest = contract(labels_dest_example, a, labels_a, s, ())
+    @test size(a_dest) == size_dest_example
+    @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[]
+
+    # Scalar-array contraction.
+    a_dest = contract(labels_dest_example, s, (), a, labels_a)
+    @test size(a_dest) == size_dest_example
+    @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[]
+
+    # Scalar-scalar contraction.
+    a_dest = contract((), s, (), t, ())
+    @test size(a_dest) == ()
+    @test a_dest[] ≈ s[] * t[]
+
+    # Array-scalar contraction.
+    a_dest = zeros(elt_dest, size_dest_example)
+    contract!(a_dest, labels_dest_example, a, labels_a, s, ())
+    @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[]
+
+    # Scalar-array contraction.
+    a_dest = zeros(elt_dest, size_dest_example)
+    contract!(a_dest, labels_dest_example, s, (), a, labels_a)
+    @test a_dest ≈ permutedims(a, (2, 4, 1, 3)) * s[]
+
+    # Scalar-scalar contraction.
+    a_dest = zeros(elt_dest, ())
+    contract!(a_dest, (), s, (), t, ())
+    @test a_dest[] ≈ s[] * t[]
+  end
 end
 @testset "qr (eltype=$elt)" for elt in elts
   a = randn(elt, 5, 4, 3, 2)
@@ -154,8 +219,6 @@ end
   labels_r = (:d, :c)
   q, r = qr(a, labels_a, labels_q, labels_r)
   label_qr = :qr
-  a′ = TensorAlgebra.contract(
-    labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)
-  )
+  a′ = contract(labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...))
   @test a ≈ a′
 end

From b996497d2ba05083e3ed5bd04f390519616d1cf4 Mon Sep 17 00:00:00 2001
From: mtfishman <mfishman@flatironinstitute.org>
Date: Wed, 15 Jan 2025 10:38:58 -0500
Subject: [PATCH 2/3] Add missing test dep

---
 test/Project.toml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/test/Project.toml b/test/Project.toml
index 60a147b..09a07c1 100644
--- a/test/Project.toml
+++ b/test/Project.toml
@@ -8,6 +8,7 @@ LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
 SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
+StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
 Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
 TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
 TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

From 0061180f7f5373c075573a8ef07c97c5f23d1701 Mon Sep 17 00:00:00 2001
From: mtfishman <mfishman@flatironinstitute.org>
Date: Wed, 15 Jan 2025 11:48:10 -0500
Subject: [PATCH 3/3] Bump version

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 19abea2..cf45b57 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
 name = "TensorAlgebra"
 uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
 authors = ["ITensor developers <support@itensor.org> and contributors"]
-version = "0.1.4"
+version = "0.1.5"
 
 [deps]
 ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"