From 9ec21a3565446a3222d6db05a5f505bac44714dd Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sat, 23 Aug 2025 22:37:50 -0400 Subject: [PATCH 01/12] `diag` with a `Val` band index --- src/bidiag.jl | 3 +++ src/dense.jl | 10 ++++++++++ src/diagonal.jl | 1 + src/tridiag.jl | 3 +++ 4 files changed, 17 insertions(+) diff --git a/src/bidiag.jl b/src/bidiag.jl index cc5e6de7..0e36e7da 100644 --- a/src/bidiag.jl +++ b/src/bidiag.jl @@ -418,6 +418,9 @@ function triu!(M::Bidiagonal{T}, k::Integer=0) where T return M end +diag(M::Bidiagonal, ::Val{0}) = M.dv +diag(M::Bidiagonal, ::Val{1}) = M.uplo == 'U' ? M.ev : zero(M.ev) +diag(M::Bidiagonal, ::Val{-1}) = M.uplo == 'L' ? M.ev : zero(M.ev) function diag(M::Bidiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n diff --git a/src/dense.jl b/src/dense.jl index f6f1accb..ce3b4244 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -307,6 +307,16 @@ julia> diag(A,1) """ diag(A::AbstractMatrix, k::Integer=0) = A[diagind(A, k, IndexStyle(A))] +""" + diag(M, ::Val{k}) where {k} + +Return the `k`th diagonal of a matrix as a vector. +For structured matrices such as `Diagonal`, this may return the underlying +band instead of making a copy if `k` lies within the bandwidth of the matrix. +This means that the type of the result may vary depending on the values of `k`. +""" +diag(A::AbstractMatrix, ::Val{N}) where {N} = diag(A, N) + """ diagview(M, k::Integer=0) diff --git a/src/diagonal.jl b/src/diagonal.jl index e60fb009..ef2e7005 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -926,6 +926,7 @@ adjoint(D::Diagonal) = Diagonal(_vecadjoint(D.diag)) permutedims(D::Diagonal) = D permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D) +diag(D::Diagonal, ::Val{0}) = D.diag function diag(D::Diagonal, k::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of k diff --git a/src/tridiag.jl b/src/tridiag.jl index 519be750..3815fc02 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -700,6 +700,9 @@ issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y) \(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = copy(A) \ B +diag(M::Tridiagonal, ::Val{0}) = M.d +diag(M::Tridiagonal, ::Val{1}) = M.du +diag(M::Tridiagonal, ::Val{-1}) = M.dl function diag(M::Tridiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n From dc17e920412872e04cb49d5f18c9530cd3bcf755 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 24 Aug 2025 10:54:55 -0400 Subject: [PATCH 02/12] Specialize `diag` for `SymTridiagonal` --- src/tridiag.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tridiag.jl b/src/tridiag.jl index 3815fc02..6f749a88 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -195,6 +195,9 @@ _diagiter(M::SymTridiagonal) = (symmetric(x, :U) for x in M.dv) _eviter_transposed(M::SymTridiagonal{<:Number}) = _evview(M) _eviter_transposed(M::SymTridiagonal) = (transpose(x) for x in _evview(M)) +diag(M::SymTridiagonal{<:Number}, ::Val{0})= M.dv +diag(M::SymTridiagonal{<:Number}, ::Val{1})= _evview(M) +diag(M::SymTridiagonal{<:Number}, ::Val{-1})= _evview(M) function diag(M::SymTridiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n From 09b8bf3a713961ff8631727d1b30d1d46d1666a7 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 24 Aug 2025 10:58:45 -0400 Subject: [PATCH 03/12] Update docstring --- src/dense.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dense.jl b/src/dense.jl index ce3b4244..0980e2e6 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -311,9 +311,11 @@ diag(A::AbstractMatrix, k::Integer=0) = A[diagind(A, k, IndexStyle(A))] diag(M, ::Val{k}) where {k} Return the `k`th diagonal of a matrix as a vector. -For structured matrices such as `Diagonal`, this may return the underlying +For banded matrix types such as `Diagonal`, this may return the underlying band instead of making a copy if `k` lies within the bandwidth of the matrix. -This means that the type of the result may vary depending on the values of `k`. + +!!! note + The type of the result may vary depending on the values of `k`. """ diag(A::AbstractMatrix, ::Val{N}) where {N} = diag(A, N) From 72fff90288e5d572d5d798916ee168ac26340c6b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 24 Aug 2025 10:59:25 -0400 Subject: [PATCH 04/12] Change variable name --- src/dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dense.jl b/src/dense.jl index 0980e2e6..a4eb0756 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -317,7 +317,7 @@ band instead of making a copy if `k` lies within the bandwidth of the matrix. !!! note The type of the result may vary depending on the values of `k`. """ -diag(A::AbstractMatrix, ::Val{N}) where {N} = diag(A, N) +diag(A::AbstractMatrix, ::Val{k}) where {k} = diag(A, k) """ diagview(M, k::Integer=0) From daca97772b7938969a2d53f153a6d2f730037f32 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 24 Aug 2025 11:02:57 -0400 Subject: [PATCH 05/12] Use 1-arg `diag` when possible --- src/dense.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/dense.jl b/src/dense.jl index a4eb0756..3f6ef78e 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -317,7 +317,11 @@ band instead of making a copy if `k` lies within the bandwidth of the matrix. !!! note The type of the result may vary depending on the values of `k`. """ -diag(A::AbstractMatrix, ::Val{k}) where {k} = diag(A, k) +function diag(A::AbstractMatrix, ::Val{k}) where {k} + # some types might have a specialized 1-arg `diag` method, + # and we may use this if possible + k == 0 ? diag(A) : diag(A, k) +end """ diagview(M, k::Integer=0) From bc371b0563b5d0ffb8b3f06b6b8e3c66fa3fc939 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 24 Aug 2025 12:14:42 -0400 Subject: [PATCH 06/12] whitespace Co-authored-by: Daniel Karrasch --- src/tridiag.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tridiag.jl b/src/tridiag.jl index 6f749a88..43da35fc 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -195,9 +195,9 @@ _diagiter(M::SymTridiagonal) = (symmetric(x, :U) for x in M.dv) _eviter_transposed(M::SymTridiagonal{<:Number}) = _evview(M) _eviter_transposed(M::SymTridiagonal) = (transpose(x) for x in _evview(M)) -diag(M::SymTridiagonal{<:Number}, ::Val{0})= M.dv -diag(M::SymTridiagonal{<:Number}, ::Val{1})= _evview(M) -diag(M::SymTridiagonal{<:Number}, ::Val{-1})= _evview(M) +diag(M::SymTridiagonal{<:Number}, ::Val{0}) = M.dv +diag(M::SymTridiagonal{<:Number}, ::Val{1}) = _evview(M) +diag(M::SymTridiagonal{<:Number}, ::Val{-1}) = _evview(M) function diag(M::SymTridiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n From 5b0bcec0e8f8294b67c7a55b345e6a64bcab9721 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 24 Aug 2025 16:46:26 -0400 Subject: [PATCH 07/12] Specialize for triangular --- src/triangular.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/triangular.jl b/src/triangular.jl index d82ddd87..654840b0 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -549,7 +549,11 @@ adjoint!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U' , true, true adjoint!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U' , true, false)) diag(A::UpperOrLowerTriangular) = diag(A.data) +diag(A::UpperOrLowerTriangular, ::Val{0}) = diag(A.data, Val(0)) +diag(A::UpperOrUnitUpperTriangular, ::Val{1}) = diag(A.data, Val(1)) +diag(A::LowerOrUnitLowerTriangular, ::Val{-1}) = diag(A.data, Val(-1)) diag(A::Union{UnitLowerTriangular, UnitUpperTriangular}) = fill(oneunit(eltype(A)), size(A,1)) +diag(A::Union{UnitLowerTriangular, UnitUpperTriangular}, ::Val{0}) = diag(A) # Unary operations -(A::LowerTriangular) = LowerTriangular(-A.data) From 08275610a4fab07f20fed4e2f2989bcde1505005 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 24 Aug 2025 17:12:37 -0400 Subject: [PATCH 08/12] Specialize for `Adjoint`/`Transpose` --- src/adjtrans.jl | 2 ++ src/triangular.jl | 4 ++-- src/tridiag.jl | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/adjtrans.jl b/src/adjtrans.jl index 96db07d5..29f8d064 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -577,6 +577,8 @@ _vecadjoint(A::Base.ReshapedArray{<:Any,1,<:AdjointAbsVec}) = adjoint(parent(A)) diagview(A::Transpose, k::Integer = 0) = _vectranspose(diagview(parent(A), -k)) diagview(A::Adjoint, k::Integer = 0) = _vecadjoint(diagview(parent(A), -k)) +diag(A::Transpose, ::Val{k}) where {k} = _vectranspose(diag(parent(A), Val(-k))) +diag(A::Adjoint, ::Val{k}) where {k} = _vecadjoint(diag(parent(A), Val(-k))) # triu and tril triu!(A::AdjOrTransAbsMat, k::Integer = 0) = wrapperop(A)(tril!(parent(A), -k)) diff --git a/src/triangular.jl b/src/triangular.jl index 654840b0..9e285346 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -2998,8 +2998,8 @@ logdet(A::UnitUpperTriangular{T}) where {T} = zero(T) logdet(A::UnitLowerTriangular{T}) where {T} = zero(T) logabsdet(A::UnitUpperTriangular{T}) where {T} = zero(T), one(T) logabsdet(A::UnitLowerTriangular{T}) where {T} = zero(T), one(T) -det(A::UpperTriangular) = prod(diag(A.data)) -det(A::LowerTriangular) = prod(diag(A.data)) +det(A::UpperTriangular) = prod(diag(A.data, Val(0))) +det(A::LowerTriangular) = prod(diag(A.data, Val(0))) function logabsdet(A::Union{UpperTriangular{T},LowerTriangular{T}}) where T sgn = one(T) abs_det = zero(real(T)) diff --git a/src/tridiag.jl b/src/tridiag.jl index 43da35fc..2426c096 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -111,7 +111,7 @@ function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal} checksquare(A) du = diag(A, 1) d = diag(A) - if !(_issymmetric(A) || _checksymmetric(d, du, diag(A, -1))) + if !(_issymmetric(A) || _checksymmetric(d, du, diag(A, Val(-1)))) throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal")) end return SymTri(d, du) From 4417771d445f6b8614de2b65dfa42c1e08066ce7 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 24 Aug 2025 17:37:32 -0400 Subject: [PATCH 09/12] Tests for `Adjoint`/`Transpose` --- test/adjtrans.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/adjtrans.jl b/test/adjtrans.jl index 075f0acf..89d7cbd9 100644 --- a/test/adjtrans.jl +++ b/test/adjtrans.jl @@ -799,6 +799,25 @@ end end end +@testset "diag with a Val index" begin + @testset "$(typeof(A))" for A in Any[rand(4, 4), rand(ComplexF64,4,4), fill([1 2; 3 4], 4, 4), + Diagonal(1:4), Bidiagonal(1:4, 1:3, :U), + Tridiagonal(1:3, 1:4, 1:3), SymTridiagonal(1:4, 1:3)] + @testset for (wrap_fn, wrap_T) in ((transpose,Transpose), (adjoint,Adjoint)) + At = wrap_fn(A) + @test diag(At, 1) == diag(At, Val(1)) + @test diag(At, 0) == diag(At, Val(0)) + @test diag(At, -1) == diag(At, Val(-1)) + if !(At isa wrap_T) + AT = wrap_T(A) + @test diag(At, Val(1)) == diag(AT, Val(1)) + @test diag(At, Val(0)) == diag(AT, Val(0)) + @test diag(At, Val(-1)) == diag(AT, Val(-1)) + end + end + end +end + @testset "triu!/tril!" begin @testset for sz in ((4,4), (3,4), (4,3)) A = rand(sz...) From 723b66aabd1eaec1a09445ad4a29ddaf932d6bdf Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 1 Sep 2025 16:09:21 -0400 Subject: [PATCH 10/12] Tests for banded matrices --- test/bidiag.jl | 9 +++++++++ test/diagonal.jl | 5 +++++ test/triangular.jl | 9 +++++++++ test/tridiag.jl | 11 +++++++++++ 4 files changed, 34 insertions(+) diff --git a/test/bidiag.jl b/test/bidiag.jl index 6c20a4b8..d9051511 100644 --- a/test/bidiag.jl +++ b/test/bidiag.jl @@ -1298,4 +1298,13 @@ end end end +@testset "diag with a Val index" begin + B = Bidiagonal(1:4, 1:3, :U) + @test diag(B, Val(0)) === 1:4 + @test diag(B, Val(1)) === 1:3 + B = Bidiagonal(1:4, 1:3, :L) + @test diag(B, Val(0)) === 1:4 + @test diag(B, Val(-1)) === 1:3 +end + end # module TestBidiagonal diff --git a/test/diagonal.jl b/test/diagonal.jl index 712f426c..ce331483 100644 --- a/test/diagonal.jl +++ b/test/diagonal.jl @@ -1576,4 +1576,9 @@ end @test D == D2 end +@testset "diag with a Val index" begin + D = Diagonal(1:4) + @test diag(D, Val(0)) === 1:4 +end + end # module TestDiagonal diff --git a/test/triangular.jl b/test/triangular.jl index e823f698..efd03e2a 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -1110,4 +1110,13 @@ end end end +@testset "diag with a Val index" begin + U = UpperTriangular(Tridiagonal(2:4, 1:4, 1:3)) + @test diag(U, Val(0)) === 1:4 + @test diag(U, Val(1)) === 1:3 + L = LowerTriangular(Tridiagonal(2:4, 1:4, 1:3)) + @test diag(L, Val(0)) === 1:4 + @test diag(L, Val(-1)) === 2:4 +end + end # module TestTriangular diff --git a/test/tridiag.jl b/test/tridiag.jl index e955a37e..888557ac 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1255,4 +1255,15 @@ end end end +@testset "diag with a Val index" begin + T = Tridiagonal(2:4, 1:4, 1:3) + @test diag(T, Val(0)) === 1:4 + @test diag(T, Val(1)) === 1:3 + @test diag(T, Val(-1)) === 2:4 + ST = SymTridiagonal(1:4, 1:3) + @test diag(ST, Val(0)) === 1:4 + @test diag(ST, Val(1)) === 1:3 + @test diag(ST, Val(-1)) === 1:3 +end + end # module TestTridiagonal From 7d8085de4b60fbe0e4fcf748c1cf1488432f0382 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 1 Sep 2025 16:35:03 -0400 Subject: [PATCH 11/12] Specialize for arbitrary k for triangular/hessenberg --- src/hessenberg.jl | 3 +++ src/triangular.jl | 10 ++++++---- test/hessenberg.jl | 10 ++++++++++ test/triangular.jl | 11 +++++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/hessenberg.jl b/src/hessenberg.jl index 7aab1c77..ef293981 100644 --- a/src/hessenberg.jl +++ b/src/hessenberg.jl @@ -667,3 +667,6 @@ function logdet(F::Hessenberg) d,s = logabsdet(F) return d + log(s) end + +diag(A::UpperHessenberg) = diag(A.data) +diag(A::UpperHessenberg, ::Val{k}) where {k} = k >= -1 ? diag(A.data, Val(k)) : diag(A, k) \ No newline at end of file diff --git a/src/triangular.jl b/src/triangular.jl index 9e285346..966e68e4 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -549,11 +549,13 @@ adjoint!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U' , true, true adjoint!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U' , true, false)) diag(A::UpperOrLowerTriangular) = diag(A.data) -diag(A::UpperOrLowerTriangular, ::Val{0}) = diag(A.data, Val(0)) -diag(A::UpperOrUnitUpperTriangular, ::Val{1}) = diag(A.data, Val(1)) -diag(A::LowerOrUnitLowerTriangular, ::Val{-1}) = diag(A.data, Val(-1)) diag(A::Union{UnitLowerTriangular, UnitUpperTriangular}) = fill(oneunit(eltype(A)), size(A,1)) -diag(A::Union{UnitLowerTriangular, UnitUpperTriangular}, ::Val{0}) = diag(A) +diag(A::UpperTriangular, ::Val{k}) where {k} = k >= 0 ? diag(A.data, Val(k)) : diag(A, k) +diag(A::LowerTriangular, ::Val{k}) where {k} = k <= 0 ? diag(A.data, Val(k)) : diag(A, k) +diag(A::UnitUpperTriangular, ::Val{0}) = diag(A) +diag(A::UnitLowerTriangular, ::Val{0}) = diag(A) +diag(A::UnitUpperTriangular, ::Val{k}) where {k} = k > 0 ? diag(A.data, Val(k)) : diag(A, k) +diag(A::UnitLowerTriangular, ::Val{k}) where {k} = k < 0 ? diag(A.data, Val(k)) : diag(A, k) # Unary operations -(A::LowerTriangular) = LowerTriangular(-A.data) diff --git a/test/hessenberg.jl b/test/hessenberg.jl index 5b61c798..81e6655a 100644 --- a/test/hessenberg.jl +++ b/test/hessenberg.jl @@ -320,4 +320,14 @@ end @test U == U2 end +@testset "diag with a Val index" begin + H = UpperHessenberg(Tridiagonal(1:3, 1:4, 1:3)) + @test diag(H, Val(0)) === 1:4 + @test diag(H, Val(1)) === 1:3 + @test diag(H, Val(-1)) === 1:3 + @test diag(H, Val(0)) == diag(H) == diag(H, 0) + @test diag(H, Val(2)) == diag(H, 2) + @test diag(H, Val(-2)) == diag(H, -2) +end + end # module TestHessenberg diff --git a/test/triangular.jl b/test/triangular.jl index efd03e2a..61a39950 100644 --- a/test/triangular.jl +++ b/test/triangular.jl @@ -1114,9 +1114,20 @@ end U = UpperTriangular(Tridiagonal(2:4, 1:4, 1:3)) @test diag(U, Val(0)) === 1:4 @test diag(U, Val(1)) === 1:3 + @test diag(U, Val(-1)) == diag(U, -1) == zeros(3) L = LowerTriangular(Tridiagonal(2:4, 1:4, 1:3)) @test diag(L, Val(0)) === 1:4 @test diag(L, Val(-1)) === 2:4 + @test diag(L, Val(1)) == diag(L, 1) == zeros(3) + + U = UnitUpperTriangular(Tridiagonal(2:4, 1:4, 1:3)) + @test diag(U, Val(1)) === 1:3 + @test diag(U, Val(0)) == diag(U, 0) == diag(U) == ones(4) + @test diag(U, Val(-1)) == diag(U, -1) == zeros(3) + L = UnitLowerTriangular(Tridiagonal(2:4, 1:4, 1:3)) + @test diag(L, Val(-1)) === 2:4 + @test diag(L, Val(0)) == diag(L, 0) == diag(L) == ones(4) + @test diag(L, Val(1)) == diag(L, 1) == zeros(3) end end # module TestTriangular From b11d34595de5056be0c22a0ddee7e2bc6b0ca592 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 1 Sep 2025 16:40:09 -0400 Subject: [PATCH 12/12] Newline --- src/hessenberg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hessenberg.jl b/src/hessenberg.jl index ef293981..b65f7be8 100644 --- a/src/hessenberg.jl +++ b/src/hessenberg.jl @@ -669,4 +669,4 @@ function logdet(F::Hessenberg) end diag(A::UpperHessenberg) = diag(A.data) -diag(A::UpperHessenberg, ::Val{k}) where {k} = k >= -1 ? diag(A.data, Val(k)) : diag(A, k) \ No newline at end of file +diag(A::UpperHessenberg, ::Val{k}) where {k} = k >= -1 ? diag(A.data, Val(k)) : diag(A, k)