Skip to content

Commit

Permalink
Fix BP convergence metric for complex networks (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 authored Jun 25, 2024
1 parent bb50333 commit dec464f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 54 deletions.
21 changes: 11 additions & 10 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Graphs: IsDirected
using SplitApplyCombine: group
using LinearAlgebra: diag
using LinearAlgebra: diag, dot
using ITensors: dir
using ITensorMPS: ITensorMPS
using NamedGraphs.PartitionedGraphs:
Expand All @@ -12,10 +12,10 @@ using NamedGraphs.PartitionedGraphs:
partitionedges,
unpartitioned_graph
using SimpleTraits: SimpleTraits, Not, @traitfn
using NDTensors: NDTensors

default_message(inds_e) = ITensor[denseblocks(delta(i)) for i in inds_e]
default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e]
default_messages(ptn::PartitionedGraph) = Dictionary()
default_message_norm(m::ITensor) = norm(m)
function default_message_update(contract_list::Vector{ITensor}; kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
Expand All @@ -33,17 +33,16 @@ default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertice
function default_partitioned_vertices(f::AbstractFormNetwork)
return group(v -> original_state_vertex(f, v), vertices(f))
end
default_cache_update_kwargs(cache) = (; maxiter=20, tol=1e-5)
default_cache_update_kwargs(cache) = (; maxiter=25, tol=1e-8)
function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork)
return (; partitioned_vertices=default_partitioned_vertices(ψ))
end

function message_diff(
message_a::Vector{ITensor}, message_b::Vector{ITensor}; message_norm=default_message_norm
)
#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
lhs, rhs = contract(message_a), contract(message_b)
norm_lhs, norm_rhs = message_norm(lhs), message_norm(rhs)
return 0.5 * norm((denseblocks(lhs) / norm_lhs) - (denseblocks(rhs) / norm_rhs))
f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs)))
return 1 - f
end

struct BeliefPropagationCache{PTN,MTS,DM}
Expand Down Expand Up @@ -99,8 +98,10 @@ for f in [
end
end

NDTensors.scalartype(bp_cache) = scalartype(tensornetwork(bp_cache))

function default_message(bp_cache::BeliefPropagationCache, edge::PartitionEdge)
return default_message(bp_cache)(linkinds(bp_cache, edge))
return default_message(bp_cache)(scalartype(bp_cache), linkinds(bp_cache, edge))
end

function message(bp_cache::BeliefPropagationCache, edge::PartitionEdge)
Expand Down
87 changes: 47 additions & 40 deletions test/test_belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ using ITensorNetworks:
tensornetwork,
update,
update_factor,
update_message
update_message,
message_diff
using ITensors: ITensors, ITensor, combiner, dag, inds, inner, op, prime, random_itensor
using ITensorNetworks.ModelNetworks: ModelNetworks
using ITensors.NDTensors: array
Expand All @@ -34,50 +35,56 @@ using NamedGraphs.PartitionedGraphs: PartitionVertex, partitionedges
using SplitApplyCombine: group
using StableRNGs: StableRNG
using Test: @test, @testset
@testset "belief_propagation" begin
ITensors.disable_warn_order()
g = named_grid((3, 3))
s = siteinds("S=1/2", g)
χ = 2
rng = StableRNG(1234)
ψ = random_tensornetwork(rng, s; link_space=χ)
ψψ = ψ prime(dag(ψ); sites=[])
bpc = BeliefPropagationCache(ψψ)
bpc = update(bpc; maxiter=50, tol=1e-10)
#Test messages are converged
for pe in partitionedges(partitioned_tensornetwork(bpc))
@test update_message(bpc, pe) message(bpc, pe) atol = 1e-8
end
#Test updating the underlying tensornetwork in the cache
v = first(vertices(ψψ))
rng = StableRNG(1234)
new_tensor = random_itensor(rng, inds(ψψ[v]))
bpc_updated = update_factor(bpc, v, new_tensor)
ψψ_updated = tensornetwork(bpc_updated)
@test ψψ_updated[v] == new_tensor

#Test forming a two-site RDM. Check it has the correct size, trace 1 and is PSD
vs = [(2, 2), (2, 3)]
@testset "belief_propagation (eltype=$elt)" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
)
begin
ITensors.disable_warn_order()
g = named_grid((3, 3))
s = siteinds("S=1/2", g)
χ = 2
rng = StableRNG(1234)
ψ = random_tensornetwork(rng, elt, s; link_space=χ)
ψψ = ψ prime(dag(ψ); sites=[])
bpc = BeliefPropagationCache(ψψ, group(v -> first(v), vertices(ψψ)))
bpc = update(bpc; maxiter=25, tol=eps(real(elt)))
#Test messages are converged
for pe in partitionedges(partitioned_tensornetwork(bpc))
@test message_diff(update_message(bpc, pe), message(bpc, pe)) < 10 * eps(real(elt))
@test eltype(only(message(bpc, pe))) == elt
end
#Test updating the underlying tensornetwork in the cache
v = first(vertices(ψψ))
rng = StableRNG(1234)
new_tensor = random_itensor(rng, inds(ψψ[v]))
bpc_updated = update_factor(bpc, v, new_tensor)
ψψ_updated = tensornetwork(bpc_updated)
@test ψψ_updated[v] == new_tensor

