Skip to content

Commit

Permalink
Mask the AEV features.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Jul 31, 2024
1 parent 706313b commit 18e1f49
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class _AEVCalculator:
Calculates AEV feature vectors for a given system
"""

def __init__(self, aev_computer):
def __init__(self, aev_computer, aev_mask):
"""
Constructor
Expand All @@ -70,10 +70,11 @@ def __init__(self, aev_computer):
aev_computer: torchani.aev.AEVComputer
Computer for AEV features.
device: torch device
The PyTorch device to use for calculations.
aev_mask: torch.tensor (N_REF, )
The mask for the AEV features.
"""
self._aev_computer = aev_computer
self._aev_mask = aev_mask

def __call__(self, zid, xyz):
"""
Expand Down Expand Up @@ -102,7 +103,7 @@ def __call__(self, zid, xyz):
xyz = xyz.reshape(1, *xyz.shape) * _BOHR_TO_ANGSTROM

# Compute the AEVs.
aev = self._aev_computer((zid, xyz))[1][0]
aev = self._aev_computer((zid, xyz))[1][0][:, self._aev_mask]
return aev / _torch.linalg.norm(aev, axis=1, keepdims=True)


Expand Down Expand Up @@ -1132,7 +1133,9 @@ def __init__(
ani2x = _torchani.models.ANI2x(periodic_table_index=True).to(self._device)
self._aev_computer = ani2x.aev_computer

self._get_features = _AEVCalculator(self._aev_computer)
# Create the AEVCaclulator.
aev_mask = _torch.tensor(self._params["aev_mask"], dtype=_torch.bool, device=self._device)
self._get_features = _AEVCalculator(self._aev_computer, aev_mask)

self._q_core = _torch.tensor(
self._params["q_core"], dtype=_torch.float32, device=self._device
Expand Down

0 comments on commit 18e1f49

Please sign in to comment.