Skip to content

Commit 94065f9

Browse files
committed
more simplification changes
1 parent 7521d50 commit 94065f9

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

tests/test_neighbors.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def ase_to_torch_batch(
9595
[0.2117993724186579, 1.0208820183960539, 7.305899571570074],
9696
],
9797
"numbers": [*[20] * 2, *[24] * 2, *[15] * 4, *[8] * 14],
98-
"pbc": torch.Tensor([True, True, True]),
98+
"pbc": [True, True, True],
9999
}
100100

101101

@@ -258,7 +258,7 @@ def test_neighbor_list_implementations(
258258
# Convert to torch tensors
259259
pos = torch.tensor(atoms.positions, device=DEVICE, dtype=DTYPE)
260260
row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE)
261-
pbc: torch.Tensor = torch.tensor(atoms.pbc, device=DEVICE, dtype=DTYPE)
261+
pbc = torch.tensor(atoms.pbc, device=DEVICE, dtype=DTYPE)
262262

263263
# Get the neighbor list from the implementation being tested
264264
mapping, shifts = nl_implementation(
@@ -366,16 +366,12 @@ def test_primitive_neighbor_list_edge_cases() -> None:
366366
cutoff = torch.tensor(1.5, device=DEVICE, dtype=DTYPE)
367367

368368
# Test all PBC combinations
369-
for pbc in [
370-
torch.Tensor([True, False, False]),
371-
torch.Tensor([False, True, False]),
372-
torch.Tensor([False, False, True]),
373-
]:
369+
for pbc in [[True, False, False], [False, True, False], [False, False, True]]:
374370
idx_i, idx_j, _shifts = neighbors.primitive_neighbor_list(
375371
quantities="ijS",
376372
positions=pos,
377373
cell=cell,
378-
pbc=pbc,
374+
pbc=torch.tensor(pbc, device=DEVICE, dtype=DTYPE),
379375
cutoff=cutoff,
380376
device=DEVICE,
381377
dtype=DTYPE,
@@ -404,14 +400,11 @@ def test_standard_nl_edge_cases() -> None:
404400
cutoff = torch.tensor(1.5, device=DEVICE, dtype=DTYPE)
405401

406402
# Test different PBC combinations
407-
for pbc in (
408-
torch.Tensor([True, True, True]),
409-
torch.Tensor([False, False, False]),
410-
):
403+
for pbc in (True, False):
411404
mapping, _shifts = neighbors.standard_nl(
412405
positions=pos,
413406
cell=cell,
414-
pbc=pbc,
407+
pbc=torch.tensor([pbc] * 3, device=DEVICE, dtype=DTYPE),
415408
cutoff=cutoff,
416409
)
417410
assert len(mapping[0]) > 0 # Should find neighbors

0 commit comments

Comments
 (0)