diff --git a/jax_cfd/spectral/utils.py b/jax_cfd/spectral/utils.py index 97a6b49..4cc9cad 100644 --- a/jax_cfd/spectral/utils.py +++ b/jax_cfd/spectral/utils.py @@ -129,8 +129,8 @@ def brick_wall_filter_2d(grid: grids.Grid): """Implements the 2/3 rule.""" n, m = grid.shape filter_ = jnp.zeros((n, m // 2 + 1)) - filter_ = filter_.at[:int(2 / 3 * n) // 2, :int(2 / 3 * (m // 2 + 1))].set(1) - filter_ = filter_.at[-int(2 / 3 * n) // 2:, :int(2 / 3 * (m // 2 + 1))].set(1) + filter_ = filter_.at[:(int(2 / 3 * n) // 2 + 1), :int(2 / 3 * (m // 2 + 1))].set(1) + filter_ = filter_.at[-(int(2 / 3 * n) // 2):, :int(2 / 3 * (m // 2 + 1))].set(1) return filter_