Skip to content

[WIP] Define random_unitary constructor #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.1"
version = "0.2.2"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"

Expand All @@ -22,7 +23,8 @@ ArrayLayouts = "1.10.4"
BlockArrays = "1.2.0"
EllipsisNotation = "1.8.0"
GradedUnitRanges = "0.1.0"
LinearAlgebra = "1.10"
LinearAlgebra = "1.10.0"
Random = "1.10.0"
TupleTools = "1.6.0"
TypeParameterAccessors = "0.2.1, 0.3"
TypeParameterAccessors = "0.2.1, 0.3.0"
julia = "1.10"
Original file line number Diff line number Diff line change
@@ -1,8 +1,29 @@
module TensorAlgebraGradedUnitRangesExt
using GradedUnitRanges: AbstractGradedUnitRange, tensor_product
using TensorAlgebra: TensorAlgebra

using GradedUnitRanges: AbstractGradedUnitRange, dual, tensor_product
using GradedUnitRanges.BlockArrays: Block, blocklengths, blocksize
using Random: AbstractRNG
using TensorAlgebra: TensorAlgebra, random_unitary

function TensorAlgebra.:⊗(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange)
return tensor_product(a1, a2)
end

function TensorAlgebra.dual(a::AbstractGradedUnitRange)
return dual(a)

Check warning on line 13 in ext/TensorAlgebraGradedUnitRangesExt/TensorAlgebraGradedUnitRangesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraGradedUnitRangesExt/TensorAlgebraGradedUnitRangesExt.jl#L12-L13

Added lines #L12 - L13 were not covered by tests
end

function TensorAlgebra.random_unitary(

Check warning on line 16 in ext/TensorAlgebraGradedUnitRangesExt/TensorAlgebraGradedUnitRangesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraGradedUnitRangesExt/TensorAlgebraGradedUnitRangesExt.jl#L16

Added line #L16 was not covered by tests
rng::AbstractRNG,
elt::Type,
ax::Tuple{AbstractGradedUnitRange},
)
a = zeros(elt, dual.(ax)..., ax...)

Check warning on line 21 in ext/TensorAlgebraGradedUnitRangesExt/TensorAlgebraGradedUnitRangesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraGradedUnitRangesExt/TensorAlgebraGradedUnitRangesExt.jl#L21

Added line #L21 was not covered by tests
# TODO: Define `blockdiagindices`.
for i in 1:minimum(blocksize(a))
a[Block(i, i)] = random_unitary(rng, elt, Int(blocklengths(only(ax))[i]))
end
return a

Check warning on line 26 in ext/TensorAlgebraGradedUnitRangesExt/TensorAlgebraGradedUnitRangesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorAlgebraGradedUnitRangesExt/TensorAlgebraGradedUnitRangesExt.jl#L23-L26

Added lines #L23 - L26 were not covered by tests
end

end
1 change: 1 addition & 0 deletions src/TensorAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ include("contract/blockedperms.jl")
include("contract/allocate_output.jl")
include("contract/contract_matricize/contract.jl")
include("factorizations.jl")
include("random_unitary.jl")

end
69 changes: 69 additions & 0 deletions src/random_unitary.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Version of `sign` that returns one
# if `x == 0`.
function nonzero_sign(x)
iszero(x) && return one(x)
return sign(x)

Check warning on line 5 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L3-L5

Added lines #L3 - L5 were not covered by tests
end

using LinearAlgebra: LinearAlgebra, Diagonal, diag
function qr_positive(M::AbstractMatrix)
Q, R = LinearAlgebra.qr(M)
Q′ = typeof(R)(Q)
signs = nonzero_sign.(diag(R))
Q′ = Q′ * Diagonal(signs)
R = Diagonal(conj.(signs)) * R
return Q′, R

Check warning on line 15 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L9-L15

Added lines #L9 - L15 were not covered by tests
end

using Random: Random, AbstractRNG

dual(x) = x

Check warning on line 20 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L20

Added line #L20 was not covered by tests

function random_unitary(

Check warning on line 22 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L22

Added line #L22 was not covered by tests
rng::AbstractRNG,
elt::Type,
ax::Tuple{AbstractUnitRange,Vararg{AbstractUnitRange}},
)
ax_fused = ⊗(ax...)
a_fused = random_unitary(rng, elt, ax_fused)
return splitdims(a_fused, dual.(ax), ax)

Check warning on line 29 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L27-L29

Added lines #L27 - L29 were not covered by tests
end

# Copy of `Base.to_dim`:
# https://github.com/JuliaLang/julia/blob/1431bec1bcd205f181ca2a3f1c314247b64076df/base/array.jl#L439-L440
to_dim(d::Integer) = d
to_dim(d::Base.OneTo) = last(d)

Check warning on line 35 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L34-L35

Added lines #L34 - L35 were not covered by tests

# Matrix version.
function random_unitary(

Check warning on line 38 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L38

Added line #L38 was not covered by tests
rng::AbstractRNG,
elt::Type,
ax::Tuple{AbstractUnitRange},
)
return random_unitary(rng, elt, map(to_dim, ax))

Check warning on line 43 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L43

Added line #L43 was not covered by tests
end

function random_unitary(rng::AbstractRNG, elt::Type, dims::Tuple{Integer})
Q, _ = qr_positive(randn(rng, elt, (dims..., dims...)))
return Q

Check warning on line 48 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L46-L48

Added lines #L46 - L48 were not covered by tests
end

# Canonicalizing other kinds of inputs.
function random_unitary(rng::AbstractRNG, elt::Type, dims::Tuple{Vararg{Union{AbstractUnitRange,Integer}}})
return random_unitary(Random.default_rng(), elt, map(to_axis, dims))

Check warning on line 53 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L52-L53

Added lines #L52 - L53 were not covered by tests
end
function random_unitary(elt::Type, dims::Tuple{Vararg{Union{AbstractUnitRange,Integer}}})
return random_unitary(Random.default_rng(), elt, dims)

Check warning on line 56 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L55-L56

Added lines #L55 - L56 were not covered by tests
end
function random_unitary(rng::AbstractRNG, elt::Type, dims::Union{AbstractUnitRange,Integer}...)
return random_unitary(rng, elt, dims)

Check warning on line 59 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L58-L59

Added lines #L58 - L59 were not covered by tests
end
function random_unitary(elt::Type, dims::Union{AbstractUnitRange,Integer}...)
return random_unitary(Random.default_rng(), elt, dims)

Check warning on line 62 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
end
function random_unitary(rng::AbstractRNG, dims::Union{AbstractUnitRange,Integer}...)
return random_unitary(rng, Float64, dims)

Check warning on line 65 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L64-L65

Added lines #L64 - L65 were not covered by tests
end
function random_unitary(dims::Union{AbstractUnitRange,Integer}...)
return random_unitary(Random.default_rng(), Float64, dims)

Check warning on line 68 in src/random_unitary.jl

View check run for this annotation

Codecov / codecov/patch

src/random_unitary.jl#L67-L68

Added lines #L67 - L68 were not covered by tests
end
Loading