@@ -117,7 +117,9 @@ def to_pyg(self, graph, features_dict):
117117 y = features_dict ["rna_targets" ].clone ().detach ()
118118
119119 if self .graph_construction in ["knn" ,"threshold" ]:
120+
120121 if self .purine_representative != self .pyrimidine_representative :
122+
121123 all_attrs_pyrimidine_rep = nx .get_node_attributes (graph , f'xyz_{ self .pyrimidine_representative } ' )
122124 all_attrs_purine_rep = nx .get_node_attributes (graph , f'xyz_{ self .purine_representative } ' )
123125 all_attrs_base_identity = nx .get_node_attributes (graph , 'nt' )
@@ -128,7 +130,9 @@ def to_pyg(self, graph, features_dict):
128130 purine_rep_coords = torch .tensor (purine_rep_coords_list )
129131 purine_mask = torch .tensor (purine_mask_list )
130132 nucleotide_coords = purine_rep_coords * purine_mask .view (- 1 ,1 )+ pyrimidine_rep_coords * (1 - purine_mask ).view (- 1 ,1 )
133+
131134 else :
135+
132136 all_attrs = nx .get_node_attributes (graph , f'xyz_{ self .representative } ' )
133137 nucleotide_coords_list = [all_attrs [n ] if all_attrs [n ] is not None else float ('nan' ) for n in node_map .keys ()]
134138 nucleotide_coords = torch .tensor (nucleotide_coords_list )
@@ -137,6 +141,7 @@ def to_pyg(self, graph, features_dict):
137141 dist_matrix = torch .cdist (nucleotide_coords , nucleotide_coords )
138142
139143 if self .graph_construction == "knn" :
144+
140145 # Find k+1 smallest elements
141146 _ , indices = dist_matrix .topk (self .top_k + 1 , largest = False )
142147 # Remove the first column (self-loops)
@@ -149,12 +154,14 @@ def to_pyg(self, graph, features_dict):
149154 edge_index = torch .stack ([source , target ], dim = 0 )
150155
151156 else :
157+
152158 dist_matrix = torch .cdist (nucleotide_coords , nucleotide_coords , p = 2 )
153159 dist_matrix .fill_diagonal_ (float ('inf' ))
154160 edges = torch .nonzero (dist_matrix < self .threshold , as_tuple = False )
155161 edge_index = edges .t ()
162+ edge_attrs = torch .zeros (edge_index .shape [1 ],dtype = int )
156163
157- return Data (x = x , y = y , edge_index = edge_index )
164+ return Data (x = x , y = y , edge_attr = edge_attrs , edge_index = edge_index )
158165
159166 edge_index = [[node_map [u ], node_map [v ]] for u , v in sorted (graph .edges (), key = lambda x : (x [0 ].split ('.' )[1 ],int (x [0 ].split ('.' )[2 ]),x [1 ].split ('.' )[1 ],int (x [1 ].split ('.' )[2 ])))]
160167 edge_index = torch .tensor (edge_index , dtype = torch .long ).T
0 commit comments