22
33import pickle
44import torch
5+ import torch .nn .functional as F
56import networkx as nx
67import numpy as np
78
@@ -25,6 +26,7 @@ def __init__(
2526 clean_edges = True ,
2627 edge_map = GRAPH_KEYS ["edge_map" ][TOOL ],
2728 etype_key = "LW" ,
29+ distogram_only = False ,
2830 ** kwargs ,
2931 ):
3032 self .distograms_path = distograms_path
@@ -34,6 +36,7 @@ def __init__(
3436 self .tau = tau
3537 self .distogram_files_prefix = distogram_files_prefix
3638 self .distogram_files_suffix = distogram_files_suffix
39+ self .distogram_only = distogram_only
3740 super ().__init__ (framework , clean_edges , edge_map , etype_key , ** kwargs )
3841 pass
3942
@@ -72,15 +75,21 @@ def __call__(self, rna_graph, features_dict):
7275 new_edges = new_edges .t ()
7376
7477 new_edge_attr = torch .full ((new_edges .size (1 ),), max (self .edge_map .values ())+ 1 , dtype = torch .long )
75- pyg_graph .edge_index = torch .cat ([pyg_graph .edge_index , new_edges ], dim = 1 )
76- pyg_graph .edge_attr = torch .cat ([pyg_graph .edge_attr , new_edge_attr ], dim = 0 )
78+
79+ if self .distogram_only :
80+ pyg_graph .edge_index = new_edges
81+ pyg_graph .edge_attr = new_edge_attr
82+ else :
83+ pyg_graph .edge_index = torch .cat ([pyg_graph .edge_index , new_edges ], dim = 1 )
84+ pyg_graph .edge_attr = torch .cat ([pyg_graph .edge_attr , new_edge_attr ], dim = 0 )
7785
7886 if self .distogram_edge_features :
7987
8088 row , col = pyg_graph .edge_index
8189 edge_distances = torch .from_numpy (distogram [row ,col ])
82- pyg_graph .edge_attr = edge_distances
83- #pyg_graph.edge_attr = torch.cat([pyg_graph.edge_attr.unsqueeze(1), edge_distances], dim=1)
90+ num_classes = len (self .edge_map )
91+ edge_attr_one_hot = F .one_hot (pyg_graph .edge_attr .long (), num_classes = num_classes )
92+ pyg_graph .edge_attr = torch .cat ([edge_attr_one_hot , edge_distances ], dim = 1 ).float ()
8493
8594 return pyg_graph
8695
0 commit comments