Skip to content

Commit

Permalink
Move _get_neighbour_pairs to base class and make static.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Aug 9, 2024
1 parent 044305c commit 0ffc0c4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 55 deletions.
55 changes: 55 additions & 0 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,3 +911,58 @@ def _lambda5(au3):
result: torch.Tensor (N_ATOMS * 3, N_ATOMS * 3)
"""
return 1 - (1 + au3) * _torch.exp(-au3)

@staticmethod
def _get_neighbor_pairs(
positions: _torch.Tensor,
cell: Optional[_torch.Tensor],
cutoff: float,
dtype: _torch.dtype,
device: _torch.device,
) -> Tuple[_torch.Tensor, _torch.Tensor]:
"""
Get the shifts and edge indices.
Notes
-----
This method calculates the shifts and edge indices by determining neighbor pairs (``neighbors``)
and respective wrapped distances (``wrappedDeltas``) using ``NNPOps.neighbors.getNeighborPairs``.
After obtaining the ``neighbors`` and ``wrappedDeltas``, the pairs with negative indices (r>cutoff)
are filtered out, and the edge indices and shifts are finally calculated.
Parameters
----------
positions : _torch.Tensor
The positions of the atoms.
cell : _torch.Tensor
The cell vectors.
cutoff : float
The cutoff distance.
dtype : _torch.dtype
The data type.
device : _torch.device
The device.
Returns
-------
edgeIndex : _torch.Tensor
The edge indices.
shifts : _torch.Tensor
The shifts.
"""
# Get the neighbor pairs, shifts and edge indices.
neighbors, wrapped_deltas, _, _ = _getNeighborPairs(positions, cutoff, -1, cell)
mask = neighbors >= 0
neighbors = neighbors[mask].view(2, -1)
wrapped_deltas = wrapped_deltas[mask[0], :]

edge_index = _torch.hstack((neighbors, neighbors.flip(0))).to(_torch.int64)
if cell is not None:
deltas = positions[edge_index[0]] - positions[edge_index[1]]
wrapped_deltas = _torch.vstack((wrapped_deltas, -wrapped_deltas))
shifts_idx = _torch.mm(deltas - wrapped_deltas, _torch.linalg.inv(cell))
shifts = _torch.mm(shifts_idx, cell)
else:
shifts = _torch.zeros((edge_index.shape[1], 3), dtype=dtype, device=device)

return edge_index, shifts
55 changes: 0 additions & 55 deletions emle/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,61 +254,6 @@ def _get_node_attrs(self, atomic_numbers: _torch.Tensor) -> _torch.Tensor:
ids = self._atomic_numbers_to_indices(atomic_numbers, z_table=self._z_table)
return self._to_one_hot(ids, num_classes=len(self._z_table))

def _get_neighbor_pairs(
self,
positions: _torch.Tensor,
cell: Optional[_torch.Tensor],
cutoff: float,
dtype: _torch.dtype,
device: _torch.device,
) -> Tuple[_torch.Tensor, _torch.Tensor]:
"""
Get the shifts and edge indices.
Notes
-----
This method calculates the shifts and edge indices by determining neighbor pairs (``neighbors``)
and respective wrapped distances (``wrappedDeltas``) using ``NNPOps.neighbors.getNeighborPairs``.
After obtaining the ``neighbors`` and ``wrappedDeltas``, the pairs with negative indices (r>cutoff)
are filtered out, and the edge indices and shifts are finally calculated.
Parameters
----------
positions : _torch.Tensor
The positions of the atoms.
cell : _torch.Tensor
The cell vectors.
cutoff : float
The cutoff distance.
dtype : _torch.dtype
The data type.
device : _torch.device
The device.
Returns
-------
edgeIndex : _torch.Tensor
The edge indices.
shifts : _torch.Tensor
The shifts.
"""
# Get the neighbor pairs, shifts and edge indices.
neighbors, wrapped_deltas, _, _ = _getNeighborPairs(positions, cutoff, -1, cell)
mask = neighbors >= 0
neighbors = neighbors[mask].view(2, -1)
wrapped_deltas = wrapped_deltas[mask[0], :]

edge_index = _torch.hstack((neighbors, neighbors.flip(0))).to(_torch.int64)
if cell is not None:
deltas = positions[edge_index[0]] - positions[edge_index[1]]
wrapped_deltas = _torch.vstack((wrapped_deltas, -wrapped_deltas))
shifts_idx = _torch.mm(deltas - wrapped_deltas, _torch.linalg.inv(cell))
shifts = _torch.mm(shifts_idx, cell)
else:
shifts = _torch.zeros((edge_index.shape[1], 3), dtype=dtype, device=device)

return edge_index, shifts

def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion on the model.
Expand Down

0 comments on commit 0ffc0c4

Please sign in to comment.