@@ -113,9 +113,7 @@ def inverse_box(box: torch.Tensor) -> torch.Tensor:
113113
114114@deprecated ("Use wrap_positions instead" )
115115def pbc_wrap_general (
116- positions : torch .Tensor ,
117- lattice_vectors : torch .Tensor ,
118- pbc : torch .Tensor | bool = True , # noqa: FBT002
116+ positions : torch .Tensor , lattice_vectors : torch .Tensor
119117) -> torch .Tensor :
120118 """Apply periodic boundary conditions using lattice
121119 vector transformation method.
@@ -131,16 +129,10 @@ def pbc_wrap_general(
131129 containing particle positions in real space.
132130 lattice_vectors (torch.Tensor): Tensor of shape (d, d) containing
133131 lattice vectors as columns (A matrix in the equations).
134- pbc (torch.Tensor | bool): Boolean tensor of shape (3,) or boolean indicating
135- whether periodic boundary conditions are applied in each dimension.
136- If a boolean is provided, all axes are assumed to have the same periodic
137- boundary conditions.
138132
139133 Returns:
140134 torch.Tensor: Wrapped positions in real space with same shape as input positions.
141135 """
142- if isinstance (pbc , bool ):
143- pbc = torch .tensor ([pbc ] * 3 )
144136 # Validate inputs
145137 if not torch .is_floating_point (positions ) or not torch .is_floating_point (
146138 lattice_vectors
@@ -157,10 +149,7 @@ def pbc_wrap_general(
157149 frac_coords = positions @ torch .linalg .inv (lattice_vectors ).T
158150
159151 # Wrap to reference cell [0,1) using modulo
160- wrapped_frac = frac_coords .clone ()
161- wrapped_frac [:, pbc [0 ]] = frac_coords [:, pbc [0 ]] % 1.0
162- wrapped_frac [:, pbc [1 ]] = frac_coords [:, pbc [1 ]] % 1.0
163- wrapped_frac [:, pbc [2 ]] = frac_coords [:, pbc [2 ]] % 1.0
152+ wrapped_frac = frac_coords % 1.0
164153
165154 # Transform back to real space: r_row_wrapped = wrapped_f_row @ M_row
166155 return wrapped_frac @ lattice_vectors .T
@@ -223,9 +212,7 @@ def pbc_wrap_batched(
223212
224213 # Wrap to reference cell [0,1) using modulo
225214 wrapped_frac = frac_coords .clone ()
226- wrapped_frac [:, pbc [0 ]] = frac_coords [:, pbc [0 ]] % 1.0
227- wrapped_frac [:, pbc [1 ]] = frac_coords [:, pbc [1 ]] % 1.0
228- wrapped_frac [:, pbc [2 ]] = frac_coords [:, pbc [2 ]] % 1.0
215+ wrapped_frac [:, pbc ] = frac_coords [:, pbc ] % 1.0
229216
230217 # Transform back to real space: r = A·f
231218 # Get the cell for each atom based on its system index
@@ -262,7 +249,7 @@ def minimum_image_displacement(
262249 dr_frac = torch .einsum ("ij,...j->...i" , cell_inv , dr )
263250
264251 # Apply minimum image convention
265- dr_frac -= torch .round (dr_frac )
252+ dr_frac -= torch .where ( pbc , torch . round (dr_frac ), torch . zeros_like ( dr_frac ) )
266253
267254 # Convert back to cartesian
268255 return torch .einsum ("ij,...j->...i" , cell , dr_frac )
0 commit comments