Skip to content

Commit e08b46e

Browse files
authored
Allow custom axes when constructing operators, SymmetrySectors.jl and ITensorBase.jl extensions (#17)
1 parent 3df5640 commit e08b46e

File tree

12 files changed

+536
-118
lines changed

12 files changed

+536
-118
lines changed

Project.toml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
name = "QuantumOperatorDefinitions"
22
uuid = "826dd319-6fd5-459a-a990-3a4f214664bf"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.4"
4+
version = "0.1.5"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
99

1010
[weakdeps]
11+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
12+
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
13+
LabelledNumbers = "f856a3a6-4152-4ec4-b2a7-02c1a55d7993"
1114
ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7"
15+
NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
16+
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
1217

1318
[extensions]
14-
QuantumOperatorDefinitionsITensorBaseExt = "ITensorBase"
19+
QuantumOperatorDefinitionsITensorBaseExt = ["ITensorBase", "NamedDimsArrays"]
20+
QuantumOperatorDefinitionsSymmetrySectorsExt = ["BlockArrays", "GradedUnitRanges", "LabelledNumbers", "SymmetrySectors"]
1521

1622
[compat]
23+
BlockArrays = "1.3.0"
24+
GradedUnitRanges = "0.1.2"
1725
ITensorBase = "0.1.10"
26+
LabelledNumbers = "0.1.0"
1827
LinearAlgebra = "1.10"
28+
NamedDimsArrays = "0.4.0"
1929
Random = "1.10"
30+
SymmetrySectors = "0.1.3"
2031
julia = "1.10"

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ julia> Pkg.add("QuantumOperatorDefinitions")
3333

3434
````julia
3535
using QuantumOperatorDefinitions: OpName, SiteType, StateName, , controlled, op, state
36-
using LinearAlgebra: Diagonal
3736
using SparseArrays: SparseMatrixCSC, SparseVector
3837
using Test: @test
3938

@@ -64,8 +63,6 @@ using Test: @test
6463
@test op("Y") == [0 -im; im 0]
6564
@test op("Z") == [1 0; 0 -1]
6665

67-
@test op("Z") isa Diagonal
68-
6966
@test op(Float32, "X") == [0 1; 1 0]
7067
@test eltype(op(Float32, "X")) === Float32
7168
@test op(SparseMatrixCSC, "X") == [0 1; 1 0]

examples/README.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ julia> Pkg.add("QuantumOperatorDefinitions")
3838
# ## Examples
3939

4040
using QuantumOperatorDefinitions: OpName, SiteType, StateName, , controlled, op, state
41-
using LinearAlgebra: Diagonal
4241
using SparseArrays: SparseMatrixCSC, SparseVector
4342
using Test: @test
4443

@@ -69,8 +68,6 @@ using Test: @test
6968
@test op("Y") == [0 -im; im 0]
7069
@test op("Z") == [1 0; 0 -1]
7170

72-
@test op("Z") isa Diagonal
73-
7471
@test op(Float32, "X") == [0 1; 1 0]
7572
@test eltype(op(Float32, "X")) === Float32
7673
@test op(SparseMatrixCSC, "X") == [0 1; 1 0]
Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,58 @@
11
module QuantumOperatorDefinitionsITensorBaseExt
22

3-
using ITensorBase: ITensor, Index, dag, gettag, prime
3+
using ITensorBase: ITensorBase, ITensor, Index, dag, gettag, prime, settag
4+
using NamedDimsArrays: dename
45
using QuantumOperatorDefinitions:
5-
QuantumOperatorDefinitions, OpName, SiteType, StateName, has_fermion_string
6+
QuantumOperatorDefinitions,
7+
@OpName_str,
8+
OpName,
9+
SiteType,
10+
StateName,
11+
has_fermion_string,
12+
name
613

714
function QuantumOperatorDefinitions.SiteType(r::Index)
8-
return SiteType(gettag(r, "sitetype", "Qudit"); dim=Int(length(r)))
15+
# We pass the axis of the (unnamed) Index because
16+
# the Index may have originated from a slice, in which
17+
# case the start may not be 1 (for NonContiguousIndex,
18+
# which we need to add support for, it may not even
19+
# be a unit range).
20+
return SiteType(
21+
gettag(r, "sitetype", "Qudit"); dim=Int.(length(r)), range=only(axes(dename(r)))
22+
)
923
end
1024

25+
function (rangetype::Type{<:Index})(t::SiteType)
26+
return settag(rangetype(AbstractUnitRange(t)), "sitetype", String(name(t)))
27+
end
28+
29+
# TODO: Define in terms of `OpName` directly, and define a generic
30+
# forwarding method `has_fermion_string(n::String, t) = has_fermion_string(OpName(n), t)`.
1131
function QuantumOperatorDefinitions.has_fermion_string(n::String, r::Index)
1232
return has_fermion_string(OpName(n), SiteType(r))
1333
end
1434

15-
function Base.AbstractArray(n::OpName, r::Index)
16-
# TODO: Define this with mapped dimnames.
17-
return ITensor(AbstractArray(n, SiteType(r)), (prime(r), dag(r)))
35+
function Base.axes(::OpName, domain::Tuple{Vararg{Index}})
36+
return (prime.(domain)..., dag.(domain)...)
1837
end
38+
## function Base.axes(::OpName"SWAP", domain::Tuple{Vararg{Index}})
39+
## return (prime.(reverse(domain))..., dag.(domain)...)
40+
## end
1941

20-
function Base.AbstractArray(n::StateName, r::Index)
21-
return ITensor(AbstractArray(n, SiteType(r)), (r,))
42+
# Fix ambiguity error with generic `AbstractArray` version.
43+
function ITensorBase.ITensor(n::Union{OpName,StateName}, domain::Index...)
44+
return ITensor(n, domain)
45+
end
46+
# Fix ambiguity error with generic `AbstractArray` version.
47+
function ITensorBase.ITensor(n::Union{OpName,StateName}, domain::Tuple{Vararg{Index}})
48+
return ITensor(AbstractArray(n, domain), axes(n, domain))
49+
end
50+
function (arrtype::Type{<:AbstractArray})(
51+
n::Union{OpName,StateName}, domain::Tuple{Vararg{Index}}
52+
)
53+
# Convert to `SiteType` in case the Index specifies a `"sitetype"` tag.
54+
# TODO: Try to build this into the generic codepath.
55+
return ITensor(arrtype(n, SiteType.(domain)), axes(n, domain))
2256
end
2357

2458
end
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
module QuantumOperatorDefinitionsSymmetrySectorsExt
2+
3+
using BlockArrays: blocklasts, blocklengths
4+
using GradedUnitRanges: AbstractGradedUnitRange, GradedOneTo, gradedrange
5+
using LabelledNumbers: label, labelled, unlabel
6+
using QuantumOperatorDefinitions:
7+
QuantumOperatorDefinitions,
8+
@SiteType_str,
9+
@GradingType_str,
10+
SiteType,
11+
GradingType,
12+
OpName,
13+
name
14+
using SymmetrySectors: ×, dual, SectorProduct, U1, Z
15+
16+
function Base.axes(::OpName, domain::Tuple{Vararg{AbstractGradedUnitRange}})
17+
return (domain..., dual.(domain)...)
18+
end
19+
20+
sortedunion(a, b) = sort(union(a, b))
21+
function QuantumOperatorDefinitions.combine_axes(a1::GradedOneTo, a2::GradedOneTo)
22+
return gradedrange(
23+
map(blocklengths(a1), blocklengths(a2)) do s1, s2
24+
l1 = unlabel(s1)
25+
l2 = unlabel(s2)
26+
@assert l1 == l2
27+
labelled(l1, label(s1) × label(s2))
28+
end,
29+
)
30+
end
31+
QuantumOperatorDefinitions.combine_axes(a::GradedOneTo, b::Base.OneTo) = a
32+
QuantumOperatorDefinitions.combine_axes(a::Base.OneTo, b::GradedOneTo) = b
33+
34+
function Base.AbstractUnitRange(::GradingType"N", t::SiteType)
35+
return gradedrange(map(i -> SectorProduct((; N=U1(i - 1))) => 1, 1:length(t)))
36+
end
37+
function Base.AbstractUnitRange(::GradingType"Sz", t::SiteType)
38+
return gradedrange(map(i -> SectorProduct((; Sz=U1(i - 1))) => 1, 1:length(t)))
39+
end
40+
function Base.AbstractUnitRange(::GradingType"Sz↑", t::SiteType)
41+
return AbstractUnitRange(GradingType"Sz"(), t)
42+
end
43+
function Base.AbstractUnitRange(::GradingType"Sz↓", t::SiteType)
44+
return gradedrange(map(i -> SectorProduct((; Sz=U1(-(i - 1)))) => 1, 1:length(t)))
45+
end
46+
47+
function sector(gradingtype::GradingType, sec)
48+
sectorname = Symbol(get(gradingtype, :name, name(gradingtype)))
49+
return SectorProduct(NamedTuple{(sectorname,)}((sec,)))
50+
end
51+
52+
function Base.AbstractUnitRange(s::GradingType"Nf", t::SiteType"Fermion")
53+
return gradedrange([sector(s, U1(0)) => 1, sector(s, U1(1)) => 1])
54+
end
55+
# TODO: Write in terms of `GradingType"Nf"` definition.
56+
function Base.AbstractUnitRange(s::GradingType"NfParity", t::SiteType"Fermion")
57+
return gradedrange([sector(s, Z{2}(0)) => 1, sector(s, Z{2}(1)) => 1])
58+
end
59+
function Base.AbstractUnitRange(s::GradingType"Sz", t::SiteType"Fermion")
60+
return gradedrange([sector(s, U1(0)) => 1, sector(s, U1(1)) => 1])
61+
end
62+
function Base.AbstractUnitRange(s::GradingType"Sz↑", t::SiteType"Fermion")
63+
return gradedrange([sector(s, U1(0)) => 1, sector(s, U1(1)) => 1])
64+
end
65+
function Base.AbstractUnitRange(s::GradingType"Sz↓", t::SiteType"Fermion")
66+
return gradedrange([sector(s, U1(0)) => 1, sector(s, U1(-1)) => 1])
67+
end
68+
69+
# TODO: Write in terms of `SiteType"Fermion"` definitions.
70+
function Base.AbstractUnitRange(s::GradingType"Nf", t::SiteType"Electron")
71+
return gradedrange([
72+
sector(s, U1(0)) => 1,
73+
sector(s, U1(1)) => 1,
74+
sector(s, U1(1)) => 1,
75+
sector(s, U1(2)) => 1,
76+
])
77+
end
78+
# TODO: Write in terms of `GradingType"Nf"` definition.
79+
function Base.AbstractUnitRange(s::GradingType"NfParity", t::SiteType"Electron")
80+
return gradedrange([
81+
sector(s, Z{2}(0)) => 1,
82+
sector(s, Z{2}(1)) => 1,
83+
sector(s, Z{2}(1)) => 1,
84+
sector(s, Z{2}(0)) => 1,
85+
])
86+
end
87+
function Base.AbstractUnitRange(s::GradingType"Sz", t::SiteType"Electron")
88+
return gradedrange([
89+
sector(s, U1(0)) => 1,
90+
sector(s, U1(1)) => 1,
91+
sector(s, U1(-1)) => 1,
92+
sector(s, U1(0)) => 1,
93+
])
94+
end
95+
96+
end

src/op.jl

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ struct OpName{Name,Params}
88
end
99
name(::OpName{Name}) where {Name} = Name
1010
params(n::OpName) = getfield(n, :params)
11-
1211
Base.getproperty(n::OpName, name::Symbol) = getfield(params(n), name)
12+
Base.get(t::OpName, name::Symbol, default) = get(params(t), name, default)
1313

1414
OpName{N}(; kwargs...) where {N} = OpName{N}((; kwargs...))
1515

@@ -54,9 +54,14 @@ end
5454
# Generic to `StateName` or `OpName`.
5555
const StateOrOpName = Union{StateName,OpName}
5656
alias(n::StateOrOpName) = n
57-
function (arrtype::Type{<:AbstractArray})(n::StateOrOpName, domain::Integer...)
57+
function (arrtype::Type{<:AbstractArray})(
58+
n::StateOrOpName, domain::Union{Integer,AbstractUnitRange}...
59+
)
5860
return arrtype(n, domain)
5961
end
62+
function (arrtype::Type{<:AbstractArray})(n::StateOrOpName, domain::Tuple{Vararg{Integer}})
63+
return arrtype(n, Base.oneto.(domain))
64+
end
6065
(arrtype::Type{<:AbstractArray})(n::StateOrOpName, ts::SiteType...) = arrtype(n, ts)
6166
function (n::StateOrOpName)(domain...)
6267
# TODO: Try one alias at a time?
@@ -87,32 +92,58 @@ function nsites(n::StateOrOpName)
8792
return nsites(n′)
8893
end
8994

90-
function op_convert(
91-
arrtype::Type{<:AbstractArray{<:Any,N}},
92-
domain::Tuple{Vararg{Integer}},
93-
a::AbstractArray{<:Any,N},
94-
) where {N}
95-
# TODO: Check the dimensions.
96-
return convert(arrtype, a)
95+
# TODO: This does some unwanted conversions, like turning
96+
# `Diagonal` dense.
97+
function array(a::AbstractArray, ax::Tuple{Vararg{AbstractUnitRange}})
98+
return a[ax...]
9799
end
98-
function op_convert(
99-
arrtype::Type{<:AbstractArray}, domain::Tuple{Vararg{Integer}}, a::AbstractArray
100-
)
101-
# TODO: Check the dimensions.
102-
return convert(arrtype, a)
100+
101+
function Base.axes(::OpName, domain::Tuple{Vararg{AbstractUnitRange}})
102+
return (domain..., domain...)
103+
end
104+
function Base.axes(n::StateOrOpName, domain::Tuple{Vararg{Integer}})
105+
return axes(n, Base.OneTo.(domain))
106+
end
107+
function Base.axes(n::StateOrOpName, domain::Tuple{Vararg{SiteType}})
108+
return axes(n, AbstractUnitRange.(domain))
109+
end
110+
111+
## function Base.axes(::OpName"SWAP", domain::Tuple{Vararg{AbstractUnitRange}})
112+
## return (reverse(domain)..., domain...)
113+
## end
114+
115+
function reversed_sites(n::StateOrOpName, domain)
116+
return reverse_sites(n, reshape(n(domain...), length.(axes(n, reverse(domain)))))
117+
end
118+
function reverse_sites(n::OpName, a::AbstractArray)
119+
ndomain = Int(ndims(a)//2)
120+
perm1 = reverse(ntuple(identity, ndomain))
121+
perm2 = perm1 .+ ndomain
122+
perm = (perm1..., perm2...)
123+
return permutedims(a, perm)
103124
end
104-
function op_convert(
105-
arrtype::Type{<:AbstractArray{<:Any,N}}, domain::Tuple{Vararg{Integer}}, a::AbstractArray
106-
) where {N}
107-
size = (domain..., domain...)
108-
@assert length(size) == N
109-
return convert(arrtype, reshape(a, size))
125+
126+
function state_or_op_convert(
127+
n::StateOrOpName,
128+
arrtype::Type{<:AbstractArray},
129+
domain::Tuple{Vararg{AbstractUnitRange}},
130+
a::AbstractArray,
131+
)
132+
ax = axes(n, domain)
133+
a′ = reshape(a, length.(ax))
134+
a′′ = array(a′, ax)
135+
return convert(arrtype, a′′)
110136
end
111-
function (arrtype::Type{<:AbstractArray})(n::OpName, domain::Tuple{Vararg{SiteType}})
112-
return op_convert(arrtype, length.(domain), n(domain...))
137+
138+
function (arrtype::Type{<:AbstractArray})(n::StateOrOpName, domain::Tuple{Vararg{SiteType}})
139+
domain′ = AbstractUnitRange.(domain)
140+
return state_or_op_convert(n, arrtype, domain′, reversed_sites(n, domain))
113141
end
114-
function (arrtype::Type{<:AbstractArray})(n::OpName, domain::Tuple{Vararg{Integer}})
115-
return op_convert(arrtype, domain, n(Int.(domain)...))
142+
function (arrtype::Type{<:AbstractArray})(
143+
n::StateOrOpName, domain::Tuple{Vararg{AbstractUnitRange}}
144+
)
145+
# TODO: Make `(::OpName)(domain...)` constructor process more general inputs.
146+
return state_or_op_convert(n, arrtype, domain, reversed_sites(n, Int.(length.(domain))))
116147
end
117148

118149
function op(arrtype::Type{<:AbstractArray}, n::String, domain...; kwargs...)
@@ -475,13 +506,13 @@ function (n::OpName"Controlled")(domain...)
475506
# Number of control sites.
476507
nc = get(params(n), :ncontrol, length(domain) - nt)
477508
@assert length(domain) == nc + nt
478-
d_control = prod(to_dim.(domain[1:nc]))
509+
d_control = prod(to_dim.(domain)) - prod(to_dim.(domain[(nc + 1):end]))
479510
return cat(I(d_control), n.arg(domain[(nc + 1):end]...); dims=(1, 2))
480511
end
481-
@op_alias "CNOT" "Controlled" op = OpName"X"()
482-
@op_alias "CX" "Controlled" op = OpName"X"()
483-
@op_alias "CY" "Controlled" op = OpName"Y"()
484-
@op_alias "CZ" "Controlled" op = OpName"Z"()
512+
@op_alias "CNOT" "Controlled" arg = OpName"X"()
513+
@op_alias "CX" "Controlled" arg = OpName"X"()
514+
@op_alias "CY" "Controlled" arg = OpName"Y"()
515+
@op_alias "CZ" "Controlled" arg = OpName"Z"()
485516
function alias(n::OpName"CPhase")
486517
return controlled(OpName"Phase"(; params(n)...))
487518
end
@@ -504,17 +535,17 @@ function alias(::OpName"CRn")
504535
end
505536
@op_alias "CRn̂" "CRn"
506537

507-
@op_alias "CCNOT" "Controlled" ncontrol = 2 op = OpName"X"()
538+
@op_alias "CCNOT" "Controlled" ncontrol = 2 arg = OpName"X"()
508539
@op_alias "Toffoli" "CCNOT"
509540
@op_alias "CCX" "CCNOT"
510541
@op_alias "TOFF" "CCNOT"
511542

512-
@op_alias "CSWAP" "Controlled" ncontrol = 2 op = OpName"SWAP"()
543+
@op_alias "CSWAP" "Controlled" ncontrol = 2 arg = OpName"SWAP"()
513544
@op_alias "Fredkin" "CSWAP"
514545
@op_alias "CSwap" "CSWAP"
515546
@op_alias "CS" "CSWAP"
516547

517-
@op_alias "CCCNOT" "Controlled" ncontrol = 3 op = OpName"X"()
548+
@op_alias "CCCNOT" "Controlled" ncontrol = 3 arg = OpName"X"()
518549

519550
## # 1-qudit rotation around generic axis n̂.
520551
## # exp(-im * α / 2 * n̂ ⋅ σ⃗)

0 commit comments

Comments
 (0)