Skip to content

Commit 8453430

Browse files
JoeyT1994mtfishman
andauthored
Normalize (#192)
Co-authored-by: Matt Fishman <[email protected]>
1 parent 8f47ae4 commit 8453430

10 files changed

+272
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
4-
version = "0.13.6"
4+
version = "0.13.7"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/ITensorNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ include("solvers/linsolve.jl")
6262
include("solvers/sweep_plans/sweep_plans.jl")
6363
include("apply.jl")
6464
include("inner.jl")
65+
include("normalize.jl")
6566
include("expect.jl")
6667
include("environment.jl")
6768
include("exports.jl")

src/abstractitensornetwork.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ function map_vertex_data(f, tn::AbstractITensorNetwork)
399399
return tn
400400
end
401401

402-
# TODO: Define `@preserve_graph map_vertex_data(f, tn)`
402+
# TODO: Define @preserve_graph map_vertex_data(f, tn)`
403403
function map_vertex_data_preserve_graph(f, tn::AbstractITensorNetwork)
404404
tn = copy(tn)
405405
for v in vertices(tn)
@@ -408,6 +408,13 @@ function map_vertex_data_preserve_graph(f, tn::AbstractITensorNetwork)
408408
return tn
409409
end
410410

411+
function map_vertices_preserve_graph!(f, tn::AbstractITensorNetwork; vertices=vertices(tn))
412+
for v in vertices
413+
@preserve_graph tn[v] = f(v)
414+
end
415+
return tn
416+
end
417+
411418
function Base.conj(tn::AbstractITensorNetwork)
412419
# TODO: Use `@preserve_graph map_vertex_data(f, tn)`
413420
return map_vertex_data_preserve_graph(conj, tn)
@@ -935,6 +942,30 @@ function add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
935942
return tn12
936943
end
937944

945+
""" Scale each tensor of the network via a function vertex -> Number"""
946+
function scale!(
947+
weight_function::Function,
948+
tn::AbstractITensorNetwork;
949+
vertices=collect(Graphs.vertices(tn)),
950+
)
951+
return map_vertices_preserve_graph!(v -> weight_function(v) * tn[v], tn; vertices)
952+
end
953+
954+
""" Scale each tensor of the network by a scale factor for each vertex in the keys of the dictionary"""
955+
function scale!(tn::AbstractITensorNetwork, vertices_weights::Dictionary)
956+
return scale!(v -> vertices_weights[v], tn; vertices=keys(vertices_weights))
957+
end
958+
959+
function scale(weight_function::Function, tn; kwargs...)
960+
tn = copy(tn)
961+
return scale!(weight_function, tn; kwargs...)
962+
end
963+
964+
function scale(tn::AbstractITensorNetwork, vertices_weights::Dictionary; kwargs...)
965+
tn = copy(tn)
966+
return scale!(tn, vertices_weights; kwargs...)
967+
end
968+
938969
Base.:+(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) = add(tn1, tn2)
939970

940971
ITensors.hasqns(tn::AbstractITensorNetwork) = any(v -> hasqns(tn[v]), vertices(tn))

src/caches/abstractbeliefpropagationcache.jl

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Graphs: IsDirected
1+
using Graphs: Graphs, IsDirected
22
using SplitApplyCombine: group
33
using LinearAlgebra: diag, dot
44
using ITensors: dir
@@ -66,7 +66,7 @@ function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; k
6666
return not_implemented()
6767
end
6868
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
69-
partitionpairs(bpc::AbstractBeliefPropagationCache) = not_implemented()
69+
PartitionedGraphs.partitionedges(bpc::AbstractBeliefPropagationCache) = not_implemented()
7070

7171
function default_edge_sequence(
7272
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
@@ -88,6 +88,10 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache)
8888
return unpartitioned_graph(partitioned_tensornetwork(bpc))
8989
end
9090

91+
function setindex_preserve_graph!(bpc::AbstractBeliefPropagationCache, args...)
92+
return setindex_preserve_graph!(tensornetwork(bpc), args...)
93+
end
94+
9195
function factors(bpc::AbstractBeliefPropagationCache, verts::Vector)
9296
return ITensor[tensornetwork(bpc)[v] for v in verts]
9397
end
@@ -107,7 +111,7 @@ function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc)
107111
end
108112

