Skip to content

Normalize #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 89 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
ec7ec3b
New BP alternating update
JoeyT1994 May 6, 2024
bd05519
Working BP DMRG Solver
JoeyT1994 May 9, 2024
e116388
Merge remote-tracking branch 'upstream/main' into bp_alternating_update
JoeyT1994 May 9, 2024
cd2b139
New Changes
JoeyT1994 May 14, 2024
6391bfa
Merge remote-tracking branch 'upstream/main' into bp_alternating_update
JoeyT1994 May 14, 2024
fa91e7c
Merge remote-tracking branch 'upstream/main' into bp_alternating_update
JoeyT1994 May 15, 2024
201882a
Small changes
JoeyT1994 May 16, 2024
7228fb5
Changes
JoeyT1994 May 31, 2024
75d0c3b
Utils additions
JoeyT1994 May 31, 2024
c90139b
More stuff
JoeyT1994 Jun 2, 2024
e87e1b3
Big Improvements
JoeyT1994 Jun 7, 2024
8d780a8
Refactor code
JoeyT1994 Jun 7, 2024
e62ae0f
Save stuff
JoeyT1994 Jun 11, 2024
371492d
Commit 1
JoeyT1994 Jun 12, 2024
5138e51
Changes
JoeyT1994 Jun 12, 2024
275191a
Changes
JoeyT1994 Jun 12, 2024
194fba3
working implementation
JoeyT1994 Jun 12, 2024
50369c1
working implementation
JoeyT1994 Jun 12, 2024
0e5e5d8
Remove old changes
JoeyT1994 Jun 12, 2024
4bc0183
Revert
JoeyT1994 Jun 12, 2024
9e14f14
Revert
JoeyT1994 Jun 12, 2024
0a7355e
Revert
JoeyT1994 Jun 12, 2024
b07b978
Revert
JoeyT1994 Jun 12, 2024
440c267
Revert
JoeyT1994 Jun 12, 2024
ed7befa
Remove files
JoeyT1994 Jun 12, 2024
322dca4
Revert
JoeyT1994 Jun 12, 2024
54f41c0
Revert
JoeyT1994 Jun 12, 2024
dc0e132
Revert
JoeyT1994 Jun 12, 2024
2af3984
revert
JoeyT1994 Jun 12, 2024
30786bc
Working version
JoeyT1994 Jun 14, 2024
f0d4fc8
Merge branch 'ITensor:main' into bp_dmrg_alt_method
JoeyT1994 Jun 14, 2024
ed5037e
Improvements
JoeyT1994 Jun 14, 2024
e61e58c
Merge remote-tracking branch 'upstream/main' into bp_dmrg_alt_method
JoeyT1994 Jun 14, 2024
6998077
merge
JoeyT1994 Jun 14, 2024
511e09f
Merge branch 'bp_dmrg_alt_method' of github.com:JoeyT1994/ITensorNetw…
JoeyT1994 Jun 14, 2024
ed0c069
Improvements
JoeyT1994 Jun 14, 2024
553a983
Simplify
JoeyT1994 Jun 15, 2024
005b0e5
Change
JoeyT1994 Jun 16, 2024
af68e63
Working first commit
JoeyT1994 Jun 16, 2024
0704609
Revert some files
JoeyT1994 Jun 16, 2024
e1344f0
Revert expect
JoeyT1994 Jun 16, 2024
66319b0
Revert some changes
JoeyT1994 Jun 16, 2024
b098d44
Update src/caches/beliefpropagationcache.jl
JoeyT1994 Jun 16, 2024
b296277
Update src/caches/beliefpropagationcache.jl
JoeyT1994 Jun 16, 2024
1c87d22
Update src/normalize.jl
JoeyT1994 Jun 16, 2024
f88b21c
Merge remote-tracking branch 'upstream/main' into normalize!
JoeyT1994 Jun 26, 2024
6a8d4b9
Renormalize messages against themselves first
JoeyT1994 Jun 26, 2024
c845947
Blah
JoeyT1994 Sep 13, 2024
90c7251
Merge remote-tracking branch 'origin/main'
JoeyT1994 Oct 17, 2024
86f3087
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Oct 17, 2024
6ff0cd5
Bug fix in current ortho. Change test
JoeyT1994 Oct 17, 2024
34e8e5e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Nov 22, 2024
d096722
Fix bug
JoeyT1994 Nov 26, 2024
70a3f7e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Dec 5, 2024
2cb7f85
Refactor and bring down upstream changes
JoeyT1994 Dec 10, 2024
73e9e1e
Merge remote-tracking branch 'origin/main' into normalize!
JoeyT1994 Dec 10, 2024
4f4e2e5
Remove erroneous file
JoeyT1994 Dec 10, 2024
620da37
Allow rescaling flat networks with bp
JoeyT1994 Dec 10, 2024
180183e
Make generic to other algorithms
JoeyT1994 Dec 10, 2024
9d64fe8
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Mar 19, 2025
9d6c1bc
File removed
JoeyT1994 Mar 19, 2025
ae17245
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Mar 23, 2025
b648353
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 1, 2025
4e7d189
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 3, 2025
83c92b0
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 7, 2025
6f024ee
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Apr 11, 2025
bc7c238
Merge
JoeyT1994 Apr 18, 2025
22e7dbb
Improvements
JoeyT1994 Apr 18, 2025
e057da0
Merge remote-tracking branch 'upstream/main' into normalize!
JoeyT1994 Apr 18, 2025
dea3b6c
Version update
JoeyT1994 Apr 18, 2025
4fa8d92
Update src/caches/abstractbeliefpropagationcache.jl
JoeyT1994 Apr 21, 2025
153d205
Update src/normalize.jl
JoeyT1994 Apr 21, 2025
f7c8733
Formatting. Delete contraction_sequence_to_graph code
JoeyT1994 Apr 21, 2025
da56c33
Rescale local tensors first
JoeyT1994 Apr 21, 2025
1f9ec52
Use RegionScalar not dot
JoeyT1994 Apr 21, 2025
ed605cd
Contraction sequence test
JoeyT1994 Apr 21, 2025
6c528df
Comment out more
JoeyT1994 Apr 21, 2025
e6c0a21
Better default cache construction in expect
JoeyT1994 Apr 21, 2025
05d2e73
Rename vs_to_rescale -> verts_to_rescale
JoeyT1994 Apr 21, 2025
40f9228
Remove vs to rescale from exact function
JoeyT1994 Apr 21, 2025
00cc700
Revert commit
JoeyT1994 Apr 21, 2025
e6b753e
Remove EinExprs Tests
JoeyT1994 Apr 21, 2025
d35d812
Rescale message -> rescale messages
JoeyT1994 Apr 22, 2025
92d8243
All but HyPar for EinExprs tests
JoeyT1994 Apr 22, 2025
29ff795
Update src/caches/abstractbeliefpropagationcache.jl
JoeyT1994 Apr 22, 2025
f4c6952
Fix imports on contraction seq test
JoeyT1994 Apr 22, 2025
026e500
Merge branch 'normalize!' of github.com:JoeyT1994/ITensorNetworks.jl …
JoeyT1994 Apr 22, 2025
2df7e92
Formatting
JoeyT1994 Apr 22, 2025
17ccf85
Centre around scaling function
JoeyT1994 Apr 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworks"
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
version = "0.13.5"
version = "0.13.6"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
2 changes: 1 addition & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ include("formnetworks/abstractformnetwork.jl")
include("formnetworks/bilinearformnetwork.jl")
include("formnetworks/linearformnetwork.jl")
include("formnetworks/quadraticformnetwork.jl")
include("contraction_tree_to_graph.jl")
include("gauging.jl")
include("utils.jl")
include("update_observer.jl")
Expand Down Expand Up @@ -63,6 +62,7 @@ include("solvers/linsolve.jl")
include("solvers/sweep_plans/sweep_plans.jl")
include("apply.jl")
include("inner.jl")
include("normalize.jl")
include("expect.jl")
include("environment.jl")
include("exports.jl")
Expand Down
20 changes: 20 additions & 0 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,26 @@ function add(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork)
return tn12
end