#Test forming a two-site RDM. Check it has the correct size, trace 1 and is PSD
vs = [(2, 2), (2, 3)]

ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
env_tensors = environment(bpc, [(v, 2) for v in vs])
rdm = contract(vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]]))
ψψsplit = split_index(ψψ, NamedEdge.([(v, 1) => (v, 2) for v in vs]))
env_tensors = environment(bpc, [(v, 2) for v in vs])
rdm = contract(vcat(env_tensors, ITensor[ψψsplit[vp] for vp in [(v, 2) for v in vs]]))

rdm = array((rdm * combiner(inds(rdm; plev=0)...)) * combiner(inds(rdm; plev=1)...))
rdm /= tr(rdm)
rdm = array((rdm * combiner(inds(rdm; plev=0)...)) * combiner(inds(rdm; plev=1)...))
rdm /= tr(rdm)

eigs = eigvals(rdm)
@test size(rdm) == (2^length(vs), 2^length(vs))
eigs = eigvals(rdm)
@test size(rdm) == (2^length(vs), 2^length(vs))

@test all(eig -> imag(eig) 0, eigs)
@test all(eig -> real(eig) >= -eps(eltype(eig)), eigs)
@test all(eig -> abs(imag(eig)) <= eps(real(elt)), eigs)
@test all(eig -> real(eig) >= -eps(real(elt)), eigs)

#Test edge case of network which evalutes to 0
χ = 2
g = named_grid((3, 1))
rng = StableRNG(1234)
ψ = random_tensornetwork(rng, ComplexF64, g; link_space=χ)
ψ[(1, 1)] = 0.0 * ψ[(1, 1)]
@test iszero(scalar(ψ; alg="bp"))
#Test edge case of network which evalutes to 0
χ = 2
g = named_grid((3, 1))
rng = StableRNG(1234)
ψ = random_tensornetwork(rng, elt, g; link_space=χ)
ψ[(1, 1)] = 0 * ψ[(1, 1)]
@test iszero(scalar(ψ; alg="bp"))
end
end
end
6 changes: 2 additions & 4 deletions test/test_gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ using Test: @test, @testset
ψ = random_tensornetwork(rng, s; link_space=χ)

# Move directly to vidal gauge
ψ_vidal = VidalITensorNetwork(
ψ; cache_update_kwargs=(; maxiter=20, tol=1e-12, verbose=true)
)
ψ_vidal = VidalITensorNetwork(ψ; cache_update_kwargs=(; maxiter=30, verbose=true))
@test gauge_error(ψ_vidal) < 1e-8

# Move to symmetric gauge
Expand All @@ -38,7 +36,7 @@ using Test: @test, @testset
bp_cache = cache_ref[]

# Test we just did a gauge transform and didn't change the overall network
@test inner(ψ_symm, ψ) / sqrt(inner(ψ_symm, ψ_symm) * inner(ψ, ψ)) 1.0
@test inner(ψ_symm, ψ) / sqrt(inner(ψ_symm, ψ_symm) * inner(ψ, ψ)) 1.0 atol = 1e-8

#Test all message tensors are approximately diagonal even when we keep running BP
bp_cache = update(bp_cache; maxiter=10)
Expand Down

0 comments on commit dec464f

Please sign in to comment.