Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/schnetpack/transform/neighborlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff):

class PymatgenNeighborList(NeighborListTransform):
"""
Calculate neighbor list using pymatgen.
Calculate neighbor list using pymatgen. Automatically casts Z and positions to np.float64.
"""

def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff):
Expand All @@ -241,6 +241,9 @@ def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff):
device = positions.device
dtype = positions.dtype

cell_np = cell_np.astype(np.float64, copy=False)
pos_np = pos_np.astype(np.float64, copy=False)

idx_i, idx_j, offsets, distances = find_points_in_spheres(
pos_np,
pos_np,
Expand Down
19 changes: 18 additions & 1 deletion tests/data/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,28 @@ def neighbor_list(request):
return neighbor_lists[request.param]


@pytest.fixture(params=[0, 1])
def precision(request):
precisions = [
torch.float64,
torch.float32,
]
return precisions[request.param]


class TestNeighborLists:
"""
Test for different neighbor lists defined in neighbor_list using the Argon environment fixtures (periodic and
non-periodic).

"""

def test_neighbor_list(self, neighbor_list, environment):
def test_neighbor_list(self, neighbor_list, environment, precision):
cutoff, props, neighbors_ref = environment

if precision == torch.float32:
_ = CastTo32()(props)

neighbor_list = neighbor_list(cutoff)
neighbors = neighbor_list(props)
R = props[structure.R]
Expand All @@ -44,6 +57,10 @@ def test_neighbor_list(self, neighbor_list, environment):
neighbors_ref = self._sort_neighbors(neighbors_ref)

for nbl, nbl_ref in zip(neighbors, neighbors_ref):

if nbl_ref.dtype == torch.float64:
nbl_ref = nbl_ref.to(dtype=precision)

torch.testing.assert_close(nbl, nbl_ref)

def _sort_neighbors(self, neighbors):
Expand Down