From 0e473c0257e353379769954704be835e045e03e3 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 11 Dec 2024 09:18:15 +0000 Subject: [PATCH] Handle non-batched inputs in backend calculate methods. --- emle/_backends/_deepmd.py | 5 +++++ emle/_backends/_orca.py | 5 +++++ emle/_backends/_rascal.py | 5 +++++ emle/_backends/_sander.py | 5 +++++ emle/_backends/_sqm.py | 5 +++++ emle/_backends/_xtb.py | 5 +++++ emle/calculator.py | 30 ++++++------------------------ 7 files changed, 36 insertions(+), 24 deletions(-) diff --git a/emle/_backends/_deepmd.py b/emle/_backends/_deepmd.py index 635e8d5..ec9a156 100644 --- a/emle/_backends/_deepmd.py +++ b/emle/_backends/_deepmd.py @@ -139,6 +139,11 @@ def calculate(self, atomic_numbers, xyz, forces=True): f"match length of 'xyz' ({len(xyz)})" ) + # Convert to batched NumPy arrays. + if len(atomic_numbers.shape) == 1: + atomic_numbers = _np.expand_dims(atomic_numbers, axis=0) + xyz = _np.expand_dims(xyz, axis=0) + e_list = [] f_list = [] diff --git a/emle/_backends/_orca.py b/emle/_backends/_orca.py index e355f19..a2f1380 100644 --- a/emle/_backends/_orca.py +++ b/emle/_backends/_orca.py @@ -139,6 +139,11 @@ def calculate(self, atomic_numbers, xyz, forces=True): f"match length of 'xyz' ({len(xyz)})" ) + # Convert to batched NumPy arrays. + if len(atomic_numbers.shape) == 1: + atomic_numbers = _np.expand_dims(atomic_numbers, axis=0) + xyz = _np.expand_dims(xyz, axis=0) + # Lists to store results. results_energy = [] results_forces = [] diff --git a/emle/_backends/_rascal.py b/emle/_backends/_rascal.py index c988a1b..d326d18 100644 --- a/emle/_backends/_rascal.py +++ b/emle/_backends/_rascal.py @@ -117,6 +117,11 @@ def calculate(atomic_numbers, xyz, forces=True): f"match length of 'xyz' ({len(xyz)})" ) + # Convert to batched NumPy arrays. + if len(atomic_numbers.shape) == 1: + atomic_numbers = _np.expand_dims(atomic_numbers, axis=0) + xyz = _np.expand_dims(xyz, axis=0) + # Lists to store results. results_energy = [] results_forces = [] diff --git a/emle/_backends/_sander.py b/emle/_backends/_sander.py index 5631d9b..433737d 100644 --- a/emle/_backends/_sander.py +++ b/emle/_backends/_sander.py @@ -175,6 +175,11 @@ def calculate(self, atomic_numbers, xyz, forces=True): f"match length of 'xyz' ({len(xyz)})" ) + # Convert to batched NumPy arrays. + if len(atomic_numbers.shape) == 1: + atomic_numbers = _np.expand_dims(atomic_numbers, axis=0) + xyz = _np.expand_dims(xyz, axis=0) + if not isinstance(forces, bool): raise TypeError("'forces' must be of type 'bool'") diff --git a/emle/_backends/_sqm.py b/emle/_backends/_sqm.py index c05f97f..0d3a042 100644 --- a/emle/_backends/_sqm.py +++ b/emle/_backends/_sqm.py @@ -130,6 +130,11 @@ def calculate(self, atomic_numbers, xyz, qm_charge=None, forces=True): f"match length of 'xyz' ({len(xyz)})" ) + # Convert to batched NumPy arrays. + if len(atomic_numbers.shape) == 1: + atomic_numbers = _np.expand_dims(atomic_numbers, axis=0) + xyz = _np.expand_dims(xyz, axis=0) + # Lists to store results. results_energy = [] results_forces = [] diff --git a/emle/_backends/_xtb.py b/emle/_backends/_xtb.py index 0c9b3d6..3c65e58 100644 --- a/emle/_backends/_xtb.py +++ b/emle/_backends/_xtb.py @@ -75,6 +75,11 @@ def calculate(atomic_numbers, xyz, forces=True): f"match length of 'xyz' ({len(xyz)})" ) + # Convert to batched NumPy arrays. + if len(atomic_numbers.shape) == 1: + atomic_numbers = _np.expand_dims(atomic_numbers, axis=0) + xyz = _np.expand_dims(xyz, axis=0) + from xtb.ase.calculator import XTB as _XTB # Lists to store results. diff --git a/emle/calculator.py b/emle/calculator.py index b19a3b9..c21bda7 100644 --- a/emle/calculator.py +++ b/emle/calculator.py @@ -997,10 +997,7 @@ def run(self, path=None): # Non-Torch backends. elif backend is not None: try: - energy, forces = self._backend( - _np.expand_dims(atomic_numbers, axis=0), - _np.expand_dims(xyz_qm, axis=0), - ) + energy, forces = self._backend(atomic_numbers, xyz_qm) E_vac = energy[0] grad_vac = -forces[0] except Exception as e: @@ -1026,10 +1023,7 @@ def run(self, path=None): # Apply delta-learning corrections using Rascal. if self._is_delta and self._backend is not None: try: - energy, forces = self._rascal_calc( - _np.expand_dims(atomic_numbers, axis=0), - _np.expand_dims(xyz_qm, axis=0), - ) + energy, forces = self._rascal_calc(atomic_numbers, xyz_qm) delta_E = energy[0] delta_grad = -forces[0] except Exception as e: @@ -1120,10 +1114,7 @@ def run(self, path=None): backend = Sander(self._parm7) # Compute the in vacuo MM energy and forces for the QM region. - energy, forces = backend.calculate( - _np.expand_dims(atomic_numbers_np, axis=0), - _np.expand_dims(xyz_qm_np, axis=0), - ) + energy, forces = backend.calculate(atomic_numbers_np, xyz_qm_np) E_mm_qm_vac = energy[0] grad_mm_qm_vac = -forces[0] @@ -1376,10 +1367,7 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None # Non-Torch backends. elif self._backend is not None: try: - energy, forces = self._backend( - _np.expand_dims(atomic_numbers, axis=0), - _np.expand_dims(xyz_qm, axis=0), - ) + energy, forces = self._backend(atomic_numbers, xyz_qm) E_vac = energy[0] grad_vac = -forces[0] except Exception as e: @@ -1406,10 +1394,7 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None # Apply delta-learning corrections using Rascal. if self._is_delta and self._backend is not None: try: - energy, forces = self._rascal_calc( - _np.expand_dims(atomic_numbers, axis=0), - _np.expand_dims(xyz_qm, axis=0), - ) + energy, forces = self._rascal_calc(atomic_numbers, xyz_qm) delta_E = energy[0] delta_grad = -forces[0] except Exception as e: @@ -1513,10 +1498,7 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None backend = Sander(self._parm7) # Compute the in vacuo MM energy and forces for the QM region. - energy, forces = backend.calculate( - _np.expand_dims(atomic_numbers_np, axis=0), - _np.expand_dims(xyz_qm_np, axis=0), - ) + energy, forces = backend.calculate(atomic_numbers_np, xyz_qm_np) E_mm_qm_vac = energy[0] grad_mm_qm_vac = -forces[0]