diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 64f681e..4f38390 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -323,6 +323,7 @@ def __init__( q_core, aev_computer=self._aev_computer, aev_mask=aev_mask, + aev_mean=params.get("aev_mean"), alpha_mode=self._alpha_mode, species=params.get("species", self._species), device=device, diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 51d15ac..a0c9ef9 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -65,6 +65,7 @@ def __init__( q_core, aev_computer=None, aev_mask=None, + aev_mean=None, species=None, alpha_mode="species", device=None, @@ -102,6 +103,9 @@ def __init__( aev_mask: torch.Tensor Mask for features coming from aev_computer. + aev_mean: torch.Tensor + Mean values to be subtracted from features + species: List[int], Tuple[int], numpy.ndarray, torch.Tensor List of species (atomic numbers) supported by the EMLE model. @@ -203,6 +207,10 @@ def __init__( else: dtype = _torch.get_default_dtype() + self._aev_mean = None + if aev_mean is not None: + self._aev_mean = _torch.tensor(aev_mean, dtype=dtype, device=device) + # Store model parameters as tensors. self.a_QEq = _torch.nn.Parameter(params["a_QEq"]) self.a_Thole = _torch.nn.Parameter(params["a_Thole"]) @@ -425,6 +433,10 @@ def forward(self, atomic_numbers, xyz_qm, q_total): # The AEVs have been pre-computed by a parent model. else: aev = self._aev[:, :, self._aev_mask] + + if self._aev_mean is not None: + aev = aev - self._aev_mean[None, None, :] + aev = aev / _torch.linalg.norm(aev, ord=2, dim=2, keepdim=True) # Compute the MBIS valence shell widths.