Skip to content

Commit 7a37ab4

Browse files
committed
fix conflicts
2 parents 98dabf4 + 70fdf59 commit 7a37ab4

3 files changed

Lines changed: 11 additions & 3 deletions

File tree

src/rnaglib/transforms/represent/distograph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,15 @@ def __call__(self, rna_graph, features_dict):
6969
distogram_dict = pickle.load(f)
7070
distogram = distogram_dict["distogram"]["softmax"]
7171

72+
nb_bins = distogram.shape[2]
7273
chain_dict = get_sequences(base_graph)
7374
sorted_distogram_residues = [item for chain in sorted(chain_dict.keys()) for item in chain_dict[chain][1]]
74-
nb_bins = distogram.shape[2]
7575

7676
if self.distogram_edges:
7777
dist_tensor = torch.from_numpy(distogram)
7878
proba_matrix = dist_tensor[:, :, :self.B].sum(dim=2)
7979
proba_matrix.fill_diagonal_(float(0))
8080
new_edge_indices = torch.nonzero(proba_matrix > self.tau, as_tuple=False)
81-
8281
node_map = {n: i for i, n in enumerate(sorted(base_graph.nodes(), key=lambda x:(x.split('.')[1],int(x.split('.')[2]))))}
8382
new_edges = [[node_map[sorted_distogram_residues[u]],node_map[sorted_distogram_residues[v]]] for u, v in new_edge_indices]
8483
new_edges = torch.tensor(new_edges, dtype=torch.long).T

src/rnaglib/transforms/represent/graph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,17 @@ def to_pyg(self, graph, features_dict):
155155
# targets: flatten the nearest neighbor indices
156156
target = neighbor_indices.flatten()
157157
edge_index= torch.stack([source, target], dim=0)
158+
<<<<<<< HEAD
158159
edge_attrs = torch.zeros(edge_index.shape[1],dtype=int)
159160

160161
elif self.graph_construction == "threshold":
161162
edges = torch.nonzero(dist_matrix < self.threshold, as_tuple=False)
162163
edge_attrs = torch.zeros(edge_index.shape[1],dtype=int)
164+
=======
165+
166+
elif self.graph_construction == "threshold":
167+
edges = torch.nonzero(dist_matrix < self.threshold, as_tuple=False)
168+
>>>>>>> 70fdf59f1e46542d77c7d48f3930e1d152e223e0
163169
edge_index = edges.t()
164170

165171
else:
@@ -171,6 +177,10 @@ def to_pyg(self, graph, features_dict):
171177
if self.distance_edge_features:
172178
edge_distances = dist_matrix[edge_index[0, :], edge_index[1, :]]
173179
edge_feats = rbf_expand(dists=edge_distances, num_bins=64, min_distance=2.0, max_distance=22.0)
180+
<<<<<<< HEAD
181+
=======
182+
edge_attrs = torch.zeros(edge_index.shape[1],dtype=int)
183+
>>>>>>> 70fdf59f1e46542d77c7d48f3930e1d152e223e0
174184
return Data(x=x, y=y, edge_attr=edge_attrs, edge_index=edge_index, edge_feats=edge_feats)
175185

176186
return Data(x=x, y=y, edge_attr=edge_attrs, edge_index=edge_index)

src/rnaglib/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from .task_utils import print_statistics
1919
from .task_utils import DummyResidueModel, DummyGraphModel
20-
2120
from .represent_utils import rbf_expand
2221

2322
from .wrappers import rna_align_wrapper, cdhit_wrapper, locarna_wrapper, US_align_wrapper

0 commit comments

Comments
 (0)