diff --git a/CITATION.cff b/CITATION.cff index 30b3df17..7d6050d9 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -14,10 +14,17 @@ authors: family-names: Smeets given-names: Stef orcid: "0000-0002-5413-9038" + - + affiliation: "University of Twente" + family-names: "Corbijn van Willenswaard" + given-names: "Lars J." + orcid: "0000-0001-6554-1527" version: "0.6.0" repository-code: "https://github.com/hpgem/nanomesh" identifiers: keywords: + - "image-analysis" + - "finite-element-analysis" - "materials-science" - "mesh-generation" - "microscopy" diff --git a/nanomesh/_doc.py b/nanomesh/_doc.py index 96a89f02..3abddb57 100644 --- a/nanomesh/_doc.py +++ b/nanomesh/_doc.py @@ -29,11 +29,15 @@ class DocFormatterMeta(type): Updates instances of `{classname}` to `classname`. """ - def __new__(mcls, classname, bases, cls_dict): - cls = super().__new__(mcls, classname, bases, cls_dict) + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + return cls + + def __new__(mcls, classname, bases, cls_dict, **kwargs): + cls = super().__new__(mcls, classname, bases, cls_dict, **kwargs) for name, method in inspect.getmembers(cls): - is_private = name.startswith('_') + is_private = name.startswith('__') is_none = (not method) or (not method.__doc__) if any((is_private, is_none)): diff --git a/nanomesh/image/_base.py b/nanomesh/image/_base.py index 52c39523..edd3c507 100644 --- a/nanomesh/image/_base.py +++ b/nanomesh/image/_base.py @@ -1,6 +1,6 @@ import operator import os -from typing import Callable, Union +from typing import Any, Callable, Dict, Union import numpy as np @@ -38,6 +38,15 @@ class GenericImage(object, metaclass=DocFormatterMeta): image : {shape}numpy.ndarray The raw image data """ + _registry: Dict[int, Any] = {} + + def __init_subclass__(cls, ndim: int, **kwargs): + super().__init_subclass__(**kwargs) + cls._registry[ndim] = cls + + def __new__(cls, image: np.ndarray): + subclass = cls._registry.get(image.ndim, cls) + return super().__new__(subclass) def __init__(self, image: np.ndarray): self.image = image diff --git a/nanomesh/image/_plane.py b/nanomesh/image/_plane.py index a5162819..9b56ef08 100644 --- a/nanomesh/image/_plane.py +++ b/nanomesh/image/_plane.py @@ -20,7 +20,7 @@ @doc(GenericImage, prefix='Data class for working with 2D image data', shape='(i,j) ') -class Plane(GenericImage): +class Plane(GenericImage, ndim=2): @classmethod def load(cls, filename: os.PathLike, **kwargs) -> Plane: diff --git a/nanomesh/image/_volume.py b/nanomesh/image/_volume.py index f79219bf..911e3334 100644 --- a/nanomesh/image/_volume.py +++ b/nanomesh/image/_volume.py @@ -19,7 +19,7 @@ @doc(GenericImage, prefix='Generic class for working with 3D (volumetric) image data', shape='(i,j,k) ') -class Volume(GenericImage): +class Volume(GenericImage, ndim=3): @classmethod def load(cls, filename: os.PathLike, **kwargs) -> 'Volume': diff --git a/nanomesh/image2mesh/_base.py b/nanomesh/image2mesh/_base.py index 4bbe6a97..914b30b2 100644 --- a/nanomesh/image2mesh/_base.py +++ b/nanomesh/image2mesh/_base.py @@ -1,20 +1,20 @@ from __future__ import annotations import logging -from abc import ABC, abstractmethod -from typing import Union +from abc import abstractmethod +from typing import Any, Dict, Union import numpy as np from .._doc import doc -from ..image import Plane, Volume +from ..image import GenericImage from ..mesh._base import GenericMesh logger = logging.getLogger(__name__) @doc(prefix='mesh from image data') -class AbstractMesher(ABC): +class AbstractMesher: """Utility class to generate a {prefix}. Parameters @@ -31,9 +31,21 @@ class AbstractMesher(ABC): contour : GenericMesh Stores the contour mesh. """ + _registry: Dict[int, Any] = {} - def __init__(self, image: Union[np.ndarray, Plane, Volume]): - if isinstance(image, (Plane, Volume)): + def __init_subclass__(cls, ndim: int, **kwargs): + super().__init_subclass__(**kwargs) + cls._registry[ndim] = cls + + def __new__(cls, image: Union[np.ndarray, GenericImage]): + if isinstance(image, GenericImage): + image = image.image + ndim = image.ndim + subclass = cls._registry.get(ndim, cls) + return super().__new__(subclass) + + def __init__(self, image: Union[np.ndarray, GenericImage]): + if isinstance(image, GenericImage): image = image.image self.contour: GenericMesh | None = None @@ -42,10 +54,11 @@ def __init__(self, image: Union[np.ndarray, Plane, Volume]): def __repr__(self): """Canonical string representation.""" + contour_str = self.contour.__repr__(indent=4) if self.contour else None s = ( f'{self.__class__.__name__}(', f' image = {self.image!r},', - f' contour = {self.contour.__repr__(indent=4)}' + f' contour = {contour_str}' ')', ) return '\n'.join(s) diff --git a/nanomesh/image2mesh/mesher2d/_mesher.py b/nanomesh/image2mesh/mesher2d/_mesher.py index 3db35b49..3a3194f4 100644 --- a/nanomesh/image2mesh/mesher2d/_mesher.py +++ b/nanomesh/image2mesh/mesher2d/_mesher.py @@ -153,7 +153,7 @@ def _generate_segments(polygons: List[Polygon]) -> np.ndarray: @doc(AbstractMesher, prefix='triangular mesh from 2D image data') -class Mesher2D(AbstractMesher): +class Mesher2D(AbstractMesher, ndim=2): def __init__(self, image: np.ndarray | Plane): super().__init__(image) diff --git a/nanomesh/image2mesh/mesher3d/_mesher.py b/nanomesh/image2mesh/mesher3d/_mesher.py index f3ed13ae..7b8c28a3 100644 --- a/nanomesh/image2mesh/mesher3d/_mesher.py +++ b/nanomesh/image2mesh/mesher3d/_mesher.py @@ -227,7 +227,7 @@ def generate_envelope(mesh: TriangleMesh, @doc(AbstractMesher, prefix='tetrahedral mesh from 3D (volumetric) image data') -class Mesher3D(AbstractMesher): +class Mesher3D(AbstractMesher, ndim=3): def __init__(self, image: np.ndarray): super().__init__(image) diff --git a/nanomesh/mesh/__init__.py b/nanomesh/mesh/__init__.py index b209da05..3964a0e4 100644 --- a/nanomesh/mesh/__init__.py +++ b/nanomesh/mesh/__init__.py @@ -1,11 +1,8 @@ -from ._base import GenericMesh, registry +from ._base import GenericMesh from ._line import LineMesh from ._tetra import TetraMesh from ._triangle import TriangleMesh -for _mesh_class in (LineMesh, TriangleMesh, TetraMesh): - registry[_mesh_class.cell_type] = _mesh_class - __all__ = [ 'LineMesh', 'GenericMesh', diff --git a/nanomesh/mesh/_base.py b/nanomesh/mesh/_base.py index 7b7c52a3..08bd458c 100644 --- a/nanomesh/mesh/_base.py +++ b/nanomesh/mesh/_base.py @@ -10,8 +10,6 @@ from .._doc import DocFormatterMeta, doc from ..region_markers import RegionMarkerList -registry: Dict[str, Any] = {} - @doc(prefix='Generic mesh class', dim_points='n', dim_cells='j') class GenericMesh(object, metaclass=DocFormatterMeta): @@ -32,11 +30,22 @@ class GenericMesh(object, metaclass=DocFormatterMeta): Additional cell data. Argument must be a 1D numpy array matching the number of cells defined by `i`. """ + _registry: Dict[int, Any] = {} cell_type: str = 'base' + def __init_subclass__(cls, cell_dim: int, **kwargs): + super().__init_subclass__(**kwargs) + cls._registry[cell_dim] = cls + + def __new__(cls, points: np.ndarray, cells: np.ndarray, *args, **kwargs): + cell_dim = cells.shape[1] + subclass = cls._registry.get(cell_dim, cls) + return super().__new__(subclass) + def __init__(self, points: np.ndarray, cells: np.ndarray, + *, fields: Dict[str, int] = None, region_markers: RegionMarkerList = None, **cell_data): @@ -106,21 +115,7 @@ def from_meshio(cls, mesh: 'meshio.Mesh'): key = key.replace(':ref', '-ref') cell_data[key] = value[0] - return GenericMesh.create(points=points, cells=cells, **cell_data) - - @classmethod - def create(cls, points, cells, **cell_data): - """Class dispatcher.""" - cell_dimensions = cells.shape[1] - if cell_dimensions == 2: - item_class = registry['line'] - elif cell_dimensions == 3: - item_class = registry['triangle'] - elif cell_dimensions == 4: - item_class = registry['tetra'] - else: - item_class = cls - return item_class(points=points, cells=cells, **cell_data) + return GenericMesh(points=points, cells=cells, **cell_data) def write(self, *args, **kwargs): """Simple wrapper around :func:`meshio.write`.""" diff --git a/nanomesh/mesh/_line.py b/nanomesh/mesh/_line.py index b061a108..7ee00515 100644 --- a/nanomesh/mesh/_line.py +++ b/nanomesh/mesh/_line.py @@ -16,7 +16,7 @@ prefix='Data class for line meshes', dim_points='2 or 3', dim_cells='2') -class LineMesh(GenericMesh): +class LineMesh(GenericMesh, cell_dim=2): cell_type = 'line' def plot_mpl(self, *args, **kwargs) -> plt.Axes: diff --git a/nanomesh/mesh/_tetra.py b/nanomesh/mesh/_tetra.py index 31297e63..8af54f9a 100644 --- a/nanomesh/mesh/_tetra.py +++ b/nanomesh/mesh/_tetra.py @@ -11,7 +11,7 @@ prefix='Data class for tetrahedral meshes', dim_points='3', dim_cells='4') -class TetraMesh(GenericMesh): +class TetraMesh(GenericMesh, cell_dim=4): cell_type = 'tetra' def to_open3d(self): diff --git a/nanomesh/mesh/_triangle.py b/nanomesh/mesh/_triangle.py index eb2de503..1ea54114 100644 --- a/nanomesh/mesh/_triangle.py +++ b/nanomesh/mesh/_triangle.py @@ -22,7 +22,7 @@ prefix='Data class for triangle meshes', dim_points='2 or 3', dim_cells='3') -class TriangleMesh(GenericMesh, PruneZ0Mixin): +class TriangleMesh(GenericMesh, PruneZ0Mixin, cell_dim=3): cell_type = 'triangle' def plot(self, **kwargs): diff --git a/nanomesh/mesh_container.py b/nanomesh/mesh_container.py index 9035742b..d7312cc8 100644 --- a/nanomesh/mesh_container.py +++ b/nanomesh/mesh_container.py @@ -194,10 +194,10 @@ def get(self, cell_type: str = None): fields = self.field_to_number.get(cell_type, None) - return GenericMesh.create(cells=cells, - points=points, - fields=fields, - **cell_data) + return GenericMesh(cells=cells, + points=points, + fields=fields, + **cell_data) def get_all_cell_data(self, cell_type: str = None) -> dict: """Get all cell data for given `cell_type`. diff --git a/tests/conftest.py b/tests/conftest.py index 6b13e688..6e0510e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,7 @@ def line_mesh(): cells = np.zeros((5, 2), dtype=int) cell_data = {LABEL_KEY: np.arange(5)} - mesh = LineMesh.create(cells=cells, points=points, **cell_data) + mesh = LineMesh(cells=cells, points=points, **cell_data) mesh.default_key = LABEL_KEY assert isinstance(mesh, LineMesh) return mesh @@ -40,7 +40,7 @@ def triangle_mesh_2d(): cells = np.zeros((5, 3), dtype=int) cell_data = {LABEL_KEY: np.arange(5)} - mesh = TriangleMesh.create(cells=cells, points=points, **cell_data) + mesh = TriangleMesh(cells=cells, points=points, **cell_data) mesh.default_key = LABEL_KEY assert isinstance(mesh, TriangleMesh) return mesh @@ -52,7 +52,7 @@ def triangle_mesh_3d(): cells = np.zeros((5, 3), dtype=int) cell_data = {LABEL_KEY: np.arange(5)} - mesh = TriangleMesh.create(cells=cells, points=points, **cell_data) + mesh = TriangleMesh(cells=cells, points=points, **cell_data) mesh.default_key = LABEL_KEY assert isinstance(mesh, TriangleMesh) return mesh @@ -64,7 +64,7 @@ def tetra_mesh(): cells = np.zeros((5, 4), dtype=int) cell_data = {LABEL_KEY: np.arange(5)} - mesh = TetraMesh.create(cells=cells, points=points, **cell_data) + mesh = TetraMesh(cells=cells, points=points, **cell_data) assert isinstance(mesh, TetraMesh) return mesh diff --git a/tests/test_mesh.py b/tests/test_mesh.py index 57710ec7..28204f41 100644 --- a/tests/test_mesh.py +++ b/tests/test_mesh.py @@ -16,7 +16,7 @@ def test_create(n_points, n_cells, expected): points = np.arange(5 * n_points).reshape(5, n_points) cells = np.zeros((5, n_cells)) - mesh = GenericMesh.create(points=points, cells=cells) + mesh = GenericMesh(points=points, cells=cells) assert mesh.cell_type == expected diff --git a/tests/test_subclassing.py b/tests/test_subclassing.py new file mode 100644 index 00000000..31bfa0ce --- /dev/null +++ b/tests/test_subclassing.py @@ -0,0 +1,76 @@ +import numpy as np +import pytest + +from nanomesh import (LineMesh, Mesher2D, Mesher3D, Plane, TetraMesh, + TriangleMesh, Volume) +from nanomesh.image import GenericImage +from nanomesh.image2mesh._base import AbstractMesher as GenericMesher +from nanomesh.mesh import GenericMesh + +im1d = np.arange(24) +im2d = np.arange(24).reshape(6, 4) +im3d = np.arange(24).reshape(3, 4, 2) + +points = np.array(( + (1, 0, 0), + (0, 1, 0), + (0, 0, 1), + (1, 1, 1), +)) + +lines = np.array(( + (0, 1), + (1, 2), + (2, 3), +)) + +triangles = np.array(( + (0, 1, 2), + (1, 2, 3), + (3, 0, 1), +)) + +tetras = np.array(( + (0, 1, 2, 3), + (3, 2, 1, 0), +)) + +other = np.array(( + (0, ), + (1, ), + (2, ), + (3, ), +)) + + +@pytest.mark.parametrize('data,instance', ( + (im1d, GenericImage), + (im2d, Plane), + (im2d, Plane), + (im3d, Volume), +)) +def test_image_subclassing(data, instance): + image = GenericImage(data) + assert isinstance(image, instance) + + +@pytest.mark.parametrize('data,instance', ( + ((points, other), GenericMesh), + ((points, lines), LineMesh), + ((points, triangles), TriangleMesh), + ((points, tetras), TetraMesh), +)) +def test_mesh_subclassing(data, instance): + mesh = GenericMesh(*data) + assert isinstance(mesh, instance) + + +@pytest.mark.parametrize('data,instance', ( + (im2d, Mesher2D), + (Plane(im2d), Mesher2D), + (im3d, Mesher3D), + (Volume(im3d), Mesher3D), +)) +def test_mesher_subclassing(data, instance): + image = GenericMesher(data) + assert isinstance(image, instance)