Skip to content

Commit

Permalink
Fix auxiliary tensors created on wrong device
Browse files Browse the repository at this point in the history
  • Loading branch information
kzinovjev committed Oct 21, 2024
1 parent 4e685b4 commit 1b81f8a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions emle/models/_emle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def forward(self, atomic_numbers, xyz_qm, q_total):
Valence widths, core charges, valence charges, A_thole tensor
"""

mask = atomic_numbers > 0
mask = _torch.tensor(atomic_numbers > 0, device=self._ref_mean_s.device)

# Convert the atomic numbers to species IDs.
species_id = self._species_map[atomic_numbers]
Expand Down Expand Up @@ -315,7 +315,7 @@ def _get_Kinv(cls, ref_features, sigma):

@classmethod
def _get_c(cls, n_ref, ref, Kinv):
mask = _torch.arange(ref.shape[1]) < n_ref[:, None]
mask = _torch.arange(ref.shape[1], device=n_ref.device) < n_ref[:, None]
ref_mean = _torch.sum(ref * mask, dim=1) / n_ref
ref_shifted = ref - ref_mean[:, None]
return ref_mean, (Kinv @ ref_shifted[:, :, None]).squeeze()
Expand Down

0 comments on commit 1b81f8a

Please sign in to comment.