@@ -95,7 +95,7 @@ def ase_to_torch_batch(
9595 [0.2117993724186579 , 1.0208820183960539 , 7.305899571570074 ],
9696 ],
9797 "numbers" : [* [20 ] * 2 , * [24 ] * 2 , * [15 ] * 4 , * [8 ] * 14 ],
98- "pbc" : torch . Tensor ( [True , True , True ]) ,
98+ "pbc" : [True , True , True ],
9999}
100100
101101
@@ -258,7 +258,7 @@ def test_neighbor_list_implementations(
258258 # Convert to torch tensors
259259 pos = torch .tensor (atoms .positions , device = DEVICE , dtype = DTYPE )
260260 row_vector_cell = torch .tensor (atoms .cell .array , device = DEVICE , dtype = DTYPE )
261- pbc : torch . Tensor = torch .tensor (atoms .pbc , device = DEVICE , dtype = DTYPE )
261+ pbc = torch .tensor (atoms .pbc , device = DEVICE , dtype = DTYPE )
262262
263263 # Get the neighbor list from the implementation being tested
264264 mapping , shifts = nl_implementation (
@@ -366,16 +366,12 @@ def test_primitive_neighbor_list_edge_cases() -> None:
366366 cutoff = torch .tensor (1.5 , device = DEVICE , dtype = DTYPE )
367367
368368 # Test all PBC combinations
369- for pbc in [
370- torch .Tensor ([True , False , False ]),
371- torch .Tensor ([False , True , False ]),
372- torch .Tensor ([False , False , True ]),
373- ]:
369+ for pbc in [[True , False , False ], [False , True , False ], [False , False , True ]]:
374370 idx_i , idx_j , _shifts = neighbors .primitive_neighbor_list (
375371 quantities = "ijS" ,
376372 positions = pos ,
377373 cell = cell ,
378- pbc = pbc ,
374+ pbc = torch . tensor ( pbc , device = DEVICE , dtype = DTYPE ) ,
379375 cutoff = cutoff ,
380376 device = DEVICE ,
381377 dtype = DTYPE ,
@@ -404,14 +400,11 @@ def test_standard_nl_edge_cases() -> None:
404400 cutoff = torch .tensor (1.5 , device = DEVICE , dtype = DTYPE )
405401
406402 # Test different PBC combinations
407- for pbc in (
408- torch .Tensor ([True , True , True ]),
409- torch .Tensor ([False , False , False ]),
410- ):
403+ for pbc in (True , False ):
411404 mapping , _shifts = neighbors .standard_nl (
412405 positions = pos ,
413406 cell = cell ,
414- pbc = pbc ,
407+ pbc = torch . tensor ([ pbc ] * 3 , device = DEVICE , dtype = DTYPE ) ,
415408 cutoff = cutoff ,
416409 )
417410 assert len (mapping [0 ]) > 0 # Should find neighbors
0 commit comments