From 19630aa6109f1fad983299b6feac0027d093892b Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 9 Jan 2025 16:44:30 -0500 Subject: [PATCH] Unify update function --- src/caches/boundarympscache.jl | 232 ++++++++++++++++++++------------- 1 file changed, 138 insertions(+), 94 deletions(-) diff --git a/src/caches/boundarympscache.jl b/src/caches/boundarympscache.jl index 0e7dfbf3..c4ab049e 100644 --- a/src/caches/boundarympscache.jl +++ b/src/caches/boundarympscache.jl @@ -232,31 +232,28 @@ function partition_update(bmpsc::BoundaryMPSCache, partition::Int64) return bmpsc end -#Update all messages within a partition along the path from from v1 to v2 -function partition_update(bmpsc::BoundaryMPSCache, v1, v2) - return update( - Algorithm("SimpleBP"), - bmpsc, - PartitionEdge.(a_star(ppg(bmpsc), v1, v2)); - message_update_kwargs=(; normalize=false), - ) +function partition_update_sequence(bmpsc::BoundaryMPSCache, v1, v2) + return PartitionEdge.(a_star(ppg(bmpsc), v1, v2)) end - -#Update all message tensors within a partition pointing towards v -function partition_update(bmpsc::BoundaryMPSCache, v) +function partition_update_sequence(bmpsc::BoundaryMPSCache, v) pv = planargraph_partition(bmpsc, v) g = subgraph(unpartitioned_graph(ppg(bmpsc)), planargraph_vertices(bmpsc, pv)) + return PartitionEdge.(post_order_dfs_edges(g, v)) +end + +#Update all messages within a partition along the path from from v1 to v2 +function partition_update(bmpsc::BoundaryMPSCache, args...) return update( Algorithm("SimpleBP"), bmpsc, - PartitionEdge.(post_order_dfs_edges(g, v)); + partition_update_sequence(bmpsc, args...); message_update_kwargs=(; normalize=false), ) end #Move the orthogonality centre one step on an interpartition from the message tensor on pe1 to that on pe2 function gauge_step( - alg::Algorithm"orthogonalize", + alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache, pe1::PartitionEdge, pe2::PartitionEdge; @@ -276,7 +273,7 @@ end #Move the biorthogonality centre one step on an interpartition from the partition edge pe1 (and its reverse) to that on pe2 function gauge_step( - alg::Algorithm"biorthogonalize", + alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, pe1::PartitionEdge, pe2::PartitionEdge, @@ -323,114 +320,161 @@ function gauge_walk(alg::Algorithm, bmpsc::BoundaryMPSCache, seq::Vector; kwargs return bmpsc end +function gauge(alg::Algorithm, bmpsc::BoundaryMPSCache, args...; kwargs...) + return gauge_walk(alg, bmpsc, mps_gauge_update_sequence(bmpsc, args...); kwargs...) +end + #Move the orthogonality centre on an interpartition to the message tensor on pe or between two pes function ITensorMPS.orthogonalize(bmpsc::BoundaryMPSCache, args...; kwargs...) - return gauge_walk( - Algorithm("orthogonalize"), bmpsc, mps_gauge_update_sequence(bmpsc, args...); kwargs... - ) + return gauge(Algorithm("orthogonal"), bmpsc, args...; kwargs...) end #Move the biorthogonality centre on an interpartition to the message tensor or between two pes function biorthogonalize(bmpsc::BoundaryMPSCache, args...; kwargs...) - return gauge_walk( - Algorithm("biorthogonalize"), - bmpsc, - mps_gauge_update_sequence(bmpsc, args...); - kwargs..., - ) + return gauge(Algorithm("biorthogonal"), bmpsc, args...; kwargs...) end -#Update all the message tensors on an interpartition via an orthogonal fitting procedure -#TODO: Unify this to one function and make two-site possible -function update( +function default_inserter( alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache, + pe::PartitionEdge, + me::Vector{ITensor}, +) + return set_message(bmpsc, reverse(pe), dag.(me)) +end + +function default_inserter( + alg::Algorithm"biorthogonal", + bmpsc::BoundaryMPSCache, + pe::PartitionEdge, + me::Vector{ITensor}, +) + p_above, p_below = partitionedge_above(bmpsc, pe), partitionedge_below(bmpsc, pe) + me = only(me) + me_prev = only(message(bmpsc, pe)) + if !isnothing(p_above) + me = replaceind( + me, + commonind(me, only(message(bmpsc, reverse(p_above)))), + commonind(me_prev, only(message(bmpsc, p_above))), + ) + end + if !isnothing(p_below) + me = replaceind( + me, + commonind(me, only(message(bmpsc, reverse(p_below)))), + commonind(me_prev, only(message(bmpsc, p_below))), + ) + end + return set_message(bmpsc, pe, ITensor[me]) +end + +function default_updater( + alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache, prev_pe, update_pe, prev_v, cur_v +) + bmpsc = if !isnothing(prev_pe) + gauge(alg, bmpsc, reverse(prev_pe), reverse(update_pe)) + else + gauge(alg, bmpsc, reverse(update_pe)) + end + bmpsc = if !isnothing(prev_v) + partition_update(bmpsc, prev_v, cur_v) + else + partition_update(bmpsc, cur_v) + end + return bmpsc +end + +function default_updater( + alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, prev_pe, update_pe, prev_v, cur_v +) + bmpsc = if !isnothing(prev_pe) + gauge(alg, bmpsc, prev_pe, update_pe) + else + gauge(alg, bmpsc, update_pe) + end + bmpsc = if !isnothing(prev_v) + partition_update(bmpsc, prev_v, cur_v) + else + partition_update(bmpsc, cur_v) + end + return bmpsc +end + +function default_cache_prep_function( + alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, partitionpair +) + return bmpsc +end +function default_cache_prep_function( + alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache, partitionpair +) + return switch_messages(bmpsc, partitionpair) +end + +default_niters(alg::Algorithm"orthogonal") = 25 +default_niters(alg::Algorithm"biorthogonal") = 3 +default_tolerance(alg::Algorithm"orthogonal") = 1e-10 +default_tolerance(alg::Algorithm"biorthogonal") = nothing + +function default_costfunction( + alg::Algorithm"orthogonal", + bmpsc::BoundaryMPSCache, + pe::PartitionEdge, + me::Vector{ITensor}, +) + return region_scalar(bp_cache(bmpsc), src(pe)) / norm(only(me)) +end + +function default_costfunction( + alg::Algorithm"biorthogonal", + bmpsc::BoundaryMPSCache, + pe::PartitionEdge, + me::Vector{ITensor}, +) + return region_scalar(bp_cache(bmpsc), src(pe)) / + dot(only(me), only(message(bmpsc, reverse(pe)))) +end + +#Update all the message tensors on an interpartition via a specified fitting procedure +#TODO: Make two-site possible +function update( + alg::Algorithm, + bmpsc::BoundaryMPSCache, partitionpair::Pair; - niters::Int64=25, - tolerance=1e-10, + inserter=default_inserter, + costfunction=default_costfunction, + updater=default_updater, + cache_prep_function=default_cache_prep_function, + niters::Int64=default_niters(alg), + tolerance=default_tolerance(alg), normalize=true, ) - bmpsc = switch_messages(bmpsc, partitionpair) + bmpsc = cache_prep_function(alg, bmpsc, partitionpair) pes = planargraph_partitionpair_partitionedges(bmpsc, partitionpair) update_seq = vcat(pes, reverse(pes)[2:length(pes)]) prev_v, prev_pe = nothing, nothing - prev_costfunction = 0 + prev_cf = 0 for i in 1:niters - costfunction = 0 + cf = 0 for update_pe in update_seq cur_v = parent(src(update_pe)) - bmpsc = if !isnothing(prev_pe) - orthogonalize(bmpsc, reverse(prev_pe), reverse(update_pe)) - else - orthogonalize(bmpsc, reverse(update_pe)) - end - bmpsc = if !isnothing(prev_v) - partition_update(bmpsc, prev_v, cur_v) - else - partition_update(bmpsc, cur_v) - end + bmpsc = updater(alg, bmpsc, prev_pe, update_pe, prev_v, cur_v) me = updated_message( bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize) ) - costfunction += region_scalar(bp_cache(bmpsc), src(update_pe)) / norm(me) - bmpsc = set_message(bmpsc, reverse(update_pe), dag.(me)) + cf += costfunction(alg, bmpsc, update_pe, me) + bmpsc = inserter(alg, bmpsc, update_pe, me) prev_v, prev_pe = cur_v, update_pe end - epsilon = abs(costfunction - prev_costfunction) / length(update_seq) + epsilon = abs(cf - prev_cf) / length(update_seq) if !isnothing(tolerance) && epsilon < tolerance - return switch_messages(bmpsc, partitionpair) + return cache_prep_function(alg, bmpsc, partitionpair) else - prev_costfunction = costfunction + prev_cf = cf end end - return switch_messages(bmpsc, partitionpair) -end - -#Update all the message tensors on an interpartition via a biorthogonal fitting procedure -function update( - alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, partitionpair::Pair; normalize=true -) - prev_v, prev_pe = nothing, nothing - for update_pe in planargraph_partitionpair_partitionedges(bmpsc, partitionpair) - cur_v = parent(src(update_pe)) - bmpsc = if !isnothing(prev_pe) - biorthogonalize(bmpsc, prev_pe, update_pe) - else - biorthogonalize(bmpsc, update_pe) - end - bmpsc = if !isnothing(prev_v) - partition_update(bmpsc, prev_v, cur_v) - else - partition_update(bmpsc, cur_v) - end - - me_prev = only(message(bmpsc, update_pe)) - me = only( - updated_message( - bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize) - ), - ) - p_above, p_below = partitionedge_above(bmpsc, update_pe), - partitionedge_below(bmpsc, update_pe) - if !isnothing(p_above) - me = replaceind( - me, - commonind(me, only(message(bmpsc, reverse(p_above)))), - commonind(me_prev, only(message(bmpsc, p_above))), - ) - end - if !isnothing(p_below) - me = replaceind( - me, - commonind(me, only(message(bmpsc, reverse(p_below)))), - commonind(me_prev, only(message(bmpsc, p_below))), - ) - end - bmpsc = set_message(bmpsc, update_pe, ITensor[me]) - prev_v, prev_pe = cur_v, update_pe - end - - return bmpsc + return cache_prep_function(alg, bmpsc, partitionpair) end #Assume all vertices live in the same partition for now