Skip to content

Commit 996f18d

Browse files
committed
added conversion function from graph lag format to compressd gunfolds format
1 parent 8ed7d65 commit 996f18d

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

gunfolds/conversions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,3 +845,26 @@ def encode_list_sccs(glist, scc_members=None):
845845
# if there is an edge between SCCs in the produced graph and none in the measured for nonsingleton SCCs - no go
846846
s += ':- directed(X,Y,U), scc(X,K), scc(Y,L), K != L, sccsize(L,Z), Z > 1, not dag(K,L,N), u(U,N).'
847847
return s
848+
849+
def Glag2CG(results):
850+
"""Converts lag graph format to gunfolds graph format,
851+
and A and B matrices representing directed and bidirected edges weights.
852+
853+
Args:
854+
results (dict): A dictionary containing:
855+
- 'graph': A 3D NumPy array of shape [N, N, 2] representing the graph structure.
856+
- 'val_matrix': A NumPy array of shape [N, N, 2] storing edge weights.
857+
858+
Returns:
859+
tuple: (graph_dict, A_matrix, B_matrix)
860+
"""
861+
862+
graph_array = results['graph']
863+
bidirected_edges = np.where(graph_array == 'o-o', 1, 0).astype(int)
864+
directed_edges = np.where(graph_array == '-->', 1, 0).astype(int)
865+
866+
graph_dict = adjs2graph(np.transpose(directed_edges[:, :, 1]), np.transpose((bidirected_edges[:, :, 0])))
867+
A_matrix = results['val_matrix'][:, :, 1]
868+
B_matrix = results['val_matrix'][:, :, 0]
869+
870+
return graph_dict, A_matrix, B_matrix

0 commit comments

Comments
 (0)