Skip to content

Commit

Permalink
Merge pull request #29 from chemle/feature_base
Browse files Browse the repository at this point in the history
Add EMLE base model
  • Loading branch information
lohedges authored Oct 22, 2024
2 parents 9fbb809 + 85040b2 commit 760ac56
Show file tree
Hide file tree
Showing 5 changed files with 940 additions and 580 deletions.
57 changes: 37 additions & 20 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,27 +1163,36 @@ def run(self, path=None):
E_vac += delta_E
grad_vac += delta_grad

# Store a copy of the QM coordinates as a NumPy array.
# Store a copy of the atomic numbers and QM coordinates as NumPy arrays.
atomic_numbers_np = atomic_numbers
xyz_qm_np = xyz_qm

# Convert inputs to Torch tensors.
atomic_numbers = _torch.tensor(
atomic_numbers, dtype=_torch.int64, device=self._device
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)
xyz_qm = _torch.tensor(
xyz_qm, dtype=_torch.float32, device=self._device, requires_grad=True
)
xyz_mm = _torch.tensor(
xyz_mm, dtype=_torch.float32, device=self._device, requires_grad=True
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)

# Compute energy and gradients.
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
try:
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
except Exception as e:
msg = f"Failed to compute EMLE energies and gradients: {e}"
_logger.error(msg)
raise RuntimeError(msg)

# Compute the total energy and gradients.
E_tot = E_vac + E.sum().detach().cpu().numpy()
Expand Down Expand Up @@ -1283,7 +1292,7 @@ def run(self, path=None):

# Write out the QM region to the xyz trajectory file.
if self._qm_xyz_frequency > 0 and self._step % self._qm_xyz_frequency == 0:
atoms = _ase.Atoms(positions=xyz_qm_np, numbers=atomic_numbers)
atoms = _ase.Atoms(positions=xyz_qm_np, numbers=atomic_numbers_np)
if hasattr(self, "_max_f_std"):
atoms.info = {"max_f_std": self._max_f_std}
_ase_io.write(self._qm_xyz_file, atoms, append=True)
Expand Down Expand Up @@ -1553,23 +1562,31 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
)

# Convert inputs to Torch tensors.
atomic_numbers = _torch.tensor(
atomic_numbers, dtype=_torch.int64, device=self._device
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)
xyz_qm = _torch.tensor(
xyz_qm, dtype=_torch.float32, device=self._device, requires_grad=True
)
xyz_mm = _torch.tensor(
xyz_mm, dtype=_torch.float32, device=self._device, requires_grad=True
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)

# Compute energy and gradients.
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
try:
E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E.sum(), (xyz_qm, xyz_mm)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
except Exception as e:
msg = f"Failed to compute EMLE energies and gradients: {e}"
_logger.error(msg)
raise RuntimeError(msg)

# Compute the total energy and gradients.
E_tot = E_vac + E.sum().detach().cpu().numpy()
Expand Down
1 change: 1 addition & 0 deletions emle/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# avoid severe module import overheads when running the client code,
# which requires no EMLE functionality.

from ._emle_base import EMLEBase
from ._emle import EMLE
from ._ani import ANI2xEMLE
from ._mace import MACEEMLE
28 changes: 24 additions & 4 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
raise TypeError("'device' must be of type 'torch.device'")
else:
device = _torch.get_default_device()
self._device = device

if dtype is not None:
if not isinstance(dtype, _torch.dtype):
Expand Down Expand Up @@ -229,8 +230,15 @@ def __init__(
except:
pass

# Add a hook to the ANI2x model to capture the AEV features.
self._add_hook()

def _add_hook(self):
"""
Add a hook to the ANI2x model to capture the AEV features.
"""
# Assign a tensor attribute that can be used for assigning the AEVs.
self._ani2x.aev_computer._aev = _torch.empty(0, device=device)
self._ani2x.aev_computer._aev = _torch.empty(0, device=self._device)

# Hook the forward pass of the ANI2x model to get the AEV features.
# Note that this currently requires a patched versions of TorchANI and NNPOps.
Expand All @@ -241,7 +249,7 @@ def hook(
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: Tuple[Tensor, Tensor],
):
module._aev = output[1][0]
module._aev = output[1]

else:

Expand All @@ -250,7 +258,7 @@ def hook(
input: Tuple[Tuple[Tensor, Tensor], Optional[Tensor], Optional[Tensor]],
output: _torchani.aev.SpeciesAEV,
):
module._aev = output[1][0]
module._aev = output[1]

# Register the hook.
self._aev_hook = self._ani2x.aev_computer.register_forward_hook(hook)
Expand All @@ -261,6 +269,13 @@ def to(self, *args, **kwargs):
"""
self._emle = self._emle.to(*args, **kwargs)
self._ani2x = self._ani2x.to(*args, **kwargs)

# Check for a device type in args and update the device attribute.
for arg in args:
if isinstance(arg, _torch.device):
self._device = arg
break

return self

def cpu(self, **kwargs):
Expand All @@ -269,6 +284,7 @@ def cpu(self, **kwargs):
"""
self._emle = self._emle.cpu(**kwargs)
self._ani2x = self._ani2x.cpu(**kwargs)
self._device = _torch.device("cpu")
return self

def cuda(self, **kwargs):
Expand All @@ -277,6 +293,7 @@ def cuda(self, **kwargs):
"""
self._emle = self._emle.cuda(**kwargs)
self._ani2x = self._ani2x.cuda(**kwargs)
self._device = _torch.device("cuda")
return self

def double(self):
Expand Down Expand Up @@ -306,6 +323,9 @@ def float(self):
except:
pass

# Re-append the hook.
self._add_hook()

return self

def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
Expand Down Expand Up @@ -351,7 +371,7 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):

# Set the AEVs captured by the forward hook as an attribute of the
# EMLE model.
self._emle._aev = self._ani2x.aev_computer._aev
self._emle._emle_base._aev = self._ani2x.aev_computer._aev

# Get the EMLE energy components.
E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
Expand Down
Loading

0 comments on commit 760ac56

Please sign in to comment.