Skip to content

Commit e667c72

Browse files
committed
add edge attributes to distance-based graphs
1 parent 3731d24 commit e667c72

2 files changed

Lines changed: 11 additions & 2 deletions

File tree

src/rnaglib/transforms/represent/distograph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def __call__(self, rna_graph, features_dict):
7474
new_edges = torch.nonzero(proba_matrix > self.tau, as_tuple=False)
7575
new_edges = new_edges.t()
7676

77-
new_edge_attr = torch.full((new_edges.size(1),), max(self.edge_map.values())+1, dtype=torch.long)
77+
max_occupied_index = max(self.edge_map.values()) if self.graph_construction=="base_pair" else 0
78+
79+
new_edge_attr = torch.full((new_edges.size(1),), max_occupied_index+1, dtype=torch.long)
7880

7981
if self.distogram_only:
8082
pyg_graph.edge_index = new_edges

src/rnaglib/transforms/represent/graph.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def to_pyg(self, graph, features_dict):
117117
y = features_dict["rna_targets"].clone().detach()
118118

119119
if self.graph_construction in ["knn","threshold"]:
120+
120121
if self.purine_representative != self.pyrimidine_representative:
122+
121123
all_attrs_pyrimidine_rep = nx.get_node_attributes(graph, f'xyz_{self.pyrimidine_representative}')
122124
all_attrs_purine_rep = nx.get_node_attributes(graph, f'xyz_{self.purine_representative}')
123125
all_attrs_base_identity = nx.get_node_attributes(graph, 'nt')
@@ -128,7 +130,9 @@ def to_pyg(self, graph, features_dict):
128130
purine_rep_coords = torch.tensor(purine_rep_coords_list)
129131
purine_mask = torch.tensor(purine_mask_list)
130132
nucleotide_coords = purine_rep_coords*purine_mask.view(-1,1)+pyrimidine_rep_coords*(1-purine_mask).view(-1,1)
133+
131134
else:
135+
132136
all_attrs = nx.get_node_attributes(graph, f'xyz_{self.representative}')
133137
nucleotide_coords_list = [all_attrs[n] if all_attrs[n] is not None else float('nan') for n in node_map.keys()]
134138
nucleotide_coords = torch.tensor(nucleotide_coords_list)
@@ -137,6 +141,7 @@ def to_pyg(self, graph, features_dict):
137141
dist_matrix = torch.cdist(nucleotide_coords, nucleotide_coords)
138142

139143
if self.graph_construction == "knn":
144+
140145
# Find k+1 smallest elements
141146
_, indices = dist_matrix.topk(self.top_k + 1, largest=False)
142147
# Remove the first column (self-loops)
@@ -149,12 +154,14 @@ def to_pyg(self, graph, features_dict):
149154
edge_index= torch.stack([source, target], dim=0)
150155

151156
else:
157+
152158
dist_matrix = torch.cdist(nucleotide_coords, nucleotide_coords, p=2)
153159
dist_matrix.fill_diagonal_(float('inf'))
154160
edges = torch.nonzero(dist_matrix < self.threshold, as_tuple=False)
155161
edge_index = edges.t()
162+
edge_attrs = torch.zeros(edge_index.shape[1],dtype=int)
156163

157-
return Data(x=x, y=y, edge_index=edge_index)
164+
return Data(x=x, y=y, edge_attr=edge_attrs, edge_index=edge_index)
158165

159166
edge_index = [[node_map[u], node_map[v]] for u, v in sorted(graph.edges(), key=lambda x: (x[0].split('.')[1],int(x[0].split('.')[2]),x[1].split('.')[1],int(x[1].split('.')[2])))]
160167
edge_index = torch.tensor(edge_index, dtype=torch.long).T

0 commit comments

Comments
 (0)