Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix total charge #38

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ def run(self, path=None):

# Compute energy and gradients.
try:
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, charge)
dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad(E.sum(), (xyz_qm, xyz_mm))
dE_dxyz_qm_bohr = dE_dxyz_qm.cpu().numpy() * _BOHR_TO_ANGSTROM
dE_dxyz_mm_bohr = dE_dxyz_mm.cpu().numpy() * _BOHR_TO_ANGSTROM
Expand Down Expand Up @@ -1208,7 +1208,7 @@ def run(self, path=None):
E_mm_qm_vac, grad_mm_qm_vac = 0.0, _np.zeros_like(xyz_qm)

# Compute the embedding contributions.
E = self._emle_mm(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
E = self._emle_mm(atomic_numbers, charges_mm, xyz_qm, xyz_mm, charge)
dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad(E.sum(), (xyz_qm, xyz_mm))
dE_dxyz_qm_bohr = dE_dxyz_qm.cpu().numpy() * _BOHR_TO_ANGSTROM
dE_dxyz_mm_bohr = dE_dxyz_mm.cpu().numpy() * _BOHR_TO_ANGSTROM
Expand Down Expand Up @@ -1779,6 +1779,7 @@ def _sire_callback_optimised(
# Compute the energy and gradients. Don't use optimised execution to
# avoid warmup costs.
with _torch.jit.optimized_execution(False):
# ANI-2x systems are always neutral, so charge not needed here
E = self._ani2x_emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm), allow_unused=allow_unused
Expand Down
20 changes: 18 additions & 2 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
emle_method="electrostatic",
alpha_mode="species",
mm_charges=None,
qm_charge=0,
model_index=None,
ani2x_model=None,
atomic_numbers=None,
Expand Down Expand Up @@ -101,6 +102,10 @@ def __init__(
List of MM charges for atoms in the QM region in units of mod
electron charge. This is required if the 'mm' method is specified.

qm_charge: int
The charge on the QM region. This can also be passed when calling
the forward method. The non-default value will take precendence.

model_index: int
The index of the ANI2x model to use. If None, then the full 8 model
ensemble will be used.
Expand Down Expand Up @@ -171,6 +176,7 @@ def __init__(
alpha_mode=alpha_mode,
atomic_numbers=(atomic_numbers if atomic_numbers is not None else None),
mm_charges=mm_charges,
qm_charge=qm_charge,
device=device,
dtype=dtype,
create_aev_calculator=False,
Expand Down Expand Up @@ -324,7 +330,14 @@ def float(self):

return self

def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
def forward(
self,
atomic_numbers: Tensor,
charges_mm: Tensor,
xyz_qm: Tensor,
xyz_mm: Tensor,
qm_charge: int = 0,
) -> Tensor:
"""
Compute the the ANI2x and static and induced EMLE energy components.

Expand All @@ -343,6 +356,9 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
xyz_mm: torch.Tensor (N_MM_ATOMS, 3)
Positions of MM atoms in Angstrom.

qm_charge: int
The charge on the QM region.

Returns
-------

Expand Down Expand Up @@ -370,7 +386,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
self._emle._emle_base._emle_aev_computer._aev = self._ani2x.aev_computer._aev

# Get the EMLE energy components.
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge)

# Return the ANI2x and EMLE energy components.
return _torch.stack([E_vac, E_emle[0], E_emle[1]])
46 changes: 31 additions & 15 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
method="electrostatic",
alpha_mode="species",
atomic_numbers=None,
qm_charge=0,
mm_charges=None,
device=None,
dtype=None,
Expand Down Expand Up @@ -129,6 +130,10 @@ def __init__(
if you are using a fixed QM region, i.e. the same QM region for each
evalulation of the module.

qm_charge: int
The charge on the QM region. This can also be passed when calling
the forward method. The non-default value will take precendence.

mm_charges: List[float], Tuple[Float], numpy.ndarray, torch.Tensor
List of MM charges for atoms in the QM region in units of mod
electron charge. This is required if the 'mm' method is specified.
Expand Down Expand Up @@ -192,6 +197,10 @@ def __init__(
)
self._atomic_numbers = atomic_numbers

if not isinstance(qm_charge, int):
raise TypeError("'qm_charge' must be of type 'int'")
self._qm_charge = qm_charge

if method == "mm":
if mm_charges is None:
raise ValueError("MM charges must be provided for the 'mm' method")
Expand Down Expand Up @@ -270,11 +279,6 @@ def __init__(
),
}

# Store the total charge.
q_total = _torch.tensor(
params.get("total_charge", 0), dtype=dtype, device=device
)

if method == "mm":
q_core_mm = _torch.tensor(mm_charges, dtype=dtype, device=device)
else:
Expand All @@ -284,7 +288,6 @@ def __init__(
self._device = device

# Register constants as buffers.
self.register_buffer("_q_total", q_total)
self.register_buffer("_q_core_mm", q_core_mm)

if not isinstance(create_aev_calculator, bool):
Expand Down Expand Up @@ -348,7 +351,6 @@ def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion on the model.
"""
self._q_total = self._q_total.to(*args, **kwargs)
self._q_core_mm = self._q_core_mm.to(*args, **kwargs)
self._emle_base = self._emle_base.to(*args, **kwargs)

Expand All @@ -364,33 +366,30 @@ def cuda(self, **kwargs):
"""
Move all model parameters and buffers to CUDA memory.
"""
self._q_total = self._q_total.cuda(**kwargs)
self._q_core_mm = self._q_core_mm.cuda(**kwargs)
self._emle_base = self._emle_base.cuda(**kwargs)

# Update the device attribute.
self._device = self._q_total.device
self._device = self._q_core_mm.device

return self

def cpu(self, **kwargs):
"""
Move all model parameters and buffers to CPU memory.
"""
self._q_total = self._q_total.cpu(**kwargs)
self._q_core_mm = self._q_core_mm.cpu(**kwargs)
self._emle_base = self._emle_base.cpu()

# Update the device attribute.
self._device = self._q_total.device
self._device = self._q_core_mm.device

return self

def double(self):
"""
Casts all floating point model parameters and buffers to float64 precision.
"""
self._q_total = self._q_total.double()
self._q_core_mm = self._q_core_mm.double()
self._emle_base = self._emle_base.double()
return self
Expand All @@ -399,12 +398,18 @@ def float(self):
"""
Casts all floating point model parameters and buffers to float32 precision.
"""
self._q_total = self._q_total.float()
self._q_core_mm = self._q_core_mm.float()
self._emle_base = self._emle_base.float()
return self

def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
def forward(
self,
atomic_numbers: Tensor,
charges_mm: Tensor,
xyz_qm: Tensor,
xyz_mm: Tensor,
qm_charge: int = 0,
) -> Tensor:
"""
Computes the static and induced EMLE energy components.

Expand All @@ -423,18 +428,27 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
xyz_mm: torch.Tensor (N_MM_ATOMS, 3)
Positions of MM atoms in Angstrom.

qm_charge: int
The charge on the QM region.

Returns
-------

result: torch.Tensor (2,)
The static and induced EMLE energy components in Hartree.
"""

# If the QM charge is a non-default value in the constructor, then
# use this value when qm_charge is zero.
if self._qm_charge != 0 and qm_charge == 0:
qm_charge = self._qm_charge

# Store the inputs as internal attributes.
self._atomic_numbers = atomic_numbers
self._charges_mm = charges_mm
self._xyz_qm = xyz_qm
self._xyz_mm = xyz_mm
self._qm_charge = qm_charge

# If there are no point charges, return zeros.
if len(xyz_mm) == 0:
Expand All @@ -445,7 +459,9 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
# These are returned as batched tensors, so we need to extract the
# first element of each.
s, q_core, q_val, A_thole = self._emle_base(
atomic_numbers[None, :], xyz_qm[None, :, :], self._q_total[None]
atomic_numbers[None, :],
xyz_qm[None, :, :],
_torch.tensor([qm_charge], dtype=xyz_qm.dtype, device=xyz_qm.device),
)

# Convert coordinates to Bohr.
Expand Down
22 changes: 20 additions & 2 deletions emle/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from ._emle import _has_nnpops
from ._utils import _get_neighbor_pairs

from torch import Tensor

try:
from mace.calculators.foundations_models import mace_off as _mace_off

Expand Down Expand Up @@ -66,6 +68,7 @@ def __init__(
emle_method="electrostatic",
alpha_mode="species",
mm_charges=None,
qm_charge=0,
mace_model=None,
atomic_numbers=None,
device=None,
Expand Down Expand Up @@ -106,6 +109,10 @@ def __init__(
List of MM charges for atoms in the QM region in units of mod
electron charge. This is required if the 'mm' method is specified.

qm_charge: int
The charge on the QM region. This can also be passed when calling
the forward method. The non-default value will take precendence.

mace_model: str
Name of the MACE-OFF23 models to use.
Available models are 'mace-off23-small', 'mace-off23-medium', 'mace-off23-large'.
Expand Down Expand Up @@ -173,6 +180,7 @@ def __init__(
alpha_mode=alpha_mode,
atomic_numbers=(atomic_numbers if atomic_numbers is not None else None),
mm_charges=mm_charges,
qm_charge=qm_charge,
device=device,
dtype=dtype,
create_aev_calculator=True,
Expand Down Expand Up @@ -361,7 +369,14 @@ def float(self):
self._mace = self._mace.float()
return self

def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
def forward(
self,
atomic_numbers: Tensor,
charges_mm: Tensor,
xyz_qm: Tensor,
xyz_mm: Tensor,
qm_charge: int = 0,
) -> Tensor:
"""
Compute the the MACE and static and induced EMLE energy components.

Expand All @@ -380,6 +395,9 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
xyz_mm: torch.Tensor (N_MM_ATOMS, 3)
Positions of MM atoms in Angstrom.

qm_charge: int
The charge on the QM region.

Returns
-------

Expand Down Expand Up @@ -436,7 +454,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
return _torch.stack([E_vac, zero, zero])

# Get the EMLE energy components.
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge)

# Return the MACE and EMLE energy components.
return _torch.stack([E_vac, E_emle[0], E_emle[1]])
Loading