Skip to content

Commit abfab20

Browse files
committed
Fix signs for alignment.
1 parent 611ab73 commit abfab20

1 file changed

Lines changed: 24 additions & 27 deletions

File tree

src/jamun/utils/align.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,14 @@ def C1(S: torch.Tensor) -> torch.Tensor:
1717
s1, s2, s3 = S[:, 0], S[:, 1], S[:, 2] # Each is [batch_size]
1818

1919
# Compute coefficients for each batch
20-
c1 = 1 / (s1 + s2) + 1 / (s1 + s3) # [batch_size]
21-
c2 = 1 / (s2 + s1) + 1 / (s2 + s3) # [batch_size]
22-
c3 = 1 / (s3 + s1) + 1 / (s3 + s2) # [batch_size]
20+
r1 = 1 / (s1 + s2) + 1 / (s1 + s3) # [batch_size]
21+
r2 = 1 / (s2 + s1) + 1 / (s2 + s3) # [batch_size]
22+
r3 = 1 / (s3 + s1) + 1 / (s3 + s2) # [batch_size]
2323

2424
# Stack to create diagonal elements matrix
25-
diag_elements = torch.stack([c1, c2, c3], dim=-1) # [batch_size, 3]
25+
C1 = torch.stack([r1, r2, r3], dim=-1) # [batch_size, 3]
2626

27-
# Create batch of diagonal matrices
28-
C1_batch = torch.diag_embed(diag_elements) # [batch_size, 3, 3]
29-
30-
return -C1_batch / 2
27+
return -C1 / 2
3128

3229

3330
def C2(S: torch.Tensor) -> torch.Tensor:
@@ -44,17 +41,14 @@ def C2(S: torch.Tensor) -> torch.Tensor:
4441
s1, s2, s3 = S[:, 0], S[:, 1], S[:, 2] # Each is [batch_size]
4542

4643
# Compute coefficients for each batch
47-
c1 = 1 / (s1 + s2) ** 2 + 1 / (s1 + s3) ** 2 # [batch_size]
48-
c2 = 1 / (s2 + s1) ** 2 + 1 / (s2 + s3) ** 2 # [batch_size]
49-
c3 = 1 / (s3 + s1) ** 2 + 1 / (s3 + s2) ** 2 # [batch_size]
44+
r1 = 1 / (s1 + s2) ** 2 + 1 / (s1 + s3) ** 2 # [batch_size]
45+
r2 = 1 / (s2 + s1) ** 2 + 1 / (s2 + s3) ** 2 # [batch_size]
46+
r3 = 1 / (s3 + s1) ** 2 + 1 / (s3 + s2) ** 2 # [batch_size]
5047

5148
# Stack to create diagonal elements matrix
52-
diag_elements = torch.stack([c1, c2, c3], dim=-1) # [batch_size, 3]
53-
54-
# Create batch of diagonal matrices
55-
C2_batch = torch.diag_embed(diag_elements) # [batch_size, 3, 3]
49+
C2 = torch.stack([r1, r2, r3], dim=-1) # [batch_size, 3]
5650

57-
return -C2_batch / 8
51+
return -C2 / 8
5852

5953

6054
def alignment_correction_upto_order(S: torch.Tensor, sigma: float, correction_order: int) -> torch.Tensor:
@@ -72,14 +66,14 @@ def alignment_correction_upto_order(S: torch.Tensor, sigma: float, correction_or
7266
batch_size = S.shape[0]
7367
assert S.shape == (batch_size, 3)
7468

75-
identity = torch.eye(3, device=S.device, dtype=S.dtype).unsqueeze(0).expand(batch_size, -1, -1)
69+
ones = torch.ones((batch_size, 3), device=S.device, dtype=S.dtype)
7670

7771
if correction_order == 0:
78-
return identity
72+
return ones
7973
if correction_order == 1:
80-
return identity + (sigma**2) * C1(S)
74+
return ones + (sigma**2) * C1(S)
8175
if correction_order == 2:
82-
return identity + (sigma**2) * C1(S) + (sigma**4) * C2(S)
76+
return ones + (sigma**2) * C1(S) + (sigma**4) * C2(S)
8377

8478
raise ValueError(f"Correction order {correction_order} not supported.")
8579

@@ -123,15 +117,18 @@ def kabsch_algorithm(
123117
H = torch.einsum("Ni,Nj,NG->Gij", A_c, B_c, batch_one_hot)
124118

125119
# SVD to get rotation.
126-
U, S_orig, VH = torch.linalg.svd(H)
127-
S = alignment_correction_upto_order(S_orig, sigma=sigma, correction_order=correction_order)
128-
R_check = torch.einsum("Gki,Gkk,Gjk->Gij", VH, S, U) # V U^T
120+
U, S, VH = torch.linalg.svd(H)
129121

130-
# Remove reflections.
122+
# Compute corrected S.
123+
R_check = torch.einsum("Gki,Gjk->Gij", VH, U) # V U^T
131124
dets = torch.linalg.det(R_check)
132-
signs = torch.eye(3, device=dets.device).repeat(num_graphs, 1, 1) # repeat the identity matrix
133-
signs[:, 2, 2] = dets
134-
R = torch.einsum("Gki,Gkk,Gkk,Gjk->Gij", VH, signs, S, U) # V S U^T
125+
signs = torch.ones(num_graphs, 3, device=dets.device)
126+
signs[:, 2] = dets
127+
S = torch.einsum("Gk,Gk->Gk", signs, S)
128+
129+
# Remove reflections.
130+
S = alignment_correction_upto_order(S, sigma=sigma, correction_order=correction_order)
131+
R = torch.einsum("Gki,Gk,Gk,Gjk->Gij", VH, signs, S, U) # V S U^T
135132

136133
# Align y to x.
137134
RA_mu = torch.einsum("Gij,Gj->Gi", R, A_mu)

0 commit comments

Comments
 (0)