Skip to content

Commit

Permalink
Merge pull request #38 from chemle/fix_total_charge
Browse files Browse the repository at this point in the history
Fix total charge
  • Loading branch information
lohedges authored Nov 18, 2024
2 parents 32d56c5 + 3075a9b commit 1268c7f
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 21 deletions.
24 changes: 22 additions & 2 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
method="electrostatic",
alpha_mode="species",
atomic_numbers=None,
qm_charge=0,
backend="torchani",
external_backend=None,
plugin_path=".",
Expand Down Expand Up @@ -161,6 +162,12 @@ def __init__(
you are using a fixed QM region, i.e. the same QM region for each
call to the calculator.
qm_charge: int
The charge on the QM region. This is required when using an
EMLECalculator instance with the OpenMM interface. When using
the sander interface, the QM charge will be taken from the ORCA
input file.
external_backend: str
The name of an external backend to use to compute in vacuo energies.
This should be a callback function formatted as 'module.function'.
Expand Down Expand Up @@ -415,13 +422,23 @@ def __init__(
else:
self._mm_charges = None

if qm_charge is not None:
try:
qm_charge = int(qm_charge)
except:
msg = "'qm_charge' must be of type 'int'"
_logger.error(msg)
raise TypeError(msg)
self._qm_charge = qm_charge

# Create the EMLE model instance.
self._emle = _EMLE(
model=model,
method=method,
alpha_mode=alpha_mode,
atomic_numbers=atomic_numbers,
mm_charges=self._mm_charges,
qm_charge=self._qm_charge,
device=self._device,
)

Expand Down Expand Up @@ -962,6 +979,7 @@ def __init__(
"method": self._method,
"alpha_mode": self._alpha_mode,
"atomic_numbers": None if atomic_numbers is None else atomic_numbers,
"qm_charge": self._qm_charge,
"backend": self._backend,
"external_backend": None if external_backend is None else external_backend,
"mm_charges": None if mm_charges is None else self._mm_charges.tolist(),
Expand Down Expand Up @@ -1179,7 +1197,7 @@ def run(self, path=None):

# Compute energy and gradients.
try:
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, charge)
dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad(E.sum(), (xyz_qm, xyz_mm))
dE_dxyz_qm_bohr = dE_dxyz_qm.cpu().numpy() * _BOHR_TO_ANGSTROM
dE_dxyz_mm_bohr = dE_dxyz_mm.cpu().numpy() * _BOHR_TO_ANGSTROM
Expand Down Expand Up @@ -1208,7 +1226,7 @@ def run(self, path=None):
E_mm_qm_vac, grad_mm_qm_vac = 0.0, _np.zeros_like(xyz_qm)

# Compute the embedding contributions.
E = self._emle_mm(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
E = self._emle_mm(atomic_numbers, charges_mm, xyz_qm, xyz_mm, charge)
dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad(E.sum(), (xyz_qm, xyz_mm))
dE_dxyz_qm_bohr = dE_dxyz_qm.cpu().numpy() * _BOHR_TO_ANGSTROM
dE_dxyz_mm_bohr = dE_dxyz_mm.cpu().numpy() * _BOHR_TO_ANGSTROM
Expand Down Expand Up @@ -1767,6 +1785,7 @@ def _sire_callback_optimised(
model_index=self._ani2x_model_index,
ani2x_model=self._torchani_model,
atomic_numbers=atomic_numbers,
qm_charge=self._qm_charge,
device=self._device,
)

Expand All @@ -1779,6 +1798,7 @@ def _sire_callback_optimised(
# Compute the energy and gradients. Don't use optimised execution to
# avoid warmup costs.
with _torch.jit.optimized_execution(False):
# ANI-2x systems are always neutral, so charge not needed here
E = self._ani2x_emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm), allow_unused=allow_unused
Expand Down
20 changes: 18 additions & 2 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
emle_method="electrostatic",
alpha_mode="species",
mm_charges=None,
qm_charge=0,
model_index=None,
ani2x_model=None,
atomic_numbers=None,
Expand Down Expand Up @@ -101,6 +102,10 @@ def __init__(
List of MM charges for atoms in the QM region in units of mod
electron charge. This is required if the 'mm' method is specified.
qm_charge: int
The charge on the QM region. This can also be passed when calling
the forward method. The non-default value will take precendence.
model_index: int
The index of the ANI2x model to use. If None, then the full 8 model
ensemble will be used.
Expand Down Expand Up @@ -171,6 +176,7 @@ def __init__(
alpha_mode=alpha_mode,
atomic_numbers=(atomic_numbers if atomic_numbers is not None else None),
mm_charges=mm_charges,
qm_charge=qm_charge,
device=device,
dtype=dtype,
create_aev_calculator=False,
Expand Down Expand Up @@ -324,7 +330,14 @@ def float(self):

return self

def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
def forward(
self,
atomic_numbers: Tensor,
charges_mm: Tensor,
xyz_qm: Tensor,
xyz_mm: Tensor,
qm_charge: int = 0,
) -> Tensor:
"""
Compute the the ANI2x and static and induced EMLE energy components.
Expand All @@ -343,6 +356,9 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
xyz_mm: torch.Tensor (N_MM_ATOMS, 3)
Positions of MM atoms in Angstrom.
qm_charge: int
The charge on the QM region.
Returns
-------
Expand Down Expand Up @@ -370,7 +386,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
self._emle._emle_base._emle_aev_computer._aev = self._ani2x.aev_computer._aev

# Get the EMLE energy components.
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge)

# Return the ANI2x and EMLE energy components.
return _torch.stack([E_vac, E_emle[0], E_emle[1]])
46 changes: 31 additions & 15 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
method="electrostatic",
alpha_mode="species",
atomic_numbers=None,
qm_charge=0,
mm_charges=None,
device=None,
dtype=None,
Expand Down Expand Up @@ -129,6 +130,10 @@ def __init__(
if you are using a fixed QM region, i.e. the same QM region for each
evalulation of the module.
qm_charge: int
The charge on the QM region. This can also be passed when calling
the forward method. The non-default value will take precendence.
mm_charges: List[float], Tuple[Float], numpy.ndarray, torch.Tensor
List of MM charges for atoms in the QM region in units of mod
electron charge. This is required if the 'mm' method is specified.
Expand Down Expand Up @@ -192,6 +197,10 @@ def __init__(
)
self._atomic_numbers = atomic_numbers

if not isinstance(qm_charge, int):
raise TypeError("'qm_charge' must be of type 'int'")
self._qm_charge = qm_charge

if method == "mm":
if mm_charges is None:
raise ValueError("MM charges must be provided for the 'mm' method")
Expand Down Expand Up @@ -270,11 +279,6 @@ def __init__(
),
}

# Store the total charge.
q_total = _torch.tensor(
params.get("total_charge", 0), dtype=dtype, device=device
)

if method == "mm":
q_core_mm = _torch.tensor(mm_charges, dtype=dtype, device=device)
else:
Expand All @@ -284,7 +288,6 @@ def __init__(
self._device = device

# Register constants as buffers.
self.register_buffer("_q_total", q_total)
self.register_buffer("_q_core_mm", q_core_mm)

if not isinstance(create_aev_calculator, bool):
Expand Down Expand Up @@ -348,7 +351,6 @@ def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion on the model.
"""
self._q_total = self._q_total.to(*args, **kwargs)
self._q_core_mm = self._q_core_mm.to(*args, **kwargs)
self._emle_base = self._emle_base.to(*args, **kwargs)

Expand All @@ -364,33 +366,30 @@ def cuda(self, **kwargs):
"""
Move all model parameters and buffers to CUDA memory.
"""
self._q_total = self._q_total.cuda(**kwargs)
self._q_core_mm = self._q_core_mm.cuda(**kwargs)
self._emle_base = self._emle_base.cuda(**kwargs)

# Update the device attribute.
self._device = self._q_total.device
self._device = self._q_core_mm.device

return self

def cpu(self, **kwargs):
"""
Move all model parameters and buffers to CPU memory.
"""
self._q_total = self._q_total.cpu(**kwargs)
self._q_core_mm = self._q_core_mm.cpu(**kwargs)
self._emle_base = self._emle_base.cpu()

# Update the device attribute.
self._device = self._q_total.device
self._device = self._q_core_mm.device

return self

def double(self):
"""
Casts all floating point model parameters and buffers to float64 precision.
"""
self._q_total = self._q_total.double()
self._q_core_mm = self._q_core_mm.double()
self._emle_base = self._emle_base.double()
return self
Expand All @@ -399,12 +398,18 @@ def float(self):
"""
Casts all floating point model parameters and buffers to float32 precision.
"""
self._q_total = self._q_total.float()
self._q_core_mm = self._q_core_mm.float()
self._emle_base = self._emle_base.float()
return self

def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
def forward(
self,
atomic_numbers: Tensor,
charges_mm: Tensor,
xyz_qm: Tensor,
xyz_mm: Tensor,
qm_charge: int = 0,
) -> Tensor:
"""
Computes the static and induced EMLE energy components.
Expand All @@ -423,18 +428,27 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
xyz_mm: torch.Tensor (N_MM_ATOMS, 3)
Positions of MM atoms in Angstrom.
qm_charge: int
The charge on the QM region.
Returns
-------
result: torch.Tensor (2,)
The static and induced EMLE energy components in Hartree.
"""

# If the QM charge is a non-default value in the constructor, then
# use this value when qm_charge is zero.
if self._qm_charge != 0 and qm_charge == 0:
qm_charge = self._qm_charge

# Store the inputs as internal attributes.
self._atomic_numbers = atomic_numbers
self._charges_mm = charges_mm
self._xyz_qm = xyz_qm
self._xyz_mm = xyz_mm
self._qm_charge = qm_charge

# If there are no point charges, return zeros.
if len(xyz_mm) == 0:
Expand All @@ -445,7 +459,9 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
# These are returned as batched tensors, so we need to extract the
# first element of each.
s, q_core, q_val, A_thole = self._emle_base(
atomic_numbers[None, :], xyz_qm[None, :, :], self._q_total[None]
atomic_numbers[None, :],
xyz_qm[None, :, :],
_torch.tensor([qm_charge], dtype=xyz_qm.dtype, device=xyz_qm.device),
)

# Convert coordinates to Bohr.
Expand Down
22 changes: 20 additions & 2 deletions emle/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from ._emle import _has_nnpops
from ._utils import _get_neighbor_pairs

from torch import Tensor

try:
from mace.calculators.foundations_models import mace_off as _mace_off

Expand Down Expand Up @@ -66,6 +68,7 @@ def __init__(
emle_method="electrostatic",
alpha_mode="species",
mm_charges=None,
qm_charge=0,
mace_model=None,
atomic_numbers=None,
device=None,
Expand Down Expand Up @@ -106,6 +109,10 @@ def __init__(
List of MM charges for atoms in the QM region in units of mod
electron charge. This is required if the 'mm' method is specified.
qm_charge: int
The charge on the QM region. This can also be passed when calling
the forward method. The non-default value will take precendence.
mace_model: str
Name of the MACE-OFF23 models to use.
Available models are 'mace-off23-small', 'mace-off23-medium', 'mace-off23-large'.
Expand Down Expand Up @@ -173,6 +180,7 @@ def __init__(
alpha_mode=alpha_mode,
atomic_numbers=(atomic_numbers if atomic_numbers is not None else None),
mm_charges=mm_charges,
qm_charge=qm_charge,
device=device,
dtype=dtype,
create_aev_calculator=True,
Expand Down Expand Up @@ -361,7 +369,14 @@ def float(self):
self._mace = self._mace.float()
return self

def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
def forward(
self,
atomic_numbers: Tensor,
charges_mm: Tensor,
xyz_qm: Tensor,
xyz_mm: Tensor,
qm_charge: int = 0,
) -> Tensor:
"""
Compute the the MACE and static and induced EMLE energy components.
Expand All @@ -380,6 +395,9 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
xyz_mm: torch.Tensor (N_MM_ATOMS, 3)
Positions of MM atoms in Angstrom.
qm_charge: int
The charge on the QM region.
Returns
-------
Expand Down Expand Up @@ -436,7 +454,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
return _torch.stack([E_vac, zero, zero])

# Get the EMLE energy components.
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge)

# Return the MACE and EMLE energy components.
return _torch.stack([E_vac, E_emle[0], E_emle[1]])

0 comments on commit 1268c7f

Please sign in to comment.