Skip to content

[ENH] Refactor mesh extraction and add scalar field at interface to structural elements | | GEN-12031 #1019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion gempy/core/data/geo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def solutions(self) -> Solutions:
return self._solutions

@solutions.setter
def solutions(self, value):
def solutions(self, value: Solutions):
# * This is set from the gempy engine

self._solutions = value
Expand All @@ -161,6 +161,8 @@ def solutions(self, value):

# * Set solutions per element
for e, element in enumerate(self.structural_frame.structural_elements[:-1]): # * Ignore basement
element.scalar_field_at_interface = value.scalar_field_at_surface_points[e]

if self._solutions.dc_meshes is None:
continue
dc_mesh = self._solutions.dc_meshes[e]
Expand Down
2 changes: 1 addition & 1 deletion gempy/core/data/structural_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class StructuralElement:
# ? Should we extract this to a separate class?
vertices: Optional[np.ndarray] = None #: The vertices of the element in 3D space.
edges: Optional[np.ndarray] = None #: The edges of the element in 3D space.
scalar_field: Optional[float] = None #: The scalar field value for the element.
scalar_field_at_interface: Optional[float] = None #: The scalar field value for the element.

_id: int = -1

Expand Down
86 changes: 12 additions & 74 deletions gempy/modules/mesh_extranction/marching_cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,91 +34,27 @@ def set_meshes_with_marching_cubes(model: GeoModel) -> None:

output_lvl0: list[InterpOutput] = model.solutions.octrees_output[0].outputs_centers

# TODO: How to get this properly in gempy
# get a list of indices of the lithological groups
lith_group_indices = []
fault_group_indices = []
index = 0
for i in model.structural_frame.structural_groups:
if i.is_fault:
fault_group_indices.append(index)
else:
lith_group_indices.append(index)
index += 1

# extract scalar field values at surface points
scalar_values = model.solutions.raw_arrays.scalar_field_at_surface_points

# TODO: Here I just get my own masks, cause the gempy masks dont work as expected
masks = _get_masking_arrays(lith_group_indices, model, scalar_values)

# TODO: Attribute of element.scalar_field was None, changed it to scalar field value of that element
# This should probably be done somewhere else and maybe renamed to scalar_field_value?
# This is just the most basic solution to be clear what I did
_set_scalar_field_to_element(model, output_lvl0, structural_groups)

# Trying to use the exiting gempy masks
# masks = []
# masks.append(
# np.ones_like(model.solutions.raw_arrays.scalar_field_matrix[0].reshape(model.grid.regular_grid.resolution),
# dtype=bool))
# for idx in lith_group_indices:
# output_group: InterpOutput = output_lvl0[idx]
# masks.append(output_group.mask_components[8:].reshape(model.grid.regular_grid.resolution))

non_fault_counter = 0
for e, structural_group in enumerate(structural_groups):
if e >= len(output_lvl0):
continue

# Outdated?
# output_group: InterpOutput = output_lvl0[e]
# scalar_field_matrix = output_group.exported_fields_dense_grid.scalar_field

# Specify the correct scalar field, can be removed in the future
scalar_field = model.solutions.raw_arrays.scalar_field_matrix[e].reshape(model.grid.regular_grid.resolution)

# pick mask depending on whether the structural group is a fault or not
if structural_group.is_fault:
mask = np.ones_like(scalar_field, dtype=bool)
output_group: InterpOutput = output_lvl0[e]
scalar_field_matrix = output_group.exported_fields_dense_grid.scalar_field
if structural_group.is_fault is False:
slice_: slice = output_group.grid.dense_grid_slice
mask = output_group.combined_scalar_field.squeezed_mask_array[slice_]
else:
mask = masks[non_fault_counter] # TODO: I need the entry without faults here
non_fault_counter += 1
mask = np.ones_like(scalar_field_matrix, dtype=bool)

for element in structural_group.elements:
extract_mesh_for_element(
structural_element=element,
regular_grid=regular_grid,
scalar_field=scalar_field,
scalar_field=scalar_field_matrix,
mask=mask
)


# TODO: This should be set somewhere else
def _set_scalar_field_to_element(model, output_lvl0, structural_groups):
counter = 0
for e, structural_group in enumerate(structural_groups):
if e >= len(output_lvl0):
continue

for element in structural_group.elements:
element.scalar_field = model.solutions.scalar_field_at_surface_points[counter]
counter += 1


# TODO: This should be set somewhere else
def _get_masking_arrays(lith_group_indices, model, scalar_values):
masks = []
masks.append(np.ones_like(model.solutions.raw_arrays.scalar_field_matrix[0].reshape(model.grid.regular_grid.resolution),
dtype=bool))
for idx in lith_group_indices:
mask = model.solutions.raw_arrays.scalar_field_matrix[idx].reshape(model.grid.regular_grid.resolution) <= \
scalar_values[idx][-1]

masks.append(mask)
return masks


def extract_mesh_for_element(structural_element: StructuralElement,
regular_grid: RegularGrid,
scalar_field: np.ndarray,
Expand All @@ -138,10 +74,12 @@ def extract_mesh_for_element(structural_element: StructuralElement,
"""
# Extract mesh using marching cubes
verts, faces, _, _ = measure.marching_cubes(
volume=scalar_field,
level=structural_element.scalar_field,
volume=scalar_field.reshape(regular_grid.resolution),
level=structural_element.scalar_field_at_interface,
spacing=(regular_grid.dx, regular_grid.dy, regular_grid.dz),
mask=mask
mask=mask.reshape(regular_grid.resolution) if mask is not None else None,
allow_degenerate=False,
method="lewiner"
)

# Adjust vertices to correct coordinates in the model's extent
Expand Down
10 changes: 9 additions & 1 deletion test/test_modules/test_marching_cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,20 @@ def test_marching_cubes_implementation():
# assert arrays.scalar_field_matrix.shape == (3, 8_000) # * 3 surfaces, 8000 points

marching_cubes.set_meshes_with_marching_cubes(model)

# Assert
assert model.solutions.block_solution_type == RawArraysSolution.BlockSolutionType.DENSE_GRID
assert model.solutions.dc_meshes is None
assert model.structural_frame.structural_groups[0].elements[0].vertices.shape == (600, 3)
assert model.structural_frame.structural_groups[1].elements[0].vertices.shape == (860, 3)
assert model.structural_frame.structural_groups[2].elements[0].vertices.shape == (1_256, 3)
assert model.structural_frame.structural_groups[2].elements[1].vertices.shape == (1_680, 3)

if PLOT:
gpv = require_gempy_viewer()
gtv: gpv.GemPyToVista = gpv.plot_3d(
model=model,
show_data=True,
image=False,
image=True,
show=True
)