Skip to content

Commit

Permalink
Handle non-batched inputs in backend calculate methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Dec 11, 2024
1 parent 9483db1 commit 0e473c0
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 24 deletions.
5 changes: 5 additions & 0 deletions emle/_backends/_deepmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def calculate(self, atomic_numbers, xyz, forces=True):
f"match length of 'xyz' ({len(xyz)})"
)

# Convert to batched NumPy arrays.
if len(atomic_numbers.shape) == 1:
atomic_numbers = _np.expand_dims(atomic_numbers, axis=0)
xyz = _np.expand_dims(xyz, axis=0)

e_list = []
f_list = []

Expand Down
5 changes: 5 additions & 0 deletions emle/_backends/_orca.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ def calculate(self, atomic_numbers, xyz, forces=True):
f"match length of 'xyz' ({len(xyz)})"
)

# Convert to batched NumPy arrays.
if len(atomic_numbers.shape) == 1:
atomic_numbers = _np.expand_dims(atomic_numbers, axis=0)
xyz = _np.expand_dims(xyz, axis=0)

# Lists to store results.
results_energy = []
results_forces = []
Expand Down
5 changes: 5 additions & 0 deletions emle/_backends/_rascal.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def calculate(atomic_numbers, xyz, forces=True):
f"match length of 'xyz' ({len(xyz)})"
)

# Convert to batched NumPy arrays.
if len(atomic_numbers.shape) == 1:
atomic_numbers = _np.expand_dims(atomic_numbers, axis=0)
xyz = _np.expand_dims(xyz, axis=0)

# Lists to store results.
results_energy = []
results_forces = []
Expand Down
5 changes: 5 additions & 0 deletions emle/_backends/_sander.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def calculate(self, atomic_numbers, xyz, forces=True):
f"match length of 'xyz' ({len(xyz)})"
)

# Convert to batched NumPy arrays.
if len(atomic_numbers.shape) == 1:
atomic_numbers = _np.expand_dims(atomic_numbers, axis=0)
xyz = _np.expand_dims(xyz, axis=0)

if not isinstance(forces, bool):
raise TypeError("'forces' must be of type 'bool'")

Expand Down
5 changes: 5 additions & 0 deletions emle/_backends/_sqm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def calculate(self, atomic_numbers, xyz, qm_charge=None, forces=True):
f"match length of 'xyz' ({len(xyz)})"
)

# Convert to batched NumPy arrays.
if len(atomic_numbers.shape) == 1:
atomic_numbers = _np.expand_dims(atomic_numbers, axis=0)
xyz = _np.expand_dims(xyz, axis=0)

# Lists to store results.
results_energy = []
results_forces = []
Expand Down
5 changes: 5 additions & 0 deletions emle/_backends/_xtb.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def calculate(atomic_numbers, xyz, forces=True):
f"match length of 'xyz' ({len(xyz)})"
)

# Convert to batched NumPy arrays.
if len(atomic_numbers.shape) == 1:
atomic_numbers = _np.expand_dims(atomic_numbers, axis=0)
xyz = _np.expand_dims(xyz, axis=0)

from xtb.ase.calculator import XTB as _XTB

# Lists to store results.
Expand Down
30 changes: 6 additions & 24 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,10 +997,7 @@ def run(self, path=None):
# Non-Torch backends.
elif backend is not None:
try:
energy, forces = self._backend(
_np.expand_dims(atomic_numbers, axis=0),
_np.expand_dims(xyz_qm, axis=0),
)
energy, forces = self._backend(atomic_numbers, xyz_qm)
E_vac = energy[0]
grad_vac = -forces[0]
except Exception as e:
Expand All @@ -1026,10 +1023,7 @@ def run(self, path=None):
# Apply delta-learning corrections using Rascal.
if self._is_delta and self._backend is not None:
try:
energy, forces = self._rascal_calc(
_np.expand_dims(atomic_numbers, axis=0),
_np.expand_dims(xyz_qm, axis=0),
)
energy, forces = self._rascal_calc(atomic_numbers, xyz_qm)
delta_E = energy[0]
delta_grad = -forces[0]
except Exception as e:
Expand Down Expand Up @@ -1120,10 +1114,7 @@ def run(self, path=None):
backend = Sander(self._parm7)

# Compute the in vacuo MM energy and forces for the QM region.
energy, forces = backend.calculate(
_np.expand_dims(atomic_numbers_np, axis=0),
_np.expand_dims(xyz_qm_np, axis=0),
)
energy, forces = backend.calculate(atomic_numbers_np, xyz_qm_np)
E_mm_qm_vac = energy[0]
grad_mm_qm_vac = -forces[0]

Expand Down Expand Up @@ -1376,10 +1367,7 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
# Non-Torch backends.
elif self._backend is not None:
try:
energy, forces = self._backend(
_np.expand_dims(atomic_numbers, axis=0),
_np.expand_dims(xyz_qm, axis=0),
)
energy, forces = self._backend(atomic_numbers, xyz_qm)
E_vac = energy[0]
grad_vac = -forces[0]
except Exception as e:
Expand All @@ -1406,10 +1394,7 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
# Apply delta-learning corrections using Rascal.
if self._is_delta and self._backend is not None:
try:
energy, forces = self._rascal_calc(
_np.expand_dims(atomic_numbers, axis=0),
_np.expand_dims(xyz_qm, axis=0),
)
energy, forces = self._rascal_calc(atomic_numbers, xyz_qm)
delta_E = energy[0]
delta_grad = -forces[0]
except Exception as e:
Expand Down Expand Up @@ -1513,10 +1498,7 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None
backend = Sander(self._parm7)

# Compute the in vacuo MM energy and forces for the QM region.
energy, forces = backend.calculate(
_np.expand_dims(atomic_numbers_np, axis=0),
_np.expand_dims(xyz_qm_np, axis=0),
)
energy, forces = backend.calculate(atomic_numbers_np, xyz_qm_np)

E_mm_qm_vac = energy[0]
grad_mm_qm_vac = -forces[0]
Expand Down

0 comments on commit 0e473c0

Please sign in to comment.