From b7327cad864a943ee055b2bfa43691970b3b6f88 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 25 Nov 2024 17:05:27 +0000 Subject: [PATCH 1/7] Implement batch forward method for the EMLE class --- emle/models/_emle.py | 58 ++++++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index e46af40..f611d39 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -416,20 +416,20 @@ def forward( Parameters ---------- - atomic_numbers: torch.Tensor (N_QM_ATOMS,) + atomic_numbers: torch.Tensor (N_QM_ATOMS,) or (BATCH, N_QM_ATOMS) Atomic numbers of QM atoms. - charges_mm: torch.Tensor (max_mm_atoms,) + charges_mm: torch.Tensor (max_mm_atoms,) or (BATCH, max_mm_atoms) MM point charges in atomic units. - xyz_qm: torch.Tensor (N_QM_ATOMS, 3) + xyz_qm: torch.Tensor (N_QM_ATOMS, 3) or (BATCH, N_QM_ATOMS, 3) Positions of QM atoms in Angstrom. - xyz_mm: torch.Tensor (N_MM_ATOMS, 3) + xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3) Positions of MM atoms in Angstrom. - qm_charge: int - The charge on the QM region. + qm_charge: int + The charge on the QM region. Returns ------- @@ -450,51 +450,61 @@ def forward( self._xyz_mm = xyz_mm self._qm_charge = qm_charge + if self._atomic_numbers.ndim == 1: + self._atomic_numbers = atomic_numbers.unsqueeze(0) + self._charges_mm = charges_mm.unsqueeze(0) + self._xyz_qm = xyz_qm.unsqueeze(0) + self._xyz_mm = xyz_mm.unsqueeze(0) + + # Batch size + batch_size = self._atomic_numbers.shape[0] + self._qm_charge = _torch.tensor([qm_charge] * batch_size, dtype=charges_mm.dtype, device=self._device) + # If there are no point charges, return zeros. - if len(xyz_mm) == 0: - return _torch.zeros(2, dtype=xyz_qm.dtype, device=xyz_qm.device) + if xyz_mm.shape[1] == 0: + return _torch.zeros(2, batch_size, dtype=self._xyz_qm.dtype, device=self._xyz_qm.device) # Get the parameters from the base model: # valence widths, core charges, valence charges, A_thole tensor # These are returned as batched tensors, so we need to extract the # first element of each. s, q_core, q_val, A_thole = self._emle_base( - atomic_numbers[None, :], - xyz_qm[None, :, :], - _torch.tensor([qm_charge], dtype=xyz_qm.dtype, device=xyz_qm.device), + self._atomic_numbers, + self._xyz_qm, + self._qm_charge, ) # Convert coordinates to Bohr. ANGSTROM_TO_BOHR = 1.8897261258369282 - xyz_qm_bohr = xyz_qm * ANGSTROM_TO_BOHR - xyz_mm_bohr = xyz_mm * ANGSTROM_TO_BOHR + xyz_qm_bohr = self._xyz_qm * ANGSTROM_TO_BOHR + xyz_mm_bohr = self._xyz_mm * ANGSTROM_TO_BOHR # Compute the static energy. if self._method == "mm": - q_core = self._q_core_mm[None, :] + q_core = self._q_core_mm q_val = _torch.zeros_like( - q_core, dtype=charges_mm.dtype, device=self._device + q_core, dtype=self._charges_mm.dtype, device=self._device ) mesh_data = self._emle_base._get_mesh_data( - xyz_qm_bohr[None, :, :], xyz_mm_bohr[None, :, :], s + xyz_qm_bohr, xyz_mm_bohr, s ) if self._method == "mechanical": q_core = q_core + q_val q_val = _torch.zeros_like( - q_core, dtype=charges_mm.dtype, device=self._device + q_core, dtype=self._charges_mm.dtype, device=self._device ) E_static = self._emle_base.get_static_energy( - q_core, q_val, charges_mm[None, :], mesh_data - )[0] + q_core, q_val, self._charges_mm, mesh_data + ) # Compute the induced energy. if self._method == "electrostatic": E_ind = self._emle_base.get_induced_energy( - A_thole, charges_mm[None, :], s, mesh_data - )[0] + A_thole, self._charges_mm, s, mesh_data + ) else: - E_ind = _torch.tensor(0.0, dtype=charges_mm.dtype, device=self._device) - - return _torch.stack([E_static, E_ind]) + E_ind = _torch.zeros_like(E_static, dtype=self._charges_mm.dtype, device=self._device) + + return _torch.stack((E_static, E_ind), dim=0) \ No newline at end of file From f808d0ddd56e09bd96560edd59f8ceb03a10c6c8 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 25 Nov 2024 17:43:05 +0000 Subject: [PATCH 2/7] Ensure TorchScript compatibility --- emle/models/_ani.py | 2 +- emle/models/_emle.py | 24 +++++++++++------------- emle/models/_mace.py | 2 +- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/emle/models/_ani.py b/emle/models/_ani.py index 0152ad6..2034a4b 100644 --- a/emle/models/_ani.py +++ b/emle/models/_ani.py @@ -389,4 +389,4 @@ def forward( E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge) # Return the ANI2x and EMLE energy components. - return _torch.stack([E_vac, E_emle[0], E_emle[1]]) + return _torch.stack([E_vac, E_emle[0][0], E_emle[1][0]]) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index f611d39..7e1c8bb 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -34,7 +34,7 @@ import torchani as _torchani from torch import Tensor -from typing import Optional, Tuple, List +from typing import Union from . import _patches from . import EMLEBase as _EMLEBase @@ -408,7 +408,7 @@ def forward( charges_mm: Tensor, xyz_qm: Tensor, xyz_mm: Tensor, - qm_charge: int = 0, + qm_charge: Union[int, Tensor] = _torch.tensor(0, dtype=_torch.int64), ) -> Tensor: """ Computes the static and induced EMLE energy components. @@ -428,7 +428,7 @@ def forward( xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3) Positions of MM atoms in Angstrom. - qm_charge: int + qm_charge: int or torch.Tensor (BATCH,) The charge on the QM region. Returns @@ -437,18 +437,11 @@ def forward( result: torch.Tensor (2,) The static and induced EMLE energy components in Hartree. """ - - # If the QM charge is a non-default value in the constructor, then - # use this value when qm_charge is zero. - if self._qm_charge != 0 and qm_charge == 0: - qm_charge = self._qm_charge - # Store the inputs as internal attributes. self._atomic_numbers = atomic_numbers self._charges_mm = charges_mm self._xyz_qm = xyz_qm self._xyz_mm = xyz_mm - self._qm_charge = qm_charge if self._atomic_numbers.ndim == 1: self._atomic_numbers = atomic_numbers.unsqueeze(0) @@ -456,9 +449,14 @@ def forward( self._xyz_qm = xyz_qm.unsqueeze(0) self._xyz_mm = xyz_mm.unsqueeze(0) - # Batch size batch_size = self._atomic_numbers.shape[0] - self._qm_charge = _torch.tensor([qm_charge] * batch_size, dtype=charges_mm.dtype, device=self._device) + + # Ensure qm_charge is a tensor and repeat for batch size if necessary + if isinstance(qm_charge, int): + qm_charge = _torch.full((batch_size,), qm_charge if qm_charge != 0 else self._qm_charge, dtype=_torch.int64, device=self._device) + elif isinstance(qm_charge, _torch.Tensor): + if qm_charge.ndim == 0: + qm_charge = qm_charge.repeat(batch_size) # If there are no point charges, return zeros. if xyz_mm.shape[1] == 0: @@ -471,7 +469,7 @@ def forward( s, q_core, q_val, A_thole = self._emle_base( self._atomic_numbers, self._xyz_qm, - self._qm_charge, + qm_charge, ) # Convert coordinates to Bohr. diff --git a/emle/models/_mace.py b/emle/models/_mace.py index ce99fae..f76a449 100644 --- a/emle/models/_mace.py +++ b/emle/models/_mace.py @@ -457,4 +457,4 @@ def forward( E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge) # Return the MACE and EMLE energy components. - return _torch.stack([E_vac, E_emle[0], E_emle[1]]) + return _torch.stack([E_vac, E_emle[0][0], E_emle[1][0]]) From 94ec4179b55711952498880a24099fdb80eadad9 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 25 Nov 2024 18:10:41 +0000 Subject: [PATCH 3/7] Fixes to batched implementation --- emle/models/_emle.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 7e1c8bb..1a3b4bc 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -408,7 +408,7 @@ def forward( charges_mm: Tensor, xyz_qm: Tensor, xyz_mm: Tensor, - qm_charge: Union[int, Tensor] = _torch.tensor(0, dtype=_torch.int64), + qm_charge: Union[int, Tensor] = 0, ) -> Tensor: """ Computes the static and induced EMLE energy components. @@ -456,7 +456,7 @@ def forward( qm_charge = _torch.full((batch_size,), qm_charge if qm_charge != 0 else self._qm_charge, dtype=_torch.int64, device=self._device) elif isinstance(qm_charge, _torch.Tensor): if qm_charge.ndim == 0: - qm_charge = qm_charge.repeat(batch_size) + qm_charge = qm_charge.repeat(batch_size).to(self._device) # If there are no point charges, return zeros. if xyz_mm.shape[1] == 0: @@ -479,7 +479,7 @@ def forward( # Compute the static energy. if self._method == "mm": - q_core = self._q_core_mm + q_core = self._q_core_mm.expand(batch_size, -1) q_val = _torch.zeros_like( q_core, dtype=self._charges_mm.dtype, device=self._device ) From bccb1b569c3c3e6368dd227613483ebd9ea18be6 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Tue, 26 Nov 2024 15:06:56 +0000 Subject: [PATCH 4/7] Fix get_mesh_data calculation to handle batches with different molecules by applying mask correctly --- emle/models/_emle.py | 3 ++- emle/models/_emle_base.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 1a3b4bc..85fa86f 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -484,8 +484,9 @@ def forward( q_core, dtype=self._charges_mm.dtype, device=self._device ) + mask = self._atomic_numbers > 0 mesh_data = self._emle_base._get_mesh_data( - xyz_qm_bohr, xyz_mm_bohr, s + xyz_qm_bohr, xyz_mm_bohr, s, mask ) if self._method == "mechanical": diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 763648b..e46e50d 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -944,7 +944,7 @@ def _get_vpot_mu(mu: Tensor, T1: Tensor) -> Tensor: @staticmethod def _get_mesh_data( - xyz: Tensor, xyz_mesh: Tensor, s: Tensor + xyz: Tensor, xyz_mesh: Tensor, s: Tensor, mask: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: """ Internal method, calculates mesh_data object. @@ -961,6 +961,9 @@ def _get_mesh_data( s: torch.Tensor (N_BATCH, MAX_QM_ATOMS,) MBIS valence widths. + mask: torch.Tensor (N_BATCH, MAX_QM_ATOMS) + Mask for padded coordinates. + Returns ------- @@ -969,11 +972,16 @@ def _get_mesh_data( """ rr = xyz_mesh[:, None, :, :] - xyz[:, :, None, :] r = _torch.linalg.norm(rr, ord=2, dim=3) + + # Mask for padded coordinates. + mask = mask.unsqueeze(-1) + r_inv = _torch.where(mask, 1.0 / r, 0.0) + T0_slater = _torch.where(mask, EMLEBase._get_T0_slater(r, s[:, :, None]), 0.0) return ( - 1.0 / r, - EMLEBase._get_T0_slater(r, s[:, :, None]), - -rr / r[..., None] ** 3, + r_inv, + T0_slater, + -rr * r_inv[..., None] ** 3, ) @staticmethod From 763daec556aede955513ca60ed8306803836e9a9 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Tue, 26 Nov 2024 15:51:59 +0000 Subject: [PATCH 5/7] Fix masking of ovepolarization correction --- emle/models/_emle.py | 4 ++-- emle/models/_emle_base.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 85fa86f..a161112 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -484,7 +484,7 @@ def forward( q_core, dtype=self._charges_mm.dtype, device=self._device ) - mask = self._atomic_numbers > 0 + mask = (self._atomic_numbers > 0).unsqueeze(-1) mesh_data = self._emle_base._get_mesh_data( xyz_qm_bohr, xyz_mm_bohr, s, mask ) @@ -501,7 +501,7 @@ def forward( # Compute the induced energy. if self._method == "electrostatic": E_ind = self._emle_base.get_induced_energy( - A_thole, self._charges_mm, s, mesh_data + A_thole, self._charges_mm, s, mesh_data, mask ) else: E_ind = _torch.zeros_like(E_static, dtype=self._charges_mm.dtype, device=self._device) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index e46e50d..02e1bf0 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -824,6 +824,7 @@ def get_induced_energy( charges_mm: Tensor, s: Tensor, mesh_data: Tuple[Tensor, Tensor, Tensor], + mask: Tensor, ) -> Tensor: """ Calculate the induced electrostatic energy. @@ -843,13 +844,16 @@ def get_induced_energy( mesh_data: mesh_data object (output of self._get_mesh_data) Mesh data object. + mask: torch.Tensor (N_BATCH, MAX_QM_ATOMS) + Mask for padded coordinates. + Returns ------- result: torch.Tensor (N_BATCH,) Induced electrostatic energy. """ - mu_ind = EMLEBase._get_mu_ind(A_thole, mesh_data, charges_mm, s) + mu_ind = EMLEBase._get_mu_ind(A_thole, mesh_data, charges_mm, s, mask) vpot_ind = EMLEBase._get_vpot_mu(mu_ind, mesh_data[2]) return _torch.sum(vpot_ind * charges_mm, dim=1) * 0.5 @@ -859,6 +863,7 @@ def _get_mu_ind( mesh_data: Tuple[Tensor, Tensor, Tensor], q: Tensor, s: Tensor, + mask: Tensor, ) -> Tensor: """ Internal method, calculates induced atomic dipoles @@ -881,6 +886,9 @@ def _get_mu_ind( q_val: torch.Tensor (N_BATCH, N_QM_ATOMS,) MBIS valence charges. + mask: torch.Tensor (N_BATCH, N_QM_ATOMS) + Mask for padded coordinates. + Returns ------- @@ -889,7 +897,7 @@ def _get_mu_ind( """ r = 1.0 / mesh_data[0] - f1 = EMLEBase._get_f1_slater(r, s[:, :, None] * 2.0) + f1 = _torch.where(mask, EMLEBase._get_f1_slater(r, s[:, :, None] * 2.0), 0.0) fields = _torch.sum( mesh_data[2] * f1[..., None] * q[:, None, :, None], dim=2 ).reshape(len(s), -1) @@ -974,7 +982,6 @@ def _get_mesh_data( r = _torch.linalg.norm(rr, ord=2, dim=3) # Mask for padded coordinates. - mask = mask.unsqueeze(-1) r_inv = _torch.where(mask, 1.0 / r, 0.0) T0_slater = _torch.where(mask, EMLEBase._get_T0_slater(r, s[:, :, None]), 0.0) From 37e0e389eb012d69ae267cdc28fa157c16fcc814 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Wed, 27 Nov 2024 18:03:00 +0000 Subject: [PATCH 6/7] Add batch support to the forward method of ANI2xEMLE --- emle/models/_ani.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/emle/models/_ani.py b/emle/models/_ani.py index 2034a4b..5e5cb75 100644 --- a/emle/models/_ani.py +++ b/emle/models/_ani.py @@ -320,7 +320,7 @@ def float(self): try: from NNPOps import OptimizedTorchANI as _OptimizedTorchANI - species = self._atomic_numbers.reshape(1, *atomic_numbers.shape) + species = self._atomic_numbers.reshape(1, *self._atomic_numbers.shape) self._ani2x = _OptimizedTorchANI(self._ani2x, species).to(self._device) except: pass @@ -344,42 +344,44 @@ def forward( Parameters ---------- - atomic_numbers: torch.Tensor (N_QM_ATOMS,) - Atomic numbers of QM atoms. + atomic_numbers: torch.Tensor (N_QM_ATOMS,) or (BATCH, N_QM_ATOMS) + Atomic numbers of QM atoms. A non-existent atom is represented by -1. - charges_mm: torch.Tensor (max_mm_atoms,) + charges_mm: torch.Tensor (max_mm_atoms,) or (BATCH, max_mm_atoms) MM point charges in atomic units. - xyz_qm: torch.Tensor (N_QM_ATOMS, 3) + xyz_qm: torch.Tensor (N_QM_ATOMS, 3) or (BATCH, N_QM_ATOMS, 3) Positions of QM atoms in Angstrom. - xyz_mm: torch.Tensor (N_MM_ATOMS, 3) + xyz_mm: torch.Tensor (N_MM_ATOMS, 3) or (BATCH, N_MM_ATOMS, 3) Positions of MM atoms in Angstrom. - qm_charge: int + qm_charge: int or torch.Tensor (BATCH,) The charge on the QM region. Returns ------- - result: torch.Tensor (3,) + result: torch.Tensor (3,) or (3, BATCH) The ANI2x and static and induced EMLE energy components in Hartree. """ - - # Reshape the atomic numbers. - atomic_numbers_ani = atomic_numbers.unsqueeze(0) - - # Reshape the coordinates, - xyz = xyz_qm.unsqueeze(0) + if atomic_numbers.ndim == 1: + # Batch the inputs tensors. + atomic_numbers = atomic_numbers.unsqueeze(0) + xyz_qm = xyz_qm.unsqueeze(0) + xyz_mm = xyz_mm.unsqueeze(0) + charges_mm = charges_mm.unsqueeze(0) # Get the in vacuo energy. - E_vac = self._ani2x((atomic_numbers_ani, xyz)).energies[0] + E_vac = self._ani2x((atomic_numbers, xyz_qm)).energies # If there are no point charges, return the in vacuo energy and zeros # for the static and induced terms. - if len(xyz_mm) == 0: - zero = _torch.tensor(0.0, dtype=xyz_qm.dtype, device=xyz_qm.device) - return _torch.stack([E_vac, zero, zero]) + if xyz_mm.shape[1] == 0: + zero = _torch.zeros( + atomic_numbers.shape[0], dtype=xyz_qm.dtype, device=xyz_qm.device + ) + return _torch.stack((E_vac, zero, zero)) # Set the AEVs captured by the forward hook as an attribute of the # EMLE AEVComputer instance. @@ -389,4 +391,4 @@ def forward( E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge) # Return the ANI2x and EMLE energy components. - return _torch.stack([E_vac, E_emle[0][0], E_emle[1][0]]) + return _torch.stack((E_vac, E_emle[0], E_emle[1])) From 8207fd4ceee079cad025a9309cc7c8cc857cef13 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Wed, 27 Nov 2024 18:05:27 +0000 Subject: [PATCH 7/7] Black formatting --- emle/models/_emle.py | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index a161112..d58bf61 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -429,12 +429,12 @@ def forward( Positions of MM atoms in Angstrom. qm_charge: int or torch.Tensor (BATCH,) - The charge on the QM region. + The charge on the QM region. Returns ------- - result: torch.Tensor (2,) + result: torch.Tensor (2,) or (2, BATCH) The static and induced EMLE energy components in Hartree. """ # Store the inputs as internal attributes. @@ -444,23 +444,30 @@ def forward( self._xyz_mm = xyz_mm if self._atomic_numbers.ndim == 1: - self._atomic_numbers = atomic_numbers.unsqueeze(0) - self._charges_mm = charges_mm.unsqueeze(0) - self._xyz_qm = xyz_qm.unsqueeze(0) - self._xyz_mm = xyz_mm.unsqueeze(0) - + self._atomic_numbers = self._atomic_numbers.unsqueeze(0) + self._charges_mm = self._charges_mm.unsqueeze(0) + self._xyz_qm = self._xyz_qm.unsqueeze(0) + self._xyz_mm = self._xyz_mm.unsqueeze(0) + batch_size = self._atomic_numbers.shape[0] # Ensure qm_charge is a tensor and repeat for batch size if necessary if isinstance(qm_charge, int): - qm_charge = _torch.full((batch_size,), qm_charge if qm_charge != 0 else self._qm_charge, dtype=_torch.int64, device=self._device) + qm_charge = _torch.full( + (batch_size,), + qm_charge if qm_charge != 0 else self._qm_charge, + dtype=_torch.int64, + device=self._device, + ) elif isinstance(qm_charge, _torch.Tensor): if qm_charge.ndim == 0: qm_charge = qm_charge.repeat(batch_size).to(self._device) # If there are no point charges, return zeros. if xyz_mm.shape[1] == 0: - return _torch.zeros(2, batch_size, dtype=self._xyz_qm.dtype, device=self._xyz_qm.device) + return _torch.zeros( + 2, batch_size, dtype=self._xyz_qm.dtype, device=self._xyz_qm.device + ) # Get the parameters from the base model: # valence widths, core charges, valence charges, A_thole tensor @@ -468,7 +475,7 @@ def forward( # first element of each. s, q_core, q_val, A_thole = self._emle_base( self._atomic_numbers, - self._xyz_qm, + self._xyz_qm, qm_charge, ) @@ -485,9 +492,7 @@ def forward( ) mask = (self._atomic_numbers > 0).unsqueeze(-1) - mesh_data = self._emle_base._get_mesh_data( - xyz_qm_bohr, xyz_mm_bohr, s, mask - ) + mesh_data = self._emle_base._get_mesh_data(xyz_qm_bohr, xyz_mm_bohr, s, mask) if self._method == "mechanical": q_core = q_core + q_val @@ -504,6 +509,8 @@ def forward( A_thole, self._charges_mm, s, mesh_data, mask ) else: - E_ind = _torch.zeros_like(E_static, dtype=self._charges_mm.dtype, device=self._device) - - return _torch.stack((E_static, E_ind), dim=0) \ No newline at end of file + E_ind = _torch.zeros_like( + E_static, dtype=self._charges_mm.dtype, device=self._device + ) + + return _torch.stack((E_static, E_ind), dim=0)