diff --git a/CHANGELOG.md b/CHANGELOG.md index b01a48c5fb..b0bc9fed7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - DrivAerML dataset support in FIGConvNet example. - Retraining recipe for DoMINO from a pretrained model checkpoint - Added Datacenter CFD use case. +- Added a utility to combine stl meshes. ### Changed diff --git a/modulus/utils/mesh/combine_stl_files.py b/modulus/utils/mesh/combine_stl_files.py new file mode 100644 index 0000000000..b98f5bbb43 --- /dev/null +++ b/modulus/utils/mesh/combine_stl_files.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import pyvista as pv + + +def combine_stls( + input_files: Union[str, List[str]], output_file: str, binary: bool = True +) -> None: + """Combine multiple STL files into a single-body STL file using PyVista. + Also converts a single multi-body STL to a single-body STL. + + Parameters + ---------- + input_files : Union[str, List[str]] + Path or list of paths to the input STL file(s) to be combined. + output_file : str + Path to save the combined STL file. + binary : bool, optional + Writes the file as binary when True and ASCII when False, by default True. + """ + + # Ensure input_files is a list + if isinstance(input_files, str): + input_files = [input_files] + + # Load all STL files as PyVista meshes + combined_mesh = pv.PolyData() + for file in input_files: + mesh = pv.read(file) + combined_mesh = combined_mesh.merge(mesh) # Merge all meshes into one + + # Save the combined mesh as an STL file + combined_mesh.save(output_file, binary=binary) diff --git a/test/utils/test_mesh_utils.py b/test/utils/test_mesh_utils.py index 8127b065cc..6eec694592 100644 --- a/test/utils/test_mesh_utils.py +++ b/test/utils/test_mesh_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib import os import random import urllib @@ -25,6 +26,63 @@ stl = pytest.importorskip("stl") +def compute_checksum(file_path): + """Compute the SHA256 checksum of a given file.""" + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + + +def write_multi_body_stl(filename, center=(0, 0, 0), side_length=1.0): + """Creates a multi-body ASCII STL cube.""" + + half = side_length / 2 + cx, cy, cz = center + + # Define vertices + vertices = np.array( + [ + [cx - half, cy - half, cz - half], # 0 + [cx + half, cy - half, cz - half], # 1 + [cx + half, cy + half, cz - half], # 2 + [cx - half, cy + half, cz - half], # 3 + [cx - half, cy - half, cz + half], # 4 + [cx + half, cy - half, cz + half], # 5 + [cx + half, cy + half, cz + half], # 6 + [cx - half, cy + half, cz + half], # 7 + ] + ) + + # Define faces + faces = [ + ([0, 1, 2], [0, 2, 3]), # Bottom + ([4, 5, 6], [4, 6, 7]), # Top + ([0, 1, 5], [0, 5, 4]), # Front + ([2, 3, 7], [2, 7, 6]), # Back + ([0, 3, 7], [0, 7, 4]), # Left + ([1, 2, 6], [1, 6, 5]), # Right + ] + + with open(filename, "w") as f: + for i, (tri1, tri2) in enumerate(faces): + f.write(f"solid body_{i}\n") + for tri in [tri1, tri2]: + v1, v2, v3 = vertices[tri] + normal = np.cross(v2 - v1, v3 - v1) + normal /= np.linalg.norm(normal) # Normalize the normal vector + + f.write(f" facet normal {normal[0]} {normal[1]} {normal[2]}\n") + f.write(" outer loop\n") + f.write(f" vertex {v1[0]} {v1[1]} {v1[2]}\n") + f.write(f" vertex {v2[0]} {v2[1]} {v2[2]}\n") + f.write(f" vertex {v3[0]} {v3[1]} {v3[2]}\n") + f.write(" endloop\n") + f.write(" endfacet\n") + f.write(f"endsolid body_{i}\n") + + @pytest.fixture def download_stl(tmp_path): url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Stanford_Bunny.stl" @@ -226,3 +284,49 @@ def test_stl_gen(pytestconfig, backend, download_stl, tmp_path): saved_stl = mesh.Mesh.from_file(str(output_filename)) assert saved_stl.vectors is not None + + +@pytest.fixture +def generate_test_stls(tmp_path): + """Fixture to generate STL files in a temporary directory before running tests.""" + + from modulus.utils.mesh.combine_stl_files import combine_stls + + cube_1_path = tmp_path / "cube_1.stl" + cube_2_path = tmp_path / "cube_2.stl" + cube_1_combined_path = tmp_path / "cube_1_combined.stl" + all_cubes_combined_path = tmp_path / "all_cubes_combined.stl" + + # Generate STL files in tmp directory + write_multi_body_stl(cube_1_path, center=(0, 0, 0), side_length=1.0) + write_multi_body_stl(cube_2_path, center=(2, 2, 2), side_length=1.0) + + # Combine STL files + combine_stls( + input_files=str(cube_1_path), + output_file=str(cube_1_combined_path), + ) + + combine_stls( + input_files=[str(cube_1_path), str(cube_2_path)], + output_file=str(all_cubes_combined_path), + ) + + return { + "cube_1_combined": cube_1_combined_path, + "all_cubes_combined": all_cubes_combined_path, + } + + +@import_or_fail(["pyvista"]) +def test_combined_stl(generate_test_stls, pytestconfig): + """Test to check combining stls.""" + + EXPECTED_CHECKSUMS = { + "cube_1_combined": "b5be925cbdfe6867a782c94321a4702cef397b4d17139dab2453d9ee8cbe0998", + "all_cubes_combined": "c2830f65700dfa66b3e65d16a982d76cd9841c4b7fbd37e2597c5b409acd1fee", + } + + for key, expected_checksum in EXPECTED_CHECKSUMS.items(): + computed_checksum = compute_checksum(generate_test_stls[key]) + assert computed_checksum == expected_checksum, f"Checksum mismatch for {key}"