1
- using Graphs: IsDirected
1
+ using Graphs: Graphs, IsDirected
2
2
using SplitApplyCombine: group
3
3
using LinearAlgebra: diag, dot
4
4
using ITensors: dir
@@ -66,7 +66,7 @@ function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; k
66
66
return not_implemented ()
67
67
end
68
68
partitions (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
69
- partitionpairs (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
69
+ PartitionedGraphs . partitionedges (bpc:: AbstractBeliefPropagationCache ) = not_implemented ()
70
70
71
71
function default_edge_sequence (
72
72
bpc:: AbstractBeliefPropagationCache ; alg= default_message_update_alg (bpc)
@@ -88,6 +88,10 @@ function tensornetwork(bpc::AbstractBeliefPropagationCache)
88
88
return unpartitioned_graph (partitioned_tensornetwork (bpc))
89
89
end
90
90
91
+ function setindex_preserve_graph! (bpc:: AbstractBeliefPropagationCache , args... )
92
+ return setindex_preserve_graph! (tensornetwork (bpc), args... )
93
+ end
94
+
91
95
function factors (bpc:: AbstractBeliefPropagationCache , verts:: Vector )
92
96
return ITensor[tensornetwork (bpc)[v] for v in verts]
93
97
end
@@ -107,7 +111,7 @@ function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc)
107
111
end
108
112
109
113
function edge_scalars (
110
- bpc:: AbstractBeliefPropagationCache , pes= partitionpairs (bpc); kwargs...
114
+ bpc:: AbstractBeliefPropagationCache , pes= partitionedges (bpc); kwargs...
111
115
)
112
116
return map (pe -> region_scalar (bpc, pe; kwargs... ), pes)
113
117
end
@@ -283,3 +287,79 @@ function update(
283
287
)
284
288
return update (Algorithm (alg), bpc; kwargs... )
285
289
end
290
+
291
+ function scale! (bp_cache:: AbstractBeliefPropagationCache , args... )
292
+ return scale! (tensornetwork (bp_cache), args... )
293
+ end
294
+
295
+ function rescale_messages (
296
+ bp_cache:: AbstractBeliefPropagationCache , partitionedge:: PartitionEdge
297
+ )
298
+ return rescale_messages (bp_cache, [partitionedge])
299
+ end
300
+
301
+ function rescale_messages (bp_cache:: AbstractBeliefPropagationCache )
302
+ return rescale_messages (bp_cache, partitionedges (bp_cache))
303
+ end
304
+
305
+ function rescale_partitions (
306
+ bpc:: AbstractBeliefPropagationCache ,
307
+ partitions:: Vector ;
308
+ verts:: Vector = vertices (bpc, partitions),
309
+ )
310
+ bpc = copy (bpc)
311
+ tn = tensornetwork (bpc)
312
+ norms = map (v -> inv (norm (tn[v])), verts)
313
+ scale! (bpc, Dictionary (verts, norms))
314
+
315
+ vertices_weights = Dictionary ()
316
+ for pv in partitions
317
+ pv_vs = filter (v -> v ∈ verts, vertices (bpc, pv))
318
+ isempty (pv_vs) && continue
319
+
320
+ vn = region_scalar (bpc, pv)
321
+ s = isreal (vn) ? sign (vn) : 1.0
322
+ vn = s * inv (vn^ (1 / length (pv_vs)))
323
+ set! (vertices_weights, first (pv_vs), s* vn)
324
+ for v in pv_vs[2 : length (pv_vs)]
325
+ set! (vertices_weights, v, vn)
326
+ end
327
+ end
328
+
329
+ scale! (bpc, vertices_weights)
330
+
331
+ return bpc
332
+ end
333
+
334
+ function rescale_partitions (bpc:: AbstractBeliefPropagationCache , args... ; kwargs... )
335
+ return rescale_partitions (bpc, collect (partitions (bpc)), args... ; kwargs... )
336
+ end
337
+
338
+ function rescale_partition (
339
+ bpc:: AbstractBeliefPropagationCache , partition, args... ; kwargs...
340
+ )
341
+ return rescale_partitions (bpc, [partition], args... ; kwargs... )
342
+ end
343
+
344
+ function rescale (bpc:: AbstractBeliefPropagationCache , args... ; kwargs... )
345
+ bpc = rescale_messages (bpc)
346
+ bpc = rescale_partitions (bpc, args... ; kwargs... )
347
+ return bpc
348
+ end
349
+
350
+ function logscalar (bpc:: AbstractBeliefPropagationCache )
351
+ numerator_terms, denominator_terms = scalar_factors_quotient (bpc)
352
+ if any (t -> real (t) < 0 , numerator_terms)
353
+ numerator_terms = complex .(numerator_terms)
354
+ end
355
+ if any (t -> real (t) < 0 , denominator_terms)
356
+ denominator_terms = complex .(denominator_terms)
357
+ end
358
+
359
+ any (iszero, denominator_terms) && return - Inf
360
+ return sum (log .(numerator_terms)) - sum (log .((denominator_terms)))
361
+ end
362
+
363
+ function ITensors. scalar (bpc:: AbstractBeliefPropagationCache )
364
+ return exp (logscalar (bpc))
365
+ end
0 commit comments