Skip to content

Commit

Permalink
Fix analyzer and add license header and __all__.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Nov 1, 2024
1 parent d0da601 commit d036fca
Showing 1 changed file with 48 additions and 10 deletions.
58 changes: 48 additions & 10 deletions emle/_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,50 @@
from abc import ABC, abstractmethod
######################################################################
# EMLE-Engine: https://github.com/chemle/emle-engine
#
# Copyright: 2023-2024
#
# Authors: Lester Hedges <[email protected]>
# Kirill Zinovjev <[email protected]>
#
# EMLE-Engine is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# EMLE-Engine is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with EMLE-Engine If not, see <http://www.gnu.org/licenses/>.
######################################################################

# Note that this file is empty since EMLECalculator and Socket should
# be directly imported from their respective sub-modules. This is to
# avoid severe module import overheads when running the client code,
# which requires no EMLE functionality.

"""
Analyser for EMLE simulation output.
"""

__all__ = ["ANI2xBackend", "DeepMDBackend", "EMLEAnalyzer"]


from abc import ABC as _ABC
from abc import abstractmethod as _abstractmethod
import os as _os

import numpy as _np
import torch as _torch
import ase as _ase

from ._utils import pad_to_max
from .models._emle_pc import EMLEPC
from ._utils import pad_to_max as _pad_to_max
from .models._emle_pc import EMLEPC as _EMLEPC


class BaseBackend(ABC):
class BaseBackend(_ABC):

def __init__(self, torch_device=None):
self._device = torch_device
Expand Down Expand Up @@ -42,7 +77,7 @@ def __call__(self, atomic_numbers, xyz, gradient=False):
return e, f
return result.detach().cpu().numpy()

@abstractmethod
@_abstractmethod
def eval(self, atomic_numbers, xyz, gradient=False):
"""
atomic_numbers: (N_BATCH, N_QM_ATOMS,)
Expand Down Expand Up @@ -117,6 +152,9 @@ def __init__(
dtype = emle_base._dtype
device = emle_base._device

# Create the point charge utility class.
emle_pc = _EMLEPC()

atomic_numbers, qm_xyz = self._parse_qm_xyz(qm_xyz_filename)
pc_charges, pc_xyz = self._parse_pc_xyz(pc_xyz_filename)

Expand All @@ -137,11 +175,11 @@ def __init__(
)
self.alpha = self._get_mol_alpha(self.A_thole, self.atomic_numbers)

mesh_data = EMLEPC._get_mesh_data(self.qm_xyz, self.pc_xyz, self.s)
self.e_static = EMLEPC.get_E_static(
mesh_data = emle_pc._get_mesh_data(self.qm_xyz, self.pc_xyz, self.s)
self.e_static = emle_pc.get_E_static(
self.q_core, self.q_val, self.pc_charges, mesh_data
)
self.e_induced = EMLEPC.get_E_induced(
self.e_induced = emle_pc.get_E_induced(
self.A_thole, self.pc_charges, self.s, mesh_data
)

Expand All @@ -151,7 +189,7 @@ def __init__(
@staticmethod
def _parse_qm_xyz(qm_xyz_filename):
atoms = _ase.io.read(qm_xyz_filename, index=":")
atomic_numbers = pad_to_max([_.get_atomic_numbers() for _ in atoms], -1)
atomic_numbers = _pad_to_max([_.get_atomic_numbers() for _ in atoms], -1)
xyz = _np.array([_.get_positions() for _ in atoms])
return atomic_numbers, xyz

Expand All @@ -166,7 +204,7 @@ def _parse_pc_xyz(pc_xyz_filename):
file.readline()
except ValueError:
break
padded_frames = pad_to_max(frames)
padded_frames = _pad_to_max(frames)
return padded_frames[:, :, 0], padded_frames[:, :, 1:]

@staticmethod
Expand Down

0 comments on commit d036fca

Please sign in to comment.