Skip to content

Commit

Permalink
Fix optimisation of EMLEAEVComputer.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Dec 16, 2024
1 parent 68f3067 commit dbfc754
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,24 @@ def __init__(
# Optimise the AEV computer using NNPOps if available.
if _has_nnpops and atomic_numbers is not None:
try:
import ase
from torchani import SpeciesConverter

# Work out the species.
species = [ase.Atom(i).symbol for i in atomic_numbers]

# Create a species converter.
species_converter = SpeciesConverter(species)

atomic_numbers = _torch.tensor(
atomic_numbers, dtype=_torch.int64, device=device
)

atomic_numbers = atomic_numbers.reshape(1, *atomic_numbers.shape)
emle_aev_computer._aev_computer = (
_NNPOps.SymmetryFunctions.TorchANISymmetryFunctions(
emle_aev_computer._aev_computer.species_converter,
emle_aev_computer._aev_computer.aev_computer,
species_converter,
emle_aev_computer._aev_computer,
atomic_numbers,
)
)
Expand Down

0 comments on commit dbfc754

Please sign in to comment.