Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of emle-analyze script #33

Merged
merged 20 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions bin/emle-analyze
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python

import argparse

parser = argparse.ArgumentParser(
description="Analysis tool for ML(EMLE)/MM simulations"
)
parser.add_argument("--orca-tarball", type=str,
metavar='name.tar',
required=True,
help="ORCA tarball")
parser.add_argument("--emle-model", type=str, metavar='name.mat',
required=True, help="EMLE model file")
parser.add_argument("--backend", type=str, choices=["deepmd", "ani2x"],
help="Gas phase ML backend ('deepmd' or 'ani2x')")
parser.add_argument("--deepmd-model", type=str,
metavar='name.pb',
help="Deepmd model file (for backend='deepmd')")
parser.add_argument("--qm-xyz", type=str,
metavar='name.xyz', required=True,
help="QM xyz file")
parser.add_argument("--pc-xyz", type=str,
metavar='name.pc', required=True,
help="Point charges xyz file")
parser.add_argument("-q", type=int,
metavar='charge', default=0,
help="Total charge of the ML region")
parser.add_argument("--alpha", action="store_true",
help="Extract molecular dipolar polarizabilities")
parser.add_argument("output", type=str,
help="Output mat file")
args = parser.parse_args()


import scipy.io

from emle.models._emle import EMLE
from emle._orca_parser import ORCAParser
from emle._analyzer import EMLEAnalyzer, ANI2xBackend, DeepMDBackend


if args.backend == "deepmd" and not args.deepmd_model:
parser.error("--deepmd-model is required when backend='deepmd'")

backend = None
if args.backend == "ani2x":
backend = ANI2xBackend()
elif args.backend == "deepmd":
backend = DeepMDBackend(args.deepmd_model)

emle_base = EMLE(model=args.emle_model)._emle_base

analyzer = EMLEAnalyzer(args.qm_xyz, args.pc_xyz, args.q,
emle_base, backend)

parser = ORCAParser(args.orca_tarball, decompose=True, alpha=args.alpha)

result = {'z': parser.z,
'xyz': parser.xyz,
'E_vac_qm': parser.vac_E,
'E_static_qm': parser.E_static,
'E_induced_qm': parser.E_induced,
's_qm': parser.mbis['s'],
'q_core_qm': parser.mbis['q_core'],
'q_val_qm': parser.mbis['q_val'],
'E_static_emle': analyzer.e_static,
'E_induced_emle': analyzer.e_induced,
's_emle': analyzer.s,
'q_core_emle': analyzer.q_core,
'q_val_emle': analyzer.q_val,
'alpha_emle': analyzer.alpha}
if args.backend:
result['E_vac_emle'] = analyzer.e_backend
if args.alpha:
result['alpha_qm'] = parser.alpha