""" Scale each tensor of the network by a scale factor on each vertex"""
function scale!(tn::AbstractITensorNetwork, vertices_weights::Dictionary)
for v in keys(vertices_weights)
setindex_preserve_graph!(tn, vertices_weights[v] * tn[v], v)
end
return tn
end

""" Scale each tensor of the network via a function (vertex, ITensor) -> Number"""
function scale!(tn::AbstractITensorNetwork, weight_function::Function)
vs = collect(vertices(tn))
vertices_weights = Dictionary(vs, [weight_function(v, tn[v]) for v in vs])
return scale!(tn, vertices_weights)
end

function scale(tn, args...)
tn = copy(tn)
return scale!(tn, args...)
end

Base.:+(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork) = add(tn1, tn2)

ITensors.hasqns(tn::AbstractITensorNetwork) = any(v -> hasqns(tn[v]), vertices(tn))
86 changes: 83 additions & 3 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Graphs: IsDirected
using Graphs: Graphs, IsDirected
using SplitApplyCombine: group
using LinearAlgebra: diag, dot
using ITensors: dir
Expand Down Expand Up @@ -66,7 +66,7 @@ function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; k
return not_implemented()
end
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
partitionpairs(bpc::AbstractBeliefPropagationCache) = not_implemented()
PartitionedGraphs.partitionedges(bpc::AbstractBeliefPropagationCache) = not_implemented()

