@@ -155,11 +155,17 @@ def to_pyg(self, graph, features_dict):
155155 # targets: flatten the nearest neighbor indices
156156 target = neighbor_indices .flatten ()
157157 edge_index = torch .stack ([source , target ], dim = 0 )
158+ < << << << HEAD
158159 edge_attrs = torch .zeros (edge_index .shape [1 ],dtype = int )
159160
160161 elif self .graph_construction == "threshold" :
161162 edges = torch .nonzero (dist_matrix < self .threshold , as_tuple = False )
162163 edge_attrs = torch .zeros (edge_index .shape [1 ],dtype = int )
164+ == == == =
165+
166+ elif self .graph_construction == "threshold" :
167+ edges = torch .nonzero (dist_matrix < self .threshold , as_tuple = False )
168+ >> >> >> > 70 fdf59f1e46542d77c7d48f3930e1d152e223e0
163169 edge_index = edges .t ()
164170
165171 else :
@@ -171,6 +177,10 @@ def to_pyg(self, graph, features_dict):
171177 if self .distance_edge_features :
172178 edge_distances = dist_matrix [edge_index [0 , :], edge_index [1 , :]]
173179 edge_feats = rbf_expand (dists = edge_distances , num_bins = 64 , min_distance = 2.0 , max_distance = 22.0 )
180+ < << << << HEAD
181+ == == == =
182+ edge_attrs = torch .zeros (edge_index .shape [1 ],dtype = int )
183+ >> >> >> > 70 fdf59f1e46542d77c7d48f3930e1d152e223e0
174184 return Data (x = x , y = y , edge_attr = edge_attrs , edge_index = edge_index , edge_feats = edge_feats )
175185
176186 return Data (x = x , y = y , edge_attr = edge_attrs , edge_index = edge_index )
0 commit comments