109113
function edge_scalars(
110-
bpc::AbstractBeliefPropagationCache, pes=partitionpairs(bpc); kwargs...
114+
bpc::AbstractBeliefPropagationCache, pes=partitionedges(bpc); kwargs...
111115
)
112116
return map(pe -> region_scalar(bpc, pe; kwargs...), pes)
113117
end
@@ -283,3 +287,79 @@ function update(
283287
)
284288
return update(Algorithm(alg), bpc; kwargs...)
285289
end
290+
291+
function scale!(bp_cache::AbstractBeliefPropagationCache, args...)
292+
return scale!(tensornetwork(bp_cache), args...)
293+
end
294+
295+
function rescale_messages(
296+
bp_cache::AbstractBeliefPropagationCache, partitionedge::PartitionEdge
297+
)
298+
return rescale_messages(bp_cache, [partitionedge])
299+
end
300+
301+
function rescale_messages(bp_cache::AbstractBeliefPropagationCache)
302+
return rescale_messages(bp_cache, partitionedges(bp_cache))
303+
end
304+
305+
function rescale_partitions(
306+
bpc::AbstractBeliefPropagationCache,
307+
partitions::Vector;
308+
verts::Vector=vertices(bpc, partitions),
309+
)
310+
bpc = copy(bpc)
311+
tn = tensornetwork(bpc)
312+
norms = map(v -> inv(norm(tn[v])), verts)
313+
scale!(bpc, Dictionary(verts, norms))
314+
315+
vertices_weights = Dictionary()
316+
for pv in partitions
317+
pv_vs = filter(v -> v verts, vertices(bpc, pv))
318+
isempty(pv_vs) && continue
319+
320+
vn = region_scalar(bpc, pv)
321+
s = isreal(vn) ? sign(vn) : 1.0
322+
vn = s * inv(vn^(1 / length(pv_vs)))
323+
set!(vertices_weights, first(pv_vs), s*vn)
324+
for v in pv_vs[2:length(pv_vs)]
325+
set!(vertices_weights, v, vn)
326+
end
327+
end
328+
329+
scale!(bpc, vertices_weights)
330+
331+
return bpc
332+
end
333+
334+
function rescale_partitions(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
335+
return rescale_partitions(bpc, collect(partitions(bpc)), args...; kwargs...)
336+
end
337+
338+
function rescale_partition(
339+
bpc::AbstractBeliefPropagationCache, partition, args...; kwargs...
340+
)
341+
return rescale_partitions(bpc, [partition], args...; kwargs...)
342+
end
343+
344+
function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
345+
bpc = rescale_messages(bpc)
346+
bpc = rescale_partitions(bpc, args...; kwargs...)
347+
return bpc
348+
end
349+
350+
function logscalar(bpc::AbstractBeliefPropagationCache)
351+
numerator_terms, denominator_terms = scalar_factors_quotient(bpc)
352+
if any(t -> real(t) < 0, numerator_terms)
353+
numerator_terms = complex.(numerator_terms)
354+
end
355+
if any(t -> real(t) < 0, denominator_terms)
356+
denominator_terms = complex.(denominator_terms)
357+
end
358+
359+
any(iszero, denominator_terms) && return -Inf
360+
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
361+
end
362+
363+
function ITensors.scalar(bpc::AbstractBeliefPropagationCache)
364+
return exp(logscalar(bpc))
365+
end

src/caches/beliefpropagationcache.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ using NamedGraphs.PartitionedGraphs:
99
boundary_partitionedges,
1010
partitionvertices,
1111
partitionedges,
12-
unpartitioned_graph
12+
partitioned_vertices,
13+
unpartitioned_graph,
14+
which_partition
1315
using SimpleTraits: SimpleTraits, Not, @traitfn
1416
using NDTensors: NDTensors
1517

@@ -80,7 +82,9 @@ function default_message_update_kwargs(
8082
end
8183

8284
partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc))
83-
partitionpairs(bpc::BeliefPropagationCache) = partitionedges(partitioned_tensornetwork(bpc))
85+
function PartitionedGraphs.partitionedges(bpc::BeliefPropagationCache)
86+
partitionedges(partitioned_tensornetwork(bpc))
87+
end
8488

