Skip to content

Commit e920800

Browse files
authored
Optimize ttn_svd (#157)
1 parent ce7b3e4 commit e920800

File tree

2 files changed

+72
-37
lines changed

2 files changed

+72
-37
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ ITensorNetworksEinExprsExt = "EinExprs"
4646
AbstractTrees = "0.4.4"
4747
Combinatorics = "1"
4848
Compat = "3, 4"
49-
DataGraphs = "0.1.7"
49+
DataGraphs = "0.1.13"
5050
DataStructures = "0.18"
5151
Dictionaries = "0.4"
5252
Distributions = "0.25.86"

src/treetensornetworks/opsum_to_ttn.jl

Lines changed: 71 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,25 @@ using ITensors.NDTensors: Block, maxdim, nblocks, nnzblocks
66
using ITensors.Ops: Op, OpSum
77
using NamedGraphs: degrees, is_leaf, vertex_path
88
using StaticArrays: MVector
9-
9+
using NamedGraphs: boundary_edges
1010
# convert ITensors.OpSum to TreeTensorNetwork
1111

1212
#
1313
# Utility methods
1414
#
1515

16-
# linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
17-
function find_index_in_tree(site, g::AbstractGraph, root_vertex)
18-
ordering = reverse(post_order_dfs_vertices(g, root_vertex))
19-
return findfirst(x -> x == site, ordering)
20-
end
21-
function find_index_in_tree(o::Op, g::AbstractGraph, root_vertex)
22-
return find_index_in_tree(ITensors.site(o), g, root_vertex)
16+
function align_edges(edges, reference_edges)
17+
return intersect(Iterators.flatten(zip(edges, reverse.(edges))), reference_edges)
2318
end
2419

25-
# determine 'support' of product operator on tree graph
26-
function span(t::Scaled{C,Prod{Op}}, g::AbstractGraph) where {C}
27-
spn = eltype(g)[]
28-
nterms = length(t)
29-
for i in 1:nterms, j in i:nterms
30-
path = vertex_path(g, ITensors.site(t[i]), ITensors.site(t[j]))
31-
spn = union(spn, path)
32-
end
33-
return spn
20+
function align_and_reorder_edges(edges, reference_edges)
21+
return intersect(reference_edges, align_edges(edges, reference_edges))
3422
end
3523

36-
# determine whether an operator string crosses a given graph vertex
37-
function crosses_vertex(t::Scaled{C,Prod{Op}}, g::AbstractGraph, v) where {C}
38-
return v span(t, g)
24+
function split_at_vertex(g::AbstractGraph, v)
25+
g = copy(g)
26+
rem_vertex!(g, v)
27+
return Set.(connected_components(g))
3928
end
4029

4130
#
@@ -45,7 +34,7 @@ end
4534
"""
4635
ttn_svd(os::OpSum, sites::IndsNetwork, root_vertex, kwargs...)
4736
48-
Construct a dense TreeTensorNetwork from a symbolic OpSum representation of a
37+
Construct a TreeTensorNetwork from a symbolic OpSum representation of a
4938
Hamiltonian, compressing shared interaction channels.
5039
"""
5140
function ttn_svd(os::OpSum, sites::IndsNetwork, root_vertex; kwargs...)
@@ -71,9 +60,9 @@ function ttn_svd(
7160
thishasqns = any(v -> hasqns(sites[v]), vertices(sites))
7261

7362
# traverse tree outwards from root vertex
74-
vs = reverse(post_order_dfs_vertices(sites, root_vertex)) # store vertices in fixed ordering relative to root
63+
vs = _default_vertex_ordering(sites, root_vertex)
7564
# ToDo: Add check in ttn_svd that the ordering matches that of find_index_in_tree, which is used in sorteachterm #fermion-sign!
76-
es = reverse(reverse.(post_order_dfs_edges(sites, root_vertex))) # store edges in fixed ordering relative to root
65+
es = _default_edge_ordering(sites, root_vertex) # store edges in fixed ordering relative to root
7766
# some things to keep track of
7867
degrees = Dict(v => degree(sites, v) for v in vs) # rank of every TTN tensor in network
7968
Vs = Dict(e => Dict{QN,Matrix{coefficient_type}}() for e in es) # link isometries for SVD compression of TTN
@@ -105,6 +94,8 @@ function ttn_svd(
10594
for v in vs
10695
is_internal[v] = isempty(sites[v])
10796
if isempty(sites[v])
97+
# FIXME: This logic only works for trivial flux, breaks for nonzero flux
98+
# ToDo: add assert or fix and add test!
10899
sites[v] = [Index(Hflux => 1)]
109100
end
110101
end
@@ -118,35 +109,65 @@ function ttn_svd(
118109
# build compressed finite state machine representation
119110
for v in vs
120111
# for every vertex, find all edges that contain this vertex
121-
edges = filter(e -> dst(e) == v || src(e) == v, es)
112+
edges = align_and_reorder_edges(incident_edges(sites, v), es)
113+
122114
# use the corresponding ordering as index order for tensor elements at this site
123115
dim_in = findfirst(e -> dst(e) == v, edges)
124116
edge_in = (isnothing(dim_in) ? [] : edges[dim_in])
125117
dims_out = findall(e -> src(e) == v, edges)
126118
edges_out = edges[dims_out]
127119

120+
# for every site w except v, determine the incident edge to v that lies
121+
# in the edge_path(w,v)
122+
subgraphs = split_at_vertex(sites, v)
123+
_boundary_edges = align_edges(
124+
[only(boundary_edges(underlying_graph(sites), subgraph)) for subgraph in subgraphs],
125+
edges,
126+
)
127+
which_incident_edge = Dict(
128+
Iterators.flatten([
129+
subgraphs[i] .=> ((_boundary_edges[i]),) for i in eachindex(subgraphs)
130+
]),
131+
)
132+
128133
# sanity check, leaves only have single incoming or outgoing edge
129134
@assert !isempty(dims_out) || !isnothing(dim_in)
130135
(isempty(dims_out) || isnothing(dim_in)) && @assert is_leaf(sites, v)
131136

132137
for term in os
133138
# loop over OpSum and pick out terms that act on current vertex
134-
crosses_vertex(term, sites, v) || continue
139+
140+
factors = ITensors.terms(term)
141+
if v in ITensors.site.(factors)
142+
crosses_vertex = true
143+
else
144+
crosses_vertex =
145+
!isone(
146+
length(Set([which_incident_edge[site] for site in ITensors.site.(factors)]))
147+
)
148+
end
149+
#if term doesn't cross vertex, skip it
150+
crosses_vertex || continue
151+
152+
# filter out factor that acts on current vertex
153+
onsite = filter(t -> (ITensors.site(t) == v), factors)
154+
not_onsite_factors = setdiff(factors, onsite)
135155

136156
# filter out factors that come in from the direction of the incoming edge
137157
incoming = filter(
138-
t -> edge_in edge_path(sites, ITensors.site(t), v), ITensors.terms(term)
158+
t -> which_incident_edge[ITensors.site(t)] == edge_in, not_onsite_factors
139159
)
160+
140161
# also store all non-incoming factors in standard order, used for channel merging
141162
not_incoming = filter(
142-
t -> edge_in edge_path(sites, ITensors.site(t), v), ITensors.terms(term)
163+
t -> (ITensors.site(t) == v) || which_incident_edge[ITensors.site(t)] != edge_in,
164+
factors,
143165
)
144-
# filter out factor that acts on current vertex
145-
onsite = filter(t -> (ITensors.site(t) == v), ITensors.terms(term))
166+
146167
# for every outgoing edge, filter out factors that go out along that edge
147168
outgoing = Dict(
148-
e => filter(t -> e edge_path(sites, v, ITensors.site(t)), ITensors.terms(term))
149-
for e in edges_out
169+
e => filter(t -> which_incident_edge[ITensors.site(t)] == e, not_onsite_factors) for
170+
e in edges_out
150171
)
151172

152173
# compute QNs
@@ -246,7 +267,8 @@ function ttn_svd(
246267

247268
for v in vs
248269
# redo the whole thing like before
249-
edges = filter(e -> dst(e) == v || src(e) == v, es)
270+
# ToDo: use neighborhood instead of going through all edges, see above
271+
edges = align_and_reorder_edges(incident_edges(sites, v), es)
250272
dim_in = findfirst(e -> dst(e) == v, edges)
251273
dims_out = findall(e -> src(e) == v, edges)
252274
# slice isometries at this vertex
@@ -340,9 +362,10 @@ function ttn_svd(
340362
if is_internal[v]
341363
H[v] += iT
342364
else
343-
if hasqns(iT)
344-
@assert flux(iT * Op) == Hflux
345-
end
365+
#ToDo: Remove this assert since it seems to be costly
366+
#if hasqns(iT)
367+
# @assert flux(iT * Op) == Hflux
368+
#end
346369
H[v] += (iT * Op)
347370
end
348371
end
@@ -409,12 +432,24 @@ function computeSiteProd(sites::IndsNetwork{V,<:Index}, ops::Prod{Op})::ITensor
409432
return T
410433
end
411434

435+
function _default_vertex_ordering(g::AbstractGraph, root_vertex)
436+
return reverse(post_order_dfs_vertices(g, root_vertex))
437+
end
438+
439+
function _default_edge_ordering(g::AbstractGraph, root_vertex)
440+
return reverse(reverse.(post_order_dfs_edges(g, root_vertex)))
441+
end
442+
412443
# This is almost an exact copy of ITensors/src/opsum_to_mpo_generic:sorteachterm except for the site ordering being
413444
# given via find_index_in_tree
414445
# changed `isless_site` to use tree vertex ordering relative to root
415446
function sorteachterm(os::OpSum, sites::IndsNetwork{V,<:Index}, root_vertex::V) where {V}
416447
os = copy(os)
417-
findpos(op::Op) = find_index_in_tree(op, sites, root_vertex)
448+
449+
# linear ordering of vertices in tree graph relative to chosen root, chosen outward from root
450+
ordering = _default_vertex_ordering(sites, root_vertex)
451+
site_positions = Dict(zip(ordering, 1:length(ordering)))
452+
findpos(op::Op) = site_positions[ITensors.site(op)]
418453
isless_site(o1::Op, o2::Op) = findpos(o1) < findpos(o2)
419454
N = nv(sites)
420455
for n in eachindex(os)

0 commit comments

Comments
 (0)