diff --git a/src/caches/boundarympscache.jl b/src/caches/boundarympscache.jl index c5c57b4b..04a05f75 100644 --- a/src/caches/boundarympscache.jl +++ b/src/caches/boundarympscache.jl @@ -51,11 +51,13 @@ function planargraph_partitionpair(bmpsc::BoundaryMPSCache, pe::PartitionEdge) end function BoundaryMPSCache( - bpc::BeliefPropagationCache; sort_f::Function=v -> first(v), message_rank::Int64=1 + bpc::BeliefPropagationCache; + grouping_function::Function=v -> first(v), + message_rank::Int64=1, ) - bpc = insert_pseudo_planar_edges(bpc; sort_f) + bpc = insert_pseudo_planar_edges(bpc; grouping_function) planar_graph = partitioned_graph(bpc) - vertex_groups = group(sort_f, collect(vertices(planar_graph))) + vertex_groups = group(grouping_function, collect(vertices(planar_graph))) ppg = PartitionedGraph(planar_graph, vertex_groups) bmpsc = BoundaryMPSCache(bpc, ppg) return set_interpartition_messages(bmpsc, message_rank) @@ -319,15 +321,15 @@ function mps_update( end bmpsc = if !isnothing(prev_v) partition_update( - bmpsc, - prev_v, - cur_v; - message_update=ms -> default_message_update(ms; normalize=false), - ) + bmpsc, + prev_v, + cur_v; + message_update=ms -> default_message_update(ms; normalize=false), + ) else partition_update( - bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false) - ) + bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false) + ) end me = update_message( bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize) @@ -360,15 +362,15 @@ function mps_update( end bmpsc = if !isnothing(prev_v) partition_update( - bmpsc, - prev_v, - cur_v; - message_update=ms -> default_message_update(ms; normalize=false), - ) + bmpsc, + prev_v, + cur_v; + message_update=ms -> default_message_update(ms; normalize=false), + ) else partition_update( - bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false) - ) + bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false) + ) end me_prev = only(message(bmpsc, update_pe)) diff --git a/src/caches/boundarympscacheutils.jl b/src/caches/boundarympscacheutils.jl index fa3840e9..f64aa9cd 100644 --- a/src/caches/boundarympscacheutils.jl +++ b/src/caches/boundarympscacheutils.jl @@ -17,12 +17,14 @@ function add_partitionedges(bpc::BeliefPropagationCache, pes::Vector{<:Partition end #Add partition edges necessary to connect up all vertices in a partition in the planar graph created by the sort function -function insert_pseudo_planar_edges(bpc::BeliefPropagationCache; sort_f=v -> first(v)) +function insert_pseudo_planar_edges( + bpc::BeliefPropagationCache; grouping_function=v -> first(v) +) pg = partitioned_graph(bpc) - partitions = unique(sort_f.(collect(vertices(pg)))) + partitions = unique(grouping_function.(collect(vertices(pg)))) pseudo_edges = PartitionEdge[] for p in partitions - vs = sort(filter(v -> sort_f(v) == p, collect(vertices(pg)))) + vs = sort(filter(v -> grouping_function(v) == p, collect(vertices(pg)))) for i in 1:(length(vs) - 1) if vs[i] ∉ neighbors(pg, vs[i + 1]) push!(pseudo_edges, PartitionEdge(NamedEdge(vs[i] => vs[i + 1]))) diff --git a/test/test_boundarymps.jl b/test/test_boundarymps.jl index 4ea4f607..0f512359 100644 --- a/test/test_boundarymps.jl +++ b/test/test_boundarymps.jl @@ -69,8 +69,8 @@ using LinearAlgebra: norm ∂tn_∂vc_exact = contract(∂tn_∂vc; sequence=contraction_sequence(∂tn_∂vc; alg="greedy")) ∂tn_∂vc_exact /= norm(∂tn_∂vc_exact) - #Orthogonal Boundary MPS - tn_boundaryMPS = BoundaryMPSCache(tn; message_rank=1) + #Orthogonal Boundary MPS, group by row + tn_boundaryMPS = BoundaryMPSCache(tn; grouping_function=v -> last(v), message_rank=1) tn_boundaryMPS = update(tn_boundaryMPS; mps_fit_kwargs) ∂tn_∂vc_boundaryMPS = contract(environment(tn_boundaryMPS, [vc]); sequence="automatic") ∂tn_∂vc_boundaryMPS /= norm(∂tn_∂vc_boundaryMPS) @@ -78,8 +78,8 @@ using LinearAlgebra: norm @test norm(∂tn_∂vc_boundaryMPS - ∂tn_∂vc_exact) <= 10 * eps(real(elt)) @test norm(∂tn_∂vc_boundaryMPS - ∂tn_∂vc_bp) <= 10 * eps(real(elt)) - #Biorthogonal Boundary MPS - tn_boundaryMPS = BoundaryMPSCache(tn; message_rank=1) + #Biorthogonal Boundary MPS, , group by row + tn_boundaryMPS = BoundaryMPSCache(tn; grouping_function=v -> last(v), message_rank=1) tn_boundaryMPS = update(tn_boundaryMPS; alg="biorthogonal") ∂tn_∂vc_boundaryMPS = contract(environment(tn_boundaryMPS, [vc]); sequence="automatic") ∂tn_∂vc_boundaryMPS /= norm(∂tn_∂vc_boundaryMPS) @@ -99,7 +99,7 @@ using LinearAlgebra: norm ρ_exact = contract(ρ; sequence=contraction_sequence(ρ; alg="greedy")) ρ_exact /= tr(ρ_exact) - #Orthogonal Boundary MPS + #Orthogonal Boundary MPS, group by column (default) ψIψ_boundaryMPS = BoundaryMPSCache(ψIψ; message_rank=χ * χ) ψIψ_boundaryMPS = update(ψIψ_boundaryMPS) ρ_boundaryMPS = contract(environment(ψIψ_boundaryMPS, [vc]); sequence="automatic") @@ -107,7 +107,7 @@ using LinearAlgebra: norm @test norm(ρ_boundaryMPS - ρ_exact) <= 10 * eps(real(elt)) - #BiOrthogonal Boundary MPS + #BiOrthogonal Boundary MPS, group by column (default) ψIψ_boundaryMPS = BoundaryMPSCache(ψIψ; message_rank=χ * χ) ψIψ_boundaryMPS = update(ψIψ_boundaryMPS; alg="biorthogonal", maxiter=50) ρ_boundaryMPS = contract(environment(ψIψ_boundaryMPS, [vc]); sequence="automatic")