function default_edge_sequence(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
Expand All @@ -88,6 +88,10 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache)
return unpartitioned_graph(partitioned_tensornetwork(bpc))
end

function setindex_preserve_graph!(bpc::AbstractBeliefPropagationCache, args...)
return setindex_preserve_graph!(tensornetwork(bpc), args...)
end

function factors(bpc::AbstractBeliefPropagationCache, verts::Vector)
return ITensor[tensornetwork(bpc)[v] for v in verts]
end
Expand All @@ -107,7 +111,7 @@ function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc)
end

function edge_scalars(
bpc::AbstractBeliefPropagationCache, pes=partitionpairs(bpc); kwargs...
bpc::AbstractBeliefPropagationCache, pes=partitionedges(bpc); kwargs...
)
return map(pe -> region_scalar(bpc, pe; kwargs...), pes)
end
Expand Down Expand Up @@ -283,3 +287,79 @@ function update(
)
return update(Algorithm(alg), bpc; kwargs...)
end

function scale!(bp_cache::AbstractBeliefPropagationCache, args...)
return scale!(tensornetwork(bp_cache), args...)
end

function rescale_messages(
bp_cache::AbstractBeliefPropagationCache, partitionedge::PartitionEdge
)
return rescale_messages(bp_cache, [partitionedge])
end

function rescale_messages(bp_cache::AbstractBeliefPropagationCache)
return rescale_messages(bp_cache, partitionedges(bp_cache))
end

function rescale_partitions(
bpc::AbstractBeliefPropagationCache,
partitions::Vector;
verts_to_rescale::Vector=vertices(bpc, partitions),
)
bpc = copy(bpc)
tn = tensornetwork(bpc)
norms = map(v -> inv(norm(tn[v])), verts_to_rescale)
scale!(bpc, Dictionary(verts_to_rescale, norms))

vertices_weights = Dictionary()
for pv in partitions
pv_vs = filter(v -> v ∈ verts_to_rescale, vertices(bpc, pv))
isempty(pv_vs) && continue

vn = region_scalar(bpc, pv)
s = isreal(vn) ? sign(vn) : 1.0
vn = s * inv(vn^(1 / length(pv_vs)))
set!(vertices_weights, first(pv_vs), s*vn)
for v in pv_vs[2:length(pv_vs)]
set!(vertices_weights, v, vn)
end
end

