@@ -17,17 +17,14 @@ def C1(S: torch.Tensor) -> torch.Tensor:
1717 s1 , s2 , s3 = S [:, 0 ], S [:, 1 ], S [:, 2 ] # Each is [batch_size]
1818
1919 # Compute coefficients for each batch
20- c1 = 1 / (s1 + s2 ) + 1 / (s1 + s3 ) # [batch_size]
21- c2 = 1 / (s2 + s1 ) + 1 / (s2 + s3 ) # [batch_size]
22- c3 = 1 / (s3 + s1 ) + 1 / (s3 + s2 ) # [batch_size]
20+ r1 = 1 / (s1 + s2 ) + 1 / (s1 + s3 ) # [batch_size]
21+ r2 = 1 / (s2 + s1 ) + 1 / (s2 + s3 ) # [batch_size]
22+ r3 = 1 / (s3 + s1 ) + 1 / (s3 + s2 ) # [batch_size]
2323
2424 # Stack to create diagonal elements matrix
25- diag_elements = torch .stack ([c1 , c2 , c3 ], dim = - 1 ) # [batch_size, 3]
25+ C1 = torch .stack ([r1 , r2 , r3 ], dim = - 1 ) # [batch_size, 3]
2626
27- # Create batch of diagonal matrices
28- C1_batch = torch .diag_embed (diag_elements ) # [batch_size, 3, 3]
29-
30- return - C1_batch / 2
27+ return - C1 / 2
3128
3229
3330def C2 (S : torch .Tensor ) -> torch .Tensor :
@@ -44,17 +41,14 @@ def C2(S: torch.Tensor) -> torch.Tensor:
4441 s1 , s2 , s3 = S [:, 0 ], S [:, 1 ], S [:, 2 ] # Each is [batch_size]
4542
4643 # Compute coefficients for each batch
47- c1 = 1 / (s1 + s2 ) ** 2 + 1 / (s1 + s3 ) ** 2 # [batch_size]
48- c2 = 1 / (s2 + s1 ) ** 2 + 1 / (s2 + s3 ) ** 2 # [batch_size]
49- c3 = 1 / (s3 + s1 ) ** 2 + 1 / (s3 + s2 ) ** 2 # [batch_size]
44+ r1 = 1 / (s1 + s2 ) ** 2 + 1 / (s1 + s3 ) ** 2 # [batch_size]
45+ r2 = 1 / (s2 + s1 ) ** 2 + 1 / (s2 + s3 ) ** 2 # [batch_size]
46+ r3 = 1 / (s3 + s1 ) ** 2 + 1 / (s3 + s2 ) ** 2 # [batch_size]
5047
5148 # Stack to create diagonal elements matrix
52- diag_elements = torch .stack ([c1 , c2 , c3 ], dim = - 1 ) # [batch_size, 3]
53-
54- # Create batch of diagonal matrices
55- C2_batch = torch .diag_embed (diag_elements ) # [batch_size, 3, 3]
49+ C2 = torch .stack ([r1 , r2 , r3 ], dim = - 1 ) # [batch_size, 3]
5650
57- return - C2_batch / 8
51+ return - C2 / 8
5852
5953
6054def alignment_correction_upto_order (S : torch .Tensor , sigma : float , correction_order : int ) -> torch .Tensor :
@@ -72,14 +66,14 @@ def alignment_correction_upto_order(S: torch.Tensor, sigma: float, correction_or
7266 batch_size = S .shape [0 ]
7367 assert S .shape == (batch_size , 3 )
7468
75- identity = torch .eye ( 3 , device = S .device , dtype = S .dtype ). unsqueeze ( 0 ). expand ( batch_size , - 1 , - 1 )
69+ ones = torch .ones (( batch_size , 3 ) , device = S .device , dtype = S .dtype )
7670
7771 if correction_order == 0 :
78- return identity
72+ return ones
7973 if correction_order == 1 :
80- return identity + (sigma ** 2 ) * C1 (S )
74+ return ones + (sigma ** 2 ) * C1 (S )
8175 if correction_order == 2 :
82- return identity + (sigma ** 2 ) * C1 (S ) + (sigma ** 4 ) * C2 (S )
76+ return ones + (sigma ** 2 ) * C1 (S ) + (sigma ** 4 ) * C2 (S )
8377
8478 raise ValueError (f"Correction order { correction_order } not supported." )
8579
@@ -123,15 +117,18 @@ def kabsch_algorithm(
123117 H = torch .einsum ("Ni,Nj,NG->Gij" , A_c , B_c , batch_one_hot )
124118
125119 # SVD to get rotation.
126- U , S_orig , VH = torch .linalg .svd (H )
127- S = alignment_correction_upto_order (S_orig , sigma = sigma , correction_order = correction_order )
128- R_check = torch .einsum ("Gki,Gkk,Gjk->Gij" , VH , S , U ) # V U^T
120+ U , S , VH = torch .linalg .svd (H )
129121
130- # Remove reflections.
122+ # Compute corrected S.
123+ R_check = torch .einsum ("Gki,Gjk->Gij" , VH , U ) # V U^T
131124 dets = torch .linalg .det (R_check )
132- signs = torch .eye (3 , device = dets .device ).repeat (num_graphs , 1 , 1 ) # repeat the identity matrix
133- signs [:, 2 , 2 ] = dets
134- R = torch .einsum ("Gki,Gkk,Gkk,Gjk->Gij" , VH , signs , S , U ) # V S U^T
125+ signs = torch .ones (num_graphs , 3 , device = dets .device )
126+ signs [:, 2 ] = dets
127+ S = torch .einsum ("Gk,Gk->Gk" , signs , S )
128+
129+ # Remove reflections.
130+ S = alignment_correction_upto_order (S , sigma = sigma , correction_order = correction_order )
131+ R = torch .einsum ("Gki,Gk,Gk,Gjk->Gij" , VH , signs , S , U ) # V S U^T
135132
136133 # Align y to x.
137134 RA_mu = torch .einsum ("Gij,Gj->Gi" , R , A_mu )
0 commit comments