Skip to content

Commit

Permalink
Rename kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Jan 2, 2025
1 parent 9d64029 commit dc46339
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 26 deletions.
36 changes: 19 additions & 17 deletions src/caches/boundarympscache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ function planargraph_partitionpair(bmpsc::BoundaryMPSCache, pe::PartitionEdge)
end

function BoundaryMPSCache(
bpc::BeliefPropagationCache; sort_f::Function=v -> first(v), message_rank::Int64=1
bpc::BeliefPropagationCache;
grouping_function::Function=v -> first(v),
message_rank::Int64=1,
)
bpc = insert_pseudo_planar_edges(bpc; sort_f)
bpc = insert_pseudo_planar_edges(bpc; grouping_function)
planar_graph = partitioned_graph(bpc)
vertex_groups = group(sort_f, collect(vertices(planar_graph)))
vertex_groups = group(grouping_function, collect(vertices(planar_graph)))
ppg = PartitionedGraph(planar_graph, vertex_groups)
bmpsc = BoundaryMPSCache(bpc, ppg)
return set_interpartition_messages(bmpsc, message_rank)
Expand Down Expand Up @@ -319,15 +321,15 @@ function mps_update(
end
bmpsc = if !isnothing(prev_v)
partition_update(
bmpsc,
prev_v,
cur_v;
message_update=ms -> default_message_update(ms; normalize=false),
)
bmpsc,
prev_v,
cur_v;
message_update=ms -> default_message_update(ms; normalize=false),
)
else
partition_update(
bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false)
)
bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false)
)
end
me = update_message(
bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize)
Expand Down Expand Up @@ -360,15 +362,15 @@ function mps_update(
end
bmpsc = if !isnothing(prev_v)
partition_update(
bmpsc,
prev_v,
cur_v;
message_update=ms -> default_message_update(ms; normalize=false),
)
bmpsc,
prev_v,
cur_v;
message_update=ms -> default_message_update(ms; normalize=false),
)
else
partition_update(
bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false)
)
bmpsc, cur_v; message_update=ms -> default_message_update(ms; normalize=false)
)
end

me_prev = only(message(bmpsc, update_pe))
Expand Down
8 changes: 5 additions & 3 deletions src/caches/boundarympscacheutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ function add_partitionedges(bpc::BeliefPropagationCache, pes::Vector{<:Partition
end

#Add partition edges necessary to connect up all vertices in a partition in the planar graph created by the sort function
function insert_pseudo_planar_edges(bpc::BeliefPropagationCache; sort_f=v -> first(v))
function insert_pseudo_planar_edges(
bpc::BeliefPropagationCache; grouping_function=v -> first(v)
)
pg = partitioned_graph(bpc)
partitions = unique(sort_f.(collect(vertices(pg))))
partitions = unique(grouping_function.(collect(vertices(pg))))
pseudo_edges = PartitionEdge[]
for p in partitions
vs = sort(filter(v -> sort_f(v) == p, collect(vertices(pg))))
vs = sort(filter(v -> grouping_function(v) == p, collect(vertices(pg))))
for i in 1:(length(vs) - 1)
if vs[i] neighbors(pg, vs[i + 1])
push!(pseudo_edges, PartitionEdge(NamedEdge(vs[i] => vs[i + 1])))
Expand Down
12 changes: 6 additions & 6 deletions test/test_boundarymps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,17 @@ using LinearAlgebra: norm
∂tn_∂vc_exact = contract(∂tn_∂vc; sequence=contraction_sequence(∂tn_∂vc; alg="greedy"))
∂tn_∂vc_exact /= norm(∂tn_∂vc_exact)

#Orthogonal Boundary MPS
tn_boundaryMPS = BoundaryMPSCache(tn; message_rank=1)
#Orthogonal Boundary MPS, group by row
tn_boundaryMPS = BoundaryMPSCache(tn; grouping_function=v -> last(v), message_rank=1)
tn_boundaryMPS = update(tn_boundaryMPS; mps_fit_kwargs)
∂tn_∂vc_boundaryMPS = contract(environment(tn_boundaryMPS, [vc]); sequence="automatic")
∂tn_∂vc_boundaryMPS /= norm(∂tn_∂vc_boundaryMPS)

@test norm(∂tn_∂vc_boundaryMPS - ∂tn_∂vc_exact) <= 10 * eps(real(elt))
@test norm(∂tn_∂vc_boundaryMPS - ∂tn_∂vc_bp) <= 10 * eps(real(elt))

#Biorthogonal Boundary MPS
tn_boundaryMPS = BoundaryMPSCache(tn; message_rank=1)
#Biorthogonal Boundary MPS, , group by row
tn_boundaryMPS = BoundaryMPSCache(tn; grouping_function=v -> last(v), message_rank=1)
tn_boundaryMPS = update(tn_boundaryMPS; alg="biorthogonal")
∂tn_∂vc_boundaryMPS = contract(environment(tn_boundaryMPS, [vc]); sequence="automatic")
∂tn_∂vc_boundaryMPS /= norm(∂tn_∂vc_boundaryMPS)
Expand All @@ -99,15 +99,15 @@ using LinearAlgebra: norm
ρ_exact = contract(ρ; sequence=contraction_sequence(ρ; alg="greedy"))
ρ_exact /= tr(ρ_exact)

#Orthogonal Boundary MPS
#Orthogonal Boundary MPS, group by column (default)
ψIψ_boundaryMPS = BoundaryMPSCache(ψIψ; message_rank=χ * χ)
ψIψ_boundaryMPS = update(ψIψ_boundaryMPS)
ρ_boundaryMPS = contract(environment(ψIψ_boundaryMPS, [vc]); sequence="automatic")
ρ_boundaryMPS /= tr(ρ_boundaryMPS)

@test norm(ρ_boundaryMPS - ρ_exact) <= 10 * eps(real(elt))

#BiOrthogonal Boundary MPS
#BiOrthogonal Boundary MPS, group by column (default)
ψIψ_boundaryMPS = BoundaryMPSCache(ψIψ; message_rank=χ * χ)
ψIψ_boundaryMPS = update(ψIψ_boundaryMPS; alg="biorthogonal", maxiter=50)
ρ_boundaryMPS = contract(environment(ψIψ_boundaryMPS, [vc]); sequence="automatic")
Expand Down

0 comments on commit dc46339

Please sign in to comment.