Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility to combine STLs #795

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- DrivAerML dataset support in FIGConvNet example.
- Retraining recipe for DoMINO from a pretrained model checkpoint
- Added Datacenter CFD use case.
- Added a utility to combine stl meshes.

### Changed

49 changes: 49 additions & 0 deletions modulus/utils/mesh/combine_stl_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

import pyvista as pv


def combine_stls(
input_files: Union[str, List[str]], output_file: str, binary: bool = True
) -> None:
"""Combine multiple STL files into a single-body STL file using PyVista.
Also converts a single multi-body STL to a single-body STL.
Parameters
----------
input_files : Union[str, List[str]]
Path or list of paths to the input STL file(s) to be combined.
output_file : str
Path to save the combined STL file.
binary : bool, optional
Writes the file as binary when True and ASCII when False, by default True.
"""

# Ensure input_files is a list
if isinstance(input_files, str):
input_files = [input_files]

# Load all STL files as PyVista meshes
combined_mesh = pv.PolyData()
for file in input_files:
mesh = pv.read(file)
combined_mesh = combined_mesh.merge(mesh) # Merge all meshes into one

# Save the combined mesh as an STL file
combined_mesh.save(output_file, binary=binary)
104 changes: 104 additions & 0 deletions test/utils/test_mesh_utils.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
import random
import urllib
@@ -25,6 +26,63 @@
stl = pytest.importorskip("stl")


def compute_checksum(file_path):
"""Compute the SHA256 checksum of a given file."""
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()


def write_multi_body_stl(filename, center=(0, 0, 0), side_length=1.0):
"""Creates a multi-body ASCII STL cube."""

half = side_length / 2
cx, cy, cz = center

# Define vertices
vertices = np.array(
[
[cx - half, cy - half, cz - half], # 0
[cx + half, cy - half, cz - half], # 1
[cx + half, cy + half, cz - half], # 2
[cx - half, cy + half, cz - half], # 3
[cx - half, cy - half, cz + half], # 4
[cx + half, cy - half, cz + half], # 5
[cx + half, cy + half, cz + half], # 6
[cx - half, cy + half, cz + half], # 7
]
)

# Define faces
faces = [
([0, 1, 2], [0, 2, 3]), # Bottom
([4, 5, 6], [4, 6, 7]), # Top
([0, 1, 5], [0, 5, 4]), # Front
([2, 3, 7], [2, 7, 6]), # Back
([0, 3, 7], [0, 7, 4]), # Left
([1, 2, 6], [1, 6, 5]), # Right
]

with open(filename, "w") as f:
for i, (tri1, tri2) in enumerate(faces):
f.write(f"solid body_{i}\n")
for tri in [tri1, tri2]:
v1, v2, v3 = vertices[tri]
normal = np.cross(v2 - v1, v3 - v1)
normal /= np.linalg.norm(normal) # Normalize the normal vector

f.write(f" facet normal {normal[0]} {normal[1]} {normal[2]}\n")
f.write(" outer loop\n")
f.write(f" vertex {v1[0]} {v1[1]} {v1[2]}\n")
f.write(f" vertex {v2[0]} {v2[1]} {v2[2]}\n")
f.write(f" vertex {v3[0]} {v3[1]} {v3[2]}\n")
f.write(" endloop\n")
f.write(" endfacet\n")
f.write(f"endsolid body_{i}\n")


@pytest.fixture
def download_stl(tmp_path):
url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Stanford_Bunny.stl"
@@ -226,3 +284,49 @@ def test_stl_gen(pytestconfig, backend, download_stl, tmp_path):
saved_stl = mesh.Mesh.from_file(str(output_filename))

assert saved_stl.vectors is not None


@pytest.fixture
def generate_test_stls(tmp_path):
"""Fixture to generate STL files in a temporary directory before running tests."""

from modulus.utils.mesh.combine_stl_files import combine_stls

cube_1_path = tmp_path / "cube_1.stl"
cube_2_path = tmp_path / "cube_2.stl"
cube_1_combined_path = tmp_path / "cube_1_combined.stl"
all_cubes_combined_path = tmp_path / "all_cubes_combined.stl"

# Generate STL files in tmp directory
write_multi_body_stl(cube_1_path, center=(0, 0, 0), side_length=1.0)
write_multi_body_stl(cube_2_path, center=(2, 2, 2), side_length=1.0)

# Combine STL files
combine_stls(
input_files=str(cube_1_path),
output_file=str(cube_1_combined_path),
)

combine_stls(
input_files=[str(cube_1_path), str(cube_2_path)],
output_file=str(all_cubes_combined_path),
)

return {
"cube_1_combined": cube_1_combined_path,
"all_cubes_combined": all_cubes_combined_path,
}


@import_or_fail(["pyvista"])
def test_combined_stl(generate_test_stls, pytestconfig):
"""Test to check combining stls."""

EXPECTED_CHECKSUMS = {
"cube_1_combined": "b5be925cbdfe6867a782c94321a4702cef397b4d17139dab2453d9ee8cbe0998",
"all_cubes_combined": "c2830f65700dfa66b3e65d16a982d76cd9841c4b7fbd37e2597c5b409acd1fee",
}

for key, expected_checksum in EXPECTED_CHECKSUMS.items():
computed_checksum = compute_checksum(generate_test_stls[key])
assert computed_checksum == expected_checksum, f"Checksum mismatch for {key}"