Skip to content

Commit

Permalink
Merge pull request #121 from neutrinoceros/allow_custom_deposit_methods
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros authored Jun 9, 2023
2 parents 40bcdb2 + 5925182 commit 1cbb243
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 44 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ GPGI stands for **G**eneric **P**article + **G**rid data **I**nterface

- [Installation](#installation)
- [Supported applications](#supported-applications)
* [Supported deposition methods](#supported-deposition-methods)
* [Builtin deposition methods](#builtin-deposition-methods)
* [Supported geometries](#supported-geometries)
- [Time complexity](#time-complexity)
- [Usage](#usage)
Expand Down Expand Up @@ -60,13 +60,17 @@ This example illustrates the simplest possible deposition method "Particle in Ce
that contains it.

More refined methods are also available.
### Supported deposition methods
### Builtin deposition methods
| method name | abreviated name | order |
|-------------------------|:---------------:|:-----:|
| Nearest Grid Point | NGP | 0 |
| Cloud in Cell | CIC | 1 |
| Triangular Shaped Cloud | TSC | 2 |

*new in gpgi 0.12*
User-defined alternative methods may be provided to `Dataset.deposit` as `method=my_func`.
Their signature need to be compatible with `gpgi.types.DepositionMethodT`.

### Supported geometries
| geometry name | axes order |
|---------------|-----------------------------|
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "gpgi"
version = "0.11.2"
version = "0.12.0"
description = "A Generic Particle+Grid Interface"
authors = [
{ name = "C.M.T. Robert" },
Expand Down
2 changes: 1 addition & 1 deletion src/gpgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

from .api import load

__version__ = "0.11.2"
__version__ = "0.12.0"
84 changes: 44 additions & 40 deletions src/gpgi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
import numpy as np

from ._boundaries import BoundaryRegistry
from .clib._deposition_methods import ( # type: ignore [import]
_deposit_cic_1D,
_deposit_cic_2D,
_deposit_cic_3D,
_deposit_ngp_1D,
_deposit_ngp_2D,
_deposit_ngp_3D,
_deposit_tsc_1D,
_deposit_tsc_2D,
_deposit_tsc_3D,
)

if TYPE_CHECKING:
from ._typing import HCIArray, RealArray
Expand Down Expand Up @@ -75,6 +86,25 @@ class DepositionMethod(enum.Enum):
]


_BUILTIN_METHODS: dict[DepositionMethod, list[DepositionMethodT]] = {
DepositionMethod.NEAREST_GRID_POINT: [
_deposit_ngp_1D,
_deposit_ngp_2D,
_deposit_ngp_3D,
],
DepositionMethod.CLOUD_IN_CELL: [
_deposit_cic_1D,
_deposit_cic_2D,
_deposit_cic_3D,
],
DepositionMethod.TRIANGULAR_SHAPED_CLOUD: [
_deposit_tsc_1D,
_deposit_tsc_2D,
_deposit_tsc_3D,
],
}


class GeometricData(Protocol):
geometry: Geometry
axes: tuple[Name, ...]
Expand Down Expand Up @@ -466,7 +496,8 @@ def deposit(
"cloud_in_cell",
"tsc",
"triangular_shaped_cloud",
],
]
| DepositionMethodT,
boundaries: dict[Name, tuple[Name, Name]] | None = None,
verbose: bool = False,
return_ghost_padded_array: bool = False,
Expand Down Expand Up @@ -517,18 +548,6 @@ def deposit(
Boundary recipes are applied the weight field (if any) first.
"""
from .clib._deposition_methods import ( # type: ignore [import]
_deposit_cic_1D,
_deposit_cic_2D,
_deposit_cic_3D,
_deposit_ngp_1D,
_deposit_ngp_2D,
_deposit_ngp_3D,
_deposit_tsc_1D,
_deposit_tsc_2D,
_deposit_tsc_3D,
)

if method in ("pic", "particle_in_cell"):
warnings.warn(
f"{method=!r} is a deprecated alias for method='ngp', "
Expand All @@ -538,13 +557,19 @@ def deposit(
)
method = "ngp"

if method in _deposition_method_names:
mkey = _deposition_method_names[method]
if callable(method):
func = method
else:
raise ValueError(
f"Unknown deposition method {method!r}, "
f"expected any of {tuple(_deposition_method_names.keys())}"
)
if method not in _deposition_method_names:
raise ValueError(
f"Unknown deposition method {method!r}, "
f"expected any of {tuple(_deposition_method_names.keys())}"
)

if (mkey := _deposition_method_names[method]) not in _BUILTIN_METHODS:
raise NotImplementedError(f"method {method} is not implemented yet")

func = _BUILTIN_METHODS[mkey][self.grid.ndim - 1]

if self.grid.size == 1:
warnings.warn(
Expand Down Expand Up @@ -577,26 +602,6 @@ def deposit(
self._sanitize_boundaries(boundaries)
self._sanitize_boundaries(weight_field_boundaries)

known_methods: dict[DepositionMethod, list[DepositionMethodT]] = {
DepositionMethod.NEAREST_GRID_POINT: [
_deposit_ngp_1D,
_deposit_ngp_2D,
_deposit_ngp_3D,
],
DepositionMethod.CLOUD_IN_CELL: [
_deposit_cic_1D,
_deposit_cic_2D,
_deposit_cic_3D,
],
DepositionMethod.TRIANGULAR_SHAPED_CLOUD: [
_deposit_tsc_1D,
_deposit_tsc_2D,
_deposit_tsc_3D,
],
}
if mkey not in known_methods:
raise NotImplementedError(f"method {method} is not implemented yet")

field = self.particles.fields[particle_field_key]
padded_ret_array = np.zeros(self.grid._padded_shape, dtype=field.dtype)
if weight_field is not None:
Expand All @@ -608,7 +613,6 @@ def deposit(

self._setup_host_cell_index(verbose=verbose)

func = known_methods[mkey][self.grid.ndim - 1]
tstart = monotonic_ns()
if weight_field is not None:
func(
Expand Down
5 changes: 5 additions & 0 deletions tests/test_deposit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unyt as un

import gpgi
from gpgi.types import _deposit_ngp_2D


@pytest.fixture()
Expand Down Expand Up @@ -93,6 +94,10 @@ def test_unknown_method(sample_2D_dataset):
sample_2D_dataset.deposit("density", method="test")


def test_callable_method(sample_2D_dataset):
sample_2D_dataset.deposit("mass", method=_deposit_ngp_2D)


@pytest.mark.parametrize("method", ["ngp", "cic", "tsc"])
@pytest.mark.mpl_image_compare
def test_2D_deposit(sample_2D_dataset, method):
Expand Down

0 comments on commit 1cbb243

Please sign in to comment.