8589
function set_messages(cache::BeliefPropagationCache, messages)
8690
return BeliefPropagationCache(partitioned_tensornetwork(cache), messages)
@@ -106,3 +110,23 @@ function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
106110
sequence = contraction_sequence(ts; alg="optimal")
107111
return contract(ts; sequence)[]
108112
end
113+
114+
function rescale_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:PartitionEdge})
115+
bp_cache = copy(bp_cache)
116+
mts = messages(bp_cache)
117+
for pe in pes
118+
me, mer = normalize.(mts[pe]), normalize.(mts[reverse(pe)])
119+
set!(mts, pe, me)
120+
set!(mts, reverse(pe), mer)
121+
n = region_scalar(bp_cache, pe)
122+
if isreal(n)
123+
me[1] *= sign(n)
124+
n *= sign(n)
125+
end
126+
127+
sf = (1 / sqrt(n)) ^ (1 / length(me))
128+
set!(mts, pe, sf .* me)
129+
set!(mts, reverse(pe), sf .* mer)
130+
end
131+
return bp_cache
132+
end

src/contract.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,7 @@ function logscalar(
6363
cache![] = update(cache![]; cache_update_kwargs...)
6464
end
6565

66-
numerator_terms, denominator_terms = scalar_factors_quotient(cache![])
67-
numerator_terms =
68-
any(t -> real(t) < 0, numerator_terms) ? complex.(numerator_terms) : numerator_terms
69-
denominator_terms = if any(t -> real(t) < 0, denominator_terms)
70-
complex.(denominator_terms)
71-
else
72-
denominator_terms
73-
end
74-
75-
any(iszero, denominator_terms) && return -Inf
76-
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
66+
return logscalar(cache![])
7767
end
7868

7969
function ITensors.scalar(alg::Algorithm, tn::AbstractITensorNetwork; kwargs...)

src/expect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function expect(
2626
(cache!)=nothing,
2727
update_cache=isnothing(cache!),
2828
cache_update_kwargs=default_cache_update_kwargs(alg),
29-
cache_construction_kwargs=default_cache_construction_kwargs(alg, QuadraticFormNetwork(ψ)),
29+
cache_construction_kwargs=(;),
3030
kwargs...,
3131
)
3232
ψIψ = QuadraticFormNetwork(ψ)

src/normalize.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using LinearAlgebra: normalize
2+
3+
function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...)
4+
return rescale(Algorithm(alg), tn; kwargs...)
5+
end
6+
7+
function rescale(alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...)
8+
logn = logscalar(alg, tn; kwargs...)
9+
vs = collect(vertices(tn))
10+
c = inv(exp(logn / length(vs)))
11+
vertices_weights = Dictionary(vs, [c for v in vs])
12+
return scale(tn, vertices_weights)
13+
end
14+
15+
function rescale(
16+
alg::Algorithm,
17+
tn::AbstractITensorNetwork,
18+
args...;
19+
(cache!)=nothing,
20+
cache_construction_kwargs=default_cache_construction_kwargs(alg, tn),
21+
update_cache=isnothing(cache!),
22+
cache_update_kwargs=default_cache_update_kwargs(alg),
23+
kwargs...,
24+
)
25+
if isnothing(cache!)
26+
cache! = Ref(cache(alg, tn; cache_construction_kwargs...))
27+
end
28+
29+
if update_cache
30+
cache![] = update(cache![]; cache_update_kwargs...)
31+
end
32+
33+
cache![] = rescale(cache![], args...; kwargs...)
34+
35+
return tensornetwork(cache![])
36+
end
37+
38+
function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...)
39+
return normalize(Algorithm(alg), tn; kwargs...)
40+
end
41+
42+
function LinearAlgebra.normalize(
43+
alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...
44+
)
45+
logn = logscalar(alg, inner_network(tn, tn); kwargs...)
46+
vs = collect(vertices(tn))
47+
c = inv(exp(logn / (2*length(vs))))
48+
vertices_weights = Dictionary(vs, [c for v in vs])
49+
return scale(tn, vertices_weights)
50+
end
51+
52+
function LinearAlgebra.normalize(
53+
alg::Algorithm,
54+
tn::AbstractITensorNetwork;
55+
(cache!)=nothing,
56+
cache_construction_function=tn ->
57+
cache(alg, tn; default_cache_construction_kwargs(alg, tn)...),
58+
update_cache=isnothing(cache!),
59+
cache_update_kwargs=default_cache_update_kwargs(alg),
60+
cache_construction_kwargs=(;),
61+
)
62+
norm_tn = inner_network(tn, tn)
63+
if isnothing(cache!)
64+
cache! = Ref(cache(alg, norm_tn; cache_construction_kwargs...))
65+
end
66+
67+
vs = collect(vertices(tn))
68+
verts = vcat([ket_vertex(norm_tn, v) for v in vs], [bra_vertex(norm_tn, v) for v in vs])
69+
norm_tn = rescale(alg, norm_tn; verts, cache!, update_cache, cache_update_kwargs)
70+
71+
return ket_network(norm_tn)
72+
end

