Skip to content

Commit

Permalink
Add Sire callback function.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Nov 23, 2023
1 parent 0abd6c9 commit 152ae76
Showing 1 changed file with 293 additions and 3 deletions.
296 changes: 293 additions & 3 deletions emle/emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@


ANGSTROM_TO_BOHR = 1.0 / ase.units.Bohr
NANOMETER_TO_BOHR = 10.0 / ase.units.Bohr
BOHR_TO_ANGSTROM = ase.units.Bohr
EV_TO_HARTREE = 1.0 / ase.units.Hartree
KCAL_MOL_TO_HARTREE = 1.0 / ase.units.Hartree * ase.units.kcal / ase.units.mol
HARTREE_TO_KJ_MOL = ase.units.Hartree / ase.units.kJ * ase.units.mol

# Settings for the default model. For system specific models, these will be
# overwritten by values in the model file.
Expand Down Expand Up @@ -945,9 +947,9 @@ def __init__(
with open("emle_settings.yaml", "w") as f:
yaml.dump(self._settings, f)

# Match run function of other interface objects.
def run(self, path=None):
"""Calculate the energy and gradients.
"""
Calculate the energy and gradients.
Parameters
----------
Expand Down Expand Up @@ -1235,7 +1237,7 @@ def run(self, path=None):
self._step += 1

def set_lambda_interpolate(self, lambda_interpolate):
""" "
"""
Set the value of the lambda interpolation parameter. Note the server must
already be in 'interpolation' mode, i.e. the user must have specified an
initial value for 'lambda_interpolate' in the constructor.
Expand Down Expand Up @@ -1290,6 +1292,294 @@ def set_lambda_interpolate(self, lambda_interpolate):
# Reset the first step flag.
self._is_first_step = not self._restart

def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
"""
A callback function to be used with Sire.
Parameters
----------
atomic_numbers : [float]
A list of atomic numbers for the QM region.
charges_mm : [float]
The charges on the MM atoms.
xyz_qm : [[float, float, float]]
The coordinates of the QM atoms in Angstrom.
xyz_mm : [[float, float, float]]
The coordinates of the MM atoms in Angstrom.
Returns
-------
energy : float
The energy in kJ/mol.
force_qm : [[float, float, float]]
The forces on the QM atoms in kJ/mol/nanometer.
force_mm : [[float, float, float]]
The forces on the MM atoms in kJ/mol/nanometer.
"""

# For performance, we assume that the input is already validated.

# Convert to numpy arrays.
atomic_numbers = np.array(atomic_numbers)
charges_mm = np.array(charges_mm)
xyz_qm = np.array(xyz_qm)
xyz_mm = np.array(xyz_mm)

# Initialise a null ASE atoms object.
atoms = None

# Make sure that the number of QM atoms matches the number of MM atoms
# when using mm embedding.
if self._method == "mm":
if len(xyz_qm) != len(self._mm_charges):
raise ValueError(
f"MM embedding is specified but the number of atoms in the QM region ({len(xyz_qm)}) "
f"doesn't match the number of MM charges ({len(self._mm_charges)})"
)

# Update the maximum number of MM atoms if this is the largest seen.
num_mm_atoms = len(charges_mm)
if num_mm_atoms > self._max_mm_atoms:
self._max_mm_atoms = num_mm_atoms

# Pad the MM coordinates and charges arrays to avoid re-jitting.
if self._max_mm_atoms > num_mm_atoms:
num_pad = self._max_mm_atoms - num_mm_atoms
xyz_mm_pad = num_pad * [[0.0, 0.0, 0.0]]
charges_mm_pad = num_pad * [0.0]
xyz_mm = np.append(xyz_mm, xyz_mm_pad, axis=0)
charges_mm = np.append(charges_mm, charges_mm_pad)

# Convert the QM atomic numbers to elements and species IDs.
species_id = []
elements = []
for id in atomic_numbers:
try:
species_id.append(self._hypers["global_species"].index(id))
elements.append(ase.atom.Atom(id).symbol)
except:
raise ValueError(
f"Unsupported element index '{id}'. "
f"The current model supports {', '.join(self._supported_elements)}"
)
self._species_id = np.array(species_id)

# First try to use the specified backend to compute in vacuo
# energies and (optionally) gradients.

# Internal backends.
if not self._is_external_backend:
# TorchANI.
if self._backend == "torchani":
try:
E_vac, grad_vac = self._run_torchani(xyz_qm, atomic_numbers)
except:
raise RuntimeError(
"Failed to calculate in vacuo energies using TorchANI backend!"
)

# DeePMD.
if self._backend == "deepmd":
try:
E_vac, grad_vac = self._run_deepmd(xyz_qm, elements)
except:
raise RuntimeError(
"Failed to calculate in vacuo energies using DeePMD backend!"
)

# ORCA.
elif self._backend == "orca":
raise ValueError(
"Sire interface is currently unsupported when using the ORCA backend!"
)

# Sander.
elif self._backend == "sander":
try:
atoms = ase.atoms.Atoms(positions=xyz_qm, numbers=atomic_numbers)
E_vac, grad_vac = self._run_pysander(
atoms, self._parm7, is_gas=True
)
except:
raise RuntimeError(
"Failed to calculate in vacuo energies using Sander backend!"
)

# SQM.
elif self._backend == "sqm":
try:
E_vac, grad_vac = self._run_sqm(xyz_qm, atomic_numbers, charge)
except:
raise RuntimeError(
"Failed to calculate in vacuo energies using SQM backend!"
)

# XTB.
elif self._backend == "xtb":
try:
atoms = ase.atoms.Atoms(positions=xyz_qm, numbers=atomic_numbers)
E_vac, grad_vac = self._run_xtb(atoms)
except:
raise RuntimeError(
"Failed to calculate in vacuo energies using XTB backend!"
)

# External backend.
else:
try:
atoms = ase.atoms.Atoms(positions=xyz_qm, numbers=atomic_numbers)
E_vac, grad_vac = self._external_backend(atoms)
except:
raise
raise RuntimeError(
"Failed to calculate in vacuo energies using external backend!"
)

# Apply delta-learning corrections using Rascal.
if self._is_delta:
try:
if atoms is None:
atoms = ase.atoms.Atoms(positions=xyz_qm, numbers=atomic_numbers)
delta_E, delta_grad = self._run_rascal(atoms)
except:
raise RuntimeError(
"Failed to compute delta-learning corrections using Rascal!"
)

# Add the delta-learning corrections to the in vacuo energies and gradients.
E_vac += delta_E
grad_vac += delta_grad

# Convert units.
xyz_qm_bohr = xyz_qm * ANGSTROM_TO_BOHR
xyz_mm_bohr = xyz_mm * ANGSTROM_TO_BOHR

mol_soap, dsoap_dxyz = self._get_soap(atomic_numbers, xyz_qm, gradient=True)
dsoap_dxyz_qm_bohr = dsoap_dxyz / ANGSTROM_TO_BOHR

s, ds_dsoap = self._get_s(mol_soap, self._species_id, gradient=True)
chi, dchi_dsoap = self._get_chi(mol_soap, self._species_id, gradient=True)
ds_dxyz_qm_bohr = self._get_df_dxyz(ds_dsoap, dsoap_dxyz_qm_bohr)
dchi_dxyz_qm_bohr = self._get_df_dxyz(dchi_dsoap, dsoap_dxyz_qm_bohr)

# Convert inputs to PyTorch tensors.
xyz_qm_bohr = torch.tensor(
xyz_qm_bohr, dtype=torch.float32, device=self._device
)
xyz_mm_bohr = torch.tensor(
xyz_mm_bohr, dtype=torch.float32, device=self._device
)
charges_mm = torch.tensor(charges_mm, dtype=torch.float32, device=self._device)
s = torch.tensor(s, dtype=torch.float32, device=self._device)
chi = torch.tensor(chi, dtype=torch.float32, device=self._device)

# Compute gradients and energy.
grads, E = self._get_E_with_grad(charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi)
dE_dxyz_qm_bohr_part, dE_dxyz_mm_bohr, dE_ds, dE_dchi = grads
dE_dxyz_qm_bohr = (
dE_dxyz_qm_bohr_part.cpu().numpy()
+ dE_ds.cpu().numpy() @ ds_dxyz_qm_bohr.swapaxes(0, 1)
+ dE_dchi.cpu().numpy() @ dchi_dxyz_qm_bohr.swapaxes(0, 1)
)

# Compute the total energy and gradients.
E_tot = E + E_vac
grad_qm = dE_dxyz_qm_bohr + grad_vac
grad_mm = dE_dxyz_mm_bohr.cpu().numpy()

# Interpolate between the MM and ML/MM potential.
if self._is_interpolate:
# Create the ASE atoms object if it wasn't already created by the backend.
if atoms is None:
atoms = ase.atoms.Atoms(positions=xyz_qm, numbers=atomic_numbers)

# Compute the in vacuo MM energy and gradients for the QM region.
E_mm_qm_vac, grad_mm_qm_vac = self._run_pysander(
atoms=atoms,
parm7=self._parm7,
is_gas=True,
)

# Swap the method to MM.
method = self._method
self._method = "mm"

# Recompute the gradients and energy.
grads, E = self._get_E_with_grad(
charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi
)
dE_dxyz_qm_bohr_part, dE_dxyz_mm_bohr, dE_ds, dE_dchi = grads
dE_dxyz_qm_bohr = (
dE_dxyz_qm_bohr_part.cpu().numpy()
+ dE_ds.cpu().numpy() @ ds_dxyz_qm_bohr.swapaxes(0, 1)
+ dE_dchi.cpu().numpy() @ dchi_dxyz_qm_bohr.swapaxes(0, 1)
)
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()

# Restore the method.
self._method = method

# Store the the MM and EMLE energies. The MM energy is an approximation.
E_mm = E_mm_qm_vac + E
E_emle = E_tot

# Work out the current value of lambda.
if len(self._lambda_interpolate) == 1:
lam = self._lambda_interpolate[0]
else:
offset = int(not self._restart)
lam = self._lambda_interpolate[0] + (
(self._step / (self._interpolate_steps - offset))
) * (self._lambda_interpolate[1] - self._lambda_interpolate[0])
if lam < 0.0:
lam = 0.0
elif lam > 1.0:
lam = 1.0

# Calculate the lambda weighted energy and gradients.
E_tot = lam * E_tot + (1 - lam) * E_mm
grad_qm = lam * grad_qm + (1 - lam) * (grad_mm_qm_vac + dE_dxyz_qm_bohr)
grad_mm = lam * grad_mm + (1 - lam) * dE_dxyz_mm_bohr

# Log energies to file.
if self._log > 0 and not self._is_first_step and self._step % self._log == 0:
with open("emle_log.txt", "a+") as f:
# Write the header.
if self._step == 0:
if self._is_interpolate:
f.write(
f"#{'Step':>9}{'λ':>22}{'E(λ) (Eh)':>22}{'E(λ=0) (Eh)':>22}{'E(λ=1) (Eh)':>22}\n"
)
else:
f.write(f"#{'Step':>9}{'E_vac (Eh)':>22}{'E_tot (Eh)':>22}\n")
# Write the record.
if self._is_interpolate:
f.write(
f"{self._step:>10}{lam:22.12f}{E_tot:22.12f}{E_mm:22.12f}{E_emle:22.12f}\n"
)
else:
f.write(f"{self._step:>10}{E_vac:22.12f}{E_tot:22.12f}\n")

# Increment the step counter.
if self._is_first_step:
self._is_first_step = False
else:
self._step += 1

# Return the energy and forces in OpenMM units.
return (
E_tot * HARTREE_TO_KJ_MOL,
(-grad_qm * HARTREE_TO_KJ_MOL * NANOMETER_TO_BOHR).tolist(),
(-grad_mm * HARTREE_TO_KJ_MOL * NANOMETER_TO_BOHR).tolist(),
)

def _get_E(self, charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi):
"""
Computes total EMLE embedding energy (sum of static and induced).
Expand Down

0 comments on commit 152ae76

Please sign in to comment.