Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch forward method of the EMLE class #39

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def float(self):
try:
from NNPOps import OptimizedTorchANI as _OptimizedTorchANI

species = self._atomic_numbers.reshape(1, *atomic_numbers.shape)
species = self._atomic_numbers.reshape(1, *self._atomic_numbers.shape)
self._ani2x = _OptimizedTorchANI(self._ani2x, species).to(self._device)
except:
pass
Expand All @@ -344,42 +344,44 @@ def forward(
Parameters
----------

atomic_numbers: torch.Tensor (N_QM_ATOMS,)
Atomic numbers of QM atoms.
atomic_numbers: torch.Tensor (N_QM_ATOMS,) or (BATCH, N_QM_ATOMS)
Atomic numbers of QM atoms. A non-existent atom is represented by -1.

charges_mm: torch.Tensor (max_mm_atoms,)
charges_mm: torch.Tensor (max_mm_atoms,) or (BATCH, max_mm_atoms)
MM point charges in atomic units.

xyz_qm: torch.Tensor (N_QM_ATOMS, 3)
xyz_qm: torch.Tensor (N_QM_ATOMS, 3) or (BATCH, N_QM_ATOMS, 3)
Positions of QM atoms in Angstrom.

xyz_mm: torch.Tensor (N_MM_ATOMS, 3)
xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3)
Positions of MM atoms in Angstrom.

qm_charge: int
qm_charge: int or torch.Tensor (BATCH,)
The charge on the QM region.

Returns
-------

result: torch.Tensor (3,)
result: torch.Tensor (3,) or (3, BATCH)
The ANI2x and static and induced EMLE energy components in Hartree.
"""

# Reshape the atomic numbers.
atomic_numbers_ani = atomic_numbers.unsqueeze(0)

# Reshape the coordinates,
xyz = xyz_qm.unsqueeze(0)
if atomic_numbers.ndim == 1:
# Batch the inputs tensors.
atomic_numbers = atomic_numbers.unsqueeze(0)
xyz_qm = xyz_qm.unsqueeze(0)
xyz_mm = xyz_mm.unsqueeze(0)
charges_mm = charges_mm.unsqueeze(0)

# Get the in vacuo energy.
E_vac = self._ani2x((atomic_numbers_ani, xyz)).energies[0]
E_vac = self._ani2x((atomic_numbers, xyz_qm)).energies

# If there are no point charges, return the in vacuo energy and zeros
# for the static and induced terms.
if len(xyz_mm) == 0:
zero = _torch.tensor(0.0, dtype=xyz_qm.dtype, device=xyz_qm.device)
return _torch.stack([E_vac, zero, zero])
if xyz_mm.shape[1] == 0:
zero = _torch.zeros(
atomic_numbers.shape[0], dtype=xyz_qm.dtype, device=xyz_qm.device
)
return _torch.stack((E_vac, zero, zero))

# Set the AEVs captured by the forward hook as an attribute of the
# EMLE AEVComputer instance.
Expand All @@ -389,4 +391,4 @@ def forward(
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge)

# Return the ANI2x and EMLE energy components.
return _torch.stack([E_vac, E_emle[0], E_emle[1]])
return _torch.stack((E_vac, E_emle[0], E_emle[1]))
84 changes: 50 additions & 34 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import torchani as _torchani

from torch import Tensor
from typing import Optional, Tuple, List
from typing import Union

from . import _patches
from . import EMLEBase as _EMLEBase
Expand Down Expand Up @@ -408,93 +408,109 @@ def forward(
charges_mm: Tensor,
xyz_qm: Tensor,
xyz_mm: Tensor,
qm_charge: int = 0,
qm_charge: Union[int, Tensor] = 0,
) -> Tensor:
"""
Computes the static and induced EMLE energy components.

Parameters
----------

atomic_numbers: torch.Tensor (N_QM_ATOMS,)
atomic_numbers: torch.Tensor (N_QM_ATOMS,) or (BATCH, N_QM_ATOMS)
Atomic numbers of QM atoms.

charges_mm: torch.Tensor (max_mm_atoms,)
charges_mm: torch.Tensor (max_mm_atoms,) or (BATCH, max_mm_atoms)
MM point charges in atomic units.