test/test_belief_propagation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ using Test: @test, @testset
5050
bpc = BeliefPropagationCache(ψψ, group(v -> first(v), vertices(ψψ)))
5151
bpc = update(bpc; maxiter=25, tol=eps(real(elt)))
5252
#Test messages are converged
53-
for pe in partitionedges(partitioned_tensornetwork(bpc))
53+
for pe in partitionedges(bpc)
5454
@test message_diff(updated_message(bpc, pe), message(bpc, pe)) < 10 * eps(real(elt))
5555
@test eltype(only(message(bpc, pe))) == elt
5656
end

test/test_normalize.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
@eval module $(gensym())
2+
using ITensorNetworks:
3+
BeliefPropagationCache,
4+
QuadraticFormNetwork,
5+
edge_scalars,
6+
norm_sqr_network,
7+
random_tensornetwork,
8+
siteinds,
9+
vertex_scalars,
10+
rescale
11+
using ITensors: dag, inner, scalar
12+
using Graphs: SimpleGraph, uniform_tree
13+
using LinearAlgebra: normalize
14+
using NamedGraphs: NamedGraph
15+
using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree
16+
using StableRNGs: StableRNG
17+
using TensorOperations: TensorOperations
18+
using Test: @test, @testset
19+
@testset "Normalize" begin
20+
21+
#First lets do a flat tree
22+
nx, ny = 2, 3
23+
χ = 2
24+
rng = StableRNG(1234)
25+
26+
g = named_comb_tree((nx, ny))
27+
tn = random_tensornetwork(rng, g; link_space=χ)
28+
29+
tn_r = rescale(tn; alg="exact")
30+
@test scalar(tn_r; alg="exact") 1.0
31+
32+
tn_r = rescale(tn; alg="bp")
33+
@test scalar(tn_r; alg="exact") 1.0
34+
35+
#Now a state on a loopy graph
36+
Lx, Ly = 3, 2
37+
χ = 2
38+
rng = StableRNG(1234)
39+
40+
g = named_grid((Lx, Ly))
41+
s = siteinds("S=1/2", g)
42+
x = random_tensornetwork(rng, s; link_space=χ)
43+
44+
ψ = normalize(x; alg="exact")
45+
@test scalar(norm_sqr_network(ψ); alg="exact") 1.0
46+
47+
ψIψ_bpc = Ref(BeliefPropagationCache(QuadraticFormNetwork(x)))
48+
ψ = normalize(x; alg="bp", (cache!)=ψIψ_bpc, update_cache=true)
49+
ψIψ_bpc = ψIψ_bpc[]
50+
@test all(x -> x 1.0, edge_scalars(ψIψ_bpc))
51+
@test all(x -> x 1.0, vertex_scalars(ψIψ_bpc))
52+
@test scalar(QuadraticFormNetwork(ψ); alg="bp") 1.0
53+
end
54+
end

0 commit comments

Comments
 (0)