Skip to content

Commit

Permalink
Add batch support to the forward method of ANI2xEMLE
Browse files Browse the repository at this point in the history
  • Loading branch information
JMorado committed Nov 27, 2024
1 parent 763daec commit 37e0e38
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def float(self):
try:
from NNPOps import OptimizedTorchANI as _OptimizedTorchANI

species = self._atomic_numbers.reshape(1, *atomic_numbers.shape)
species = self._atomic_numbers.reshape(1, *self._atomic_numbers.shape)
self._ani2x = _OptimizedTorchANI(self._ani2x, species).to(self._device)
except:
pass
Expand All @@ -344,42 +344,44 @@ def forward(
Parameters
----------
atomic_numbers: torch.Tensor (N_QM_ATOMS,)
Atomic numbers of QM atoms.
atomic_numbers: torch.Tensor (N_QM_ATOMS,) or (BATCH, N_QM_ATOMS)
Atomic numbers of QM atoms. A non-existent atom is represented by -1.
charges_mm: torch.Tensor (max_mm_atoms,)
charges_mm: torch.Tensor (max_mm_atoms,) or (BATCH, max_mm_atoms)
MM point charges in atomic units.
xyz_qm: torch.Tensor (N_QM_ATOMS, 3)
xyz_qm: torch.Tensor (N_QM_ATOMS, 3) or (BATCH, N_QM_ATOMS, 3)
Positions of QM atoms in Angstrom.
xyz_mm: torch.Tensor (N_MM_ATOMS, 3)
xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3)
Positions of MM atoms in Angstrom.
qm_charge: int
qm_charge: int or torch.Tensor (BATCH,)
The charge on the QM region.
Returns
-------
result: torch.Tensor (3,)
result: torch.Tensor (3,) or (3, BATCH)
The ANI2x and static and induced EMLE energy components in Hartree.
"""

# Reshape the atomic numbers.
atomic_numbers_ani = atomic_numbers.unsqueeze(0)

# Reshape the coordinates,
xyz = xyz_qm.unsqueeze(0)
if atomic_numbers.ndim == 1:
# Batch the inputs tensors.
atomic_numbers = atomic_numbers.unsqueeze(0)
xyz_qm = xyz_qm.unsqueeze(0)
xyz_mm = xyz_mm.unsqueeze(0)
charges_mm = charges_mm.unsqueeze(0)

# Get the in vacuo energy.
E_vac = self._ani2x((atomic_numbers_ani, xyz)).energies[0]
E_vac = self._ani2x((atomic_numbers, xyz_qm)).energies

# If there are no point charges, return the in vacuo energy and zeros
# for the static and induced terms.
if len(xyz_mm) == 0:
zero = _torch.tensor(0.0, dtype=xyz_qm.dtype, device=xyz_qm.device)
return _torch.stack([E_vac, zero, zero])
if xyz_mm.shape[1] == 0:
zero = _torch.zeros(
atomic_numbers.shape[0], dtype=xyz_qm.dtype, device=xyz_qm.device
)
return _torch.stack((E_vac, zero, zero))

# Set the AEVs captured by the forward hook as an attribute of the
# EMLE AEVComputer instance.
Expand All @@ -389,4 +391,4 @@ def forward(
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge)

# Return the ANI2x and EMLE energy components.
return _torch.stack([E_vac, E_emle[0][0], E_emle[1][0]])
return _torch.stack((E_vac, E_emle[0], E_emle[1]))

0 comments on commit 37e0e38

Please sign in to comment.