diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 78eddcfa..0ca52a12 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -62,6 +62,7 @@ include("solvers/linsolve.jl") include("solvers/sweep_plans/sweep_plans.jl") include("apply.jl") include("inner.jl") +include("normalize.jl") include("expect.jl") include("environment.jl") include("exports.jl") diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index c8bb57bc..b5ea3ef9 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -935,6 +935,26 @@ function add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) return tn12 end +""" Scale each tensor of the network by a scale factor on each vertex""" +function scale!(tn::AbstractITensorNetwork, vertices_weights::Dictionary) + for v in keys(vertices_weights) + setindex_preserve_graph!(tn, vertices_weights[v] * tn[v], v) + end + return tn +end + +""" Scale each tensor of the network via a function (vertex, ITensor) -> Number""" +function scale!(tn::AbstractITensorNetwork, weight_function::Function) + vs = collect(vertices(tn)) + vertices_weights = Dictionary(vs, [weight_function(v, tn[v]) for v in vs]) + return scale!(tn, vertices_weights) +end + +function scale(tn, args...) + tn = copy(tn) + return scale!(tn, args...) +end + Base.:+(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) = add(tn1, tn2) ITensors.hasqns(tn::AbstractITensorNetwork) = any(v -> hasqns(tn[v]), vertices(tn)) diff --git a/src/caches/abstractbeliefpropagationcache.jl b/src/caches/abstractbeliefpropagationcache.jl index 55e1e9fe..ffd17b86 100644 --- a/src/caches/abstractbeliefpropagationcache.jl +++ b/src/caches/abstractbeliefpropagationcache.jl @@ -1,4 +1,4 @@ -using Graphs: IsDirected +using Graphs: Graphs, IsDirected using SplitApplyCombine: group using LinearAlgebra: diag, dot using ITensors: dir @@ -66,7 +66,7 @@ function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; k return not_implemented() end partitions(bpc::AbstractBeliefPropagationCache) = not_implemented() -partitionpairs(bpc::AbstractBeliefPropagationCache) = not_implemented() +PartitionedGraphs.partitionedges(bpc::AbstractBeliefPropagationCache) = not_implemented() function default_edge_sequence( bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc) @@ -88,6 +88,10 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache) return unpartitioned_graph(partitioned_tensornetwork(bpc)) end +function setindex_preserve_graph!(bpc::AbstractBeliefPropagationCache, args...) + return setindex_preserve_graph!(tensornetwork(bpc), args...) +end + function factors(bpc::AbstractBeliefPropagationCache, verts::Vector) return ITensor[tensornetwork(bpc)[v] for v in verts] end @@ -107,7 +111,7 @@ function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc) end function edge_scalars( - bpc::AbstractBeliefPropagationCache, pes=partitionpairs(bpc); kwargs... + bpc::AbstractBeliefPropagationCache, pes=partitionedges(bpc); kwargs... ) return map(pe -> region_scalar(bpc, pe; kwargs...), pes) end @@ -283,3 +287,79 @@ function update( ) return update(Algorithm(alg), bpc; kwargs...) end + +function scale!(bp_cache::AbstractBeliefPropagationCache, args...) + return scale!(tensornetwork(bp_cache), args...) +end + +function rescale_messages( + bp_cache::AbstractBeliefPropagationCache, partitionedge::PartitionEdge +) + return rescale_messages(bp_cache, [partitionedge]) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache) + return rescale_messages(bp_cache, partitionedges(bp_cache)) +end + +function rescale_partitions( + bpc::AbstractBeliefPropagationCache, + partitions::Vector; + verts_to_rescale::Vector=vertices(bpc, partitions), +) + bpc = copy(bpc) + tn = tensornetwork(bpc) + norms = map(v -> inv(norm(tn[v])), verts_to_rescale) + scale!(bpc, Dictionary(verts_to_rescale, norms)) + + vertices_weights = Dictionary() + for pv in partitions + pv_vs = filter(v -> v ∈ verts_to_rescale, vertices(bpc, pv)) + isempty(pv_vs) && continue + + vn = region_scalar(bpc, pv) + s = isreal(vn) ? sign(vn) : 1.0 + vn = s * inv(vn^(1 / length(pv_vs))) + set!(vertices_weights, first(pv_vs), s*vn) + for v in pv_vs[2:length(pv_vs)] + set!(vertices_weights, v, vn) + end + end + + scale!(bpc, vertices_weights) + + return bpc +end + +function rescale_partitions(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + return rescale_partitions(bpc, collect(partitions(bpc)), args...; kwargs...) +end + +function rescale_partition( + bpc::AbstractBeliefPropagationCache, partition, args...; kwargs... +) + return rescale_partitions(bpc, [partition], args...; kwargs...) +end + +function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + bpc = rescale_messages(bpc) + bpc = rescale_partitions(bpc, args...; kwargs...) + return bpc +end + +function logscalar(bpc::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bpc) + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + any(iszero, denominator_terms) && return -Inf + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end + +function ITensors.scalar(bpc::AbstractBeliefPropagationCache) + return exp(logscalar(bpc)) +end diff --git a/src/caches/beliefpropagationcache.jl b/src/caches/beliefpropagationcache.jl index 2c0a23a5..92fd7f2c 100644 --- a/src/caches/beliefpropagationcache.jl +++ b/src/caches/beliefpropagationcache.jl @@ -9,7 +9,9 @@ using NamedGraphs.PartitionedGraphs: boundary_partitionedges, partitionvertices, partitionedges, - unpartitioned_graph + partitioned_vertices, + unpartitioned_graph, + which_partition using SimpleTraits: SimpleTraits, Not, @traitfn using NDTensors: NDTensors @@ -80,7 +82,9 @@ function default_message_update_kwargs( end partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc)) -partitionpairs(bpc::BeliefPropagationCache) = partitionedges(partitioned_tensornetwork(bpc)) +function PartitionedGraphs.partitionedges(bpc::BeliefPropagationCache) + partitionedges(partitioned_tensornetwork(bpc)) +end function set_messages(cache::BeliefPropagationCache, messages) return BeliefPropagationCache(partitioned_tensornetwork(cache), messages) @@ -106,3 +110,23 @@ function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge) sequence = contraction_sequence(ts; alg="optimal") return contract(ts; sequence)[] end + +function rescale_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:PartitionEdge}) + bp_cache = copy(bp_cache) + mts = messages(bp_cache) + for pe in pes + me, mer = normalize.(mts[pe]), normalize.(mts[reverse(pe)]) + set!(mts, pe, me) + set!(mts, reverse(pe), mer) + n = region_scalar(bp_cache, pe) + if isreal(n) + me[1] *= sign(n) + n *= sign(n) + end + + sf = (1 / sqrt(n)) ^ (1 / length(me)) + set!(mts, pe, sf .* me) + set!(mts, reverse(pe), sf .* mer) + end + return bp_cache +end diff --git a/src/contract.jl b/src/contract.jl index 016a7fd6..3fe8a915 100644 --- a/src/contract.jl +++ b/src/contract.jl @@ -63,17 +63,7 @@ function logscalar( cache![] = update(cache![]; cache_update_kwargs...) end - numerator_terms, denominator_terms = scalar_factors_quotient(cache![]) - numerator_terms = - any(t -> real(t) < 0, numerator_terms) ? complex.(numerator_terms) : numerator_terms - denominator_terms = if any(t -> real(t) < 0, denominator_terms) - complex.(denominator_terms) - else - denominator_terms - end - - any(iszero, denominator_terms) && return -Inf - return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) + return logscalar(cache![]) end function ITensors.scalar(alg::Algorithm, tn::AbstractITensorNetwork; kwargs...) diff --git a/src/expect.jl b/src/expect.jl index 9ac804b9..47636d73 100644 --- a/src/expect.jl +++ b/src/expect.jl @@ -26,7 +26,7 @@ function expect( (cache!)=nothing, update_cache=isnothing(cache!), cache_update_kwargs=default_cache_update_kwargs(alg), - cache_construction_kwargs=default_cache_construction_kwargs(alg, QuadraticFormNetwork(ψ)), + cache_construction_kwargs=(;), kwargs..., ) ψIψ = QuadraticFormNetwork(ψ) diff --git a/src/normalize.jl b/src/normalize.jl new file mode 100644 index 00000000..b6a0313d --- /dev/null +++ b/src/normalize.jl @@ -0,0 +1,76 @@ +using LinearAlgebra + +function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...) + return rescale(Algorithm(alg), tn; kwargs...) +end + +function rescale(alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...) + logn = logscalar(alg, tn; kwargs...) + vs = collect(vertices(tn)) + c = inv(exp(logn / length(vs))) + vertices_weights = Dictionary(vs, [c for v in vs]) + return scale(tn, vertices_weights) +end + +function rescale( + alg::Algorithm, + tn::AbstractITensorNetwork, + args...; + (cache!)=nothing, + cache_construction_kwargs=default_cache_construction_kwargs(alg, tn), + update_cache=isnothing(cache!), + cache_update_kwargs=default_cache_update_kwargs(alg), + kwargs..., +) + if isnothing(cache!) + cache! = Ref(cache(alg, tn; cache_construction_kwargs...)) + end + + if update_cache + cache![] = update(cache![]; cache_update_kwargs...) + end + + cache![] = rescale(cache![], args...; kwargs...) + + return tensornetwork(cache![]) +end + +function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...) + return normalize(Algorithm(alg), tn; kwargs...) +end + +function LinearAlgebra.normalize( + alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs... +) + logn = logscalar(alg, inner_network(tn, tn); kwargs...) + vs = collect(vertices(tn)) + c = inv(exp(logn / (2*length(vs)))) + vertices_weights = Dictionary(vs, [c for v in vs]) + return scale(tn, vertices_weights) +end + +function LinearAlgebra.normalize( + alg::Algorithm, + tn::AbstractITensorNetwork; + (cache!)=nothing, + cache_construction_function=tn -> + cache(alg, tn; default_cache_construction_kwargs(alg, tn)...), + update_cache=isnothing(cache!), + cache_update_kwargs=default_cache_update_kwargs(alg), + cache_construction_kwargs=(;), +) + norm_tn = inner_network(tn, tn) + if isnothing(cache!) + cache! = Ref(cache(alg, norm_tn; cache_construction_kwargs...)) + end + + vs = collect(vertices(tn)) + verts_to_rescale = vcat( + [ket_vertex(norm_tn, v) for v in vs], [bra_vertex(norm_tn, v) for v in vs] + ) + norm_tn = rescale( + alg, norm_tn; verts_to_rescale, cache!, update_cache, cache_update_kwargs + ) + + return ket_network(norm_tn) +end diff --git a/test/test_belief_propagation.jl b/test/test_belief_propagation.jl index 76e49ebd..5877dd45 100644 --- a/test/test_belief_propagation.jl +++ b/test/test_belief_propagation.jl @@ -50,7 +50,7 @@ using Test: @test, @testset bpc = BeliefPropagationCache(ψψ, group(v -> first(v), vertices(ψψ))) bpc = update(bpc; maxiter=25, tol=eps(real(elt))) #Test messages are converged - for pe in partitionedges(partitioned_tensornetwork(bpc)) + for pe in partitionedges(bpc) @test message_diff(updated_message(bpc, pe), message(bpc, pe)) < 10 * eps(real(elt)) @test eltype(only(message(bpc, pe))) == elt end diff --git a/test/test_normalize.jl b/test/test_normalize.jl new file mode 100644 index 00000000..1c30c20c --- /dev/null +++ b/test/test_normalize.jl @@ -0,0 +1,54 @@ +@eval module $(gensym()) +using ITensorNetworks: + BeliefPropagationCache, + QuadraticFormNetwork, + edge_scalars, + norm_sqr_network, + random_tensornetwork, + siteinds, + vertex_scalars, + rescale +using ITensors: dag, inner, scalar +using Graphs: SimpleGraph, uniform_tree +using LinearAlgebra: normalize +using NamedGraphs: NamedGraph +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree +using StableRNGs: StableRNG +using TensorOperations: TensorOperations +using Test: @test, @testset +@testset "Normalize" begin + + #First lets do a flat tree + nx, ny = 2, 3 + χ = 2 + rng = StableRNG(1234) + + g = named_comb_tree((nx, ny)) + tn = random_tensornetwork(rng, g; link_space=χ) + + tn_r = rescale(tn; alg="exact") + @test scalar(tn_r; alg="exact") ≈ 1.0 + + tn_r = rescale(tn; alg="bp") + @test scalar(tn_r; alg="exact") ≈ 1.0 + + #Now a state on a loopy graph + Lx, Ly = 3, 2 + χ = 2 + rng = StableRNG(1234) + + g = named_grid((Lx, Ly)) + s = siteinds("S=1/2", g) + x = random_tensornetwork(rng, s; link_space=χ) + + ψ = normalize(x; alg="exact") + @test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0 + + ψIψ_bpc = Ref(BeliefPropagationCache(QuadraticFormNetwork(x))) + ψ = normalize(x; alg="bp", (cache!)=ψIψ_bpc, update_cache=true) + ψIψ_bpc = ψIψ_bpc[] + @test all(x -> x ≈ 1.0, edge_scalars(ψIψ_bpc)) + @test all(x -> x ≈ 1.0, vertex_scalars(ψIψ_bpc)) + @test scalar(QuadraticFormNetwork(ψ); alg="bp") ≈ 1.0 +end +end