@@ -6,36 +6,25 @@ using ITensors.NDTensors: Block, maxdim, nblocks, nnzblocks
6
6
using ITensors. Ops: Op, OpSum
7
7
using NamedGraphs: degrees, is_leaf, vertex_path
8
8
using StaticArrays: MVector
9
-
9
+ using NamedGraphs : boundary_edges
10
10
# convert ITensors.OpSum to TreeTensorNetwork
11
11
12
12
#
13
13
# Utility methods
14
14
#
15
15
16
- # linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
17
- function find_index_in_tree (site, g:: AbstractGraph , root_vertex)
18
- ordering = reverse (post_order_dfs_vertices (g, root_vertex))
19
- return findfirst (x -> x == site, ordering)
20
- end
21
- function find_index_in_tree (o:: Op , g:: AbstractGraph , root_vertex)
22
- return find_index_in_tree (ITensors. site (o), g, root_vertex)
16
+ function align_edges (edges, reference_edges)
17
+ return intersect (Iterators. flatten (zip (edges, reverse .(edges))), reference_edges)
23
18
end
24
19
25
- # determine 'support' of product operator on tree graph
26
- function span (t:: Scaled{C,Prod{Op}} , g:: AbstractGraph ) where {C}
27
- spn = eltype (g)[]
28
- nterms = length (t)
29
- for i in 1 : nterms, j in i: nterms
30
- path = vertex_path (g, ITensors. site (t[i]), ITensors. site (t[j]))
31
- spn = union (spn, path)
32
- end
33
- return spn
20
+ function align_and_reorder_edges (edges, reference_edges)
21
+ return intersect (reference_edges, align_edges (edges, reference_edges))
34
22
end
35
23
36
- # determine whether an operator string crosses a given graph vertex
37
- function crosses_vertex (t:: Scaled{C,Prod{Op}} , g:: AbstractGraph , v) where {C}
38
- return v ∈ span (t, g)
24
+ function split_at_vertex (g:: AbstractGraph , v)
25
+ g = copy (g)
26
+ rem_vertex! (g, v)
27
+ return Set .(connected_components (g))
39
28
end
40
29
41
30
#
45
34
"""
46
35
ttn_svd(os::OpSum, sites::IndsNetwork, root_vertex, kwargs...)
47
36
48
- Construct a dense TreeTensorNetwork from a symbolic OpSum representation of a
37
+ Construct a TreeTensorNetwork from a symbolic OpSum representation of a
49
38
Hamiltonian, compressing shared interaction channels.
50
39
"""
51
40
function ttn_svd (os:: OpSum , sites:: IndsNetwork , root_vertex; kwargs... )
@@ -71,9 +60,9 @@ function ttn_svd(
71
60
thishasqns = any (v -> hasqns (sites[v]), vertices (sites))
72
61
73
62
# traverse tree outwards from root vertex
74
- vs = reverse ( post_order_dfs_vertices ( sites, root_vertex)) # store vertices in fixed ordering relative to root
63
+ vs = _default_vertex_ordering ( sites, root_vertex)
75
64
# ToDo: Add check in ttn_svd that the ordering matches that of find_index_in_tree, which is used in sorteachterm #fermion-sign!
76
- es = reverse ( reverse .( post_order_dfs_edges ( sites, root_vertex)) ) # store edges in fixed ordering relative to root
65
+ es = _default_edge_ordering ( sites, root_vertex) # store edges in fixed ordering relative to root
77
66
# some things to keep track of
78
67
degrees = Dict (v => degree (sites, v) for v in vs) # rank of every TTN tensor in network
79
68
Vs = Dict (e => Dict {QN,Matrix{coefficient_type}} () for e in es) # link isometries for SVD compression of TTN
@@ -105,6 +94,8 @@ function ttn_svd(
105
94
for v in vs
106
95
is_internal[v] = isempty (sites[v])
107
96
if isempty (sites[v])
97
+ # FIXME : This logic only works for trivial flux, breaks for nonzero flux
98
+ # ToDo: add assert or fix and add test!
108
99
sites[v] = [Index (Hflux => 1 )]
109
100
end
110
101
end
@@ -118,35 +109,65 @@ function ttn_svd(
118
109
# build compressed finite state machine representation
119
110
for v in vs
120
111
# for every vertex, find all edges that contain this vertex
121
- edges = filter (e -> dst (e) == v || src (e) == v, es)
112
+ edges = align_and_reorder_edges (incident_edges (sites, v), es)
113
+
122
114
# use the corresponding ordering as index order for tensor elements at this site
123
115
dim_in = findfirst (e -> dst (e) == v, edges)
124
116
edge_in = (isnothing (dim_in) ? [] : edges[dim_in])
125
117
dims_out = findall (e -> src (e) == v, edges)
126
118
edges_out = edges[dims_out]
127
119
120
+ # for every site w except v, determine the incident edge to v that lies
121
+ # in the edge_path(w,v)
122
+ subgraphs = split_at_vertex (sites, v)
123
+ _boundary_edges = align_edges (
124
+ [only (boundary_edges (underlying_graph (sites), subgraph)) for subgraph in subgraphs],
125
+ edges,
126
+ )
127
+ which_incident_edge = Dict (
128
+ Iterators. flatten ([
129
+ subgraphs[i] .=> ((_boundary_edges[i]),) for i in eachindex (subgraphs)
130
+ ]),
131
+ )
132
+
128
133
# sanity check, leaves only have single incoming or outgoing edge
129
134
@assert ! isempty (dims_out) || ! isnothing (dim_in)
130
135
(isempty (dims_out) || isnothing (dim_in)) && @assert is_leaf (sites, v)
131
136
132
137
for term in os
133
138
# loop over OpSum and pick out terms that act on current vertex
134
- crosses_vertex (term, sites, v) || continue
139
+
140
+ factors = ITensors. terms (term)
141
+ if v in ITensors. site .(factors)
142
+ crosses_vertex = true
143
+ else
144
+ crosses_vertex =
145
+ ! isone (
146
+ length (Set ([which_incident_edge[site] for site in ITensors. site .(factors)]))
147
+ )
148
+ end
149
+ # if term doesn't cross vertex, skip it
150
+ crosses_vertex || continue
151
+
152
+ # filter out factor that acts on current vertex
153
+ onsite = filter (t -> (ITensors. site (t) == v), factors)
154
+ not_onsite_factors = setdiff (factors, onsite)
135
155
136
156
# filter out factors that come in from the direction of the incoming edge
137
157
incoming = filter (
138
- t -> edge_in ∈ edge_path (sites, ITensors. site (t), v), ITensors . terms (term)
158
+ t -> which_incident_edge[ ITensors. site (t)] == edge_in, not_onsite_factors
139
159
)
160
+
140
161
# also store all non-incoming factors in standard order, used for channel merging
141
162
not_incoming = filter (
142
- t -> edge_in ∉ edge_path (sites, ITensors. site (t), v), ITensors. terms (term)
163
+ t -> (ITensors. site (t) == v) || which_incident_edge[ITensors. site (t)] != edge_in,
164
+ factors,
143
165
)
144
- # filter out factor that acts on current vertex
145
- onsite = filter (t -> (ITensors. site (t) == v), ITensors. terms (term))
166
+
146
167
# for every outgoing edge, filter out factors that go out along that edge
147
168
outgoing = Dict (
148
- e => filter (t -> e ∈ edge_path (sites, v, ITensors. site (t)), ITensors . terms (term))
149
- for e in edges_out
169
+ e => filter (t -> which_incident_edge[ ITensors. site (t)] == e, not_onsite_factors) for
170
+ e in edges_out
150
171
)
151
172
152
173
# compute QNs
@@ -246,7 +267,8 @@ function ttn_svd(
246
267
247
268
for v in vs
248
269
# redo the whole thing like before
249
- edges = filter (e -> dst (e) == v || src (e) == v, es)
270
+ # ToDo: use neighborhood instead of going through all edges, see above
271
+ edges = align_and_reorder_edges (incident_edges (sites, v), es)
250
272
dim_in = findfirst (e -> dst (e) == v, edges)
251
273
dims_out = findall (e -> src (e) == v, edges)
252
274
# slice isometries at this vertex
@@ -340,9 +362,10 @@ function ttn_svd(
340
362
if is_internal[v]
341
363
H[v] += iT
342
364
else
343
- if hasqns (iT)
344
- @assert flux (iT * Op) == Hflux
345
- end
365
+ # ToDo: Remove this assert since it seems to be costly
366
+ # if hasqns(iT)
367
+ # @assert flux(iT * Op) == Hflux
368
+ # end
346
369
H[v] += (iT * Op)
347
370
end
348
371
end
@@ -409,12 +432,24 @@ function computeSiteProd(sites::IndsNetwork{V,<:Index}, ops::Prod{Op})::ITensor
409
432
return T
410
433
end
411
434
435
+ function _default_vertex_ordering (g:: AbstractGraph , root_vertex)
436
+ return reverse (post_order_dfs_vertices (g, root_vertex))
437
+ end
438
+
439
+ function _default_edge_ordering (g:: AbstractGraph , root_vertex)
440
+ return reverse (reverse .(post_order_dfs_edges (g, root_vertex)))
441
+ end
442
+
412
443
# This is almost an exact copy of ITensors/src/opsum_to_mpo_generic:sorteachterm except for the site ordering being
413
444
# given via find_index_in_tree
414
445
# changed `isless_site` to use tree vertex ordering relative to root
415
446
function sorteachterm (os:: OpSum , sites:: IndsNetwork{V,<:Index} , root_vertex:: V ) where {V}
416
447
os = copy (os)
417
- findpos (op:: Op ) = find_index_in_tree (op, sites, root_vertex)
448
+
449
+ # linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
450
+ ordering = _default_vertex_ordering (sites, root_vertex)
451
+ site_positions = Dict (zip (ordering, 1 : length (ordering)))
452
+ findpos (op:: Op ) = site_positions[ITensors. site (op)]
418
453
isless_site (o1:: Op , o2:: Op ) = findpos (o1) < findpos (o2)
419
454
N = nv (sites)
420
455
for n in eachindex (os)
0 commit comments