Skip to content

Commit 70fdf59

Browse files
committed
add RBF encoding
1 parent 1d2dfe6 commit 70fdf59

2 files changed

Lines changed: 35 additions & 0 deletions

File tree

src/rnaglib/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from .task_utils import print_statistics
1919
from .task_utils import DummyResidueModel, DummyGraphModel
20+
from .represent_utils import rbf_expand
2021

2122
from .wrappers import rna_align_wrapper, cdhit_wrapper, locarna_wrapper, US_align_wrapper
2223

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
3+
def rbf_expand(dists, num_bins: int = 64, min_distance=2.0, max_distance=22.0, gamma=None):
4+
# Calculate centers
5+
# First bin
6+
centers = torch.zeros(num_bins)
7+
centers[0] = 1.0
8+
# Middle bins
9+
width = (max_distance-min_distance) / (num_bins-2)
10+
11+
for i in range(1, num_bins-1):
12+
centers[i] = min_distance + width * (i - 0.5)
13+
14+
# Last bin
15+
centers[-1] = max_distance
16+
centers = centers.view(1, -1)
17+
18+
if gamma is None:
19+
gamma = 1.0 / (width ** 2)
20+
21+
if dists.dim() == 1:
22+
dists = dists.unsqueeze(-1) # [E, 1]
23+
24+
# Calculate standard Gaussian for all: [E, 64]
25+
diff = dists - centers
26+
rbf = torch.exp(-gamma * diff.pow(2))
27+
28+
# Apply saturation to the tail
29+
# if dist > max_distance, force value to 1.0
30+
rbf[:, -1] = torch.where(dists.squeeze() > max_distance,
31+
torch.ones_like(rbf[:, -1]),
32+
rbf[:, -1])
33+
34+
return rbf

0 commit comments

Comments
 (0)