Skip to content

Commit a294085

Browse files
committed
add possibility of distogram-only graph
1 parent 88781b4 commit a294085

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

src/rnaglib/transforms/represent/distograph.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pickle
44
import torch
5+
import torch.nn.functional as F
56
import networkx as nx
67
import numpy as np
78

@@ -25,6 +26,7 @@ def __init__(
2526
clean_edges=True,
2627
edge_map=GRAPH_KEYS["edge_map"][TOOL],
2728
etype_key="LW",
29+
distogram_only=False,
2830
**kwargs,
2931
):
3032
self.distograms_path = distograms_path
@@ -34,6 +36,7 @@ def __init__(
3436
self.tau = tau
3537
self.distogram_files_prefix = distogram_files_prefix
3638
self.distogram_files_suffix = distogram_files_suffix
39+
self.distogram_only = distogram_only
3740
super().__init__(framework, clean_edges, edge_map, etype_key, **kwargs)
3841
pass
3942

@@ -72,15 +75,21 @@ def __call__(self, rna_graph, features_dict):
7275
new_edges = new_edges.t()
7376

7477
new_edge_attr = torch.full((new_edges.size(1),), max(self.edge_map.values())+1, dtype=torch.long)
75-
pyg_graph.edge_index = torch.cat([pyg_graph.edge_index, new_edges], dim=1)
76-
pyg_graph.edge_attr = torch.cat([pyg_graph.edge_attr, new_edge_attr], dim=0)
78+
79+
if self.distogram_only:
80+
pyg_graph.edge_index = new_edges
81+
pyg_graph.edge_attr = new_edge_attr
82+
else:
83+
pyg_graph.edge_index = torch.cat([pyg_graph.edge_index, new_edges], dim=1)
84+
pyg_graph.edge_attr = torch.cat([pyg_graph.edge_attr, new_edge_attr], dim=0)
7785

7886
if self.distogram_edge_features:
7987

8088
row, col = pyg_graph.edge_index
8189
edge_distances = torch.from_numpy(distogram[row,col])
82-
pyg_graph.edge_attr = edge_distances
83-
#pyg_graph.edge_attr = torch.cat([pyg_graph.edge_attr.unsqueeze(1), edge_distances], dim=1)
90+
num_classes = len(self.edge_map)
91+
edge_attr_one_hot = F.one_hot(pyg_graph.edge_attr.long(), num_classes=num_classes)
92+
pyg_graph.edge_attr = torch.cat([edge_attr_one_hot, edge_distances], dim=1).float()
8493

8594
return pyg_graph
8695

0 commit comments

Comments
 (0)