scipy.io.savemat(args.output, result)
2 changes: 2 additions & 0 deletions bin/emle-server
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ deepmd_model = os.getenv("EMLE_DEEPMD_MODEL")
deepmd_deviation = os.getenv("EMLE_DEEPMD_DEVIATION")
deepmd_deviation_threshold = os.getenv("EMLE_DEEPMD_DEVIATION_THRESHOLD")
qm_xyz_file = os.getenv("EMLE_QM_XYZ_FILE")
pc_xyz_file = os.getenv("EMLE_PC_XYZ_FILE")
try:
qm_xyz_frequency = int(os.getenv("EMLE_QM_XYZ_FREQUENCY"))
except:
Expand Down Expand Up @@ -143,6 +144,7 @@ env = {
"deepmd_deviation": deepmd_deviation,
"deepmd_deviation_threshold": deepmd_deviation_threshold,
"qm_xyz_file": qm_xyz_file,
"pc_xyz_file": pc_xyz_file,
"qm_xyz_frequency": qm_xyz_frequency,
"ani2x_model_index": ani2x_model_index,
"rascal_model": rascal_model,
Expand Down
183 changes: 183 additions & 0 deletions emle/_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from abc import ABC, abstractmethod
import os as _os

import numpy as _np
import torch as _torch
import ase as _ase

from ._utils import pad_to_max
from .models._emle_pc import EMLEPC


class BaseBackend(ABC):

def __init__(self, torch_device=None):
self._device = torch_device

def __call__(self, atomic_numbers, xyz, gradient=False):
"""
atomic_numbers: np.ndarray (N_BATCH, N_QM_ATOMS,)
The atomic numbers of the atoms.

xyz: np.ndarray (N_BATCH, N_QM_ATOMS,)
The positions of the atoms.

gradient: bool
Whether the gradient should be calculated

Returns energy (and, optionally, gradient) as np.ndarrays
"""
if self._device:
atomic_numbers = _torch.tensor(atomic_numbers, device=self._device)
xyz = _torch.tensor(xyz, device=self._device)

result = self.eval(atomic_numbers, xyz, gradient)

if not self._device:
return result

if gradient:
e = result[0].detach().cpu().numpy()
f = result[1].detach().cpu().numpy()
return e, f
return result.detach().cpu().numpy()

@abstractmethod
def eval(self, atomic_numbers, xyz, gradient=False):
"""
atomic_numbers: (N_BATCH, N_QM_ATOMS,)
The atomic numbers of the atoms.

xyz: (N_BATCH, N_QM_ATOMS,)
The positions of the atoms.

gradient: bool
Whether the gradient should be calculated
"""
pass


class ANI2xBackend(BaseBackend):

def __init__(self, device=None, ani2x_model_index=None):
import torchani as _torchani

if device is None:
cuda_available = _torch.cuda.is_available()
device = _torch.device("cuda" if cuda_available else "cpu")

super().__init__(device)

self._ani2x = _torchani.models.ANI2x(
periodic_table_index=True, model_index=ani2x_model_index
).to(device)

def eval(self, atomic_numbers, xyz, do_gradient=False):
energy = self._ani2x((atomic_numbers, xyz.float())).energies
if not do_gradient:
return energy
gradient = _torch.autograd.grad(energy.sum(), xyz)[0]
return energy, gradient


class DeepMDBackend(BaseBackend):

def __init__(self, model=None):

super().__init__()

if not _os.path.isfile(model):
raise ValueError(f"Unable to locate DeePMD model file: '{model}'")

try:
from deepmd.infer import DeepPot as _DeepPot
self._dp = _DeepPot(model)
self._z_map = {element: index for index, element in
enumerate(self._dp.get_type_map())}
except Exception as e:
raise RuntimeError(f"Unable to create the DeePMD potentials: {e}")

def eval(self, atomic_numbers, xyz, do_gradient=False):
# Assuming all the frames are of the same system
atom_types = [self._z_map[_ase.Atom(z).symbol]
for z in atomic_numbers[0]]
e, f, _ = self._dp.eval(xyz, cells=None, atom_types=atom_types)
e = e.flatten()
return (e, f) if do_gradient else e


class EMLEAnalyzer:

def __init__(self, qm_xyz_filename, pc_xyz_filename, q_total,
emle_base, backend=None):

self.q_total = q_total
dtype = emle_base._dtype
device = emle_base._device

atomic_numbers, qm_xyz = self._parse_qm_xyz(qm_xyz_filename)
pc_charges, pc_xyz = self._parse_pc_xyz(pc_xyz_filename)

if backend:
self.e_backend = backend(atomic_numbers, qm_xyz)

self.atomic_numbers = _torch.tensor(atomic_numbers,
dtype=_torch.int,
device=device)
self.qm_xyz = _torch.tensor(qm_xyz, dtype=dtype, device=device)
self.pc_charges = _torch.tensor(pc_charges, dtype=dtype, device=device)
self.pc_xyz = _torch.tensor(pc_xyz, dtype=dtype, device=device)

self.s, self.q_core, self.q_val, self.A_thole = emle_base(
self.atomic_numbers,
self.qm_xyz,
_torch.ones(len(self.qm_xyz), device=device) * self.q_total
)
self.alpha = self._get_mol_alpha(self.A_thole, self.atomic_numbers)

mesh_data = EMLEPC._get_mesh_data(self.qm_xyz, self.pc_xyz, self.s)
self.e_static = EMLEPC.get_E_static(self.q_core,
self.q_val,
self.pc_charges,
mesh_data)
self.e_induced = EMLEPC.get_E_induced(self.A_thole,
self.pc_charges,
self.s,
mesh_data)

for attr in ('s', 'q_core', 'q_val', 'alpha', 'e_static', 'e_induced'):
setattr(self, attr, getattr(self, attr).detach().cpu().numpy())

@staticmethod
def _parse_qm_xyz(qm_xyz_filename):
atoms = _ase.io.read(qm_xyz_filename, index=':')
atomic_numbers = pad_to_max([_.get_atomic_numbers() for _ in atoms], -1)
xyz = _np.array([_.get_positions() for _ in atoms])
return atomic_numbers, xyz

@staticmethod
def _parse_pc_xyz(pc_xyz_filename):
frames = []
with open(pc_xyz_filename, 'r') as file:
while True:
try:
n = int(file.readline().strip())
frames.append(_np.loadtxt(file, max_rows=n))
file.readline()
except ValueError:
break
padded_frames = pad_to_max(frames)
return padded_frames[:, :, 0], padded_frames[:, :, 1:]

@staticmethod
def _get_mol_alpha(A_thole, atomic_numbers):
mask = atomic_numbers > 0
mask_mat = mask[:, :, None] * mask[:, None, :]
mask_mat = mask_mat.repeat_interleave(3, dim=1)
mask_mat = mask_mat.repeat_interleave(3, dim=2)

n_mols = A_thole.shape[0]
n_atoms = A_thole.shape[1] // 3
Ainv = _torch.linalg.inv(A_thole) * mask_mat
return _torch.sum(Ainv.reshape(n_mols, n_atoms, 3, n_atoms, 3),
dim=(1, 3))
Loading
Loading