scale!(bpc, vertices_weights)

return bpc
end

function rescale_partitions(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
return rescale_partitions(bpc, collect(partitions(bpc)), args...; kwargs...)
end

function rescale_partition(
bpc::AbstractBeliefPropagationCache, partition, args...; kwargs...
)
return rescale_partitions(bpc, [partition], args...; kwargs...)
end

function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
bpc = rescale_messages(bpc)
bpc = rescale_partitions(bpc, args...; kwargs...)
return bpc
end

function logscalar(bpc::AbstractBeliefPropagationCache)
numerator_terms, denominator_terms = scalar_factors_quotient(bpc)
if any(t -> real(t) < 0, numerator_terms)
numerator_terms = complex.(numerator_terms)
end
if any(t -> real(t) < 0, denominator_terms)
denominator_terms = complex.(denominator_terms)
end

any(iszero, denominator_terms) && return -Inf
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
end

function ITensors.scalar(bpc::AbstractBeliefPropagationCache)
return exp(logscalar(bpc))
end
28 changes: 26 additions & 2 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ using NamedGraphs.PartitionedGraphs:
boundary_partitionedges,
partitionvertices,
partitionedges,
unpartitioned_graph
partitioned_vertices,
unpartitioned_graph,
which_partition
using SimpleTraits: SimpleTraits, Not, @traitfn
using NDTensors: NDTensors

Expand Down Expand Up @@ -80,7 +82,9 @@ function default_message_update_kwargs(
end

partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc))
partitionpairs(bpc::BeliefPropagationCache) = partitionedges(partitioned_tensornetwork(bpc))
function PartitionedGraphs.partitionedges(bpc::BeliefPropagationCache)
partitionedges(partitioned_tensornetwork(bpc))
end

function set_messages(cache::BeliefPropagationCache, messages)
return BeliefPropagationCache(partitioned_tensornetwork(cache), messages)
Expand All @@ -106,3 +110,23 @@ function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
sequence = contraction_sequence(ts; alg="optimal")
return contract(ts; sequence)[]
end

function rescale_messages(bp_cache::BeliefPropagationCache, pes::Vector{<:PartitionEdge})
bp_cache = copy(bp_cache)
mts = messages(bp_cache)
for pe in pes
me, mer = normalize.(mts[pe]), normalize.(mts[reverse(pe)])
set!(mts, pe, me)
set!(mts, reverse(pe), mer)
n = region_scalar(bp_cache, pe)
if isreal(n)
me[1] *= sign(n)
n *= sign(n)
end

sf = (1 / sqrt(n)) ^ (1 / length(me))
set!(mts, pe, sf .* me)
set!(mts, reverse(pe), sf .* mer)
end
return bp_cache
end
12 changes: 1 addition & 11 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,7 @@ function logscalar(
cache![] = update(cache![]; cache_update_kwargs...)
end

numerator_terms, denominator_terms = scalar_factors_quotient(cache![])
numerator_terms =
any(t -> real(t) < 0, numerator_terms) ? complex.(numerator_terms) : numerator_terms
denominator_terms = if any(t -> real(t) < 0, denominator_terms)
complex.(denominator_terms)
else
denominator_terms
end

any(iszero, denominator_terms) && return -Inf
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
return logscalar(cache![])
end

function ITensors.scalar(alg::Algorithm, tn::AbstractITensorNetwork; kwargs...)
Expand Down
82 changes: 0 additions & 82 deletions src/contraction_tree_to_graph.jl

This file was deleted.

2 changes: 1 addition & 1 deletion src/expect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function expect(
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(alg),
cache_construction_kwargs=default_cache_construction_kwargs(alg, QuadraticFormNetwork(ψ)),
cache_construction_kwargs=(;),
kwargs...,
)
ψIψ = QuadraticFormNetwork(ψ)
Expand Down
Loading
Loading