xyz_qm: torch.Tensor (N_QM_ATOMS, 3)
xyz_qm: torch.Tensor (N_QM_ATOMS, 3) or (BATCH, N_QM_ATOMS, 3)
Positions of QM atoms in Angstrom.

xyz_mm: torch.Tensor (N_MM_ATOMS, 3)
xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3)
Positions of MM atoms in Angstrom.

qm_charge: int
qm_charge: int or torch.Tensor (BATCH,)
The charge on the QM region.

Returns
-------

result: torch.Tensor (2,)
result: torch.Tensor (2,) or (2, BATCH)
The static and induced EMLE energy components in Hartree.
"""

# If the QM charge is a non-default value in the constructor, then
# use this value when qm_charge is zero.
if self._qm_charge != 0 and qm_charge == 0:
qm_charge = self._qm_charge

# Store the inputs as internal attributes.
self._atomic_numbers = atomic_numbers
self._charges_mm = charges_mm
self._xyz_qm = xyz_qm
self._xyz_mm = xyz_mm
self._qm_charge = qm_charge

if self._atomic_numbers.ndim == 1:
self._atomic_numbers = self._atomic_numbers.unsqueeze(0)
self._charges_mm = self._charges_mm.unsqueeze(0)
self._xyz_qm = self._xyz_qm.unsqueeze(0)
self._xyz_mm = self._xyz_mm.unsqueeze(0)

batch_size = self._atomic_numbers.shape[0]

# Ensure qm_charge is a tensor and repeat for batch size if necessary
if isinstance(qm_charge, int):
qm_charge = _torch.full(
(batch_size,),
qm_charge if qm_charge != 0 else self._qm_charge,
dtype=_torch.int64,
device=self._device,
)
elif isinstance(qm_charge, _torch.Tensor):
if qm_charge.ndim == 0:
qm_charge = qm_charge.repeat(batch_size).to(self._device)

# If there are no point charges, return zeros.
if len(xyz_mm) == 0:
return _torch.zeros(2, dtype=xyz_qm.dtype, device=xyz_qm.device)
if xyz_mm.shape[1] == 0:
return _torch.zeros(
2, batch_size, dtype=self._xyz_qm.dtype, device=self._xyz_qm.device
)

# Get the parameters from the base model:
# valence widths, core charges, valence charges, A_thole tensor
# These are returned as batched tensors, so we need to extract the
# first element of each.
s, q_core, q_val, A_thole = self._emle_base(
atomic_numbers[None, :],
xyz_qm[None, :, :],
_torch.tensor([qm_charge], dtype=xyz_qm.dtype, device=xyz_qm.device),
self._atomic_numbers,
self._xyz_qm,
qm_charge,
)

# Convert coordinates to Bohr.
ANGSTROM_TO_BOHR = 1.8897261258369282
xyz_qm_bohr = xyz_qm * ANGSTROM_TO_BOHR
xyz_mm_bohr = xyz_mm * ANGSTROM_TO_BOHR
xyz_qm_bohr = self._xyz_qm * ANGSTROM_TO_BOHR
xyz_mm_bohr = self._xyz_mm * ANGSTROM_TO_BOHR

# Compute the static energy.
if self._method == "mm":
q_core = self._q_core_mm[None, :]
q_core = self._q_core_mm.expand(batch_size, -1)
q_val = _torch.zeros_like(
q_core, dtype=charges_mm.dtype, device=self._device
q_core, dtype=self._charges_mm.dtype, device=self._device
)

mesh_data = self._emle_base._get_mesh_data(
xyz_qm_bohr[None, :, :], xyz_mm_bohr[None, :, :], s
)
mask = (self._atomic_numbers > 0).unsqueeze(-1)
mesh_data = self._emle_base._get_mesh_data(xyz_qm_bohr, xyz_mm_bohr, s, mask)

if self._method == "mechanical":
q_core = q_core + q_val
q_val = _torch.zeros_like(
q_core, dtype=charges_mm.dtype, device=self._device
q_core, dtype=self._charges_mm.dtype, device=self._device
)
E_static = self._emle_base.get_static_energy(
q_core, q_val, charges_mm[None, :], mesh_data
)[0]
q_core, q_val, self._charges_mm, mesh_data
)

# Compute the induced energy.
if self._method == "electrostatic":
E_ind = self._emle_base.get_induced_energy(
A_thole, charges_mm[None, :], s, mesh_data
)[0]
A_thole, self._charges_mm, s, mesh_data, mask
)
else:
E_ind = _torch.tensor(0.0, dtype=charges_mm.dtype, device=self._device)
E_ind = _torch.zeros_like(
E_static, dtype=self._charges_mm.dtype, device=self._device
)

return _torch.stack([E_static, E_ind])
return _torch.stack((E_static, E_ind), dim=0)
27 changes: 21 additions & 6 deletions emle/models/_emle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ def get_induced_energy(
charges_mm: Tensor,
s: Tensor,
mesh_data: Tuple[Tensor, Tensor, Tensor],
mask: Tensor,
) -> Tensor:
"""
Calculate the induced electrostatic energy.
Expand All @@ -843,13 +844,16 @@ def get_induced_energy(
mesh_data: mesh_data object (output of self._get_mesh_data)
Mesh data object.

mask: torch.Tensor (N_BATCH, MAX_QM_ATOMS)
Mask for padded coordinates.

Returns
-------

result: torch.Tensor (N_BATCH,)
Induced electrostatic energy.
"""
mu_ind = EMLEBase._get_mu_ind(A_thole, mesh_data, charges_mm, s)
mu_ind = EMLEBase._get_mu_ind(A_thole, mesh_data, charges_mm, s, mask)
vpot_ind = EMLEBase._get_vpot_mu(mu_ind, mesh_data[2])
return _torch.sum(vpot_ind * charges_mm, dim=1) * 0.5

Expand All @@ -859,6 +863,7 @@ def _get_mu_ind(
mesh_data: Tuple[Tensor, Tensor, Tensor],
q: Tensor,
s: Tensor,
mask: Tensor,
) -> Tensor:
"""
Internal method, calculates induced atomic dipoles
Expand All @@ -881,6 +886,9 @@ def _get_mu_ind(
q_val: torch.Tensor (N_BATCH, N_QM_ATOMS,)
MBIS valence charges.

mask: torch.Tensor (N_BATCH, N_QM_ATOMS)
Mask for padded coordinates.

Returns
-------

Expand All @@ -889,7 +897,7 @@ def _get_mu_ind(
"""

r = 1.0 / mesh_data[0]
f1 = EMLEBase._get_f1_slater(r, s[:, :, None] * 2.0)
f1 = _torch.where(mask, EMLEBase._get_f1_slater(r, s[:, :, None] * 2.0), 0.0)
fields = _torch.sum(
mesh_data[2] * f1[..., None] * q[:, None, :, None], dim=2
).reshape(len(s), -1)
Expand Down Expand Up @@ -944,7 +952,7 @@ def _get_vpot_mu(mu: Tensor, T1: Tensor) -> Tensor:

@staticmethod
def _get_mesh_data(
xyz: Tensor, xyz_mesh: Tensor, s: Tensor
xyz: Tensor, xyz_mesh: Tensor, s: Tensor, mask: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Internal method, calculates mesh_data object.
Expand All @@ -961,6 +969,9 @@ def _get_mesh_data(
s: torch.Tensor (N_BATCH, MAX_QM_ATOMS,)
MBIS valence widths.

mask: torch.Tensor (N_BATCH, MAX_QM_ATOMS)
Mask for padded coordinates.

Returns
-------

Expand All @@ -969,11 +980,15 @@ def _get_mesh_data(
"""
rr = xyz_mesh[:, None, :, :] - xyz[:, :, None, :]
r = _torch.linalg.norm(rr, ord=2, dim=3)

# Mask for padded coordinates.
r_inv = _torch.where(mask, 1.0 / r, 0.0)
T0_slater = _torch.where(mask, EMLEBase._get_T0_slater(r, s[:, :, None]), 0.0)

return (
1.0 / r,
EMLEBase._get_T0_slater(r, s[:, :, None]),
-rr / r[..., None] ** 3,
r_inv,
T0_slater,
-rr * r_inv[..., None] ** 3,
)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion emle/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,4 @@ def forward(
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge)

# Return the MACE and EMLE energy components.
return _torch.stack([E_vac, E_emle[0], E_emle[1]])
return _torch.stack([E_vac, E_emle[0][0], E_emle[1][0]])
Loading