diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c186e94e6..586cb6ea84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Refactored DiTBlock to be more modular - Added NATTEN 2D neighborhood attention backend for DiTBlock - Migrated blood flow example to PyTorch Geometric. +- Refactored DoMINO model code and examples for performance optimizations and improved readability. - Migrated HydroGraphNet example to PyTorch Geometric. - Support for saving and loading nested `physicsnemo.Module`s. It is now possible to create nested modules with `m = Module(submodule, ...)`, and save diff --git a/docs/img/domino/combined-training-curve.png b/docs/img/domino/combined-training-curve.png new file mode 100644 index 0000000000..9a56f9d76d Binary files /dev/null and b/docs/img/domino/combined-training-curve.png differ diff --git a/docs/img/domino/drag-r2.jpg b/docs/img/domino/drag-r2.jpg new file mode 100644 index 0000000000..1411bd0277 Binary files /dev/null and b/docs/img/domino/drag-r2.jpg differ diff --git a/docs/img/domino/lift-r2.jpg b/docs/img/domino/lift-r2.jpg new file mode 100644 index 0000000000..de7813af24 Binary files /dev/null and b/docs/img/domino/lift-r2.jpg differ diff --git a/docs/img/domino/surface-training-curve.png b/docs/img/domino/surface-training-curve.png new file mode 100644 index 0000000000..992acbd424 Binary files /dev/null and b/docs/img/domino/surface-training-curve.png differ diff --git a/docs/img/domino_perf.png b/docs/img/domino_perf.png new file mode 100644 index 0000000000..0038267354 Binary files /dev/null and b/docs/img/domino_perf.png differ diff --git a/examples/cfd/external_aerodynamics/domino/README.md b/examples/cfd/external_aerodynamics/domino/README.md index d66d42f45a..ee456cc573 100644 --- a/examples/cfd/external_aerodynamics/domino/README.md +++ b/examples/cfd/external_aerodynamics/domino/README.md @@ -77,19 +77,24 @@ please refer to their [paper](https://arxiv.org/pdf/2408.11969). #### Data Preprocessing -`PhysicsNeMo` has a related project to help with data processing, called [PhysicsNeMo-Curator](https://github.com/NVIDIA/physicsnemo-curator). +`PhysicsNeMo` has a related project to help with data processing, called +[PhysicsNeMo-Curator](https://github.com/NVIDIA/physicsnemo-curator). Using `PhysicsNeMo-Curator`, the data needed to train a DoMINO model can be setup easily. -Please refer to [these instructions on getting started](https://github.com/NVIDIA/physicsnemo-curator?tab=readme-ov-file#what-is-physicsnemo-curator) +Please refer to +[these instructions on getting started](https://github.com/NVIDIA/physicsnemo-curator?tab=readme-ov-file#what-is-physicsnemo-curator) with `PhysicsNeMo-Curator`. -Download the DrivAer ML dataset using the [provided instructions in PhysicsNeMo-Curator](https://github.com/NVIDIA/physicsnemo-curator/blob/main/examples/external_aerodynamics/domino/README.md#download-drivaerml-dataset). +Download the DrivAer ML dataset using the +[provided instructions in PhysicsNeMo-Curator](https://github.com/NVIDIA/physicsnemo-curator/blob/main/examples/external_aerodynamics/domino/README.md#download-drivaerml-dataset). The first step for running the DoMINO pipeline requires processing the raw data (vtp, vtu and stl) into either Zarr or NumPy format for training. Each of the raw simulations files are downloaded in `vtp`, `vtu` and `stl` formats. For instructions on running data processing to produce a DoMINO training ready dataset, -please refer to [How-to Curate data for DoMINO Model](https://github.com/NVIDIA/physicsnemo-curator/blob/main/examples/external_aerodynamics/domino/README.md). +please refer to +[How-to Curate data for DoMINO Model](https://github.com/NVIDIA/physicsnemo-curator/blob/main/examples/external_aerodynamics/domino/README.md). -Caching is implemented in [`CachedDoMINODataset`](https://github.com/NVIDIA/physicsnemo/blob/main/physicsnemo/datapipes/cae/domino_datapipe.py#L1250). +Caching is implemented in +[`CachedDoMINODataset`](https://github.com/NVIDIA/physicsnemo/blob/main/physicsnemo/datapipes/cae/domino_datapipe.py#L1250). Optionally, users can run `cache_data.py` to save outputs of DoMINO datapipe in the `.npy` files. The DoMINO datapipe is set up to calculate Signed Distance Field and Nearest Neighbor interpolations on-the-fly during @@ -101,6 +106,36 @@ processed files. The final processed dataset should be divided and saved into 2 directories, for training and validation. +#### Data Scaling factors + +DoMINO has several data-specific configuration tools that rely on some +knowledge of the dataset: + +- The output fields (the labels) are normalized during training to a mean + of zero and a standard deviation of one, averaged over the dataset. + The scaling is controlled by passing the `volume_factors` and + `surface_factors` values to the datapipe. +- The input locations are scaled by, and optionally cropped to, used defined + bounding boxes for both surface and volume. Whether cropping occurs, or not, + is controlled by the `sample_in_bbox` value of the datapipe. Normalization + to the bounding box is enabled with `normalize_coordinates`. By default, + both are set to true. The value of the boxes are configured in the + `config.yaml` file, and are configured separately for surface and volume. + +> Note: The datapipe module has a helper function `create_domino_dataset` +> with sensible defaults to help create a Domino Datapipe. + +To facilitate setting reasonable values of these, you can use the +`compute_statistics.py` script. This will load the core dataset as defined +in your `config.yaml` file, loop over several events (200, by default), and +both print and store the surface/volume field statistics as well as the +coordinate statistics. + +> Note that, for volumetric fields especially, the min/max found may be +> significantly outside the surface region. Many simulations extend volumetric +> sampling to far field, and you may instead want to crop significant amounts +> of volumetric distance. + #### Training Specify the training and validation data paths, bounding box sizes etc. in the @@ -176,9 +211,6 @@ The `domain_size` represents the number of GPUs used for each batch - setting but with extra overhead. `shard_grid` and `shard_points` will enable domain parallelism over the latent space and input/output points, respectively. -Please see `src/train_sharded.py` for more details regarding the changes -from the standard training script required for domain parallel DoMINO training. - As one last note regarding domain-parallel training: in the phase of the DoMINO where the output solutions are calculated, the model can used two different techniques (numerically identical) to calculate the output. Due to the @@ -189,6 +221,114 @@ launch overhead at the cost of more memory use. For non-sharded training, the `two-loop` setting is more optimal. The difference in `one-loop` or `two-loop` is purely computational, not algorithmic. +### Performance Optimizations + +The training and inference scripts for DoMINO contain several performance +enhancements to accelerate the training and usage of the model. In this +section we'll highlight several of them, as well as how to customize them +if needed. + +#### Memory Pool Optimizations + +The preprocessor of DoMINO requires a computation of k Nearest Neighbors, +which is accelerated via the `cuml` Neighbors tool. By default, `cuml` and +`torch` both use memory allocation pools to speed up allocating tensors, but +they do not use the same pool. This means that during preprocessing, it's +possible for the kNN operation to spend a significant amount of time in +memory allocations - and further, it limits the available memory to `torch`. + +To mitigate this, by default in DoMINO we use the Rapids Memory Manager +([`rmm`](https://github.com/rapidsai/rmm)). If, for some reason, you wish +to disable this you can do so with an environment variable: + +```bash +export PHYSICSNEMO_DISABLE_RMM=True +``` + +Or remove this line from the training script: + +```python +from physicsnemo.utils.memory import unified_gpu_memory +``` + +> Note - why not make it configurable? We have to set up the shared memory +> pool allocation very early in the program, before the config has even +> been read. So, we enable by default and the opt-out path is via the +> environment. + +#### Reduced Volume Reads + +The dataset size for volumetric data can be quite substantial - DrivAerML, for +example, has mesh sizes of 160M points per example. Even though the models +do not process all 160M points, in order to down sample dynamically they all +must be read from disk - which can exceed bandwidth and CPU decoding capacity +on nodes with multiple GPUs. + +As a performance enhancement, DoMINO's data pipeline offers a mitigation: instead +of reading an entire volumetric mesh, during preprocessing we _shuffle_ the +volumetric inputs and outputs (in tandem) and subsequent reads choose random +slices of the volumetric data. By default, DoMINO will read about 100x more data +than necessary for the sampling size. This allows the pipeline to still apply +cuts for data inside of the bounding box, and further random sampling to improve +training stability. To enable/disable this parameter, set +`data.volume_sample_from_disk=True` (enable) or `False` (disable) + +> Note - if you volumetric data is not larger than a few million mesh points, +> pre-shuffling and sampling from disk is likely not necessary for you. + +`physicsnemo-curator` supports shuffling the volumetric data during preprocessing. +If, however, you've already preprocessed your data and just want to apply +shuffling, use the script at `src/shuffle_volumetric_curator_output.py` + +The shuffling script will also apply sharding to the output files, which +improves IO performance. So, `zarr>=3.0` is required to use the outputs from +curator. `src/shuffle_volumetric_curator_output.py` is meant to be an example of how +to apply shuffling, so modify and update as you need for your dataset. + +> If you have tensorstore installed (it's in `requirements.txt`), the data reader +> will work equally well with Zarr 2 or Zarr 3 files. + +#### Overall Performance + +DoMINO is a computationally complex and challenging workload. Over the course +of several releases, we have chipped away at performance bottlenecks to speed +up the training and inference time (with `inference_on_stl.py`). Overall +training performance has decreased from about 5 days to just over 4 hours, with +eight H100 GPUs. We hope these optimizations enable you to explore more +parameters and surrogate models; if there is a performance issue you see, +please open an issue on GitHub. + +![Results from DoMINO for RTWT SC demo](../../../../docs/img/domino_perf.png) + +### Example Training Results + +To provide an example of what a successful training should look like, we include here +some example results. Training curves may look similar to this: + +![Combined Training Curve](../../../../docs/img/domino/combined-training-curve.png) + +And, when evaluating the results on the validation dataset, this particular +run had the following L2 and R2 Metrics: + +| Metric | Surface Only | Combined | +|--------------------:|:------------:|:--------:| +| X Velocity | N/A | 0.086 | +| Y Velocity | N/A | 0.185 | +| Z Velocity | N/A | 0.197 | +| Volumetric Pressure | N/A | 0.106 | +| Turb. V | N/A | 0.134 | +| Surface Pressure | 0.101 | 0.105 | +| X-Tau (Shear) | 0.138 | 0.145 | +| Y-Tau (Shear) | 0.174 | 0.185 | +| Z-Tau (Shear) | 0.198 | 0.207 | +| Drag R2 | 0.983 | 0.975 | +| Lift R2 | 0.971 | 0.968 | + +With the PhysicsNeMo CFD tool, you can create plots of the lift and drag +forces computed by domino vs. the CFD Solver. For example, here is the drag force: + +![Draf Force R^2](../../../../docs/img/domino/drag-r2.jpg) + ### Training with Physics Losses DoMINO supports enforcing of PDE residuals as soft constraints. This can be used diff --git a/examples/cfd/external_aerodynamics/domino/requirements.txt b/examples/cfd/external_aerodynamics/domino/requirements.txt index 4c689c85e2..1d2cfe7dd9 100644 --- a/examples/cfd/external_aerodynamics/domino/requirements.txt +++ b/examples/cfd/external_aerodynamics/domino/requirements.txt @@ -3,3 +3,4 @@ warp-lang tensorboard cuml einops +tensorstore \ No newline at end of file diff --git a/examples/cfd/external_aerodynamics/domino/src/benchmark_dataloader.py b/examples/cfd/external_aerodynamics/domino/src/benchmark_dataloader.py new file mode 100644 index 0000000000..04ca2340e9 --- /dev/null +++ b/examples/cfd/external_aerodynamics/domino/src/benchmark_dataloader.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a distributed pipeline for training the DoMINO model on +CFD datasets. It includes the computation of scaling factors, instantiating +the DoMINO model and datapipe, automatically loading the most recent checkpoint, +training the model in parallel using DistributedDataParallel across multiple +GPUs, calculating the loss and updating model parameters using mixed precision. +This is a common recipe that enables training of combined models for surface and +volume as well either of them separately. Validation is also conducted every epoch, +where predictions are compared against ground truth values. The code logs training +and validation metrics to TensorBoard. The train tab in config.yaml can be used to +specify batch size, number of epochs and other training parameters. +""" + +import time +import os +import re +import torch +import torchinfo + +from typing import Literal, Any + + +import hydra +from hydra.utils import to_absolute_path +from omegaconf import DictConfig, OmegaConf + +# This will set up the cupy-ecosystem and pytorch to share memory pools +from physicsnemo.utils.memory import unified_gpu_memory + + +import torch.distributed as dist +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter +from nvtx import annotate as nvtx_annotate +import torch.cuda.nvtx as nvtx + + +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.utils import load_checkpoint, save_checkpoint +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper + +from physicsnemo.datapipes.cae.domino_datapipe import ( + DoMINODataPipe, + compute_scaling_factors, + create_domino_dataset, +) +from physicsnemo.models.domino.model import DoMINO +from physicsnemo.utils.domino.utils import * + +# This is included for GPU memory tracking: +from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo +import time + +from utils import ( + ScalingFactors, + get_keys_to_read, + coordinate_distributed_environment, + load_scaling_factors, +) + + +from physicsnemo.utils.profiling import profile, Profiler + + +def benchmark_io_epoch( + dataloader, + logger, + gpu_handle, + epoch_index, + device, +): + dist = DistributedManager() + + # If you tell the dataloader the indices in advance, it will preload + # and pre-preprocess data + # dataloader.set_indices(indices) + + gpu_start_info = nvmlDeviceGetMemoryInfo(gpu_handle) + start_time = time.perf_counter() + for i_batch, sample_batched in enumerate(dataloader): + # Gather data and report + elapsed_time = time.perf_counter() - start_time + start_time = time.perf_counter() + gpu_end_info = nvmlDeviceGetMemoryInfo(gpu_handle) + gpu_memory_used = gpu_end_info.used / (1024**3) + gpu_memory_delta = (gpu_end_info.used - gpu_start_info.used) / (1024**3) + + logging_string = f"Device {device}, batch processed: {i_batch + 1}\n" + logging_string += f" GPU memory used: {gpu_memory_used:.3f} Gb\n" + logging_string += f" GPU memory delta: {gpu_memory_delta:.3f} Gb\n" + logging_string += f" Time taken: {elapsed_time:.2f} seconds\n" + logger.info(logging_string) + gpu_start_info = nvmlDeviceGetMemoryInfo(gpu_handle) + + return + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + # initialize distributed manager + DistributedManager.initialize() + dist = DistributedManager() + + # Initialize NVML + nvmlInit() + + gpu_handle = nvmlDeviceGetHandleByIndex(dist.device.index) + + model_type = cfg.model.model_type + + logger = PythonLogger("Train") + logger = RankZeroLoggingWrapper(logger, dist) + + logger.info(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + ################################ + # Get scaling factors + ################################ + vol_factors, surf_factors = load_scaling_factors(cfg) + + keys_to_read, keys_to_read_if_available = get_keys_to_read( + cfg, model_type, get_ground_truth=True + ) + + domain_mesh, data_mesh, placements = coordinate_distributed_environment(cfg) + + train_dataset = create_domino_dataset( + cfg, + phase="train", + keys_to_read=keys_to_read, + keys_to_read_if_available=keys_to_read_if_available, + vol_factors=vol_factors, + surf_factors=surf_factors, + device_mesh=domain_mesh, + placements=placements, + ) + train_sampler = DistributedSampler( + train_dataset, num_replicas=data_mesh.size(), rank=data_mesh.get_local_rank() + ) + + for epoch in range(0, cfg.train.epochs): + start_time = time.perf_counter() + logger.info(f"Device {dist.device}, epoch {epoch}:") + + train_sampler.set_epoch(epoch) + + train_dataset.dataset.set_indices(list(train_sampler)) + + epoch_start_time = time.perf_counter() + with Profiler(): + benchmark_io_epoch( + dataloader=train_dataset, + logger=logger, + gpu_handle=gpu_handle, + epoch_index=epoch, + device=dist.device, + ) + epoch_end_time = time.perf_counter() + logger.info( + f"Device {dist.device}, Epoch {epoch} took {epoch_end_time - epoch_start_time:.3f} seconds" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/external_aerodynamics/domino/src/compute_statistics.py b/examples/cfd/external_aerodynamics/domino/src/compute_statistics.py new file mode 100644 index 0000000000..991105492e --- /dev/null +++ b/examples/cfd/external_aerodynamics/domino/src/compute_statistics.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compute and save scaling factors for DoMINO datasets. + +This script computes mean, standard deviation, minimum, and maximum values +for all field variables in a DoMINO dataset. The computed statistics are +saved in a structured format that can be easily loaded and used for +normalization during training and inference. + +The script uses the same configuration system as the training script, +ensuring consistency in dataset handling and processing parameters. +""" + +import os +import time +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig, OmegaConf + +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper + +from physicsnemo.datapipes.cae.domino_datapipe import compute_scaling_factors +from utils import ScalingFactors + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + """ + Main function to compute and save scaling factors. + + Args: + cfg: Hydra configuration object containing all parameters + """ + ################################ + # Initialize distributed manager + ################################ + DistributedManager.initialize() + dist = DistributedManager() + + ################################ + # Initialize logger + ################################ + logger = PythonLogger("ComputeStatistics") + logger = RankZeroLoggingWrapper(logger, dist) + + logger.info("Starting scaling factors computation") + logger.info(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") + + ################################ + # Create output directory + ################################ + output_dir = os.path.dirname(cfg.data.scaling_factors) + os.makedirs(output_dir, exist_ok=True) + + if dist.world_size > 1: + torch.distributed.barrier() + + ################################ + # Check if scaling exists + ################################ + pickle_path = output_dir + "/scaling_factors.pkl" + + try: + scaling_factors = ScalingFactors.load(pickle_path) + logger.info(f"Scaling factors loaded from: {pickle_path}") + except FileNotFoundError: + logger.info(f"Scaling factors not found at: {pickle_path}; recomputing.") + scaling_factors = None + + ################################ + # Compute scaling factors + ################################ + if scaling_factors is None: + logger.info("Computing scaling factors from dataset...") + start_time = time.perf_counter() + + target_keys = [ + "volume_fields", + "surface_fields", + "stl_centers", + "volume_mesh_centers", + "surface_mesh_centers", + ] + + mean, std, min_val, max_val = compute_scaling_factors( + cfg=cfg, + input_path=cfg.data.input_dir, + target_keys=target_keys, + max_samples=cfg.data.max_samples_for_statistics, + ) + mean = {k: m.cpu().numpy() for k, m in mean.items()} + std = {k: s.cpu().numpy() for k, s in std.items()} + min_val = {k: m.cpu().numpy() for k, m in min_val.items()} + max_val = {k: m.cpu().numpy() for k, m in max_val.items()} + + compute_time = time.perf_counter() - start_time + logger.info( + f"Scaling factors computation completed in {compute_time:.2f} seconds" + ) + + ################################ + # Create structured data object + ################################ + dataset_info = { + "input_path": cfg.data.input_dir, + "model_type": cfg.model.model_type, + "normalization": cfg.model.normalization, + "compute_time": compute_time, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "config_name": cfg.project.name, + } + + scaling_factors = ScalingFactors( + mean=mean, + std=std, + min_val=min_val, + max_val=max_val, + field_keys=target_keys, + ) + + ################################ + # Save scaling factors + ################################ + if dist.rank == 0: + # Save as structured pickle file + pickle_path = output_dir + "/scaling_factors.pkl" + scaling_factors.save(pickle_path) + logger.info(f"Scaling factors saved to: {pickle_path}") + + # Save summary report + summary_path = output_dir + "/scaling_factors_summary.txt" + with open(summary_path, "w") as f: + f.write(scaling_factors.summary()) + logger.info(f"Summary report saved to: {summary_path}") + + ################################ + # Display summary + ################################ + logger.info("Scaling factors computation summary:") + logger.info(f"Field keys processed: {scaling_factors.field_keys}") + + logger.info("Scaling factors computation completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml b/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml index 84256a0d97..b074681ce4 100644 --- a/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml +++ b/examples/cfd/external_aerodynamics/domino/src/conf/config.yaml @@ -18,7 +18,7 @@ # │ Project Details │ # └───────────────────────────────────────────┘ project: # Project name - name: AWS_Dataset + name: DrivAerML_Dataset exp_tag: 1 # Experiment tag # Main output directory. @@ -62,13 +62,13 @@ variables: global_parameters: inlet_velocity: type: vector - reference: [38.89] # vector [30, 0, 0] should be specified as [30], while [30, 30, 0] should be [30, 30]. + reference: [30.00] # vector [30, 0, 0] should be specified as [30], while [30, 30, 0] should be [30, 30]. air_density: type: scalar - reference: 1.226 + reference: 1.205 # ┌───────────────────────────────────────────┐ -# │ Training Data Configs │ +# │ Data Configs │ # └───────────────────────────────────────────┘ data: # Input directory for training and validation data input_dir: /user/data/aws_data_all/ @@ -77,10 +77,16 @@ data: # Input directory for training and validation data min: [-3.5, -2.25, -0.32] max: [8.5, 2.25, 3.00] bounding_box_surface: # Bounding box dimensions for car surface - min: [-1.1, -1.2, -0.32] - max: [4.5, 1.2, 1.3] + min: [-1.5, -1.4, -0.32] + max: [5.0, 1.4, 1.4] gpu_preprocessing: true gpu_output: true + normalize_coordinates: true + sample_in_bbox: true + sampling: true + scaling_factors: ${project_dir}/scaling_factors/scaling_factors.pkl + volume_sample_from_disk: true + max_samples_for_statistics: 200 # ┌───────────────────────────────────────────┐ # │ Domain Parallelism Settings │ @@ -95,13 +101,12 @@ domain_parallelism: # └───────────────────────────────────────────┘ model: model_type: combined # train which model? surface, volume, combined - activation: "relu" # "relu" or "gelu" + activation: "gelu" # "relu" or "gelu" loss_function: loss_type: "mse" # mse or rmse area_weighing_factor: 10000 # Generally inverse of maximum area interp_res: [128, 64, 64] # resolution of latent space 128, 64, 48 use_sdf_in_basis_func: true # SDF in basis function network - positional_encoding: false # calculate positional encoding? volume_points_sample: 8192 # Number of points to sample in volume per epoch surface_points_sample: 8192 # Number of points to sample on surface per epoch surface_sampling_algorithm: area_weighted #random or area_weighted @@ -109,7 +114,7 @@ model: num_neighbors_surface: 7 # How many neighbors on surface? num_neighbors_volume: 10 # How many neighbors on volume? combine_volume_surface: false # combine volume and surface encodings - return_volume_neighbors: true # Whether to return volume neighbors or not + return_volume_neighbors: false # Whether to return volume neighbors or not use_surface_normals: true # Use surface normals and surface areas for surface computation? use_surface_area: true # Use only surface normals and not surface area integral_loss_scaling_factor: 100 # Scale integral loss by this factor @@ -119,9 +124,6 @@ model: vol_loss_scaling: 1.0 # scale volume loss with this factor in combined mode geometry_encoding_type: both # geometry encoder type, sdf, stl, both solution_calculation_mode: two-loop # one-loop is better for sharded, two-loop is lower memory but more overhead. Physics losses are not supported via one-loop presently. - resampling_surface_mesh: # resampling of surface mesh before constructing kd tree - resample: false #false or true - points: 1_000_000 # number of points geometry_rep: # Hyperparameters for geometry representation network geo_conv: base_neurons: 32 # 256 or 64 @@ -131,8 +133,8 @@ model: surface_radii: [0.01, 0.05, 1.0] # radii for surface surface_hops: 1 # Number of surface iterations volume_hops: 1 # Number of volume iterations - volume_neighbors_in_radius: [10, 10, 10, 10] # Number of neighbors in radius for volume - surface_neighbors_in_radius: [10, 10, 10] # Number of neighbors in radius for surface + volume_neighbors_in_radius: [32, 64, 128, 256] # Number of neighbors in radius for volume + surface_neighbors_in_radius: [8, 16, 128] # Number of neighbors in radius for surface fourier_features: false num_modes: 5 activation: ${model.activation} @@ -142,6 +144,8 @@ model: processor_type: conv # conv or unet (conv is better; fno, fignet to be added) self_attention: false # can be used only with unet cross_attention: false # can be used only with unet + surface_sdf_scaling_factor: [0.01, 0.02, 0.04] # Scaling factor for SDF, smaller is more emphasis on surface + volume_sdf_scaling_factor: [0.04] # Scaling factor for SDF, smaller is more emphasis on surface nn_basis_functions: # Hyperparameters for basis function network base_layer: 512 fourier_features: true @@ -174,15 +178,35 @@ model: # └───────────────────────────────────────────┘ train: # Training configurable parameters epochs: 1000 - checkpoint_interval: 1 + checkpoint_interval: 2 dataloader: batch_size: 1 - pin_memory: false # if the preprocessing is outputing GPU data, set this to false + preload_depth: 1 + pin_memory: True # if the preprocessing is outputing GPU data, set this to false sampler: shuffle: true drop_last: false checkpoint_dir: /user/models/ # Use only for retraining add_physics_loss: false + lr_scheduler: + name: MultiStepLR # Also supports CosineAnnealingLR + milestones: [50, 200, 400, 500, 600, 700, 800, 900] # only used if lr_scheduler is MultiStepLR + gamma: 0.5 # only used if lr_scheduler is MultiStepLR + T_max: ${train.epochs} # only used if lr_scheduler is CosineAnnealingLR + eta_min: 1e-6 # only used if lr_scheduler is CosineAnnealingLR + optimizer: + name: Adam # or AdamW + lr: 0.001 + weight_decay: 0.0 + amp: + enabled: true + autocast: + dtype: torch.float16 + scaler: + _target_: torch.cuda.amp.GradScaler + enabled: ${..enabled} + clip_grad: true + grad_max_norm: 2.0 # ┌───────────────────────────────────────────┐ @@ -191,7 +215,8 @@ train: # Training configurable parameters val: # Validation configurable parameters dataloader: batch_size: 1 - pin_memory: false # if the preprocessing is outputing GPU data, set this to false + preload_depth: 1 + pin_memory: true # if the preprocessing is outputing GPU data, set this to false sampler: shuffle: true drop_last: false @@ -205,4 +230,6 @@ eval: # Testing configurable parameters checkpoint_name: DoMINO.0.455.pt # Name of checkpoint to select from saved checkpoints scaling_param_path: /user/scaling_params refine_stl: False # Automatically refine STL during inference - stencil_size: 7 # Stencil size for evaluating surface and volume model + #TODO - This was hardcoded anyways, remove it. + # stencil_size: 7 # Stencil size for evaluating surface and volume model + num_points: 1_240_000 # Number of points to sample on surface and volume per batch diff --git a/examples/cfd/external_aerodynamics/domino/src/deprecated/README.md b/examples/cfd/external_aerodynamics/domino/src/deprecated/README.md new file mode 100644 index 0000000000..fb7d062f56 --- /dev/null +++ b/examples/cfd/external_aerodynamics/domino/src/deprecated/README.md @@ -0,0 +1,5 @@ +# DoMINO Deprecation + +The files in this folder have been deprecated as of the PhysicsNeMo 25.11 release - +they are no longer officially supported. They are kept here only as a reference, +and may be removed in a future release. diff --git a/examples/cfd/external_aerodynamics/domino/src/deprecated/inference_on_stl.py b/examples/cfd/external_aerodynamics/domino/src/deprecated/inference_on_stl.py new file mode 100644 index 0000000000..b48e2b50f2 --- /dev/null +++ b/examples/cfd/external_aerodynamics/domino/src/deprecated/inference_on_stl.py @@ -0,0 +1,1617 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code defines a standalone distributed inference pipeline the DoMINO model. +This inference pipeline can be used to evaluate the model given an STL and +an inflow speed. The pre-trained model checkpoint can be specified in this script +or inferred from the config file. The results are calculated on a point cloud +sampled in the volume around the STL and on the surface of the STL. They are stored +in a dictionary, which can be written out for visualization. +""" + +import os +import time + +import hydra, re +from hydra import compose, initialize +from hydra.utils import to_absolute_path +from omegaconf import DictConfig, OmegaConf + +import numpy as np +import torch + +from physicsnemo.models.domino.model import DoMINO +from physicsnemo.utils.domino.utils import ( + unnormalize, + create_directory, + nd_interpolator, + get_filenames, + write_to_vtp, +) +from torch.cuda.amp import autocast +from torch.nn.parallel import DistributedDataParallel +from physicsnemo.distributed import DistributedManager + +from numpy.typing import NDArray +from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable +import warp as wp +from pathlib import Path +import pandas as pd +import matplotlib.pyplot as plt +import pyvista as pv + +try: + from physicsnemo.sym.geometry.tessellation import Tessellation + + SYM_AVAILABLE = True +except ImportError: + SYM_AVAILABLE = False + + +def combine_stls(stl_path, stl_files): + meshes = [] + combined_mesh = pv.PolyData() + for file in stl_files: + if ".stl" in file and "single_solid" not in file: + stl_file_path = os.path.join(stl_path, file) + reader = pv.get_reader(stl_file_path) + mesh_stl = reader.read() + combined_mesh = combined_mesh.merge(mesh_stl) + # meshes.append(mesh_stl) + break + # combined_mesh = pv.merge(meshes) + return combined_mesh + + +def plot(truth, prediction, var, save_path, axes_titles=None, plot_error=True): + if plot_error: + c = 3 + else: + c = 2 + fig, axes = plt.subplots(1, c, figsize=(15, 5)) + error = truth - prediction + # Plot Truth + im = axes[0].imshow( + truth, + cmap="jet", + vmax=np.ma.masked_invalid(truth).max(), + vmin=np.ma.masked_invalid(truth).min(), + ) + axes[0].axis("off") + cbar = fig.colorbar(im, ax=axes[0], orientation="vertical") + cbar.ax.tick_params(labelsize=12) + if axes_titles is None: + axes[0].set_title(f"{var} Truth") + else: + axes[0].set_title(axes_titles[0]) + + # Plot Predicted + im = axes[1].imshow( + prediction, + cmap="jet", + vmax=np.ma.masked_invalid(prediction).max(), + vmin=np.ma.masked_invalid(prediction).min(), + ) + axes[1].axis("off") + cbar = fig.colorbar(im, ax=axes[1], orientation="vertical") + cbar.ax.tick_params(labelsize=12) + if axes_titles is None: + axes[1].set_title(f"{var} Predicted") + else: + axes[1].set_title(axes_titles[1]) + + if plot_error: + # Plot Error + im = axes[2].imshow( + error, + cmap="jet", + vmax=np.ma.masked_invalid(error).max(), + vmin=np.ma.masked_invalid(error).min(), + ) + axes[2].axis("off") + cbar = fig.colorbar(im, ax=axes[2], orientation="vertical") + cbar.ax.tick_params(labelsize=12) + if axes_titles is None: + axes[2].set_title(f"{var} Error") + else: + axes[2].set_title(axes_titles[2]) + + MAE = np.mean(np.ma.masked_invalid((error))) + + if MAE: + fig.suptitle(f"MAE {MAE}", fontsize=18, x=0.5) + + plt.tight_layout() + + path_to_save_path = os.path.join(save_path) + plt.savefig(path_to_save_path, bbox_inches="tight", pad_inches=0.1) + plt.close() + + +@wp.kernel +def _bvh_query_distance( + mesh: wp.uint64, + points: wp.array(dtype=wp.vec3f), + max_dist: wp.float32, + sdf: wp.array(dtype=wp.float32), + sdf_hit_point: wp.array(dtype=wp.vec3f), + sdf_hit_point_id: wp.array(dtype=wp.int32), +): + """ + Computes the signed distance from each point in the given array `points` + to the mesh represented by `mesh`,within the maximum distance `max_dist`, + and stores the result in the array `sdf`. + + Parameters: + mesh (wp.uint64): The identifier of the mesh. + points (wp.array): An array of 3D points for which to compute the + signed distance. + max_dist (wp.float32): The maximum distance within which to search + for the closest point on the mesh. + sdf (wp.array): An array to store the computed signed distances. + sdf_hit_point (wp.array): An array to store the computed hit points. + sdf_hit_point_id (wp.array): An array to store the computed hit point ids. + + Returns: + None + """ + tid = wp.tid() + + res = wp.mesh_query_point_sign_winding_number(mesh, points[tid], max_dist) + + mesh_ = wp.mesh_get(mesh) + + p0 = mesh_.points[mesh_.indices[3 * res.face + 0]] + p1 = mesh_.points[mesh_.indices[3 * res.face + 1]] + p2 = mesh_.points[mesh_.indices[3 * res.face + 2]] + + p_closest = res.u * p0 + res.v * p1 + (1.0 - res.u - res.v) * p2 + + sdf[tid] = res.sign * wp.abs(wp.length(points[tid] - p_closest)) + sdf_hit_point[tid] = p_closest + sdf_hit_point_id[tid] = res.face + + +def signed_distance_field( + mesh_vertices: list[tuple[float, float, float]], + mesh_indices: NDArray[float], + input_points: list[tuple[float, float, float]], + max_dist: float = 1e8, + include_hit_points: bool = False, + include_hit_points_id: bool = False, + device: int = 0, +) -> wp.array: + """ + Computes the signed distance field (SDF) for a given mesh and input points. + + Parameters: + ---------- + mesh_vertices (list[tuple[float, float, float]]): List of vertices defining the mesh. + mesh_indices (list[tuple[int, int, int]]): List of indices defining the triangles of the mesh. + input_points (list[tuple[float, float, float]]): List of input points for which to compute the SDF. + max_dist (float, optional): Maximum distance within which to search for + the closest point on the mesh. Default is 1e8. + include_hit_points (bool, optional): Whether to include hit points in + the output. Default is False. + include_hit_points_id (bool, optional): Whether to include hit point + IDs in the output. Default is False. + + Returns: + ------- + wp.array: An array containing the computed signed distance field. + + Example: + ------- + >>> mesh_vertices = [(0, 0, 0), (1, 0, 0), (0, 1, 0)] + >>> mesh_indices = np.array((0, 1, 2)) + >>> input_points = [(0.5, 0.5, 0.5)] + >>> signed_distance_field(mesh_vertices, mesh_indices, input_points).numpy() + Module ... + array([0.5], dtype=float32) + """ + + wp.init() + # mesh = wp.Mesh( + # wp.array(mesh_vertices.cpu(), dtype=wp.vec3), wp.array(mesh_indices.cpu(), dtype=wp.int32) + # ) + mesh = wp.Mesh( + wp.from_torch(mesh_vertices, dtype=wp.vec3), + wp.from_torch(mesh_indices, dtype=wp.int32), + ) + + sdf_points = wp.from_torch(input_points, dtype=wp.vec3) + sdf = wp.zeros(shape=sdf_points.shape, dtype=wp.float32) + sdf_hit_point = wp.zeros(shape=sdf_points.shape, dtype=wp.vec3f) + sdf_hit_point_id = wp.zeros(shape=sdf_points.shape, dtype=wp.int32) + wp.launch( + kernel=_bvh_query_distance, + dim=len(sdf_points), + inputs=[mesh.id, sdf_points, max_dist, sdf, sdf_hit_point, sdf_hit_point_id], + ) + if include_hit_points and include_hit_points_id: + return ( + wp.to_torch(sdf), + wp.to_torch(sdf_hit_point), + wp.to_torch(sdf_hit_point_id), + ) + elif include_hit_points: + return (wp.to_torch(sdf), wp.to_torch(sdf_hit_point)) + elif include_hit_points_id: + return (wp.to_torch(sdf), wp.to_torch(sdf_hit_point_id)) + else: + return wp.to_torch(sdf) + + +def shuffle_array_torch(surface_vertices, geometry_points, device): + idx = torch.unsqueeze( + torch.randperm(surface_vertices.shape[0])[:geometry_points], -1 + ).to(device) + idx = idx.repeat(1, 3) + surface_sampled = torch.gather(surface_vertices, 0, idx) + return surface_sampled + + +class inferenceDataPipe: + def __init__( + self, + device: int = 0, + grid_resolution: Optional[list] = [256, 96, 64], + normalize_coordinates: bool = False, + geom_points_sample: int = 300000, + positional_encoding: bool = False, + surface_vertices=None, + surface_indices=None, + surface_areas=None, + surface_centers=None, + use_sdf_basis=False, + ): + self.surface_vertices = surface_vertices + self.surface_indices = surface_indices + self.surface_areas = surface_areas + self.surface_centers = surface_centers + self.device = device + self.grid_resolution = grid_resolution + self.normalize_coordinates = normalize_coordinates + self.geom_points_sample = geom_points_sample + self.positional_encoding = positional_encoding + self.use_sdf_basis = use_sdf_basis + torch.manual_seed(int(42 + torch.cuda.current_device())) + self.data_dict = {} + + def clear_dict(self): + del self.data_dict + + def clear_volume_dict(self): + del self.data_dict["volume_mesh_centers"] + del self.data_dict["pos_enc_closest"] + del self.data_dict["pos_normals_com"] + del self.data_dict["sdf_nodes"] + + def create_grid_torch(self, mx, mn, nres): + start_time = time.time() + dx = torch.linspace(mn[0], mx[0], nres[0], device=self.device) + dy = torch.linspace(mn[1], mx[1], nres[1], device=self.device) + dz = torch.linspace(mn[2], mx[2], nres[2], device=self.device) + + xv, yv, zv = torch.meshgrid(dx, dy, dz, indexing="ij") + xv = torch.unsqueeze(xv, -1) + yv = torch.unsqueeze(yv, -1) + zv = torch.unsqueeze(zv, -1) + grid = torch.cat((xv, yv, zv), axis=-1) + return grid + + def process_surface_mesh(self, bounding_box=None, bounding_box_surface=None): + # Use coarse mesh to calculate SDF + surface_vertices = self.surface_vertices + surface_indices = self.surface_indices + surface_areas = self.surface_areas + surface_centers = self.surface_centers + + start_time = time.time() + + if bounding_box is None: + # Create a bounding box + s_max = torch.amax(surface_vertices, 0) + s_min = torch.amin(surface_vertices, 0) + + c_max = s_max + (s_max - s_min) / 2 + c_min = s_min - (s_max - s_min) / 2 + c_min[2] = s_min[2] + else: + c_min = bounding_box[0] + c_max = bounding_box[1] + + if bounding_box_surface is None: + # Create a bounding box + s_max = torch.amax(surface_vertices, 0) + s_min = torch.amin(surface_vertices, 0) + + surf_max = s_max + (s_max - s_min) / 2 + surf_min = s_min - (s_max - s_min) / 2 + surf_min[2] = s_min[2] + else: + surf_min = bounding_box_surface[0] + surf_max = bounding_box_surface[1] + + nx, ny, nz = self.grid_resolution + + grid = self.create_grid_torch(c_max, c_min, self.grid_resolution) + grid_reshaped = torch.reshape(grid, (nx * ny * nz, 3)) + + # SDF on grid + sdf_grid = signed_distance_field( + surface_vertices, surface_indices, grid_reshaped, device=self.device + ) + sdf_grid = torch.reshape(sdf_grid, (nx, ny, nz)) + + surface_areas = torch.unsqueeze(surface_areas, -1) + center_of_mass = torch.sum(surface_centers * surface_areas, 0) / torch.sum( + surface_areas + ) + + s_grid = self.create_grid_torch(surf_max, surf_min, self.grid_resolution) + surf_grid_reshaped = torch.reshape(s_grid, (nx * ny * nz, 3)) + + surf_sdf_grid = signed_distance_field( + surface_vertices, surface_indices, surf_grid_reshaped, device=self.device + ) + surf_sdf_grid = torch.reshape(surf_sdf_grid, (nx, ny, nz)) + + if self.normalize_coordinates: + sdf_grid = ( + 2.0 + * (sdf_grid - torch.amax(grid)) + / (torch.amax(grid) - torch.amin(grid)) + - 1.0 + ) + surf_sdf_grid = ( + 2.0 + * (surf_sdf_grid - torch.amax(s_grid)) + / (torch.amax(s_grid) - torch.amin(s_grid)) + - 1.0 + ) + grid = 2.0 * (grid - c_min) / (c_max - c_min) - 1.0 + s_grid = 2.0 * (s_grid - surf_min) / (surf_max - surf_min) - 1.0 + + surface_vertices = torch.unsqueeze(surface_vertices, 0) + grid = torch.unsqueeze(grid, 0) + s_grid = torch.unsqueeze(s_grid, 0) + sdf_grid = torch.unsqueeze(sdf_grid, 0) + surf_sdf_grid = torch.unsqueeze(surf_sdf_grid, 0) + max_min = [c_min, c_max] + surf_max_min = [surf_min, surf_max] + center_of_mass = center_of_mass + + return ( + surface_vertices, + grid, + sdf_grid, + max_min, + s_grid, + surf_sdf_grid, + surf_max_min, + center_of_mass, + ) + + def sample_stl_points( + self, + num_points, + stl_centers, + stl_area, + stl_normals, + max_min, + center_of_mass, + bounding_box=None, + stencil_size=7, + ): + if bounding_box is not None: + c_max = bounding_box[1] + c_min = bounding_box[0] + else: + c_min = max_min[0] + c_max = max_min[1] + + start_time = time.time() + + nx, ny, nz = self.grid_resolution + + idx = np.arange(stl_centers.shape[0]) + # np.random.shuffle(idx) + if num_points is not None: + idx = idx[:num_points] + + surface_coordinates = stl_centers + surface_normals = stl_normals + surface_area = stl_area + + if stencil_size > 1: + interp_func = KDTree(surface_coordinates) + dd, ii = interp_func.query(surface_coordinates, k=stencil_size) + surface_neighbors = surface_coordinates[ii] + surface_neighbors = surface_neighbors[:, 1:] + 1e-6 + surface_neighbors_normals = surface_normals[ii] + surface_neighbors_normals = surface_neighbors_normals[:, 1:] + surface_neighbors_area = surface_area[ii] + surface_neighbors_area = surface_neighbors_area[:, 1:] + else: + surface_neighbors = np.expand_dims(surface_coordinates, 1) + 1e-6 + surface_neighbors_normals = np.expand_dims(surface_normals, 1) + surface_neighbors_area = np.expand_dims(surface_area, 1) + + surface_coordinates = torch.from_numpy(surface_coordinates).to(self.device) + surface_normals = torch.from_numpy(surface_normals).to(self.device) + surface_area = torch.from_numpy(surface_area).to(self.device) + surface_neighbors = torch.from_numpy(surface_neighbors).to(self.device) + surface_neighbors_normals = torch.from_numpy(surface_neighbors_normals).to( + self.device + ) + surface_neighbors_area = torch.from_numpy(surface_neighbors_area).to( + self.device + ) + + pos_normals_com = surface_coordinates - center_of_mass + + if self.normalize_coordinates: + surface_coordinates = ( + 2.0 * (surface_coordinates - c_min) / (c_max - c_min) - 1.0 + ) + surface_neighbors = ( + 2.0 * (surface_neighbors - c_min) / (c_max - c_min) - 1.0 + ) + + surface_coordinates = surface_coordinates[idx] + surface_area = surface_area[idx] + surface_normals = surface_normals[idx] + pos_normals_com = pos_normals_com[idx] + surface_coordinates = torch.unsqueeze(surface_coordinates, 0) + surface_normals = torch.unsqueeze(surface_normals, 0) + surface_area = torch.unsqueeze(surface_area, 0) + pos_normals_com = torch.unsqueeze(pos_normals_com, 0) + + surface_neighbors = surface_neighbors[idx] + surface_neighbors_normals = surface_neighbors_normals[idx] + surface_neighbors_area = surface_neighbors_area[idx] + surface_neighbors = torch.unsqueeze(surface_neighbors, 0) + surface_neighbors_normals = torch.unsqueeze(surface_neighbors_normals, 0) + surface_neighbors_area = torch.unsqueeze(surface_neighbors_area, 0) + + scaling_factors = [c_max, c_min] + + return ( + surface_coordinates, + surface_neighbors, + surface_normals, + surface_neighbors_normals, + surface_area, + surface_neighbors_area, + pos_normals_com, + scaling_factors, + idx, + ) + + def sample_points_on_surface( + self, + num_points_surf, + max_min, + center_of_mass, + stl_path, + bounding_box=None, + stencil_size=7, + ): + if bounding_box is not None: + c_max = bounding_box[1] + c_min = bounding_box[0] + else: + c_min = max_min[0] + c_max = max_min[1] + + start_time = time.time() + + nx, ny, nz = self.grid_resolution + + obj = Tessellation.from_stl(stl_path, airtight=False) + + boundary = obj.sample_boundary(num_points_surf) + surface_coordinates = np.concatenate( + [ + np.float32(boundary["x"]), + np.float32(boundary["y"]), + np.float32(boundary["z"]), + ], + axis=1, + ) + surface_normals = np.concatenate( + [ + np.float32(boundary["normal_x"]), + np.float32(boundary["normal_y"]), + np.float32(boundary["normal_z"]), + ], + axis=1, + ) + + surface_area = np.float32(boundary["area"]) + + if self.normalize_coordinates: + surface_coordinates = ( + 2.0 * (surface_coordinates - c_min) / (c_max - c_min) - 1.0 + ) + center_of_mass_normalized = ( + 2.0 * (center_of_mass - c_min) / (c_max - c_min) - 1.0 + ) + else: + center_of_mass_normalized = center_of_mass + + interp_func = KDTree(surface_coordinates) + dd, ii = interp_func.query(surface_coordinates, k=stencil_size) + surface_neighbors = surface_coordinates[ii] + surface_neighbors = surface_neighbors[:, 1:] + surface_neighbors_normals = surface_normals[ii] + surface_neighbors_normals = surface_neighbors_normals[:, 1:] + surface_neighbors_area = surface_area[ii] + surface_neighbors_area = surface_neighbors_area[:, 1:] + + surface_coordinates = torch.from_numpy(surface_coordinates).to(self.device) + surface_normals = torch.from_numpy(surface_normals).to(self.device) + surface_area = torch.from_numpy(surface_area).to(self.device) + surface_neighbors = torch.from_numpy(surface_neighbors).to(self.device) + surface_neighbors_normals = torch.from_numpy(surface_neighbors_normals).to( + self.device + ) + surface_neighbors_area = torch.from_numpy(surface_neighbors_area).to( + self.device + ) + + pos_normals_com = surface_coordinates - center_of_mass_normalized + + surface_coordinates = torch.unsqueeze(surface_coordinates, 0) + surface_normals = torch.unsqueeze(surface_normals, 0) + surface_area = torch.unsqueeze(surface_area, 0) + pos_normals_com = torch.unsqueeze(pos_normals_com, 0) + + surface_neighbors = torch.unsqueeze(surface_neighbors, 0) + surface_neighbors_normals = torch.unsqueeze(surface_neighbors_normals, 0) + surface_neighbors_area = torch.unsqueeze(surface_neighbors_area, 0) + + scaling_factors = [c_max, c_min] + + return ( + surface_coordinates, + surface_neighbors, + surface_normals, + surface_neighbors_normals, + surface_area, + surface_neighbors_area, + pos_normals_com, + scaling_factors, + ) + + def sample_points_in_volume( + self, num_points_vol, max_min, center_of_mass, bounding_box=None + ): + if bounding_box is not None: + c_max = bounding_box[1] + c_min = bounding_box[0] + else: + c_min = max_min[0] + c_max = max_min[1] + + start_time = time.time() + + nx, ny, nz = self.grid_resolution + for k in range(10): + if k > 0: + num_pts_vol = num_points_vol - int(volume_coordinates.shape[0] / 2) + else: + num_pts_vol = int(1.25 * num_points_vol) + + volume_coordinates_sub = (c_max - c_min) * torch.rand( + num_pts_vol, 3, device=self.device, dtype=torch.float32 + ) + c_min + + sdf_nodes, sdf_node_closest_point = signed_distance_field( + self.surface_vertices, + self.surface_indices, + volume_coordinates_sub, + include_hit_points=True, + device=self.device, + ) + sdf_nodes = torch.unsqueeze(sdf_nodes, -1) + + idx = torch.unsqueeze(torch.where((sdf_nodes > 0))[0], -1) + idx = idx.repeat(1, volume_coordinates_sub.shape[1]) + if k == 0: + volume_coordinates = torch.gather(volume_coordinates_sub, 0, idx) + else: + volume_coordinates_1 = torch.gather(volume_coordinates_sub, 0, idx) + volume_coordinates = torch.cat( + (volume_coordinates, volume_coordinates_1), axis=0 + ) + + if volume_coordinates.shape[0] > num_points_vol: + volume_coordinates = volume_coordinates[:num_points_vol] + break + + sdf_nodes, sdf_node_closest_point = signed_distance_field( + self.surface_vertices, + self.surface_indices, + volume_coordinates, + include_hit_points=True, + device=self.device, + ) + sdf_nodes = torch.unsqueeze(sdf_nodes, -1) + + if self.normalize_coordinates: + volume_coordinates = ( + 2.0 * (volume_coordinates - c_min) / (c_max - c_min) - 1.0 + ) + sdf_nodes = ( + 2.0 + * (sdf_nodes - torch.amax(c_max)) + / (torch.amax(c_max) - torch.amin(c_min)) + - 1.0 + ) + sdf_node_closest_point = ( + 2.0 * (sdf_node_closest_point - c_min) / (c_max - c_min) - 1.0 + ) + center_of_mass_normalized = ( + 2.0 * (center_of_mass - c_min) / (c_max - c_min) - 1.0 + ) + else: + center_of_mass_normalized = center_of_mass + + pos_normals_closest = volume_coordinates - sdf_node_closest_point + pos_normals_com = volume_coordinates - center_of_mass_normalized + + volume_coordinates = torch.unsqueeze(volume_coordinates, 0) + pos_normals_com = torch.unsqueeze(pos_normals_com, 0) + + if self.use_sdf_basis: + pos_normals_closest = torch.unsqueeze(pos_normals_closest, 0) + sdf_nodes = torch.unsqueeze(sdf_nodes, 0) + + scaling_factors = [c_max, c_min] + return ( + volume_coordinates, + pos_normals_com, + pos_normals_closest, + sdf_nodes, + scaling_factors, + ) + + +class dominoInference: + def __init__( + self, + cfg: DictConfig, + dist: None, + cached_geo_encoding: bool = False, + ): + self.cfg = cfg + self.dist = dist + self.stream_velocity = None + self.stencil_size = None + self.stl_path = None + self.stl_vertices = None + self.stl_centers = None + self.surface_areas = None + self.mesh_indices_flattened = None + self.length_scale = 1.0 + if self.dist is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = self.dist.device + + self.air_density = torch.full((1, 1), 1.205, dtype=torch.float32).to( + self.device + ) + ( + self.num_vol_vars, + self.num_surf_vars, + self.num_global_features, + ) = self.get_num_variables() + self.model = None + self.grid_resolution = torch.tensor(self.cfg.model.interp_res).to(self.device) + self.vol_factors = None + self.bounding_box_min_max = None + self.bounding_box_surface_min_max = None + self.center_of_mass = None + self.grid = None + self.geometry_encoding = None + self.geometry_encoding_surface = None + self.cached_geo_encoding = cached_geo_encoding + self.out_dict = {} + + def get_geometry_encoding(self): + return self.geometry_encoding + + def get_geometry_encoding_surface(self): + return self.geometry_encoding_surface + + def get_out_dict(self): + return self.out_dict + + def clear_out_dict(self): + self.out_dict.clear() + + def initialize_data_processor(self): + self.ifp = inferenceDataPipe( + device=self.device, + surface_vertices=self.stl_vertices, + surface_indices=self.mesh_indices_flattened, + surface_areas=self.surface_areas, + surface_centers=self.stl_centers, + grid_resolution=self.grid_resolution, + normalize_coordinates=True, + geom_points_sample=300000, + positional_encoding=False, + use_sdf_basis=self.cfg.model.use_sdf_in_basis_func, + ) + + def load_bounding_box(self): + if ( + self.cfg.data.bounding_box.min is not None + and self.cfg.data.bounding_box.max is not None + ): + c_min = torch.from_numpy( + np.array(self.cfg.data.bounding_box.min, dtype=np.float32) + ).to(self.device) + c_max = torch.from_numpy( + np.array(self.cfg.data.bounding_box.max, dtype=np.float32) + ).to(self.device) + self.bounding_box_min_max = [c_min, c_max] + + if ( + self.cfg.data.bounding_box_surface.min is not None + and self.cfg.data.bounding_box_surface.max is not None + ): + c_min = torch.from_numpy( + np.array(self.cfg.data.bounding_box_surface.min, dtype=np.float32) + ).to(self.device) + c_max = torch.from_numpy( + np.array(self.cfg.data.bounding_box_surface.max, dtype=np.float32) + ).to(self.device) + self.bounding_box_surface_min_max = [c_min, c_max] + + def load_volume_scaling_factors(self): + scaling_param_path = self.cfg.eval.scaling_param_path + vol_factors_path = os.path.join( + scaling_param_path, "volume_scaling_factors.npy" + ) + + vol_factors = np.load(vol_factors_path, allow_pickle=True) + vol_factors = torch.from_numpy(vol_factors).to(self.device) + + return vol_factors + + def load_surface_scaling_factors(self): + scaling_param_path = self.cfg.eval.scaling_param_path + surf_factors_path = os.path.join( + scaling_param_path, "surface_scaling_factors.npy" + ) + + surf_factors = np.load(surf_factors_path, allow_pickle=True) + surf_factors = torch.from_numpy(surf_factors).to(self.device) + + return surf_factors + + def read_stl(self): + stl_files = get_filenames(self.stl_path) + mesh_stl = combine_stls(self.stl_path, stl_files) + if self.cfg.eval.refine_stl: + mesh_stl = mesh_stl.subdivide( + nsub=2, subfilter="linear" + ) # .smooth(n_iter=20) + stl_vertices = mesh_stl.points + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + stl_centers = mesh_stl.cell_centers().points + # Assuming triangular elements + stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[:, 1:] + mesh_indices_flattened = stl_faces.flatten() + + surface_areas = mesh_stl.compute_cell_sizes( + length=False, area=True, volume=False + ) + surface_areas = np.array(surface_areas.cell_data["Area"]) + + surface_normals = np.array(mesh_stl.cell_normals, dtype=np.float32) + + self.stl_vertices = torch.from_numpy(np.float32(stl_vertices)).to(self.device) + self.stl_centers = torch.from_numpy(np.float32(stl_centers)).to(self.device) + self.surface_areas = torch.from_numpy(np.float32(surface_areas)).to(self.device) + self.stl_normals = -1.0 * torch.from_numpy(np.float32(surface_normals)).to( + self.device + ) + self.mesh_indices_flattened = torch.from_numpy( + np.int32(mesh_indices_flattened) + ).to(self.device) + self.length_scale = length_scale + self.mesh_stl = mesh_stl + + def read_stl_trimesh( + self, stl_vertices, stl_faces, stl_centers, surface_normals, surface_areas + ): + mesh_indices_flattened = stl_faces.flatten() + length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + self.stl_vertices = torch.from_numpy(stl_vertices).to(self.device) + self.stl_centers = torch.from_numpy(stl_centers).to(self.device) + self.stl_normals = -1.0 * torch.from_numpy(surface_normals).to(self.device) + self.surface_areas = torch.from_numpy(surface_areas).to(self.device) + self.mesh_indices_flattened = torch.from_numpy( + np.int32(mesh_indices_flattened) + ).to(self.device) + self.length_scale = length_scale + + def get_num_variables(self): + volume_variable_names = list(self.cfg.variables.volume.solution.keys()) + num_vol_vars = 0 + for j in volume_variable_names: + if self.cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + + surface_variable_names = list(self.cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if self.cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + + num_global_features = 0 + global_params_names = list(cfg.variables.global_parameters.keys()) + for param in global_params_names: + if cfg.variables.global_parameters[param].type == "vector": + num_global_features += len( + cfg.variables.global_parameters[param].reference + ) + elif cfg.variables.global_parameters[param].type == "scalar": + num_global_features += 1 + else: + raise ValueError(f"Unknown global parameter type") + + return num_vol_vars, num_surf_vars, num_global_features + + def initialize_model(self, model_path): + model = ( + DoMINO( + input_features=3, + output_features_vol=self.num_vol_vars, + output_features_surf=self.num_surf_vars, + global_features=self.num_global_features, + model_parameters=self.cfg.model, + ) + .to(self.device) + .eval() + ) + model = torch.compile(model, disable=True) + + checkpoint_iter = torch.load( + to_absolute_path(model_path), map_location=self.dist.device + ) + + model.load_state_dict(checkpoint_iter) + + if self.dist is not None: + if self.dist.world_size > 1: + model = DistributedDataParallel( + model, + device_ids=[self.dist.local_rank], + output_device=self.dist.device, + broadcast_buffers=self.dist.broadcast_buffers, + find_unused_parameters=self.dist.find_unused_parameters, + gradient_as_bucket_view=True, + static_graph=True, + ) + + self.model = model + self.vol_factors = self.load_volume_scaling_factors() + self.surf_factors = self.load_surface_scaling_factors() + self.load_bounding_box() + + def set_stream_velocity(self, stream_velocity): + self.stream_velocity = torch.full( + (1, 1), stream_velocity, dtype=torch.float32 + ).to(self.device) + + def set_stencil_size(self, stencil_size): + self.stencil_size = stencil_size + + def set_air_density(self, air_density): + self.air_density = torch.full((1, 1), air_density, dtype=torch.float32).to( + self.device + ) + + def set_stl_path(self, filename): + self.stl_path = filename + + @torch.no_grad() + def compute_geo_encoding(self, cached_geom_path=None): + start_time = time.time() + + if not self.cached_geo_encoding: + ( + surface_vertices, + grid, + sdf_grid, + max_min, + s_grid, + surf_sdf_grid, + surf_max_min, + center_of_mass, + ) = self.ifp.process_surface_mesh( + self.bounding_box_min_max, self.bounding_box_surface_min_max + ) + if self.bounding_box_min_max is None: + self.bounding_box_min_max = max_min + if self.bounding_box_surface_min_max is None: + self.bounding_box_surface_min_max = surf_max_min + self.center_of_mass = center_of_mass + self.grid = grid + self.s_grid = s_grid + self.sdf_grid = sdf_grid + self.surf_sdf_grid = surf_sdf_grid + self.out_dict["sdf"] = sdf_grid + + geo_encoding, geo_encoding_surface = self.calculate_geometry_encoding( + surface_vertices, grid, sdf_grid, s_grid, surf_sdf_grid, self.model + ) + else: + out_dict_cached = torch.load(cached_geom_path, map_location=self.device) + self.bounding_box_min_max = out_dict_cached["bounding_box_min_max"] + self.grid = out_dict_cached["grid"] + self.sdf_grid = out_dict_cached["sdf_grid"] + self.center_of_mass = out_dict_cached["com"] + geo_encoding = out_dict_cached["geo_encoding"] + geo_encoding_surface = out_dict_cached["geo_encoding_surface"] + self.out_dict["sdf"] = self.sdf_grid + torch.cuda.synchronize() + print("Time taken for geo encoding = %f" % (time.time() - start_time)) + + self.geometry_encoding = geo_encoding + self.geometry_encoding_surface = geo_encoding_surface + + def compute_forces(self): + pressure = self.out_dict["pressure_surface"] + wall_shear = self.out_dict["wall-shear-stress"] + # sampling_indices = self.out_dict["sampling_indices"] + + surface_normals = self.stl_normals[self.sampling_indices] + surface_areas = self.surface_areas[self.sampling_indices] + + drag_force = torch.sum( + pressure[0, :, 0] * surface_normals[:, 0] * surface_areas + - wall_shear[0, :, 0] * surface_areas + ) + lift_force = torch.sum( + pressure[0, :, 0] * surface_normals[:, 2] * surface_areas + - wall_shear[0, :, 2] * surface_areas + ) + + self.out_dict["drag_force"] = drag_force + self.out_dict["lift_force"] = lift_force + + @torch.inference_mode() + def compute_surface_solutions(self, num_sample_points=None, plot_solutions=False): + total_time = 0.0 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + geo_encoding = self.geometry_encoding_surface + j = 0 + + with autocast(enabled=True): + start_event.record() + ( + surface_mesh_centers, + surface_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + pos_normals_com, + surf_scaling_factors, + sampling_indices, + ) = self.ifp.sample_stl_points( + num_sample_points, + self.stl_centers.cpu().numpy(), + self.surface_areas.cpu().numpy(), + self.stl_normals.cpu().numpy(), + max_min=self.bounding_box_surface_min_max, + center_of_mass=self.center_of_mass, + stencil_size=self.stencil_size, + ) + end_event.record() + end_event.synchronize() + cur_time = start_event.elapsed_time(end_event) / 1000.0 + print(f"sample_points_in_surface time (s): {cur_time:.4f}") + # vol_coordinates_all.append(volume_mesh_centers) + surface_coordinates_all = surface_mesh_centers + + inner_time = time.time() + start_event.record() + if num_sample_points == None: + point_batch_size = 512_000 + num_points = surface_coordinates_all.shape[1] + subdomain_points = int(np.floor(num_points / point_batch_size)) + surface_solutions = torch.zeros(1, num_points, self.num_surf_vars).to( + self.device + ) + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + surface_solutions_batch = self.compute_solution_on_surface( + geo_encoding, + surface_mesh_centers[:, start_idx:end_idx], + surface_neighbors[:, start_idx:end_idx], + surface_normals[:, start_idx:end_idx], + surface_neighbors_normals[:, start_idx:end_idx], + surface_areas[:, start_idx:end_idx], + surface_neighbors_areas[:, start_idx:end_idx], + pos_normals_com[:, start_idx:end_idx], + self.s_grid, + self.model, + inlet_velocity=self.stream_velocity, + air_density=self.air_density, + ) + surface_solutions[:, start_idx:end_idx] = surface_solutions_batch + else: + point_batch_size = 512_000 + num_points = num_sample_points + subdomain_points = int(np.floor(num_points / point_batch_size)) + surface_solutions = torch.zeros(1, num_points, self.num_surf_vars).to( + self.device + ) + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + surface_solutions_batch = self.compute_solution_on_surface( + geo_encoding, + surface_mesh_centers[:, start_idx:end_idx], + surface_neighbors[:, start_idx:end_idx], + surface_normals[:, start_idx:end_idx], + surface_neighbors_normals[:, start_idx:end_idx], + surface_areas[:, start_idx:end_idx], + surface_neighbors_areas[:, start_idx:end_idx], + pos_normals_com[:, start_idx:end_idx], + self.s_grid, + self.model, + inlet_velocity=self.stream_velocity, + air_density=self.air_density, + ) + # print(torch.amax(surface_solutions_batch, (0, 1)), torch.amin(surface_solutions_batch, (0, 1))) + surface_solutions[:, start_idx:end_idx] = surface_solutions_batch + + # print(surface_solutions.shape) + end_event.record() + end_event.synchronize() + cur_time = start_event.elapsed_time(end_event) / 1000.0 + print(f"compute_solution time (s): {cur_time:.4f}") + total_time += float(time.time() - inner_time) + surface_solutions_all = surface_solutions + print( + "Time taken for compute solution on surface for=%f, %f" + % (time.time() - inner_time, torch.cuda.utilization(self.device)) + ) + cmax = surf_scaling_factors[0] + cmin = surf_scaling_factors[1] + + surface_coordinates_all = torch.reshape( + surface_coordinates_all, (1, num_points, 3) + ) + surface_solutions_all = torch.reshape(surface_solutions_all, (1, num_points, 4)) + + if self.surf_factors is not None: + surface_solutions_all = unnormalize( + surface_solutions_all, self.surf_factors[0], self.surf_factors[1] + ) + + self.out_dict["surface_coordinates"] = ( + 0.5 * (surface_coordinates_all + 1.0) * (cmax - cmin) + cmin + ) + self.out_dict["pressure_surface"] = ( + surface_solutions_all[:, :, :1] + * self.stream_velocity**2.0 + * self.air_density + ) + self.out_dict["wall-shear-stress"] = ( + surface_solutions_all[:, :, 1:4] + * self.stream_velocity**2.0 + * self.air_density + ) + self.sampling_indices = sampling_indices + + @torch.inference_mode() + def compute_volume_solutions(self, num_sample_points, plot_solutions=False): + total_time = 0.0 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + geo_encoding = self.geometry_encoding + j = 0 + + # Compute volume + point_batch_size = 512_000 + num_points = num_sample_points + subdomain_points = int(np.floor(num_points / point_batch_size)) + volume_solutions = torch.zeros(1, num_points, self.num_vol_vars).to(self.device) + volume_coordinates = torch.zeros(1, num_points, 3).to(self.device) + + for p in range(subdomain_points + 1): + start_idx = p * point_batch_size + end_idx = (p + 1) * point_batch_size + if end_idx > num_points: + point_batch_size = num_points - start_idx + end_idx = num_points + + with autocast(enabled=True): + inner_time = time.time() + start_event.record() + ( + volume_mesh_centers, + pos_normals_com, + pos_normals_closest, + sdf_nodes, + scaling_factors, + ) = self.ifp.sample_points_in_volume( + num_points_vol=point_batch_size, + max_min=self.bounding_box_min_max, + center_of_mass=self.center_of_mass, + ) + end_event.record() + end_event.synchronize() + cur_time = start_event.elapsed_time(end_event) / 1000.0 + print(f"sample_points_in_volume time (s): {cur_time:.4f}") + + volume_coordinates[:, start_idx:end_idx] = volume_mesh_centers + + start_event.record() + + volume_solutions_batch = self.compute_solution_in_volume( + geo_encoding, + volume_mesh_centers, + sdf_nodes, + pos_normals_closest, + pos_normals_com, + self.grid, + self.model, + use_sdf_basis=self.cfg.model.use_sdf_in_basis_func, + inlet_velocity=self.stream_velocity, + air_density=self.air_density, + ) + volume_solutions[:, start_idx:end_idx] = volume_solutions_batch + end_event.record() + end_event.synchronize() + cur_time = start_event.elapsed_time(end_event) / 1000.0 + print(f"compute_solution time (s): {cur_time:.4f}") + total_time += float(time.time() - inner_time) + # volume_solutions_all = volume_solutions + print( + "Time taken for compute solution in volume for =%f" + % (time.time() - inner_time) + ) + # print("Points processed:", end_idx) + print("Total time measured = %f" % total_time) + print("Points processed:", end_idx) + + cmax = scaling_factors[0] + cmin = scaling_factors[1] + volume_coordinates_all = volume_coordinates + volume_solutions_all = volume_solutions + + cmax = scaling_factors[0] + cmin = scaling_factors[1] + + volume_coordinates_all = torch.reshape( + volume_coordinates_all, (1, num_sample_points, 3) + ) + volume_solutions_all = torch.reshape( + volume_solutions_all, (1, num_sample_points, self.num_vol_vars) + ) + + if self.vol_factors is not None: + volume_solutions_all = unnormalize( + volume_solutions_all, self.vol_factors[0], self.vol_factors[1] + ) + + self.out_dict["coordinates"] = ( + 0.5 * (volume_coordinates_all + 1.0) * (cmax - cmin) + cmin + ) + self.out_dict["velocity"] = ( + volume_solutions_all[:, :, :3] * self.stream_velocity + ) + self.out_dict["pressure"] = ( + volume_solutions_all[:, :, 3:4] + * self.stream_velocity**2.0 + * self.air_density + ) + # self.out_dict["turbulent-kinetic-energy"] = ( + # volume_solutions_all[:, :, 4:5] + # * self.stream_velocity**2.0 + # * self.air_density + # ) + # self.out_dict["turbulent-viscosity"] = ( + # volume_solutions_all[:, :, 5:] * self.stream_velocity * self.length_scale + # ) + self.out_dict["bounding_box_dims"] = torch.vstack(self.bounding_box_min_max) + + if plot_solutions: + print("Plotting solutions") + plot_save_path = os.path.join(self.cfg.output, "plots/contours/") + create_directory(plot_save_path) + + p_grid = 0.5 * (self.grid + 1.0) * (cmax - cmin) + cmin + p_grid = p_grid.cpu().numpy() + sdf_grid = self.sdf_grid.cpu().numpy() + volume_coordinates_all = ( + 0.5 * (volume_coordinates_all + 1.0) * (cmax - cmin) + cmin + ) + volume_solutions_all[:, :, :3] = ( + volume_solutions_all[:, :, :3] * self.stream_velocity + ) + volume_solutions_all[:, :, 3:4] = ( + volume_solutions_all[:, :, 3:4] + * self.stream_velocity**2.0 + * self.air_density + ) + # volume_solutions_all[:, :, 4:5] = ( + # volume_solutions_all[:, :, 4:5] + # * self.stream_velocity**2.0 + # * self.air_density + # ) + # volume_solutions_all[:, :, 5] = ( + # volume_solutions_all[:, :, 5] * self.stream_velocity * self.length_scale + # ) + volume_coordinates_all = volume_coordinates_all.cpu().numpy() + volume_solutions_all = volume_solutions_all.cpu().numpy() + + # ND interpolation on a grid + prediction_grid = nd_interpolator( + volume_coordinates_all, volume_solutions_all[0], p_grid[0] + ) + nx, ny, nz, vars = prediction_grid.shape + idx = np.where(sdf_grid[0] < 0.0) + prediction_grid[idx] = float("inf") + axes_titles = ["y/4 plane", "y/2 plane"] + + plot( + prediction_grid[:, int(ny / 4), :, 0], + prediction_grid[:, int(ny / 2), :, 0], + var="x-vel", + save_path=plot_save_path + f"x-vel-midplane_{self.stream_velocity}.png", + axes_titles=axes_titles, + plot_error=False, + ) + plot( + prediction_grid[:, int(ny / 4), :, 1], + prediction_grid[:, int(ny / 2), :, 1], + var="y-vel", + save_path=plot_save_path + f"y-vel-midplane_{self.stream_velocity}.png", + axes_titles=axes_titles, + plot_error=False, + ) + plot( + prediction_grid[:, int(ny / 4), :, 2], + prediction_grid[:, int(ny / 2), :, 2], + var="z-vel", + save_path=plot_save_path + f"z-vel-midplane_{self.stream_velocity}.png", + axes_titles=axes_titles, + plot_error=False, + ) + plot( + prediction_grid[:, int(ny / 4), :, 3], + prediction_grid[:, int(ny / 2), :, 3], + var="pres", + save_path=plot_save_path + f"pres-midplane_{self.stream_velocity}.png", + axes_titles=axes_titles, + plot_error=False, + ) + # plot( + # prediction_grid[:, int(ny / 4), :, 4], + # prediction_grid[:, int(ny / 2), :, 4], + # var="tke", + # save_path=plot_save_path + f"tke-midplane_{self.stream_velocity}.png", + # axes_titles=axes_titles, + # plot_error=False, + # ) + # plot( + # prediction_grid[:, int(ny / 4), :, 5], + # prediction_grid[:, int(ny / 2), :, 5], + # var="nut", + # save_path=plot_save_path + f"nut-midplane_{self.stream_velocity}.png", + # axes_titles=axes_titles, + # plot_error=False, + # ) + + def cold_start(self, cached_geom_path=None): + print("Cold start") + self.compute_geo_encoding(cached_geom_path) + self.compute_volume_solutions(num_sample_points=10) + self.clear_out_dict() + + @torch.no_grad() + def calculate_geometry_encoding( + self, geo_centers, p_grid, sdf_grid, s_grid, sdf_surf_grid, model + ): + vol_min = self.bounding_box_min_max[0] + vol_max = self.bounding_box_min_max[1] + surf_min = self.bounding_box_surface_min_max[0] + surf_max = self.bounding_box_surface_min_max[1] + + geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + if self.dist.world_size == 1: + encoding_g_vol = model.geo_rep_volume(geo_centers_vol, p_grid, sdf_grid) + else: + encoding_g_vol = model.module.geo_rep_volume( + geo_centers_vol, p_grid, sdf_grid + ) + + geo_centers_surf = 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + + if self.dist.world_size == 1: + encoding_g_surf = model.geo_rep_surface( + geo_centers_surf, s_grid, sdf_surf_grid + ) + else: + encoding_g_surf = model.module.geo_rep_surface( + geo_centers_surf, s_grid, sdf_surf_grid + ) + + if self.dist.world_size == 1: + encoding_g_surf1 = model.geo_rep_surface1( + geo_centers_surf, s_grid, sdf_surf_grid + ) + else: + encoding_g_surf1 = model.module.geo_rep_surface1( + geo_centers_surf, s_grid, sdf_surf_grid + ) + + geo_encoding = 0.5 * encoding_g_surf1 + 0.5 * encoding_g_vol + geo_encoding_surface = 0.5 * encoding_g_surf + return geo_encoding, geo_encoding_surface + + @torch.no_grad() + def compute_solution_on_surface( + self, + geo_encoding, + surface_mesh_centers, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + pos_normals_com, + s_grid, + model, + inlet_velocity, + air_density, + ): + """ + Global parameters: For this particular case, the model was trained on single velocity/density values + across all simulations. Hence, global_params_values and global_params_reference are the same. + """ + global_params_values = torch.cat( + (inlet_velocity, air_density), axis=1 + ) # (1, 2) + global_params_values = torch.unsqueeze(global_params_values, -1) # (1, 2, 1) + + global_params_reference = torch.cat( + (inlet_velocity, air_density), axis=1 + ) # (1, 2) + global_params_reference = torch.unsqueeze( + global_params_reference, -1 + ) # (1, 2, 1) + + if self.dist.world_size == 1: + geo_encoding_local = model.geo_encoding_local( + geo_encoding, surface_mesh_centers, s_grid, mode="surface" + ) + else: + geo_encoding_local = model.module.geo_encoding_local( + geo_encoding, surface_mesh_centers, s_grid, mode="surface" + ) + + pos_encoding = pos_normals_com + surface_areas = torch.unsqueeze(surface_areas, -1) + surface_neighbors_areas = torch.unsqueeze(surface_neighbors_areas, -1) + + if self.dist.world_size == 1: + pos_encoding = model.position_encoder(pos_encoding, eval_mode="surface") + tpredictions_batch = model.calculate_solution_with_neighbors( + surface_mesh_centers, + geo_encoding_local, + pos_encoding, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + global_params_values, + global_params_reference, + ) + else: + pos_encoding = model.module.position_encoder( + pos_encoding, eval_mode="surface" + ) + tpredictions_batch = model.module.calculate_solution_with_neighbors( + surface_mesh_centers, + geo_encoding_local, + pos_encoding, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + global_params_values, + global_params_reference, + ) + + return tpredictions_batch + + @torch.no_grad() + def compute_solution_in_volume( + self, + geo_encoding, + volume_mesh_centers, + sdf_nodes, + pos_enc_closest, + pos_normals_com, + p_grid, + model, + use_sdf_basis, + inlet_velocity, + air_density, + ): + ## Global parameters + global_params_values = torch.cat( + (inlet_velocity, air_density), axis=1 + ) # (1, 2) + global_params_values = torch.unsqueeze(global_params_values, -1) # (1, 2, 1) + + global_params_reference = torch.cat( + (inlet_velocity, air_density), axis=1 + ) # (1, 2) + global_params_reference = torch.unsqueeze( + global_params_reference, -1 + ) # (1, 2, 1) + + if self.dist.world_size == 1: + geo_encoding_local = model.geo_encoding_local( + geo_encoding, volume_mesh_centers, p_grid, mode="volume" + ) + else: + geo_encoding_local = model.module.geo_encoding_local( + geo_encoding, volume_mesh_centers, p_grid, mode="volume" + ) + if use_sdf_basis: + pos_encoding = torch.cat( + (sdf_nodes, pos_enc_closest, pos_normals_com), axis=-1 + ) + else: + pos_encoding = pos_normals_com + + if self.dist.world_size == 1: + pos_encoding = model.position_encoder(pos_encoding, eval_mode="volume") + tpredictions_batch = model.calculate_solution( + volume_mesh_centers, + geo_encoding_local, + pos_encoding, + global_params_values, + global_params_reference, + num_sample_points=self.stencil_size, + eval_mode="volume", + ) + else: + pos_encoding = model.module.position_encoder( + pos_encoding, eval_mode="volume" + ) + tpredictions_batch = model.module.calculate_solution( + volume_mesh_centers, + geo_encoding_local, + pos_encoding, + global_params_values, + global_params_reference, + num_sample_points=self.stencil_size, + eval_mode="volume", + ) + return tpredictions_batch + + +if __name__ == "__main__": + OmegaConf.register_new_resolver("eval", eval) + with initialize(version_base="1.3", config_path="conf"): + cfg = compose(config_name="config") + + DistributedManager.initialize() + dist = DistributedManager() + + if dist.world_size > 1: + torch.distributed.barrier() + + input_path = cfg.eval.test_path + dirnames = get_filenames(input_path) + dev_id = torch.cuda.current_device() + num_files = int(len(dirnames) / 8) + dirnames_per_gpu = dirnames[int(num_files * dev_id) : int(num_files * (dev_id + 1))] + + domino = dominoInference(cfg, dist, False) + domino.initialize_model( + model_path="/lustre/models/DoMINO.0.7.pt" + ) ## Replace the model path with location of the trained model + + for count, dirname in enumerate(dirnames_per_gpu): + # print(f"Processing file {dirname}") + filepath = os.path.join(input_path, dirname) + + STREAM_VELOCITY = 30.0 + AIR_DENSITY = 1.205 + + # Neighborhood points sampled for evaluation, tradeoff between accuracy and speed + STENCIL_SIZE = ( + 7 # Higher stencil size -> more accuracy but more evaluation time + ) + + domino.set_stl_path(filepath) + domino.set_stream_velocity(STREAM_VELOCITY) + domino.set_stencil_size(STENCIL_SIZE) + + domino.read_stl() + + domino.initialize_data_processor() + + # Calculate geometry encoding + domino.compute_geo_encoding() + + # Calculate volume solutions + domino.compute_volume_solutions( + num_sample_points=10_256_000, plot_solutions=False + ) + + # Calculate surface solutions + domino.compute_surface_solutions() + domino.compute_forces() + out_dict = domino.get_out_dict() + + print( + "Dirname:", + dirname, + "Drag:", + out_dict["drag_force"], + "Lift:", + out_dict["lift_force"], + ) + vtp_path = f"/lustre/snidhan/physicsnemo-work/domino-global-param-runs/stl-results/pred_{dirname}_4.vtp" + domino.mesh_stl.save(vtp_path) + reader = vtk.vtkXMLPolyDataReader() + reader.SetFileName(f"{vtp_path}") + reader.Update() + polydata_surf = reader.GetOutput() + + surfParam_vtk = numpy_support.numpy_to_vtk( + out_dict["pressure_surface"][0].cpu().numpy() + ) + surfParam_vtk.SetName(f"Pressure") + polydata_surf.GetCellData().AddArray(surfParam_vtk) + + surfParam_vtk = numpy_support.numpy_to_vtk( + out_dict["wall-shear-stress"][0].cpu().numpy() + ) + surfParam_vtk.SetName(f"Wall-shear-stress") + polydata_surf.GetCellData().AddArray(surfParam_vtk) + + write_to_vtp(polydata_surf, vtp_path) + exit() diff --git a/examples/cfd/external_aerodynamics/domino/src/openfoam_datapipe.py b/examples/cfd/external_aerodynamics/domino/src/deprecated/openfoam_datapipe.py similarity index 100% rename from examples/cfd/external_aerodynamics/domino/src/openfoam_datapipe.py rename to examples/cfd/external_aerodynamics/domino/src/deprecated/openfoam_datapipe.py diff --git a/examples/cfd/external_aerodynamics/domino/src/retraining.py b/examples/cfd/external_aerodynamics/domino/src/deprecated/retraining.py similarity index 100% rename from examples/cfd/external_aerodynamics/domino/src/retraining.py rename to examples/cfd/external_aerodynamics/domino/src/deprecated/retraining.py diff --git a/examples/cfd/external_aerodynamics/domino/src/train_sharded.py b/examples/cfd/external_aerodynamics/domino/src/deprecated/train_sharded.py similarity index 99% rename from examples/cfd/external_aerodynamics/domino/src/train_sharded.py rename to examples/cfd/external_aerodynamics/domino/src/deprecated/train_sharded.py index f321f50b12..3b1c818cc2 100644 --- a/examples/cfd/external_aerodynamics/domino/src/train_sharded.py +++ b/examples/cfd/external_aerodynamics/domino/src/deprecated/train_sharded.py @@ -79,7 +79,7 @@ from physicsnemo.launch.utils import load_checkpoint, save_checkpoint from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper -from physicsnemo.datapipes.cae.domino_datapipe import ( +from physicsnemo.datapipes.cae.domino_datapipe2 import ( compute_scaling_factors, create_domino_dataset, ) diff --git a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py index a85cc7df86..7276807bfa 100644 --- a/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py +++ b/examples/cfd/external_aerodynamics/domino/src/inference_on_stl.py @@ -15,1572 +15,643 @@ # limitations under the License. """ -This code defines a standalone distributed inference pipeline the DoMINO model. -This inference pipeline can be used to evaluate the model given an STL and -an inflow speed. The pre-trained model checkpoint can be specified in this script -or inferred from the config file. The results are calculated on a point cloud -sampled in the volume around the STL and on the surface of the STL. They are stored -in a dictionary, which can be written out for visualization. +This code shows how to use a trained DoMINO model, with it's corresponding +preprocessing pipeline, to infer values on and around an STL mesh file. + +This script uses the meshes from the DrivaerML dataset, however, the logic +is largely the same. As an overview: +- Load the model +- Set up the preprocessor +- Loop over meshes +- In each mesh, sample random points on the surface, volume, or both +- Preprocess the points and run them through the model +- Process the STL mesh centers, too +- Collect the results and return +- Save the results to file. """ -import os import time +import os +import re +from typing import Literal, Any -import hydra, re -from hydra import compose, initialize +import apex +import numpy as np +import hydra from hydra.utils import to_absolute_path from omegaconf import DictConfig, OmegaConf - -import numpy as np import torch -from physicsnemo.models.domino.model import DoMINO -from physicsnemo.utils.domino.utils import ( - unnormalize, - create_directory, - nd_interpolator, - get_filenames, - write_to_vtp, -) -from torch.cuda.amp import autocast +# This will set up the cupy-ecosystem and pytorch to share memory pools +from physicsnemo.utils.memory import unified_gpu_memory + +import torchinfo +import torch.distributed as dist +from torch.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel -from physicsnemo.distributed import DistributedManager +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter +from nvtx import annotate as nvtx_annotate +import torch.cuda.nvtx as nvtx -from numpy.typing import NDArray -from typing import Any, Iterable, List, Literal, Mapping, Optional, Union, Callable -import warp as wp -from pathlib import Path -import pandas as pd -import matplotlib.pyplot as plt -import pyvista as pv - -try: - from physicsnemo.sym.geometry.tessellation import Tessellation - - SYM_AVAILABLE = True -except ImportError: - SYM_AVAILABLE = False - - -def combine_stls(stl_path, stl_files): - meshes = [] - combined_mesh = pv.PolyData() - for file in stl_files: - if ".stl" in file and "single_solid" not in file: - stl_file_path = os.path.join(stl_path, file) - reader = pv.get_reader(stl_file_path) - mesh_stl = reader.read() - combined_mesh = combined_mesh.merge(mesh_stl) - # meshes.append(mesh_stl) - break - # combined_mesh = pv.merge(meshes) - return combined_mesh - - -def plot(truth, prediction, var, save_path, axes_titles=None, plot_error=True): - if plot_error: - c = 3 - else: - c = 2 - fig, axes = plt.subplots(1, c, figsize=(15, 5)) - error = truth - prediction - # Plot Truth - im = axes[0].imshow( - truth, - cmap="jet", - vmax=np.ma.masked_invalid(truth).max(), - vmin=np.ma.masked_invalid(truth).min(), - ) - axes[0].axis("off") - cbar = fig.colorbar(im, ax=axes[0], orientation="vertical") - cbar.ax.tick_params(labelsize=12) - if axes_titles is None: - axes[0].set_title(f"{var} Truth") - else: - axes[0].set_title(axes_titles[0]) - - # Plot Predicted - im = axes[1].imshow( - prediction, - cmap="jet", - vmax=np.ma.masked_invalid(prediction).max(), - vmin=np.ma.masked_invalid(prediction).min(), - ) - axes[1].axis("off") - cbar = fig.colorbar(im, ax=axes[1], orientation="vertical") - cbar.ax.tick_params(labelsize=12) - if axes_titles is None: - axes[1].set_title(f"{var} Predicted") - else: - axes[1].set_title(axes_titles[1]) - - if plot_error: - # Plot Error - im = axes[2].imshow( - error, - cmap="jet", - vmax=np.ma.masked_invalid(error).max(), - vmin=np.ma.masked_invalid(error).min(), - ) - axes[2].axis("off") - cbar = fig.colorbar(im, ax=axes[2], orientation="vertical") - cbar.ax.tick_params(labelsize=12) - if axes_titles is None: - axes[2].set_title(f"{var} Error") - else: - axes[2].set_title(axes_titles[2]) +from physicsnemo.distributed import DistributedManager +from physicsnemo.launch.utils import load_checkpoint, save_checkpoint +from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper - MAE = np.mean(np.ma.masked_invalid((error))) +from physicsnemo.datapipes.cae.domino_datapipe import ( + DoMINODataPipe, + create_domino_dataset, +) - if MAE: - fig.suptitle(f"MAE {MAE}", fontsize=18, x=0.5) - plt.tight_layout() +from physicsnemo.models.domino.model import DoMINO +from physicsnemo.utils.domino.utils import sample_points_on_mesh - path_to_save_path = os.path.join(save_path) - plt.savefig(path_to_save_path, bbox_inches="tight", pad_inches=0.1) - plt.close() +from utils import ScalingFactors, get_keys_to_read, coordinate_distributed_environment +# This is included for GPU memory tracking: +from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo +import time -@wp.kernel -def _bvh_query_distance( - mesh: wp.uint64, - points: wp.array(dtype=wp.vec3f), - max_dist: wp.float32, - sdf: wp.array(dtype=wp.float32), - sdf_hit_point: wp.array(dtype=wp.vec3f), - sdf_hit_point_id: wp.array(dtype=wp.int32), -): - """ - Computes the signed distance from each point in the given array `points` - to the mesh represented by `mesh`,within the maximum distance `max_dist`, - and stores the result in the array `sdf`. - - Parameters: - mesh (wp.uint64): The identifier of the mesh. - points (wp.array): An array of 3D points for which to compute the - signed distance. - max_dist (wp.float32): The maximum distance within which to search - for the closest point on the mesh. - sdf (wp.array): An array to store the computed signed distances. - sdf_hit_point (wp.array): An array to store the computed hit points. - sdf_hit_point_id (wp.array): An array to store the computed hit point ids. - - Returns: - None - """ - tid = wp.tid() - res = wp.mesh_query_point_sign_winding_number(mesh, points[tid], max_dist) +# Initialize NVML +nvmlInit() - mesh_ = wp.mesh_get(mesh) - p0 = mesh_.points[mesh_.indices[3 * res.face + 0]] - p1 = mesh_.points[mesh_.indices[3 * res.face + 1]] - p2 = mesh_.points[mesh_.indices[3 * res.face + 2]] +from physicsnemo.utils.profiling import profile, Profiler - p_closest = res.u * p0 + res.v * p1 + (1.0 - res.u - res.v) * p2 - sdf[tid] = res.sign * wp.abs(wp.length(points[tid] - p_closest)) - sdf_hit_point[tid] = p_closest - sdf_hit_point_id[tid] = res.face +from loss import compute_loss_dict +from utils import get_num_vars -def signed_distance_field( - mesh_vertices: list[tuple[float, float, float]], - mesh_indices: NDArray[float], - input_points: list[tuple[float, float, float]], - max_dist: float = 1e8, - include_hit_points: bool = False, - include_hit_points_id: bool = False, - device: int = 0, -) -> wp.array: +def reject_interior_volume_points( + preprocessed_data: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: """ - Computes the signed distance field (SDF) for a given mesh and input points. - - Parameters: - ---------- - mesh_vertices (list[tuple[float, float, float]]): List of vertices defining the mesh. - mesh_indices (list[tuple[int, int, int]]): List of indices defining the triangles of the mesh. - input_points (list[tuple[float, float, float]]): List of input points for which to compute the SDF. - max_dist (float, optional): Maximum distance within which to search for - the closest point on the mesh. Default is 1e8. - include_hit_points (bool, optional): Whether to include hit points in - the output. Default is False. - include_hit_points_id (bool, optional): Whether to include hit point - IDs in the output. Default is False. - - Returns: - ------- - wp.array: An array containing the computed signed distance field. - - Example: - ------- - >>> mesh_vertices = [(0, 0, 0), (1, 0, 0), (0, 1, 0)] - >>> mesh_indices = np.array((0, 1, 2)) - >>> input_points = [(0.5, 0.5, 0.5)] - >>> signed_distance_field(mesh_vertices, mesh_indices, input_points).numpy() - Module ... - array([0.5], dtype=float32) + Reject volume points that are inside the STL mesh. """ - - wp.init() - # mesh = wp.Mesh( - # wp.array(mesh_vertices.cpu(), dtype=wp.vec3), wp.array(mesh_indices.cpu(), dtype=wp.int32) - # ) - mesh = wp.Mesh( - wp.from_torch(mesh_vertices, dtype=wp.vec3), - wp.from_torch(mesh_indices, dtype=wp.int32), - ) - - sdf_points = wp.from_torch(input_points, dtype=wp.vec3) - sdf = wp.zeros(shape=sdf_points.shape, dtype=wp.float32) - sdf_hit_point = wp.zeros(shape=sdf_points.shape, dtype=wp.vec3f) - sdf_hit_point_id = wp.zeros(shape=sdf_points.shape, dtype=wp.int32) - wp.launch( - kernel=_bvh_query_distance, - dim=len(sdf_points), - inputs=[mesh.id, sdf_points, max_dist, sdf, sdf_hit_point, sdf_hit_point_id], + ###################################################### + # Use the sign of the volume SDF to filter out points + # That are inside the STL mesh + ###################################################### + sdf_nodes = preprocessed_data["sdf_nodes"] + # The sfd_nodes tensor typically has shape (n_vol_points, 1) + valid_volume_idx = sdf_nodes > 0 + # So remove it if it's there: + valid_volume_idx = valid_volume_idx.squeeze(-1) + # Apply this selection to all the volume points: + for key in [ + "volume_mesh_centers", + "sdf_nodes", + "pos_volume_closest", + "pos_volume_center_of_mass", + ]: + preprocessed_data[key] = preprocessed_data[key][valid_volume_idx] + + return preprocessed_data + + +def sample_volume_points( + c_min: torch.Tensor, + c_max: torch.Tensor, + n_points: int, + device: torch.device, + eps: float = 1e-7, +) -> torch.Tensor: + """ + Generate a set of random points interior to the specified bounding box. + + Args: + c_min: The minimum coordinate of the bounding box. + c_max: The maximum coordinate of the bounding box. + n_points: The number of points to sample. + device: The device to sample the points on. + eps: The small edge factor to shift away from the lower bound. + """ + # We use a small edge factor to shift away from the lower bound, + # which can, in some cases, be exactly on the border. + uniform_points = ( + torch.rand(n_points, 3, device=device, dtype=torch.float32) * (1 - 2 * eps) + + eps ) - if include_hit_points and include_hit_points_id: - return ( - wp.to_torch(sdf), - wp.to_torch(sdf_hit_point), - wp.to_torch(sdf_hit_point_id), - ) - elif include_hit_points: - return (wp.to_torch(sdf), wp.to_torch(sdf_hit_point)) - elif include_hit_points_id: - return (wp.to_torch(sdf), wp.to_torch(sdf_hit_point_id)) - else: - return wp.to_torch(sdf) - - -def shuffle_array_torch(surface_vertices, geometry_points, device): - idx = torch.unsqueeze( - torch.randperm(surface_vertices.shape[0])[:geometry_points], -1 - ).to(device) - idx = idx.repeat(1, 3) - surface_sampled = torch.gather(surface_vertices, 0, idx) - return surface_sampled - - -class inferenceDataPipe: - def __init__( - self, - device: int = 0, - grid_resolution: Optional[list] = [256, 96, 64], - normalize_coordinates: bool = False, - geom_points_sample: int = 300000, - positional_encoding: bool = False, - surface_vertices=None, - surface_indices=None, - surface_areas=None, - surface_centers=None, - use_sdf_basis=False, - ): - self.surface_vertices = surface_vertices - self.surface_indices = surface_indices - self.surface_areas = surface_areas - self.surface_centers = surface_centers - self.device = device - self.grid_resolution = grid_resolution - self.normalize_coordinates = normalize_coordinates - self.geom_points_sample = geom_points_sample - self.positional_encoding = positional_encoding - self.use_sdf_basis = use_sdf_basis - torch.manual_seed(int(42 + torch.cuda.current_device())) - self.data_dict = {} - - def clear_dict(self): - del self.data_dict - - def clear_volume_dict(self): - del self.data_dict["volume_mesh_centers"] - del self.data_dict["pos_enc_closest"] - del self.data_dict["pos_normals_com"] - del self.data_dict["sdf_nodes"] - - def create_grid_torch(self, mx, mn, nres): - start_time = time.time() - dx = torch.linspace(mn[0], mx[0], nres[0], device=self.device) - dy = torch.linspace(mn[1], mx[1], nres[1], device=self.device) - dz = torch.linspace(mn[2], mx[2], nres[2], device=self.device) - - xv, yv, zv = torch.meshgrid(dx, dy, dz, indexing="ij") - xv = torch.unsqueeze(xv, -1) - yv = torch.unsqueeze(yv, -1) - zv = torch.unsqueeze(zv, -1) - grid = torch.cat((xv, yv, zv), axis=-1) - return grid - - def process_surface_mesh(self, bounding_box=None, bounding_box_surface=None): - # Use coarse mesh to calculate SDF - surface_vertices = self.surface_vertices - surface_indices = self.surface_indices - surface_areas = self.surface_areas - surface_centers = self.surface_centers - - start_time = time.time() - - if bounding_box is None: - # Create a bounding box - s_max = torch.amax(surface_vertices, 0) - s_min = torch.amin(surface_vertices, 0) - - c_max = s_max + (s_max - s_min) / 2 - c_min = s_min - (s_max - s_min) / 2 - c_min[2] = s_min[2] - else: - c_min = bounding_box[0] - c_max = bounding_box[1] - - if bounding_box_surface is None: - # Create a bounding box - s_max = torch.amax(surface_vertices, 0) - s_min = torch.amin(surface_vertices, 0) - - surf_max = s_max + (s_max - s_min) / 2 - surf_min = s_min - (s_max - s_min) / 2 - surf_min[2] = s_min[2] - else: - surf_min = bounding_box_surface[0] - surf_max = bounding_box_surface[1] - - nx, ny, nz = self.grid_resolution - - grid = self.create_grid_torch(c_max, c_min, self.grid_resolution) - grid_reshaped = torch.reshape(grid, (nx * ny * nz, 3)) - - # SDF on grid - sdf_grid = signed_distance_field( - surface_vertices, surface_indices, grid_reshaped, device=self.device - ) - sdf_grid = torch.reshape(sdf_grid, (nx, ny, nz)) - - surface_areas = torch.unsqueeze(surface_areas, -1) - center_of_mass = torch.sum(surface_centers * surface_areas, 0) / torch.sum( - surface_areas - ) - - s_grid = self.create_grid_torch(surf_max, surf_min, self.grid_resolution) - surf_grid_reshaped = torch.reshape(s_grid, (nx * ny * nz, 3)) - - surf_sdf_grid = signed_distance_field( - surface_vertices, surface_indices, surf_grid_reshaped, device=self.device - ) - surf_sdf_grid = torch.reshape(surf_sdf_grid, (nx, ny, nz)) - - if self.normalize_coordinates: - grid = 2.0 * (grid - c_min) / (c_max - c_min) - 1.0 - s_grid = 2.0 * (s_grid - surf_min) / (surf_max - surf_min) - 1.0 - - surface_vertices = torch.unsqueeze(surface_vertices, 0) - grid = torch.unsqueeze(grid, 0) - s_grid = torch.unsqueeze(s_grid, 0) - sdf_grid = torch.unsqueeze(sdf_grid, 0) - surf_sdf_grid = torch.unsqueeze(surf_sdf_grid, 0) - max_min = [c_min, c_max] - surf_max_min = [surf_min, surf_max] - center_of_mass = center_of_mass - - return ( - surface_vertices, - grid, - sdf_grid, - max_min, - s_grid, - surf_sdf_grid, - surf_max_min, - center_of_mass, - ) - - def sample_stl_points( - self, - num_points, - stl_centers, - stl_area, - stl_normals, - max_min, - center_of_mass, - bounding_box=None, - stencil_size=7, - ): - if bounding_box is not None: - c_max = bounding_box[1] - c_min = bounding_box[0] - else: - c_min = max_min[0] - c_max = max_min[1] - - start_time = time.time() - - nx, ny, nz = self.grid_resolution - - idx = np.arange(stl_centers.shape[0]) - # np.random.shuffle(idx) - if num_points is not None: - idx = idx[:num_points] - - surface_coordinates = stl_centers - surface_normals = stl_normals - surface_area = stl_area - - if stencil_size > 1: - interp_func = KDTree(surface_coordinates) - dd, ii = interp_func.query(surface_coordinates, k=stencil_size) - surface_neighbors = surface_coordinates[ii] - surface_neighbors = surface_neighbors[:, 1:] + 1e-6 - surface_neighbors_normals = surface_normals[ii] - surface_neighbors_normals = surface_neighbors_normals[:, 1:] - surface_neighbors_area = surface_area[ii] - surface_neighbors_area = surface_neighbors_area[:, 1:] - else: - surface_neighbors = np.expand_dims(surface_coordinates, 1) + 1e-6 - surface_neighbors_normals = np.expand_dims(surface_normals, 1) - surface_neighbors_area = np.expand_dims(surface_area, 1) - - surface_coordinates = torch.from_numpy(surface_coordinates).to(self.device) - surface_normals = torch.from_numpy(surface_normals).to(self.device) - surface_area = torch.from_numpy(surface_area).to(self.device) - surface_neighbors = torch.from_numpy(surface_neighbors).to(self.device) - surface_neighbors_normals = torch.from_numpy(surface_neighbors_normals).to( - self.device - ) - surface_neighbors_area = torch.from_numpy(surface_neighbors_area).to( - self.device - ) - - pos_normals_com = surface_coordinates - center_of_mass - - if self.normalize_coordinates: - surface_coordinates = ( - 2.0 * (surface_coordinates - c_min) / (c_max - c_min) - 1.0 - ) - surface_neighbors = ( - 2.0 * (surface_neighbors - c_min) / (c_max - c_min) - 1.0 - ) - - surface_coordinates = surface_coordinates[idx] - surface_area = surface_area[idx] - surface_normals = surface_normals[idx] - pos_normals_com = pos_normals_com[idx] - surface_coordinates = torch.unsqueeze(surface_coordinates, 0) - surface_normals = torch.unsqueeze(surface_normals, 0) - surface_area = torch.unsqueeze(surface_area, 0) - pos_normals_com = torch.unsqueeze(pos_normals_com, 0) - - surface_neighbors = surface_neighbors[idx] - surface_neighbors_normals = surface_neighbors_normals[idx] - surface_neighbors_area = surface_neighbors_area[idx] - surface_neighbors = torch.unsqueeze(surface_neighbors, 0) - surface_neighbors_normals = torch.unsqueeze(surface_neighbors_normals, 0) - surface_neighbors_area = torch.unsqueeze(surface_neighbors_area, 0) - - scaling_factors = [c_max, c_min] - - return ( - surface_coordinates, - surface_neighbors, - surface_normals, - surface_neighbors_normals, - surface_area, - surface_neighbors_area, - pos_normals_com, - scaling_factors, - idx, - ) - - def sample_points_on_surface( - self, - num_points_surf, - max_min, - center_of_mass, - stl_path, - bounding_box=None, - stencil_size=7, - ): - if bounding_box is not None: - c_max = bounding_box[1] - c_min = bounding_box[0] - else: - c_min = max_min[0] - c_max = max_min[1] - - start_time = time.time() - - nx, ny, nz = self.grid_resolution - - obj = Tessellation.from_stl(stl_path, airtight=False) - - boundary = obj.sample_boundary(num_points_surf) - surface_coordinates = np.concatenate( - [ - np.float32(boundary["x"]), - np.float32(boundary["y"]), - np.float32(boundary["z"]), - ], - axis=1, - ) - surface_normals = np.concatenate( - [ - np.float32(boundary["normal_x"]), - np.float32(boundary["normal_y"]), - np.float32(boundary["normal_z"]), - ], - axis=1, - ) - - surface_area = np.float32(boundary["area"]) - - interp_func = KDTree(surface_coordinates) - dd, ii = interp_func.query(surface_coordinates, k=stencil_size) - surface_neighbors = surface_coordinates[ii] - surface_neighbors = surface_neighbors[:, 1:] - surface_neighbors_normals = surface_normals[ii] - surface_neighbors_normals = surface_neighbors_normals[:, 1:] - surface_neighbors_area = surface_area[ii] - surface_neighbors_area = surface_neighbors_area[:, 1:] - - surface_coordinates = torch.from_numpy(surface_coordinates).to(self.device) - surface_normals = torch.from_numpy(surface_normals).to(self.device) - surface_area = torch.from_numpy(surface_area).to(self.device) - surface_neighbors = torch.from_numpy(surface_neighbors).to(self.device) - surface_neighbors_normals = torch.from_numpy(surface_neighbors_normals).to( - self.device - ) - surface_neighbors_area = torch.from_numpy(surface_neighbors_area).to( - self.device - ) - - pos_normals_com = surface_coordinates - center_of_mass - - if self.normalize_coordinates: - surface_coordinates = ( - 2.0 * (surface_coordinates - c_min) / (c_max - c_min) - 1.0 - ) - - surface_coordinates = torch.unsqueeze(surface_coordinates, 0) - surface_normals = torch.unsqueeze(surface_normals, 0) - surface_area = torch.unsqueeze(surface_area, 0) - pos_normals_com = torch.unsqueeze(pos_normals_com, 0) - - surface_neighbors = torch.unsqueeze(surface_neighbors, 0) - surface_neighbors_normals = torch.unsqueeze(surface_neighbors_normals, 0) - surface_neighbors_area = torch.unsqueeze(surface_neighbors_area, 0) - - scaling_factors = [c_max, c_min] - - return ( - surface_coordinates, - surface_neighbors, - surface_normals, - surface_neighbors_normals, - surface_area, - surface_neighbors_area, - pos_normals_com, - scaling_factors, - ) - - def sample_points_in_volume( - self, num_points_vol, max_min, center_of_mass, bounding_box=None - ): - if bounding_box is not None: - c_max = bounding_box[1] - c_min = bounding_box[0] - else: - c_min = max_min[0] - c_max = max_min[1] - - start_time = time.time() - - nx, ny, nz = self.grid_resolution - for k in range(10): - if k > 0: - num_pts_vol = num_points_vol - int(volume_coordinates.shape[0] / 2) - else: - num_pts_vol = int(1.25 * num_points_vol) - - volume_coordinates_sub = (c_max - c_min) * torch.rand( - num_pts_vol, 3, device=self.device, dtype=torch.float32 - ) + c_min - - sdf_nodes, sdf_node_closest_point = signed_distance_field( - self.surface_vertices, - self.surface_indices, - volume_coordinates_sub, - include_hit_points=True, - device=self.device, - ) - sdf_nodes = torch.unsqueeze(sdf_nodes, -1) - - idx = torch.unsqueeze(torch.where((sdf_nodes > 0))[0], -1) - idx = idx.repeat(1, volume_coordinates_sub.shape[1]) - if k == 0: - volume_coordinates = torch.gather(volume_coordinates_sub, 0, idx) - else: - volume_coordinates_1 = torch.gather(volume_coordinates_sub, 0, idx) - volume_coordinates = torch.cat( - (volume_coordinates, volume_coordinates_1), axis=0 - ) - - if volume_coordinates.shape[0] > num_points_vol: - volume_coordinates = volume_coordinates[:num_points_vol] - break - - sdf_nodes, sdf_node_closest_point = signed_distance_field( - self.surface_vertices, - self.surface_indices, - volume_coordinates, - include_hit_points=True, - device=self.device, - ) - sdf_nodes = torch.unsqueeze(sdf_nodes, -1) - - pos_normals_closest = volume_coordinates - sdf_node_closest_point - pos_normals_com = volume_coordinates - center_of_mass - - if self.normalize_coordinates: - volume_coordinates = ( - 2.0 * (volume_coordinates - c_min) / (c_max - c_min) - 1.0 - ) - - volume_coordinates = torch.unsqueeze(volume_coordinates, 0) - pos_normals_com = torch.unsqueeze(pos_normals_com, 0) - - if self.use_sdf_basis: - pos_normals_closest = torch.unsqueeze(pos_normals_closest, 0) - sdf_nodes = torch.unsqueeze(sdf_nodes, 0) - - scaling_factors = [c_max, c_min] - return ( - volume_coordinates, - pos_normals_com, - pos_normals_closest, - sdf_nodes, - scaling_factors, - ) - - -class dominoInference: - def __init__( - self, - cfg: DictConfig, - dist: None, - cached_geo_encoding: bool = False, - ): - self.cfg = cfg - self.dist = dist - self.stream_velocity = None - self.stencil_size = None - self.stl_path = None - self.stl_vertices = None - self.stl_centers = None - self.surface_areas = None - self.mesh_indices_flattened = None - self.length_scale = 1.0 - if self.dist is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self.device = self.dist.device - - self.air_density = torch.full((1, 1), 1.205, dtype=torch.float32).to( - self.device - ) - ( - self.num_vol_vars, - self.num_surf_vars, - self.num_global_features, - ) = self.get_num_variables() - self.model = None - self.grid_resolution = torch.tensor(self.cfg.model.interp_res).to(self.device) - self.vol_factors = None - self.bounding_box_min_max = None - self.bounding_box_surface_min_max = None - self.center_of_mass = None - self.grid = None - self.geometry_encoding = None - self.geometry_encoding_surface = None - self.cached_geo_encoding = cached_geo_encoding - self.out_dict = {} - - def get_geometry_encoding(self): - return self.geometry_encoding - - def get_geometry_encoding_surface(self): - return self.geometry_encoding_surface - - def get_out_dict(self): - return self.out_dict - - def clear_out_dict(self): - self.out_dict.clear() - - def initialize_data_processor(self): - self.ifp = inferenceDataPipe( - device=self.device, - surface_vertices=self.stl_vertices, - surface_indices=self.mesh_indices_flattened, - surface_areas=self.surface_areas, - surface_centers=self.stl_centers, - grid_resolution=self.grid_resolution, - normalize_coordinates=True, - geom_points_sample=300000, - positional_encoding=False, - use_sdf_basis=self.cfg.model.use_sdf_in_basis_func, - ) - - def load_bounding_box(self): - if ( - self.cfg.data.bounding_box.min is not None - and self.cfg.data.bounding_box.max is not None - ): - c_min = torch.from_numpy( - np.array(self.cfg.data.bounding_box.min, dtype=np.float32) - ).to(self.device) - c_max = torch.from_numpy( - np.array(self.cfg.data.bounding_box.max, dtype=np.float32) - ).to(self.device) - self.bounding_box_min_max = [c_min, c_max] - - if ( - self.cfg.data.bounding_box_surface.min is not None - and self.cfg.data.bounding_box_surface.max is not None - ): - c_min = torch.from_numpy( - np.array(self.cfg.data.bounding_box_surface.min, dtype=np.float32) - ).to(self.device) - c_max = torch.from_numpy( - np.array(self.cfg.data.bounding_box_surface.max, dtype=np.float32) - ).to(self.device) - self.bounding_box_surface_min_max = [c_min, c_max] - - def load_volume_scaling_factors(self): - scaling_param_path = self.cfg.eval.scaling_param_path - vol_factors_path = os.path.join( - scaling_param_path, "volume_scaling_factors.npy" - ) + sampled_volume_points = (c_max - c_min) * uniform_points + c_min + return sampled_volume_points + + +def inference_on_single_stl( + stl_coordinates: torch.Tensor, + stl_faces: torch.Tensor, + global_params_values: torch.Tensor, + global_params_reference: torch.Tensor, + model: DoMINO, + datapipe: DoMINODataPipe, + batch_size: int, + total_points: int, + gpu_handle: int | None = None, + logger: PythonLogger | None = None, +): + """ + Perform model inference on a single STL mesh. - vol_factors = np.load(vol_factors_path, allow_pickle=True) - vol_factors = torch.from_numpy(vol_factors).to(self.device) + This function will take the input mesh + faces and + then sample the surface and volume to produce the model outputs + at `total_points` locations in batches of `batch_size`. - return vol_factors - def load_surface_scaling_factors(self): - scaling_param_path = self.cfg.eval.scaling_param_path - surf_factors_path = os.path.join( - scaling_param_path, "surface_scaling_factors.npy" - ) - surf_factors = np.load(surf_factors_path, allow_pickle=True) - surf_factors = torch.from_numpy(surf_factors).to(self.device) - - return surf_factors - - def read_stl(self): - stl_files = get_filenames(self.stl_path) - mesh_stl = combine_stls(self.stl_path, stl_files) - if self.cfg.eval.refine_stl: - mesh_stl = mesh_stl.subdivide( - nsub=2, subfilter="linear" - ) # .smooth(n_iter=20) - stl_vertices = mesh_stl.points - length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) - stl_centers = mesh_stl.cell_centers().points - # Assuming triangular elements - stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[:, 1:] - mesh_indices_flattened = stl_faces.flatten() - - surface_areas = mesh_stl.compute_cell_sizes( - length=False, area=True, volume=False - ) - surface_areas = np.array(surface_areas.cell_data["Area"]) - - surface_normals = np.array(mesh_stl.cell_normals, dtype=np.float32) - - self.stl_vertices = torch.from_numpy(np.float32(stl_vertices)).to(self.device) - self.stl_centers = torch.from_numpy(np.float32(stl_centers)).to(self.device) - self.surface_areas = torch.from_numpy(np.float32(surface_areas)).to(self.device) - self.stl_normals = -1.0 * torch.from_numpy(np.float32(surface_normals)).to( - self.device - ) - self.mesh_indices_flattened = torch.from_numpy( - np.int32(mesh_indices_flattened) - ).to(self.device) - self.length_scale = length_scale - self.mesh_stl = mesh_stl - - def read_stl_trimesh( - self, stl_vertices, stl_faces, stl_centers, surface_normals, surface_areas - ): - mesh_indices_flattened = stl_faces.flatten() - length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) - self.stl_vertices = torch.from_numpy(stl_vertices).to(self.device) - self.stl_centers = torch.from_numpy(stl_centers).to(self.device) - self.stl_normals = -1.0 * torch.from_numpy(surface_normals).to(self.device) - self.surface_areas = torch.from_numpy(surface_areas).to(self.device) - self.mesh_indices_flattened = torch.from_numpy( - np.int32(mesh_indices_flattened) - ).to(self.device) - self.length_scale = length_scale - - def get_num_variables(self): - volume_variable_names = list(self.cfg.variables.volume.solution.keys()) - num_vol_vars = 0 - for j in volume_variable_names: - if self.cfg.variables.volume.solution[j] == "vector": - num_vol_vars += 3 - else: - num_vol_vars += 1 - - surface_variable_names = list(self.cfg.variables.surface.solution.keys()) - num_surf_vars = 0 - for j in surface_variable_names: - if self.cfg.variables.surface.solution[j] == "vector": - num_surf_vars += 3 - else: - num_surf_vars += 1 - - num_global_features = 0 - global_params_names = list(cfg.variables.global_parameters.keys()) - for param in global_params_names: - if cfg.variables.global_parameters[param].type == "vector": - num_global_features += len( - cfg.variables.global_parameters[param].reference + Args: + stl_coordinates: The coordinates of the STL mesh. + stl_faces: The faces of the STL mesh. + global_params_values: The values of the global parameters. + global_params_reference: The reference values of the global parameters. + model: The model to use for inference. + datapipe: The datapipe to use for preprocessing. + batch_size: The batch size to use for inference. + total_points: The total number of points to process. + gpu_handle: The GPU handle to use for inference. + logger: The logger to use for logging. + """ + device = stl_coordinates.device + batch_start_time = time.perf_counter() + ###################################################### + # The IO only reads in "stl_faces" and "stl_coordinates". + # "stl_areas" and "stl_centers" would be computed by + # pyvista on CPU - instead, we do it on the GPU + # right here. + ###################################################### + + # Center is a mean of the 3 vertices + triangle_vertices = stl_coordinates[stl_faces.reshape((-1, 3))] + stl_centers = triangle_vertices.mean(dim=-1) + ###################################################### + # Area we compute from the cross product of two sides: + ###################################################### + d1 = triangle_vertices[:, 1] - triangle_vertices[:, 0] + d2 = triangle_vertices[:, 2] - triangle_vertices[:, 0] + stl_mesh_normals = torch.linalg.cross(d1, d2, dim=1) + normals_norm = torch.linalg.norm(stl_mesh_normals, dim=1) + stl_mesh_normals = stl_mesh_normals / normals_norm.unsqueeze(1) + stl_areas = 0.5 * normals_norm + + ###################################################### + # For computing the points, we take those stl objects, + # sample in chunks of `batch_size` until we've + # accumulated `total_points` predictions. + ###################################################### + + batch_output_dict = {} + N = 2 + total_points_processed = 0 + + # Use these lists to build up the output tensors: + surface_results = [] + volume_results = [] + + while total_points_processed < total_points: + inner_loop_start_time = time.perf_counter() + + ###################################################### + # Create the dictionary as the preprocessing expects: + ###################################################### + inference_dict = { + "stl_coordinates": stl_coordinates, + "stl_faces": stl_faces, + "stl_centers": stl_centers, + "stl_areas": stl_areas, + "global_params_values": global_params_values, + "global_params_reference": global_params_reference, + } + + # If the surface data is part of the model, sample the surface: + + if datapipe.model_type == "surface" or datapipe.model_type == "combined": + ###################################################### + # This function will sample points on the STL surface + ###################################################### + sampled_points, sampled_faces, sampled_areas, sampled_normals = ( + sample_points_on_mesh( + stl_coordinates, + stl_faces, + batch_size, + mesh_normals=stl_mesh_normals, + mesh_areas=stl_areas, ) - elif cfg.variables.global_parameters[param].type == "scalar": - num_global_features += 1 - else: - raise ValueError(f"Unknown global parameter type") - - return num_vol_vars, num_surf_vars, num_global_features - - def initialize_model(self, model_path): - model = ( - DoMINO( - input_features=3, - output_features_vol=self.num_vol_vars, - output_features_surf=self.num_surf_vars, - global_features=self.num_global_features, - model_parameters=self.cfg.model, ) - .to(self.device) - .eval() - ) - model = torch.compile(model, disable=True) - - checkpoint_iter = torch.load( - to_absolute_path(model_path), map_location=self.dist.device - ) - model.load_state_dict(checkpoint_iter) - - if self.dist is not None: - if self.dist.world_size > 1: - model = DistributedDataParallel( - model, - device_ids=[self.dist.local_rank], - output_device=self.dist.device, - broadcast_buffers=self.dist.broadcast_buffers, - find_unused_parameters=self.dist.find_unused_parameters, - gradient_as_bucket_view=True, - static_graph=True, - ) - - self.model = model - self.vol_factors = self.load_volume_scaling_factors() - self.surf_factors = self.load_surface_scaling_factors() - self.load_bounding_box() - - def set_stream_velocity(self, stream_velocity): - self.stream_velocity = torch.full( - (1, 1), stream_velocity, dtype=torch.float32 - ).to(self.device) - - def set_stencil_size(self, stencil_size): - self.stencil_size = stencil_size - - def set_air_density(self, air_density): - self.air_density = torch.full((1, 1), air_density, dtype=torch.float32).to( - self.device - ) - - def set_stl_path(self, filename): - self.stl_path = filename - - @torch.no_grad() - def compute_geo_encoding(self, cached_geom_path=None): - start_time = time.time() - - if not self.cached_geo_encoding: - ( - surface_vertices, - grid, - sdf_grid, - max_min, - s_grid, - surf_sdf_grid, - surf_max_min, - center_of_mass, - ) = self.ifp.process_surface_mesh( - self.bounding_box_min_max, self.bounding_box_surface_min_max - ) - if self.bounding_box_min_max is None: - self.bounding_box_min_max = max_min - if self.bounding_box_surface_min_max is None: - self.bounding_box_surface_min_max = surf_max_min - self.center_of_mass = center_of_mass - self.grid = grid - self.s_grid = s_grid - self.sdf_grid = sdf_grid - self.surf_sdf_grid = surf_sdf_grid - self.out_dict["sdf"] = sdf_grid - - geo_encoding, geo_encoding_surface = self.calculate_geometry_encoding( - surface_vertices, grid, sdf_grid, s_grid, surf_sdf_grid, self.model - ) + inference_dict["surface_mesh_centers"] = sampled_points + inference_dict["surface_normals"] = sampled_normals + inference_dict["surface_areas"] = sampled_areas + inference_dict["surface_faces"] = sampled_faces + + # If the volume data is part of the model, sample the volume: + if datapipe.model_type == "volume" or datapipe.model_type == "combined": + ###################################################### + # Build up volume points too with uniform sampling + ###################################################### + c_min = datapipe.config.bounding_box_dims[1] + c_max = datapipe.config.bounding_box_dims[0] + inference_dict["volume_mesh_centers"] = sample_volume_points( + c_min, + c_max, + batch_size, + device, + ) + + ###################################################### + # Pre-process the data with the datapipe: + ###################################################### + preprocessed_data = datapipe.process_data(inference_dict) + + if datapipe.model_type == "volume" or datapipe.model_type == "combined": + preprocessed_data = reject_interior_volume_points(preprocessed_data) + + ###################################################### + # Add a batch dimension to the data_dict + # (normally this is added in __getitem__ of the datapipe) + ###################################################### + preprocessed_data = {k: v.unsqueeze(0) for k, v in preprocessed_data.items()} + + ###################################################### + # Forward pass through the model: + ###################################################### + with torch.no_grad(): + output_vol, output_surf = model(preprocessed_data) + + ###################################################### + # unnormalize the outputs with the datapipe + # Whatever settings are configured for normalizing the + # output fields - even though we don't have ground + # truth here - are reused to undo that for the predictions + ###################################################### + output_vol, output_surf = datapipe.unscale_model_outputs( + output_vol, output_surf + ) + + surface_results.append(output_surf) + volume_results.append(output_vol) + + total_points_processed += batch_size + + current_loop_time = time.perf_counter() + + logging_string = f"Device {device} processed {total_points_processed} points of {total_points}\n" + if gpu_handle is not None: + gpu_info = nvmlDeviceGetMemoryInfo(gpu_handle) + gpu_memory_used = gpu_info.used / (1024**3) + logging_string += f" GPU memory used: {gpu_memory_used:.3f} Gb\n" + + logging_string += f" Time taken since batch start: {current_loop_time - batch_start_time:.2f} seconds\n" + logging_string += f" iteration throughput: {batch_size / (current_loop_time - inner_loop_start_time):.1f} points per second\n" + logging_string += f" Batch mean throughput: {total_points_processed / (current_loop_time - batch_start_time):.1f} points per second.\n" + + if logger is not None: + logger.info(logging_string) else: - out_dict_cached = torch.load(cached_geom_path, map_location=self.device) - self.bounding_box_min_max = out_dict_cached["bounding_box_min_max"] - self.grid = out_dict_cached["grid"] - self.sdf_grid = out_dict_cached["sdf_grid"] - self.center_of_mass = out_dict_cached["com"] - geo_encoding = out_dict_cached["geo_encoding"] - geo_encoding_surface = out_dict_cached["geo_encoding_surface"] - self.out_dict["sdf"] = self.sdf_grid - torch.cuda.synchronize() - print("Time taken for geo encoding = %f" % (time.time() - start_time)) - - self.geometry_encoding = geo_encoding - self.geometry_encoding_surface = geo_encoding_surface - - def compute_forces(self): - pressure = self.out_dict["pressure_surface"] - wall_shear = self.out_dict["wall-shear-stress"] - # sampling_indices = self.out_dict["sampling_indices"] - - surface_normals = self.stl_normals[self.sampling_indices] - surface_areas = self.surface_areas[self.sampling_indices] - - drag_force = torch.sum( - pressure[0, :, 0] * surface_normals[:, 0] * surface_areas - - wall_shear[0, :, 0] * surface_areas - ) - lift_force = torch.sum( - pressure[0, :, 0] * surface_normals[:, 2] * surface_areas - - wall_shear[0, :, 2] * surface_areas - ) + print(logging_string) + + ###################################################### + # Here at the end, get the values for the stl centers + # by updating the previous inference dict + # Only do this if the surface is part of the computation + # Comments are shorter here - it's a condensed version + # of the above logic. + ###################################################### + if datapipe.model_type == "surface" or datapipe.model_type == "combined": + inference_dict = { + "stl_coordinates": stl_coordinates, + "stl_faces": stl_faces, + "stl_centers": stl_centers, + "stl_areas": stl_areas, + "global_params_values": global_params_values, + "global_params_reference": global_params_reference, + } + inference_dict["surface_mesh_centers"] = stl_centers + inference_dict["surface_normals"] = stl_mesh_normals + inference_dict["surface_areas"] = stl_areas + inference_dict["surface_faces"] = stl_faces + + if datapipe.model_type == "combined" or datapipe.model_type == "volume": + c_min = datapipe.config.bounding_box_dims[1] + c_max = datapipe.config.bounding_box_dims[0] + inference_dict["volume_mesh_centers"] = sample_volume_points( + c_min, + c_max, + stl_centers.shape[0], + device, + ) + + # Preprocess: + preprocessed_data = datapipe.process_data(inference_dict) + + # Pull out the invalid volume points again, if needed: + if datapipe.model_type == "combined" or datapipe.model_type == "volume": + preprocessed_data = reject_interior_volume_points(preprocessed_data) + + # Run the model forward: + with torch.no_grad(): + preprocessed_data = { + k: v.unsqueeze(0) for k, v in preprocessed_data.items() + } + _, output_surf = model(preprocessed_data) + + # Unnormalize the outputs: + _, stl_center_results = datapipe.unscale_model_outputs(None, output_surf) - self.out_dict["drag_force"] = drag_force - self.out_dict["lift_force"] = lift_force - - @torch.inference_mode() - def compute_surface_solutions(self, num_sample_points=None, plot_solutions=False): - total_time = 0.0 - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - geo_encoding = self.geometry_encoding_surface - j = 0 - - with autocast(enabled=True): - start_event.record() - ( - surface_mesh_centers, - surface_neighbors, - surface_normals, - surface_neighbors_normals, - surface_areas, - surface_neighbors_areas, - pos_normals_com, - surf_scaling_factors, - sampling_indices, - ) = self.ifp.sample_stl_points( - num_sample_points, - self.stl_centers.cpu().numpy(), - self.surface_areas.cpu().numpy(), - self.stl_normals.cpu().numpy(), - max_min=self.bounding_box_surface_min_max, - center_of_mass=self.center_of_mass, - stencil_size=self.stencil_size, - ) - end_event.record() - end_event.synchronize() - cur_time = start_event.elapsed_time(end_event) / 1000.0 - print(f"sample_points_in_surface time (s): {cur_time:.4f}") - # vol_coordinates_all.append(volume_mesh_centers) - surface_coordinates_all = surface_mesh_centers - - inner_time = time.time() - start_event.record() - if num_sample_points == None: - point_batch_size = 512_000 - num_points = surface_coordinates_all.shape[1] - subdomain_points = int(np.floor(num_points / point_batch_size)) - surface_solutions = torch.zeros(1, num_points, self.num_surf_vars).to( - self.device - ) - for p in range(subdomain_points + 1): - start_idx = p * point_batch_size - end_idx = (p + 1) * point_batch_size - surface_solutions_batch = self.compute_solution_on_surface( - geo_encoding, - surface_mesh_centers[:, start_idx:end_idx], - surface_neighbors[:, start_idx:end_idx], - surface_normals[:, start_idx:end_idx], - surface_neighbors_normals[:, start_idx:end_idx], - surface_areas[:, start_idx:end_idx], - surface_neighbors_areas[:, start_idx:end_idx], - pos_normals_com[:, start_idx:end_idx], - self.s_grid, - self.model, - inlet_velocity=self.stream_velocity, - air_density=self.air_density, - ) - surface_solutions[:, start_idx:end_idx] = surface_solutions_batch - else: - point_batch_size = 512_000 - num_points = num_sample_points - subdomain_points = int(np.floor(num_points / point_batch_size)) - surface_solutions = torch.zeros(1, num_points, self.num_surf_vars).to( - self.device - ) - for p in range(subdomain_points + 1): - start_idx = p * point_batch_size - end_idx = (p + 1) * point_batch_size - surface_solutions_batch = self.compute_solution_on_surface( - geo_encoding, - surface_mesh_centers[:, start_idx:end_idx], - surface_neighbors[:, start_idx:end_idx], - surface_normals[:, start_idx:end_idx], - surface_neighbors_normals[:, start_idx:end_idx], - surface_areas[:, start_idx:end_idx], - surface_neighbors_areas[:, start_idx:end_idx], - pos_normals_com[:, start_idx:end_idx], - self.s_grid, - self.model, - inlet_velocity=self.stream_velocity, - air_density=self.air_density, - ) - # print(torch.amax(surface_solutions_batch, (0, 1)), torch.amin(surface_solutions_batch, (0, 1))) - surface_solutions[:, start_idx:end_idx] = surface_solutions_batch - - # print(surface_solutions.shape) - end_event.record() - end_event.synchronize() - cur_time = start_event.elapsed_time(end_event) / 1000.0 - print(f"compute_solution time (s): {cur_time:.4f}") - total_time += float(time.time() - inner_time) - surface_solutions_all = surface_solutions - print( - "Time taken for compute solution on surface for=%f, %f" - % (time.time() - inner_time, torch.cuda.utilization(self.device)) - ) - cmax = surf_scaling_factors[0] - cmin = surf_scaling_factors[1] + else: + stl_center_results = None - surface_coordinates_all = torch.reshape( - surface_coordinates_all, (1, num_points, 3) - ) - surface_solutions_all = torch.reshape(surface_solutions_all, (1, num_points, 4)) + # Stack up the results into one big tensor for surface and volume: + if len(surface_results) > 0 and all([s is not None for s in surface_results]): + surface_results = torch.cat(surface_results, dim=1) + else: + surface_results = None + if len(volume_results) > 0 and all([v is not None for v in volume_results]): + volume_results = torch.cat(volume_results, dim=1) + else: + volume_results = None - if self.surf_factors is not None: - surface_solutions_all = unnormalize( - surface_solutions_all, self.surf_factors[0], self.surf_factors[1] - ) + return stl_center_results, surface_results, volume_results - self.out_dict["surface_coordinates"] = ( - 0.5 * (surface_coordinates_all + 1.0) * (cmax - cmin) + cmin - ) - self.out_dict["pressure_surface"] = ( - surface_solutions_all[:, :, :1] - * self.stream_velocity**2.0 - * self.air_density - ) - self.out_dict["wall-shear-stress"] = ( - surface_solutions_all[:, :, 1:4] - * self.stream_velocity**2.0 - * self.air_density - ) - self.sampling_indices = sampling_indices - - @torch.inference_mode() - def compute_volume_solutions(self, num_sample_points, plot_solutions=False): - total_time = 0.0 - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - geo_encoding = self.geometry_encoding - j = 0 - - # Compute volume - point_batch_size = 512_000 - num_points = num_sample_points - subdomain_points = int(np.floor(num_points / point_batch_size)) - volume_solutions = torch.zeros(1, num_points, self.num_vol_vars).to(self.device) - volume_coordinates = torch.zeros(1, num_points, 3).to(self.device) - - for p in range(subdomain_points + 1): - start_idx = p * point_batch_size - end_idx = (p + 1) * point_batch_size - if end_idx > num_points: - point_batch_size = num_points - start_idx - end_idx = num_points - - with autocast(enabled=True): - inner_time = time.time() - start_event.record() - ( - volume_mesh_centers, - pos_normals_com, - pos_normals_closest, - sdf_nodes, - scaling_factors, - ) = self.ifp.sample_points_in_volume( - num_points_vol=point_batch_size, - max_min=self.bounding_box_min_max, - center_of_mass=self.center_of_mass, - ) - end_event.record() - end_event.synchronize() - cur_time = start_event.elapsed_time(end_event) / 1000.0 - print(f"sample_points_in_volume time (s): {cur_time:.4f}") - - volume_coordinates[:, start_idx:end_idx] = volume_mesh_centers - - start_event.record() - - volume_solutions_batch = self.compute_solution_in_volume( - geo_encoding, - volume_mesh_centers, - sdf_nodes, - pos_normals_closest, - pos_normals_com, - self.grid, - self.model, - use_sdf_basis=self.cfg.model.use_sdf_in_basis_func, - inlet_velocity=self.stream_velocity, - air_density=self.air_density, - ) - volume_solutions[:, start_idx:end_idx] = volume_solutions_batch - end_event.record() - end_event.synchronize() - cur_time = start_event.elapsed_time(end_event) / 1000.0 - print(f"compute_solution time (s): {cur_time:.4f}") - total_time += float(time.time() - inner_time) - # volume_solutions_all = volume_solutions - print( - "Time taken for compute solution in volume for =%f" - % (time.time() - inner_time) - ) - # print("Points processed:", end_idx) - print("Total time measured = %f" % total_time) - print("Points processed:", end_idx) - cmax = scaling_factors[0] - cmin = scaling_factors[1] - volume_coordinates_all = volume_coordinates - volume_solutions_all = volume_solutions +def inference_epoch( + dataloader: DoMINODataPipe, + sampler: DistributedSampler, + model: DoMINO, + gpu_handle: int, + logger: PythonLogger, + batch_size: int = 24_000, + total_points: int = 1_024_000, +): + ###################################################### + # Inference can run in a distributed way by coordinating + # the indices for each rank, which the sampler does + ###################################################### + + batch_start_time = time.perf_counter() + + # N.B. - iterating over the dataset directly here. + # That's because we need to sample on the STL and volume and + # that means we'll preprocess after that. + for i_batch, sample_batched in enumerate(dataloader.dataset): + dataloading_time = time.perf_counter() - batch_start_time + + logger.info( + f"Batch {i_batch} data loading time: {dataloading_time:.3f} seconds" + ) + + procesing_time_start = time.perf_counter() + stl_center_results, surface_results, volume_results = inference_on_single_stl( + sample_batched["stl_coordinates"], + sample_batched["stl_faces"], + sample_batched["global_params_values"], + sample_batched["global_params_reference"], + model, + dataloader, + batch_size, + total_points, + gpu_handle, + logger, + ) + + ###################################################### + # Peel off pressure, velocity, nut, shear, etc. + # Also compute drag, lift forces. + ###################################################### + # TODO + # TODO + # TODO + # TODO + # TODO + # TODO + # TODO + + procesing_time_end = time.perf_counter() + logger.info( + f"Batch {i_batch} GPU processing time: {procesing_time_end - procesing_time_start:.3f} seconds" + ) + logger.info( + f"Batch {i_batch} stl points: {stl_center_results.shape[1] if stl_center_results is not None else 0}" + ) + + output_start_time = time.perf_counter() + ###################################################### + # Save the outputs to file: + ###################################################### + # TODO + # TODO + # TODO + # TODO + # TODO + # TODO + output_end_time = time.perf_counter() + logger.info( + f"Batch {i_batch} output time: {output_end_time - output_start_time:.3f} seconds" + ) + + batch_start_time = time.perf_counter() + + +@hydra.main(version_base="1.3", config_path="conf", config_name="config") +def main(cfg: DictConfig) -> None: + ###################################################### + # initialize distributed manager + ###################################################### + DistributedManager.initialize() + dist = DistributedManager() - cmax = scaling_factors[0] - cmin = scaling_factors[1] + # DoMINO supports domain parallel training and inference. This function helps coordinate + # how to set that up, if needed. + domain_mesh, data_mesh, placements = coordinate_distributed_environment(cfg) - volume_coordinates_all = torch.reshape( - volume_coordinates_all, (1, num_sample_points, 3) - ) - volume_solutions_all = torch.reshape( - volume_solutions_all, (1, num_sample_points, self.num_vol_vars) - ) + ###################################################### + # Initialize NVML + ###################################################### + nvmlInit() + gpu_handle = nvmlDeviceGetHandleByIndex(dist.device.index) - if self.vol_factors is not None: - volume_solutions_all = unnormalize( - volume_solutions_all, self.vol_factors[0], self.vol_factors[1] - ) + ###################################################### + # Initialize logger + ###################################################### - self.out_dict["coordinates"] = ( - 0.5 * (volume_coordinates_all + 1.0) * (cmax - cmin) + cmin - ) - self.out_dict["velocity"] = ( - volume_solutions_all[:, :, :3] * self.stream_velocity - ) - self.out_dict["pressure"] = ( - volume_solutions_all[:, :, 3:4] - * self.stream_velocity**2.0 - * self.air_density - ) - # self.out_dict["turbulent-kinetic-energy"] = ( - # volume_solutions_all[:, :, 4:5] - # * self.stream_velocity**2.0 - # * self.air_density - # ) - # self.out_dict["turbulent-viscosity"] = ( - # volume_solutions_all[:, :, 5:] * self.stream_velocity * self.length_scale - # ) - self.out_dict["bounding_box_dims"] = torch.vstack(self.bounding_box_min_max) - - if plot_solutions: - print("Plotting solutions") - plot_save_path = os.path.join(self.cfg.output, "plots/contours/") - create_directory(plot_save_path) - - p_grid = 0.5 * (self.grid + 1.0) * (cmax - cmin) + cmin - p_grid = p_grid.cpu().numpy() - sdf_grid = self.sdf_grid.cpu().numpy() - volume_coordinates_all = ( - 0.5 * (volume_coordinates_all + 1.0) * (cmax - cmin) + cmin - ) - volume_solutions_all[:, :, :3] = ( - volume_solutions_all[:, :, :3] * self.stream_velocity - ) - volume_solutions_all[:, :, 3:4] = ( - volume_solutions_all[:, :, 3:4] - * self.stream_velocity**2.0 - * self.air_density - ) - # volume_solutions_all[:, :, 4:5] = ( - # volume_solutions_all[:, :, 4:5] - # * self.stream_velocity**2.0 - # * self.air_density - # ) - # volume_solutions_all[:, :, 5] = ( - # volume_solutions_all[:, :, 5] * self.stream_velocity * self.length_scale - # ) - volume_coordinates_all = volume_coordinates_all.cpu().numpy() - volume_solutions_all = volume_solutions_all.cpu().numpy() - - # ND interpolation on a grid - prediction_grid = nd_interpolator( - volume_coordinates_all, volume_solutions_all[0], p_grid[0] - ) - nx, ny, nz, vars = prediction_grid.shape - idx = np.where(sdf_grid[0] < 0.0) - prediction_grid[idx] = float("inf") - axes_titles = ["y/4 plane", "y/2 plane"] - - plot( - prediction_grid[:, int(ny / 4), :, 0], - prediction_grid[:, int(ny / 2), :, 0], - var="x-vel", - save_path=plot_save_path + f"x-vel-midplane_{self.stream_velocity}.png", - axes_titles=axes_titles, - plot_error=False, - ) - plot( - prediction_grid[:, int(ny / 4), :, 1], - prediction_grid[:, int(ny / 2), :, 1], - var="y-vel", - save_path=plot_save_path + f"y-vel-midplane_{self.stream_velocity}.png", - axes_titles=axes_titles, - plot_error=False, - ) - plot( - prediction_grid[:, int(ny / 4), :, 2], - prediction_grid[:, int(ny / 2), :, 2], - var="z-vel", - save_path=plot_save_path + f"z-vel-midplane_{self.stream_velocity}.png", - axes_titles=axes_titles, - plot_error=False, - ) - plot( - prediction_grid[:, int(ny / 4), :, 3], - prediction_grid[:, int(ny / 2), :, 3], - var="pres", - save_path=plot_save_path + f"pres-midplane_{self.stream_velocity}.png", - axes_titles=axes_titles, - plot_error=False, - ) - # plot( - # prediction_grid[:, int(ny / 4), :, 4], - # prediction_grid[:, int(ny / 2), :, 4], - # var="tke", - # save_path=plot_save_path + f"tke-midplane_{self.stream_velocity}.png", - # axes_titles=axes_titles, - # plot_error=False, - # ) - # plot( - # prediction_grid[:, int(ny / 4), :, 5], - # prediction_grid[:, int(ny / 2), :, 5], - # var="nut", - # save_path=plot_save_path + f"nut-midplane_{self.stream_velocity}.png", - # axes_titles=axes_titles, - # plot_error=False, - # ) - - def cold_start(self, cached_geom_path=None): - print("Cold start") - self.compute_geo_encoding(cached_geom_path) - self.compute_volume_solutions(num_sample_points=10) - self.clear_out_dict() - - @torch.no_grad() - def calculate_geometry_encoding( - self, geo_centers, p_grid, sdf_grid, s_grid, sdf_surf_grid, model - ): - vol_min = self.bounding_box_min_max[0] - vol_max = self.bounding_box_min_max[1] - surf_min = self.bounding_box_surface_min_max[0] - surf_max = self.bounding_box_surface_min_max[1] - - geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 - if self.dist.world_size == 1: - encoding_g_vol = model.geo_rep_volume(geo_centers_vol, p_grid, sdf_grid) - else: - encoding_g_vol = model.module.geo_rep_volume( - geo_centers_vol, p_grid, sdf_grid - ) + logger = PythonLogger("Inference") + logger = RankZeroLoggingWrapper(logger, dist) - geo_centers_surf = 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + logger.info(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") - if self.dist.world_size == 1: - encoding_g_surf = model.geo_rep_surface( - geo_centers_surf, s_grid, sdf_surf_grid - ) - else: - encoding_g_surf = model.module.geo_rep_surface( - geo_centers_surf, s_grid, sdf_surf_grid - ) + ###################################################### + # Get scaling factors + # Likely, you want to reuse the scaling factors from training. + ###################################################### + vol_factors, surf_factors = load_scaling_factors(cfg) - if self.dist.world_size == 1: - encoding_g_surf1 = model.geo_rep_surface1( - geo_centers_surf, s_grid, sdf_surf_grid - ) - else: - encoding_g_surf1 = model.module.geo_rep_surface1( - geo_centers_surf, s_grid, sdf_surf_grid - ) + ###################################################### + # Configure the model + ###################################################### + model_type = cfg.model.model_type + num_vol_vars, num_surf_vars, num_global_features = get_num_vars(cfg, model_type) - geo_encoding = 0.5 * encoding_g_surf1 + 0.5 * encoding_g_vol - geo_encoding_surface = 0.5 * encoding_g_surf - return geo_encoding, geo_encoding_surface - - @torch.no_grad() - def compute_solution_on_surface( - self, - geo_encoding, - surface_mesh_centers, - surface_mesh_neighbors, - surface_normals, - surface_neighbors_normals, - surface_areas, - surface_neighbors_areas, - pos_normals_com, - s_grid, - model, - inlet_velocity, - air_density, - ): - """ - Global parameters: For this particular case, the model was trained on single velocity/density values - across all simulations. Hence, global_params_values and global_params_reference are the same. - """ - global_params_values = torch.cat( - (inlet_velocity, air_density), axis=1 - ) # (1, 2) - global_params_values = torch.unsqueeze(global_params_values, -1) # (1, 2, 1) - - global_params_reference = torch.cat( - (inlet_velocity, air_density), axis=1 - ) # (1, 2) - global_params_reference = torch.unsqueeze( - global_params_reference, -1 - ) # (1, 2, 1) - - if self.dist.world_size == 1: - geo_encoding_local = model.geo_encoding_local( - geo_encoding, surface_mesh_centers, s_grid, mode="surface" - ) - else: - geo_encoding_local = model.module.geo_encoding_local( - geo_encoding, surface_mesh_centers, s_grid, mode="surface" - ) + if model_type == "combined" or model_type == "surface": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + else: + surface_variable_names = [] - pos_encoding = pos_normals_com - surface_areas = torch.unsqueeze(surface_areas, -1) - surface_neighbors_areas = torch.unsqueeze(surface_neighbors_areas, -1) - - if self.dist.world_size == 1: - pos_encoding = model.position_encoder(pos_encoding, eval_mode="surface") - tpredictions_batch = model.calculate_solution_with_neighbors( - surface_mesh_centers, - geo_encoding_local, - pos_encoding, - surface_mesh_neighbors, - surface_normals, - surface_neighbors_normals, - surface_areas, - surface_neighbors_areas, - global_params_values, - global_params_reference, - ) - else: - pos_encoding = model.module.position_encoder( - pos_encoding, eval_mode="surface" - ) - tpredictions_batch = model.module.calculate_solution_with_neighbors( - surface_mesh_centers, - geo_encoding_local, - pos_encoding, - surface_mesh_neighbors, - surface_normals, - surface_neighbors_normals, - surface_areas, - surface_neighbors_areas, - global_params_values, - global_params_reference, - ) + if model_type == "combined" or model_type == "volume": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + else: + volume_variable_names = [] + + ###################################################### + # Check that the sample size is equal. + # unequal samples could be done but they aren't, here.s + ###################################################### + if cfg.model.model_type == "combined": + if cfg.model.volume_points_sample != cfg.model.surface_points_sample: + raise ValueError( + "Volume and surface points sample must be equal for combined model" + ) + + # Get the number of sample points: + sample_points = ( + cfg.model.surface_points_sample + if cfg.model.model_type == "surface" + else cfg.model.volume_points_sample + ) - return tpredictions_batch - - @torch.no_grad() - def compute_solution_in_volume( - self, - geo_encoding, - volume_mesh_centers, - sdf_nodes, - pos_enc_closest, - pos_normals_com, - p_grid, - model, - use_sdf_basis, - inlet_velocity, - air_density, - ): - ## Global parameters - global_params_values = torch.cat( - (inlet_velocity, air_density), axis=1 - ) # (1, 2) - global_params_values = torch.unsqueeze(global_params_values, -1) # (1, 2, 1) - - global_params_reference = torch.cat( - (inlet_velocity, air_density), axis=1 - ) # (1, 2) - global_params_reference = torch.unsqueeze( - global_params_reference, -1 - ) # (1, 2, 1) - - if self.dist.world_size == 1: - geo_encoding_local = model.geo_encoding_local( - geo_encoding, volume_mesh_centers, p_grid, mode="volume" - ) - else: - geo_encoding_local = model.module.geo_encoding_local( - geo_encoding, volume_mesh_centers, p_grid, mode="volume" - ) - if use_sdf_basis: - pos_encoding = torch.cat( - (sdf_nodes, pos_enc_closest, pos_normals_com), axis=-1 - ) - else: - pos_encoding = pos_normals_com - - if self.dist.world_size == 1: - pos_encoding = model.position_encoder(pos_encoding, eval_mode="volume") - tpredictions_batch = model.calculate_solution( - volume_mesh_centers, - geo_encoding_local, - pos_encoding, - global_params_values, - global_params_reference, - num_sample_points=self.stencil_size, - eval_mode="volume", - ) - else: - pos_encoding = model.module.position_encoder( - pos_encoding, eval_mode="volume" - ) - tpredictions_batch = model.module.calculate_solution( - volume_mesh_centers, - geo_encoding_local, - pos_encoding, - global_params_values, - global_params_reference, - num_sample_points=self.stencil_size, - eval_mode="volume", - ) - return tpredictions_batch + ###################################################### + # If the batch size doesn't evenly divide + # the num points, that's ok. But print a warning + # that the total points will get tweaked. + ###################################################### + if cfg.eval.num_points % sample_points != 0: + logger.warning( + f"Batch size {sample_points} doesn't evenly divide num points {cfg.eval.num_points}." + ) + logger.warning( + f"Total points will be rounded up to {((cfg.eval.num_points // sample_points) + 1) * sample_points}." + ) + + ###################################################### + # Configure the dataset + # We are applying preprocessing in a separate step + # for this - so the dataset and datapipe are separate + ###################################################### + + # This helper function is to determine which keys to read from the data + # (and which to use default values for, if they aren't present - like + # air_density, for example) + keys_to_read, keys_to_read_if_available = get_keys_to_read( + cfg, model_type, get_ground_truth=True + ) + # Override the model type + # For the inference pipeline, we adjust the tooling a little for the data. + # We use only a bare STL dataset that will read the mesh coordinates + # and triangle definitions. We'll compute the centers and normals + # on the GPU (instead of on the CPU, as pyvista would do) and + # then we can sample from that mesh on the GPU. + # test_dataset = DrivaerMLDataset( + # data_dir=cfg.eval.test_path, + # keys_to_read=[ + # "stl_coordinates", + # "stl_faces", + # ], + # keys_to_read_if_available=keys_to_read_if_available, + # output_device=dist.device, + # ) + # Volumetric data will be generated on the fly on the GPU. + + ###################################################### + # Configure the datapipe + # We _won't_ iterate over the datapipe, however, we can use the + # datapipe processing tools on the sampled surface and + # volume points with the same preprocessing. + # It also is used to un-normalize the model outputs. + ###################################################### + overrides = {} + if hasattr(cfg.data, "gpu_preprocessing"): + overrides["gpu_preprocessing"] = cfg.data.gpu_preprocessing + + if hasattr(cfg.data, "gpu_output"): + overrides["gpu_output"] = cfg.data.gpu_output + + test_dataloader = create_domino_dataset( + cfg, + phase="test", + keys_to_read=["stl_coordinates", "stl_faces"], + keys_to_read_if_available=keys_to_read_if_available, + vol_factors=vol_factors, + surf_factors=surf_factors, + normalize_coordinates=cfg.data.normalize_coordinates, + sample_in_bbox=cfg.data.sample_in_bbox, + sampling=cfg.data.sampling, + device_mesh=domain_mesh, + placements=placements, + ) -if __name__ == "__main__": - OmegaConf.register_new_resolver("eval", eval) - with initialize(version_base="1.3", config_path="conf"): - cfg = compose(config_name="config") + ###################################################### + # The sampler is used in multi-gpu inference to + # coordinate the batches used for each rank. + ###################################################### + test_sampler = DistributedSampler( + test_dataloader, + num_replicas=data_mesh.size(), + rank=data_mesh.get_local_rank(), + **cfg.train.sampler, + ) - DistributedManager.initialize() - dist = DistributedManager() + ###################################################### + # Configure the model + # and move it to the device. + ###################################################### + model = DoMINO( + input_features=3, + output_features_vol=num_vol_vars, + output_features_surf=num_surf_vars, + global_features=num_global_features, + model_parameters=cfg.model, + ).to(dist.device) + + # Print model summary (structure and parmeter count). + logger.info(f"Model summary:\n{torchinfo.summary(model, verbose=0, depth=2)}\n") if dist.world_size > 1: torch.distributed.barrier() - input_path = cfg.eval.test_path - dirnames = get_filenames(input_path) - dev_id = torch.cuda.current_device() - num_files = int(len(dirnames) / 8) - dirnames_per_gpu = dirnames[int(num_files * dev_id) : int(num_files * (dev_id + 1))] - - domino = dominoInference(cfg, dist, False) - domino.initialize_model( - model_path="/lustre/models/DoMINO.0.7.pt" - ) ## Replace the model path with location of the trained model - - for count, dirname in enumerate(dirnames_per_gpu): - # print(f"Processing file {dirname}") - filepath = os.path.join(input_path, dirname) - - STREAM_VELOCITY = 30.0 - AIR_DENSITY = 1.205 - - # Neighborhood points sampled for evaluation, tradeoff between accuracy and speed - STENCIL_SIZE = ( - 7 # Higher stencil size -> more accuracy but more evaluation time - ) - - domino.set_stl_path(filepath) - domino.set_stream_velocity(STREAM_VELOCITY) - domino.set_stencil_size(STENCIL_SIZE) - - domino.read_stl() + load_checkpoint( + to_absolute_path(cfg.resume_dir), + models=model, + device=dist.device, + ) - domino.initialize_data_processor() + start_time = time.perf_counter() - # Calculate geometry encoding - domino.compute_geo_encoding() + # This controls what indices to use for each epoch. + test_sampler.set_epoch(0) - # Calculate volume solutions - domino.compute_volume_solutions( - num_sample_points=10_256_000, plot_solutions=False - ) + prof = Profiler() - # Calculate surface solutions - domino.compute_surface_solutions() - domino.compute_forces() - out_dict = domino.get_out_dict() - - print( - "Dirname:", - dirname, - "Drag:", - out_dict["drag_force"], - "Lift:", - out_dict["lift_force"], - ) - vtp_path = f"/lustre/snidhan/physicsnemo-work/domino-global-param-runs/stl-results/pred_{dirname}_4.vtp" - domino.mesh_stl.save(vtp_path) - reader = vtk.vtkXMLPolyDataReader() - reader.SetFileName(f"{vtp_path}") - reader.Update() - polydata_surf = reader.GetOutput() - - surfParam_vtk = numpy_support.numpy_to_vtk( - out_dict["pressure_surface"][0].cpu().numpy() + model.eval() + epoch_start_time = time.perf_counter() + with prof: + inference_epoch( + dataloader=test_dataloader, + sampler=test_sampler, + model=model, + logger=logger, + gpu_handle=gpu_handle, + batch_size=sample_points, + total_points=cfg.eval.num_points, ) - surfParam_vtk.SetName(f"Pressure") - polydata_surf.GetCellData().AddArray(surfParam_vtk) + epoch_end_time = time.perf_counter() + logger.info( + f"Device {dist.device}, Epoch took {epoch_end_time - epoch_start_time:.3f} seconds" + ) - surfParam_vtk = numpy_support.numpy_to_vtk( - out_dict["wall-shear-stress"][0].cpu().numpy() - ) - surfParam_vtk.SetName(f"Wall-shear-stress") - polydata_surf.GetCellData().AddArray(surfParam_vtk) - write_to_vtp(polydata_surf, vtp_path) - exit() +if __name__ == "__main__": + # Profiler().enable("torch") + # Profiler().initialize() + main() + # Profiler().finalize() diff --git a/examples/cfd/external_aerodynamics/domino/src/loss.py b/examples/cfd/external_aerodynamics/domino/src/loss.py new file mode 100644 index 0000000000..3ab52c7903 --- /dev/null +++ b/examples/cfd/external_aerodynamics/domino/src/loss.py @@ -0,0 +1,553 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from typing import Literal, Any + +from physicsnemo.utils.domino.utils import unnormalize + +from typing import Literal, Any + +import torch.cuda.nvtx as nvtx + +from physicsnemo.utils.domino.utils import * + + +def compute_physics_loss( + output: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + loss_type: Literal["mse", "rmse"], + dims: tuple[int, ...] | None, + first_deriv: torch.nn.Module, + eqn: Any, + bounding_box: torch.Tensor, + vol_factors: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute physics-based loss terms for Navier-Stokes equations. + + Args: + output: Model output containing (output, coords_neighbors, output_neighbors, neighbors_list) + target: Ground truth values + mask: Mask for valid values + loss_type: Type of loss to calculate ("mse" or "rmse") + dims: Dimensions for loss calculation + first_deriv: First derivative calculator + eqn: Equations + bounding_box: Bounding box for normalization + vol_factors: Volume factors for normalization + + Returns: + Tuple of (data_loss, continuity_loss, momentum_x_loss, momentum_y_loss, momentum_z_loss) + """ + # Physics loss enabled + output, coords_neighbors, output_neighbors, neighbors_list = output + batch_size = output.shape[1] + fields, num_neighbors = output_neighbors.shape[3], output_neighbors.shape[2] + coords_total = coords_neighbors[0, :] + output_total = output_neighbors[0, :] + output_total_unnormalized = unnormalize( + output_total, vol_factors[0], vol_factors[1] + ) + coords_total_unnormalized = unnormalize( + coords_total, bounding_box[0], bounding_box[1] + ) + + # compute first order gradients on all the nodes from the neighbors_list + grad_list = {} + for parent_id, neighbor_ids in neighbors_list.items(): + neighbor_ids_tensor = torch.tensor(neighbor_ids).to( + output_total_unnormalized.device + ) + du = ( + output_total_unnormalized[:, [parent_id]] + - output_total_unnormalized[:, neighbor_ids_tensor] + ) + dv = ( + coords_total_unnormalized[:, [parent_id]] + - coords_total_unnormalized[:, neighbor_ids_tensor] + ) + grads = first_deriv.forward( + coords=None, connectivity_tensor=None, y=None, du=du, dv=dv + ) + grad = torch.cat(grads, dim=1) + grad_list[parent_id] = grad + + # compute second order gradients on only the center node + neighbor_ids_tensor = torch.tensor(neighbors_list[0]).to( + output_total_unnormalized.device + ) + grad_neighbors_center = torch.stack([v for v in grad_list.values()], dim=1) + grad_neighbors_center = grad_neighbors_center.reshape( + batch_size, len(neighbors_list[0]) + 1, -1 + ) + + du = grad_neighbors_center[:, [0]] - grad_neighbors_center[:, neighbor_ids_tensor] + dv = ( + coords_total_unnormalized[:, [0]] + - coords_total_unnormalized[:, neighbor_ids_tensor] + ) + + # second order gradients + ggrads_center = first_deriv.forward( + coords=None, connectivity_tensor=None, y=None, du=du, dv=dv + ) + ggrad_center = torch.cat(ggrads_center, dim=1) + grad_neighbors_center = grad_neighbors_center.reshape( + batch_size, len(neighbors_list[0]) + 1, 3, -1 + ) + + # Get the outputs on the original nodes + fields_center_unnormalized = output_total_unnormalized[:, 0, :] + grad_center = grad_neighbors_center[:, 0, :, :] + grad_grad_uvw_center = ggrad_center[:, :, :9] + + nu = 1.507 * 1e-5 + + dict_mapping = { + "u": fields_center_unnormalized[:, [0]], + "v": fields_center_unnormalized[:, [1]], + "w": fields_center_unnormalized[:, [2]], + "p": fields_center_unnormalized[:, [3]], + "nu": nu + fields_center_unnormalized[:, [4]], + "u__x": grad_center[:, 0, [0]], + "u__y": grad_center[:, 1, [0]], + "u__z": grad_center[:, 2, [0]], + "v__x": grad_center[:, 0, [1]], + "v__y": grad_center[:, 1, [1]], + "v__z": grad_center[:, 2, [1]], + "w__x": grad_center[:, 0, [2]], + "w__y": grad_center[:, 1, [2]], + "w__z": grad_center[:, 2, [2]], + "p__x": grad_center[:, 0, [3]], + "p__y": grad_center[:, 1, [3]], + "p__z": grad_center[:, 2, [3]], + "nu__x": grad_center[:, 0, [4]], + "nu__y": grad_center[:, 1, [4]], + "nu__z": grad_center[:, 2, [4]], + "u__x__x": grad_grad_uvw_center[:, 0, [0]], + "u__x__y": grad_grad_uvw_center[:, 1, [0]], + "u__x__z": grad_grad_uvw_center[:, 2, [0]], + "u__y__x": grad_grad_uvw_center[:, 1, [0]], # same as __x__y + "u__y__y": grad_grad_uvw_center[:, 1, [1]], + "u__y__z": grad_grad_uvw_center[:, 2, [1]], + "u__z__x": grad_grad_uvw_center[:, 2, [0]], # same as __x__z + "u__z__y": grad_grad_uvw_center[:, 2, [1]], # same as __y__z + "u__z__z": grad_grad_uvw_center[:, 2, [2]], + "v__x__x": grad_grad_uvw_center[:, 0, [3]], + "v__x__y": grad_grad_uvw_center[:, 1, [3]], + "v__x__z": grad_grad_uvw_center[:, 2, [3]], + "v__y__x": grad_grad_uvw_center[:, 1, [3]], # same as __x__y + "v__y__y": grad_grad_uvw_center[:, 1, [4]], + "v__y__z": grad_grad_uvw_center[:, 2, [4]], + "v__z__x": grad_grad_uvw_center[:, 2, [3]], # same as __x__z + "v__z__y": grad_grad_uvw_center[:, 2, [4]], # same as __y__z + "v__z__z": grad_grad_uvw_center[:, 2, [5]], + "w__x__x": grad_grad_uvw_center[:, 0, [6]], + "w__x__y": grad_grad_uvw_center[:, 1, [6]], + "w__x__z": grad_grad_uvw_center[:, 2, [6]], + "w__y__x": grad_grad_uvw_center[:, 1, [6]], # same as __x__y + "w__y__y": grad_grad_uvw_center[:, 1, [7]], + "w__y__z": grad_grad_uvw_center[:, 2, [7]], + "w__z__x": grad_grad_uvw_center[:, 2, [6]], # same as __x__z + "w__z__y": grad_grad_uvw_center[:, 2, [7]], # same as __y__z + "w__z__z": grad_grad_uvw_center[:, 2, [8]], + } + continuity = eqn["continuity"].evaluate(dict_mapping)["continuity"] + momentum_x = eqn["momentum_x"].evaluate(dict_mapping)["momentum_x"] + momentum_y = eqn["momentum_y"].evaluate(dict_mapping)["momentum_y"] + momentum_z = eqn["momentum_z"].evaluate(dict_mapping)["momentum_z"] + + # Compute the weights for the equation residuals + weight_continuity = torch.sigmoid(0.5 * (torch.abs(continuity) - 10)) + weight_momentum_x = torch.sigmoid(0.5 * (torch.abs(momentum_x) - 10)) + weight_momentum_y = torch.sigmoid(0.5 * (torch.abs(momentum_y) - 10)) + weight_momentum_z = torch.sigmoid(0.5 * (torch.abs(momentum_z) - 10)) + + weighted_continuity = weight_continuity * torch.abs(continuity) + weighted_momentum_x = weight_momentum_x * torch.abs(momentum_x) + weighted_momentum_y = weight_momentum_y * torch.abs(momentum_y) + weighted_momentum_z = weight_momentum_z * torch.abs(momentum_z) + + # Compute data loss + num = torch.sum(mask * (output - target) ** 2.0, dims) + if loss_type == "rmse": + denom = torch.sum(mask * target**2.0, dims) + else: + denom = torch.sum(mask) + + del coords_total, output_total + torch.cuda.empty_cache() + + return ( + torch.mean(num / denom), + torch.mean(torch.abs(weighted_continuity)), + torch.mean(torch.abs(weighted_momentum_x)), + torch.mean(torch.abs(weighted_momentum_y)), + torch.mean(torch.abs(weighted_momentum_z)), + ) + + +def loss_fn( + output: torch.Tensor, + target: torch.Tensor, + loss_type: Literal["mse", "rmse"], + padded_value: float = -10, +) -> torch.Tensor: + """Calculate mean squared error or root mean squared error with masking for padded values. + + Args: + output: Predicted values from the model + target: Ground truth values + loss_type: Type of loss to calculate ("mse" or "rmse") + padded_value: Value used for padding in the tensor + + Returns: + Calculated loss as a scalar tensor + """ + mask = abs(target - padded_value) > 1e-3 + + if loss_type == "rmse": + dims = (0, 1) + else: + dims = None + + num = torch.sum(mask * (output - target) ** 2.0, dims) + if loss_type == "rmse": + denom = torch.sum(mask * (target - torch.mean(target, (0, 1))) ** 2.0, dims) + loss = torch.mean(num / denom) + elif loss_type == "mse": + denom = torch.sum(mask) + loss = torch.mean(num / denom) + else: + raise ValueError(f"Invalid loss type: {loss_type}") + return loss + + +def loss_fn_with_physics( + output: torch.Tensor, + target: torch.Tensor, + loss_type: Literal["mse", "rmse"], + padded_value: float = -10, + first_deriv: torch.nn.Module = None, + eqn: Any = None, + bounding_box: torch.Tensor = None, + vol_factors: torch.Tensor = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate loss with physics-based terms for appropriate equations. + + Args: + output: Predicted values from the model (with neighbor data when physics enabled) + target: Ground truth values + loss_type: Type of loss to calculate ("mse" or "rmse") + padded_value: Value used for padding in the tensor + first_deriv: First derivative calculator + eqn: Equations + bounding_box: Bounding box for normalization + vol_factors: Volume factors for normalization + + Returns: + Tuple of (data_loss, continuity_loss, momentum_x_loss, momentum_y_loss, momentum_z_loss) + """ + mask = abs(target - padded_value) > 1e-3 + + if loss_type == "rmse": + dims = (0, 1) + else: + dims = None + + # Call the physics loss computation function + return compute_physics_loss( + output=output, + target=target, + mask=mask, + loss_type=loss_type, + dims=dims, + first_deriv=first_deriv, + eqn=eqn, + bounding_box=bounding_box, + vol_factors=vol_factors, + ) + + +def loss_fn_surface( + output: torch.Tensor, target: torch.Tensor, loss_type: Literal["mse", "rmse"] +) -> torch.Tensor: + """Calculate loss for surface data by handling scalar and vector components separately. + + Args: + output: Predicted surface values from the model + target: Ground truth surface values + loss_type: Type of loss to calculate ("mse" or "rmse") + + Returns: + Combined scalar and vector loss as a scalar tensor + """ + # Separate the scalar and vector components: + output_scalar, output_vector = torch.split(output, [1, 3], dim=2) + target_scalar, target_vector = torch.split(target, [1, 3], dim=2) + + numerator = torch.mean((output_scalar - target_scalar) ** 2.0) + vector_diff_sq = torch.mean((target_vector - output_vector) ** 2.0, (0, 1)) + if loss_type == "mse": + masked_loss_pres = numerator + masked_loss_ws = torch.sum(vector_diff_sq) + else: + denom = torch.mean((target_scalar - torch.mean(target_scalar, (0, 1))) ** 2.0) + masked_loss_pres = numerator / denom + + # Compute the mean diff**2 of the vector component, leave the last dimension: + masked_loss_ws_num = vector_diff_sq + masked_loss_ws_denom = torch.mean( + (target_vector - torch.mean(target_vector, (0, 1))) ** 2.0, (0, 1) + ) + masked_loss_ws = torch.sum(masked_loss_ws_num / masked_loss_ws_denom) + + loss = masked_loss_pres + masked_loss_ws + + return loss / 4.0 + + +def loss_fn_area( + output: torch.Tensor, + target: torch.Tensor, + normals: torch.Tensor, + area: torch.Tensor, + area_scaling_factor: float, + loss_type: Literal["mse", "rmse"], +) -> torch.Tensor: + """Calculate area-weighted loss for surface data considering normal vectors. + + Args: + output: Predicted surface values from the model + target: Ground truth surface values + normals: Normal vectors for the surface + area: Area values for surface elements + area_scaling_factor: Scaling factor for area weighting + loss_type: Type of loss to calculate ("mse" or "rmse") + + Returns: + Area-weighted loss as a scalar tensor + """ + area = area * area_scaling_factor + area_scale_factor = area + + # Separate the scalar and vector components. + target_scalar, target_vector = torch.split( + target * area_scale_factor, [1, 3], dim=2 + ) + output_scalar, output_vector = torch.split( + output * area_scale_factor, [1, 3], dim=2 + ) + + # Apply the normals to the scalar components (only [:,:,0]): + normals, _ = torch.split(normals, [1, normals.shape[-1] - 1], dim=2) + target_scalar = target_scalar * normals + output_scalar = output_scalar * normals + + # Compute the mean diff**2 of the scalar component: + masked_loss_pres = torch.mean(((output_scalar - target_scalar) ** 2.0), dim=(0, 1)) + if loss_type == "rmse": + masked_loss_pres /= torch.mean( + (target_scalar - torch.mean(target_scalar, (0, 1))) ** 2.0, dim=(0, 1) + ) + + # Compute the mean diff**2 of the vector component, leave the last dimension: + masked_loss_ws = torch.mean((target_vector - output_vector) ** 2.0, (0, 1)) + if loss_type == "rmse": + masked_loss_ws /= torch.mean( + (target_vector - torch.mean(target_vector, (0, 1))) ** 2.0, (0, 1) + ) + + # Combine the scalar and vector components: + loss = 0.25 * (masked_loss_pres + torch.sum(masked_loss_ws)) + + return loss + + +def integral_loss_fn( + output, target, area, normals, stream_velocity=None, padded_value=-10 +): + drag_loss = drag_loss_fn( + output, target, area, normals, stream_velocity=stream_velocity, padded_value=-10 + ) + lift_loss = lift_loss_fn( + output, target, area, normals, stream_velocity=stream_velocity, padded_value=-10 + ) + return lift_loss + drag_loss + + +def lift_loss_fn(output, target, area, normals, stream_velocity=None, padded_value=-10): + vel_inlet = stream_velocity # Get this from the dataset + mask = abs(target - padded_value) > 1e-3 + + output_true = target * mask * area * (vel_inlet) ** 2.0 + output_pred = output * mask * area * (vel_inlet) ** 2.0 + + normals = torch.select(normals, 2, 2) + # output_true_0 = output_true[:, :, 0] + output_true_0 = output_true.select(2, 0) + output_pred_0 = output_pred.select(2, 0) + + pres_true = output_true_0 * normals + pres_pred = output_pred_0 * normals + + wz_true = output_true[:, :, -1] + wz_pred = output_pred[:, :, -1] + + masked_pred = torch.mean(pres_pred + wz_pred, (1)) + masked_truth = torch.mean(pres_true + wz_true, (1)) + + loss = (masked_pred - masked_truth) ** 2.0 + loss = torch.mean(loss) + return loss + + +def drag_loss_fn(output, target, area, normals, stream_velocity=None, padded_value=-10): + vel_inlet = stream_velocity # Get this from the dataset + mask = abs(target - padded_value) > 1e-3 + output_true = target * mask * area * (vel_inlet) ** 2.0 + output_pred = output * mask * area * (vel_inlet) ** 2.0 + + pres_true = output_true[:, :, 0] * normals[:, :, 0] + pres_pred = output_pred[:, :, 0] * normals[:, :, 0] + + wx_true = output_true[:, :, 1] + wx_pred = output_pred[:, :, 1] + + masked_pred = torch.mean(pres_pred + wx_pred, (1)) + masked_truth = torch.mean(pres_true + wx_true, (1)) + + loss = (masked_pred - masked_truth) ** 2.0 + loss = torch.mean(loss) + return loss + + +def compute_loss_dict( + prediction_vol: torch.Tensor, + prediction_surf: torch.Tensor, + batch_inputs: dict, + loss_fn_type: dict, + integral_scaling_factor: float, + surf_loss_scaling: float, + vol_loss_scaling: float, + first_deriv: torch.nn.Module | None = None, + eqn: Any = None, + bounding_box: torch.Tensor | None = None, + vol_factors: torch.Tensor | None = None, + add_physics_loss: bool = False, +) -> tuple[torch.Tensor, dict]: + """ + Compute the loss terms in a single function call. + + Computes: + - Volume loss if prediction_vol is not None + - Surface loss if prediction_surf is not None + - Integral loss if prediction_surf is not None + - Total loss as a weighted sum of the above + + Returns: + - Total loss as a scalar tensor + - Dictionary of loss terms (for logging, etc) + """ + nvtx.range_push("Loss Calculation") + total_loss_terms = [] + loss_dict = {} + + if prediction_vol is not None: + target_vol = batch_inputs["volume_fields"] + + if add_physics_loss: + loss_vol = loss_fn_with_physics( + prediction_vol, + target_vol, + loss_fn_type.loss_type, + padded_value=-10, + first_deriv=first_deriv, + eqn=eqn, + bounding_box=bounding_box, + vol_factors=vol_factors, + ) + loss_dict["loss_vol"] = loss_vol[0] + loss_dict["loss_continuity"] = loss_vol[1] + loss_dict["loss_momentum_x"] = loss_vol[2] + loss_dict["loss_momentum_y"] = loss_vol[3] + loss_dict["loss_momentum_z"] = loss_vol[4] + total_loss_terms.append(loss_vol[0]) + total_loss_terms.append(loss_vol[1]) + total_loss_terms.append(loss_vol[2]) + total_loss_terms.append(loss_vol[3]) + total_loss_terms.append(loss_vol[4]) + else: + loss_vol = loss_fn( + prediction_vol, + target_vol, + loss_fn_type.loss_type, + padded_value=-10, + ) + loss_dict["loss_vol"] = loss_vol + total_loss_terms.append(loss_vol) + + if prediction_surf is not None: + target_surf = batch_inputs["surface_fields"] + surface_areas = batch_inputs["surface_areas"] + surface_areas = torch.unsqueeze(surface_areas, -1) + surface_normals = batch_inputs["surface_normals"] + + # Needs to be taken from the dataset + stream_velocity = batch_inputs["global_params_values"][:, 0, :] + + loss_surf = loss_fn_surface( + prediction_surf, + target_surf, + loss_fn_type.loss_type, + ) + + loss_surf_area = loss_fn_area( + prediction_surf, + target_surf, + surface_normals, + surface_areas, + area_scaling_factor=loss_fn_type.area_weighing_factor, + loss_type=loss_fn_type.loss_type, + ) + + if loss_fn_type.loss_type == "mse": + loss_surf = loss_surf * surf_loss_scaling + loss_surf_area = loss_surf_area * surf_loss_scaling + + total_loss_terms.append(loss_surf) + loss_dict["loss_surf"] = loss_surf + total_loss_terms.append(loss_surf_area) + loss_dict["loss_surf_area"] = loss_surf_area + loss_integral = ( + integral_loss_fn( + prediction_surf, + target_surf, + surface_areas, + surface_normals, + stream_velocity, + padded_value=-10, + ) + ) * integral_scaling_factor + loss_dict["loss_integral"] = loss_integral + total_loss_terms.append(loss_integral) + + total_loss = sum(total_loss_terms) + loss_dict["total_loss"] = total_loss + nvtx.range_pop() + + return total_loss, loss_dict diff --git a/examples/cfd/external_aerodynamics/domino/src/shuffle_volumetric_curator_output.py b/examples/cfd/external_aerodynamics/domino/src/shuffle_volumetric_curator_output.py new file mode 100644 index 0000000000..553d4e575a --- /dev/null +++ b/examples/cfd/external_aerodynamics/domino/src/shuffle_volumetric_curator_output.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import multiprocessing as mp +from functools import partial + +import numpy as np +import shutil + +import zarr +from numcodecs import Blosc + +""" +This script reads each zarr file from a specified directory, and copies the +data to the output directory. For the keys "volume_fields" and "volume_mesh_centers", +the script will apply a permutation (aka shuffle) of those fields in tandem. + +Since the datasets used are often very large, this script also applies +sharding to the output files which is a Zarr3 feature. + +Therefore, zarr >= 3.0 is required. +""" + + +def check_file_completeness(input_file: str, output_file: str) -> bool: + """ + Check if the output file exists and contains all required data from input file. + """ + if not os.path.exists(output_file): + return False + + in_file = zarr.open(input_file, mode="r") + try: + out_file = zarr.open(output_file, mode="r") + except zarr.errors.PathNotFoundError: + print(f"No output, returning False") + return False + + # Check if all keys except 'filename' exist and have same shapes + for key in in_file.keys(): + if key == "filename": + continue + if key not in out_file and key not in out_file.attrs: + print(f"Key {key} not in output, returning False") + return False + if isinstance(in_file[key], zarr.Array): + if key in out_file.attrs: + continue + if in_file[key].shape != out_file[key].shape: + print(f"Key {key} shape mismatch, returning False") + return False + return True + + +def store_array(store, name: str, data: np.ndarray): + # By default, chunk size is 10k points: + chunk_size = (10_000,) + data.shape[1:] + # By default, shard size is 2 million points: + shard_size = (2_000_000,) + data.shape[1:] + + zarr.create_array( + store=store, + name=name, + data=data, + chunks=chunk_size, + shards=shard_size, + compressors="auto", + ) + + +def copy_file_with_shuffled_volume_data( + input_file: str, output_file: str, random_seed: int | None = None +): + """ + Copy a file with shuffled volume data, using Zarr v3 sharding for efficient storage. + Only processes if the output file doesn't exist or is incomplete. + """ + file_is_complete = check_file_completeness(input_file, output_file) + if file_is_complete: + print(f"Skipping {output_file} - already complete") + return True + + print(f"Processing {input_file} -> {output_file}") + + # return False + + # if the output folder exists but isn't complete, purge it. + # It's probably an interrupted conversion. + if os.path.exists(output_file): + shutil.rmtree(output_file) + + # return file_is_complete + volume_keys = ["volume_fields", "volume_mesh_centers"] + + in_file = zarr.open(input_file, mode="r") + + # Create store with sharding configuration + store = zarr.storage.LocalStore(output_file) + root = zarr.group(store=store) + + # First copy all non-volume data + for key in in_file.keys(): + if key not in volume_keys: + if key == "filename": + continue + in_data = in_file[key] + if in_data.shape != (): + # For array data, use the same chunks as input but with sharding + store_array(store, key, in_data[:]) + else: + # Store scalar values as attributes + root.attrs[key] = in_data[()] + + # Open and shuffle the volume data + volume_fields = in_file["volume_fields"][:] + volume_mesh_centers = in_file["volume_mesh_centers"][:] + + if random_seed is not None: + np.random.seed(random_seed) + + # Generate a permutation + permutation = np.random.permutation(volume_fields.shape[0]) + + # Shuffle the volume data + shuffled_volume_fields = volume_fields[permutation] + shuffled_volume_mesh_centers = volume_mesh_centers[permutation] + + store_array(store, "volume_fields", shuffled_volume_fields) + store_array(store, "volume_mesh_centers", shuffled_volume_mesh_centers) + + print(f"Processed {output_file} - COMPLETE") + return True + + +def process_file(file: str, top_dir: str, out_dir: str): + """ + Process a single file, creating output directory if needed. + """ + os.makedirs(out_dir, exist_ok=True) + input_path = os.path.join(top_dir, file) + output_path = os.path.join(out_dir, file) + return copy_file_with_shuffled_volume_data(input_path, output_path) + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Shuffle volumetric curator output") + parser.add_argument("--input-dir", required=True, help="Input directory path") + parser.add_argument("--output-dir", required=True, help="Output directory path") + parser.add_argument( + "--num-cores", type=int, default=64, help="Number of cores to use" + ) + args = parser.parse_args() + + # Get list of files to process + files = os.listdir(args.input_dir) + + # Create a partial function with fixed directories + process_func = partial( + process_file, top_dir=args.input_dir, out_dir=args.output_dir + ) + + # Use multiprocessing to process files in parallel + num_cores = max(1, args.num_cores) # Leave one core free + print(f"Processing {len(files)} files using {num_cores} cores") + + with mp.Pool(num_cores) as pool: + results = pool.map(process_func, files) + print(f"Results: {results}") + print(f"Total conversions: {sum(results)}") + + +if __name__ == "__main__": + main() diff --git a/examples/cfd/external_aerodynamics/domino/src/test.py b/examples/cfd/external_aerodynamics/domino/src/test.py index 944910f9f8..4bb32dc6cd 100644 --- a/examples/cfd/external_aerodynamics/domino/src/test.py +++ b/examples/cfd/external_aerodynamics/domino/src/test.py @@ -33,6 +33,9 @@ from hydra.utils import to_absolute_path from omegaconf import DictConfig, OmegaConf +# This will set up the cupy-ecosystem and pytorch to share memory pools +from physicsnemo.utils.memory import unified_gpu_memory + import numpy as np import cupy as cp @@ -53,8 +56,12 @@ from physicsnemo.distributed import DistributedManager from physicsnemo.datapipes.cae.domino_datapipe import DoMINODataPipe from physicsnemo.models.domino.model import DoMINO +from physicsnemo.models.domino.geometry_rep import scale_sdf from physicsnemo.utils.domino.utils import * +from physicsnemo.utils.domino.vtk_file_utils import * from physicsnemo.utils.sdf import signed_distance_field +from physicsnemo.utils.neighbors import knn +from utils import ScalingFactors, load_scaling_factors # AIR_DENSITY = 1.205 # STREAM_VELOCITY = 30.00 @@ -84,7 +91,7 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): with torch.no_grad(): point_batch_size = 256000 - data_dict = dict_to_device(data_dict, device) + # data_dict = dict_to_device(data_dict, device) # Non-dimensionalization factors length_scale = data_dict["length_scale"] @@ -110,11 +117,16 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): p_grid = data_dict["grid"] sdf_grid = data_dict["sdf_grid"] # Scaling factors - vol_max = data_dict["volume_min_max"][:, 1] - vol_min = data_dict["volume_min_max"][:, 0] + if "volume_min_max" in data_dict.keys(): + vol_max = data_dict["volume_min_max"][:, 1] + vol_min = data_dict["volume_min_max"][:, 0] + geo_centers_vol = ( + 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + ) + else: + geo_centers_vol = geo_centers # Normalize based on computational domain - geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 encoding_g_vol = model.geo_rep_volume(geo_centers_vol, p_grid, sdf_grid) if output_features_surf is not None: @@ -147,10 +159,12 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] p_grid = data_dict["grid"] - prediction_vol = np.zeros_like(target_vol.cpu().numpy()) + prediction_vol = torch.zeros_like(target_vol) num_points = volume_mesh_centers.shape[1] subdomain_points = int(np.floor(num_points / point_batch_size)) - + sdf_scaling_factor = ( + cfg.model.geometry_rep.geo_processor.volume_sdf_scaling_factor + ) start_time = time.time() for p in range(subdomain_points + 1): @@ -162,58 +176,64 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): :, start_idx:end_idx ] sdf_nodes_batch = sdf_nodes[:, start_idx:end_idx] + scaled_sdf_nodes_batch = [] + for p in range(len(sdf_scaling_factor)): + scaled_sdf_nodes_batch.append( + scale_sdf(sdf_nodes_batch, sdf_scaling_factor[p]) + ) + scaled_sdf_nodes_batch = torch.cat(scaled_sdf_nodes_batch, dim=-1) + pos_volume_closest_batch = pos_volume_closest[:, start_idx:end_idx] pos_normals_com_batch = pos_volume_center_of_mass[ :, start_idx:end_idx ] - geo_encoding_local = model.geo_encoding_local( + geo_encoding_local = model.volume_local_geo_encodings( 0.5 * encoding_g_vol, volume_mesh_centers_batch, p_grid, - mode="volume", ) if cfg.model.use_sdf_in_basis_func: - pos_encoding = torch.cat( + pos_encoding_all = torch.cat( ( sdf_nodes_batch, + scaled_sdf_nodes_batch, pos_volume_closest_batch, pos_normals_com_batch, ), axis=-1, ) else: - pos_encoding = pos_normals_com_batch - pos_encoding = model.position_encoder( - pos_encoding, eval_mode="volume" - ) - tpredictions_batch = model.calculate_solution( + pos_encoding_all = pos_normals_com_batch + + pos_encoding = model.fc_p_vol(pos_encoding_all) + tpredictions_batch = model.solution_calculator_vol( volume_mesh_centers_batch, geo_encoding_local, pos_encoding, global_params_values, global_params_reference, - num_sample_points=cfg.model.num_neighbors_volume, - eval_mode="volume", ) running_tloss_vol += loss_fn(tpredictions_batch, target_batch) - prediction_vol[:, start_idx:end_idx] = ( - tpredictions_batch.cpu().numpy() - ) + prediction_vol[:, start_idx:end_idx] = tpredictions_batch - prediction_vol = unnormalize(prediction_vol, vol_factors[0], vol_factors[1]) + if cfg.model.normalization == "min_max_scaling": + prediction_vol = unnormalize( + prediction_vol, vol_factors[0], vol_factors[1] + ) + elif cfg.model.normalization == "mean_std_scaling": + prediction_vol = unstandardize( + prediction_vol, vol_factors[0], vol_factors[1] + ) + # print(np.amax(prediction_vol, axis=(0, 1)), np.amin(prediction_vol, axis=(0, 1))) - prediction_vol[:, :, :3] = ( - prediction_vol[:, :, :3] * stream_velocity[0, 0].cpu().numpy() - ) + prediction_vol[:, :, :3] = prediction_vol[:, :, :3] * stream_velocity[0, 0] prediction_vol[:, :, 3] = ( prediction_vol[:, :, 3] - * stream_velocity[0, 0].cpu().numpy() ** 2.0 - * air_density[0, 0].cpu().numpy() + * stream_velocity[0, 0] ** 2.0 + * air_density[0, 0] ) prediction_vol[:, :, 4] = ( - prediction_vol[:, :, 4] - * stream_velocity[0, 0].cpu().numpy() - * length_scale[0].cpu().numpy() + prediction_vol[:, :, 4] * stream_velocity[0, 0] * length_scale[0] ) else: prediction_vol = None @@ -236,7 +256,7 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): subdomain_points = int(np.floor(num_points / point_batch_size)) target_surf = data_dict["surface_fields"] - prediction_surf = np.zeros_like(target_surf.cpu().numpy()) + prediction_surf = torch.zeros_like(target_surf) start_time = time.time() @@ -262,18 +282,14 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): pos_surface_center_of_mass_batch = pos_surface_center_of_mass[ :, start_idx:end_idx ] - geo_encoding_local = model.geo_encoding_local( + geo_encoding_local = model.surface_local_geo_encodings( 0.5 * encoding_g_surf, surface_mesh_centers_batch, s_grid, - mode="surface", - ) - pos_encoding = pos_surface_center_of_mass_batch - pos_encoding = model.position_encoder( - pos_encoding, eval_mode="surface" ) + pos_encoding = model.fc_p_surf(pos_surface_center_of_mass_batch) - tpredictions_batch = model.calculate_solution_with_neighbors( + tpredictions_batch = model.solution_calculator_surf( surface_mesh_centers_batch, geo_encoding_local, pos_encoding, @@ -284,20 +300,22 @@ def test_step(data_dict, model, device, cfg, vol_factors, surf_factors): surface_neighbors_areas_batch, global_params_values, global_params_reference, - num_sample_points=cfg.model.num_neighbors_surface, ) running_tloss_surf += loss_fn(tpredictions_batch, target_batch) - prediction_surf[:, start_idx:end_idx] = ( - tpredictions_batch.cpu().numpy() - ) + prediction_surf[:, start_idx:end_idx] = tpredictions_batch + if cfg.model.normalization == "min_max_scaling": + prediction_surf = unnormalize( + prediction_surf, surf_factors[0], surf_factors[1] + ) + elif cfg.model.normalization == "mean_std_scaling": + prediction_surf = unstandardize( + prediction_surf, surf_factors[0], surf_factors[1] + ) prediction_surf = ( - unnormalize(prediction_surf, surf_factors[0], surf_factors[1]) - * stream_velocity[0, 0].cpu().numpy() ** 2.0 - * air_density[0, 0].cpu().numpy() + prediction_surf * stream_velocity[0, 0] ** 2.0 * air_density[0, 0] ) - else: prediction_surf = None @@ -346,22 +364,12 @@ def main(cfg: DictConfig): else: global_features += 1 - vol_save_path = os.path.join( - cfg.eval.scaling_param_path, "volume_scaling_factors.npy" - ) - surf_save_path = os.path.join( - cfg.eval.scaling_param_path, "surface_scaling_factors.npy" - ) - if os.path.exists(vol_save_path): - vol_factors = np.load(vol_save_path) - else: - vol_factors = None - - if os.path.exists(surf_save_path): - surf_factors = np.load(surf_save_path) - else: - surf_factors = None + ###################################################### + # Get scaling factors - precompute them if this fails! + ###################################################### + pickle_path = os.path.join(cfg.data.scaling_factors) + vol_factors, surf_factors = load_scaling_factors(cfg) print("Vol factors:", vol_factors) print("Surf factors:", surf_factors) @@ -429,41 +437,56 @@ def main(cfg: DictConfig): :, 1: ] # Assuming triangular elements mesh_indices_flattened = stl_faces.flatten() - length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) + length_scale = np.array( + np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)), + dtype=np.float32, + ) + length_scale = torch.from_numpy(length_scale).to(torch.float32).to(dist.device) stl_sizes = mesh_stl.compute_cell_sizes(length=False, area=True, volume=False) stl_sizes = np.array(stl_sizes.cell_data["Area"], dtype=np.float32) stl_centers = np.array(mesh_stl.cell_centers().points, dtype=np.float32) + # Convert to torch tensors and load on device + stl_vertices = torch.from_numpy(stl_vertices).to(torch.float32).to(dist.device) + stl_sizes = torch.from_numpy(stl_sizes).to(torch.float32).to(dist.device) + stl_centers = torch.from_numpy(stl_centers).to(torch.float32).to(dist.device) + mesh_indices_flattened = ( + torch.from_numpy(mesh_indices_flattened).to(torch.int32).to(dist.device) + ) + # Center of mass calculation center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes) - if cfg.data.bounding_box_surface is None: - s_max = np.amax(stl_vertices, 0) - s_min = np.amin(stl_vertices, 0) - else: - bounding_box_dims_surf = [] - bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.max)) - bounding_box_dims_surf.append(np.asarray(cfg.data.bounding_box_surface.min)) - s_max = np.float32(bounding_box_dims_surf[0]) - s_min = np.float32(bounding_box_dims_surf[1]) + s_max = ( + torch.from_numpy(np.asarray(cfg.data.bounding_box_surface.max)) + .to(torch.float32) + .to(dist.device) + ) + s_min = ( + torch.from_numpy(np.asarray(cfg.data.bounding_box_surface.min)) + .to(torch.float32) + .to(dist.device) + ) nx, ny, nz = cfg.model.interp_res - surf_grid = create_grid(s_max, s_min, [nx, ny, nz]) - surf_grid_reshaped = surf_grid.reshape(nx * ny * nz, 3) + surf_grid = create_grid( + s_max, s_min, torch.from_numpy(np.asarray([nx, ny, nz])).to(dist.device) + ) + + normed_stl_vertices_cp = normalize(stl_vertices, s_max, s_min) + surf_grid_normed = normalize(surf_grid, s_max, s_min) # SDF calculation on the grid using WARP - sdf_surf_grid = signed_distance_field( - cp.asarray(stl_vertices).astype(cp.float32), - cp.asarray(mesh_indices_flattened).astype(cp.int32), - cp.asarray(surf_grid_reshaped).astype(cp.float32), + time_start = time.time() + sdf_surf_grid, _ = signed_distance_field( + normed_stl_vertices_cp, + mesh_indices_flattened, + surf_grid_normed, use_sign_winding_number=True, - return_cupy=False, - ).reshape(nx, ny, nz) + ) - surf_grid = np.float32(surf_grid) - sdf_surf_grid = np.float32(sdf_surf_grid) - surf_grid_max_min = np.float32(np.asarray([s_min, s_max])) + surf_grid_max_min = torch.stack([s_min, s_max]) # Get global parameters and global parameters scaling from config.yaml global_params_names = list(cfg.variables.global_parameters.keys()) @@ -492,6 +515,9 @@ def main(cfg: DictConfig): global_params_reference = np.array( global_params_reference_list, dtype=np.float32 ) + global_params_reference = torch.from_numpy(global_params_reference).to( + dist.device + ) # Define the list of global parameter values for each simulation. # Note: The user must ensure that the values provided here correspond to the @@ -507,7 +533,12 @@ def main(cfg: DictConfig): raise ValueError( f"Global parameter {key} not supported for this dataset" ) - global_params_values = np.array(global_params_values_list, dtype=np.float32) + global_params_values_list = np.array( + global_params_values_list, dtype=np.float32 + ) + global_params_values = torch.from_numpy(global_params_values_list).to( + dist.device + ) # Read VTP if model_type == "surface" or model_type == "combined": @@ -535,11 +566,27 @@ def main(cfg: DictConfig): surface_normals = ( surface_normals / np.linalg.norm(surface_normals, axis=1)[:, np.newaxis] ) + surface_coordinates = ( + torch.from_numpy(surface_coordinates).to(torch.float32).to(dist.device) + ) + surface_normals = ( + torch.from_numpy(surface_normals).to(torch.float32).to(dist.device) + ) + surface_sizes = ( + torch.from_numpy(surface_sizes).to(torch.float32).to(dist.device) + ) + surface_fields = ( + torch.from_numpy(surface_fields).to(torch.float32).to(dist.device) + ) if cfg.model.num_neighbors_surface > 1: - interp_func = KDTree(surface_coordinates) - dd, ii = interp_func.query( - surface_coordinates, k=cfg.model.num_neighbors_surface + time_start = time.time() + # print(f"file: {dirname}, surface coordinates shape: {surface_coordinates.shape}") + # try: + ii, dd = knn( + points=surface_coordinates, + queries=surface_coordinates, + k=cfg.model.num_neighbors_surface, ) surface_neighbors = surface_coordinates[ii] @@ -549,27 +596,26 @@ def main(cfg: DictConfig): surface_neighbors_normals = surface_neighbors_normals[:, 1:] surface_neighbors_sizes = surface_sizes[ii] surface_neighbors_sizes = surface_neighbors_sizes[:, 1:] + # except: + # print(f"file: {dirname}, memory error in knn") + # print("setting surface neighbors to 0") + # surface_neighbors = surface_coordinates + # surface_neighbors_normals = surface_normals + # surface_neighbors_sizes = surface_sizes + # cfg.model.num_neighbors_surface = 1 else: surface_neighbors = surface_coordinates surface_neighbors_normals = surface_normals surface_neighbors_sizes = surface_sizes - dx, dy, dz = ( - (s_max[0] - s_min[0]) / nx, - (s_max[1] - s_min[1]) / ny, - (s_max[2] - s_min[2]) / nz, - ) - - if cfg.model.positional_encoding: - pos_surface_center_of_mass = calculate_normal_positional_encoding( - surface_coordinates, center_of_mass, cell_length=[dx, dy, dz] - ) + if cfg.data.normalize_coordinates: + surface_coordinates = normalize(surface_coordinates, s_max, s_min) + surf_grid = normalize(surf_grid, s_max, s_min) + center_of_mass_normalized = normalize(center_of_mass, s_max, s_min) + surface_neighbors = normalize(surface_neighbors, s_max, s_min) else: - pos_surface_center_of_mass = surface_coordinates - center_of_mass - - surface_coordinates = normalize(surface_coordinates, s_max, s_min) - surface_neighbors = normalize(surface_neighbors, s_max, s_min) - surf_grid = normalize(surf_grid, s_max, s_min) + center_of_mass_normalized = center_of_mass + pos_surface_center_of_mass = surface_coordinates - center_of_mass_normalized else: surface_coordinates = None @@ -591,65 +637,60 @@ def main(cfg: DictConfig): polydata_vol, volume_variable_names ) volume_fields = np.concatenate(volume_fields, axis=-1) + volume_coordinates = ( + torch.from_numpy(volume_coordinates).to(torch.float32).to(dist.device) + ) + volume_fields = ( + torch.from_numpy(volume_fields).to(torch.float32).to(dist.device) + ) - bounding_box_dims = [] - bounding_box_dims.append(np.asarray(cfg.data.bounding_box.max)) - bounding_box_dims.append(np.asarray(cfg.data.bounding_box.min)) - - v_max = np.amax(volume_coordinates, 0) - v_min = np.amin(volume_coordinates, 0) - if bounding_box_dims is None: - c_max = s_max + (s_max - s_min) / 2 - c_min = s_min - (s_max - s_min) / 2 - c_min[2] = s_min[2] - else: - c_max = np.float32(bounding_box_dims[0]) - c_min = np.float32(bounding_box_dims[1]) - - dx, dy, dz = ( - (c_max[0] - c_min[0]) / nx, - (c_max[1] - c_min[1]) / ny, - (c_max[2] - c_min[2]) / nz, + c_max = ( + torch.from_numpy(np.asarray(cfg.data.bounding_box.max)) + .to(torch.float32) + .to(dist.device) ) + c_min = ( + torch.from_numpy(np.asarray(cfg.data.bounding_box.min)) + .to(torch.float32) + .to(dist.device) + ) + # Generate a grid of specified resolution to map the bounding box # The grid is used for capturing structured geometry features and SDF representation of geometry - grid = create_grid(c_max, c_min, [nx, ny, nz]) - grid_reshaped = grid.reshape(nx * ny * nz, 3) + grid = create_grid( + c_max, c_min, torch.from_numpy(np.asarray([nx, ny, nz])).to(dist.device) + ) + + if cfg.data.normalize_coordinates: + volume_coordinates = normalize(volume_coordinates, c_max, c_min) + grid = normalize(grid, c_max, c_min) + center_of_mass_normalized = normalize(center_of_mass, c_max, c_min) + normed_stl_vertices_vol = normalize(stl_vertices, c_max, c_min) + else: + center_of_mass_normalized = center_of_mass # SDF calculation on the grid using WARP - sdf_grid = signed_distance_field( - cp.asarray(stl_vertices).astype(cp.float32), - cp.asarray(mesh_indices_flattened).astype(cp.int32), - cp.asarray(grid_reshaped).astype(cp.float32), + time_start = time.time() + sdf_grid, _ = signed_distance_field( + normed_stl_vertices_vol, + mesh_indices_flattened, + grid, use_sign_winding_number=True, - return_cupy=False, - ).reshape(nx, ny, nz) + ) # SDF calculation + time_start = time.time() sdf_nodes, sdf_node_closest_point = signed_distance_field( - cp.asarray(stl_vertices).astype(cp.float32), - cp.asarray(mesh_indices_flattened).astype(cp.int32), - cp.asarray(volume_coordinates).astype(cp.float32), - include_hit_points=True, + normed_stl_vertices_vol, + mesh_indices_flattened, + volume_coordinates, use_sign_winding_number=True, - return_cupy=False, ) sdf_nodes = sdf_nodes.reshape(-1, 1) + vol_grid_max_min = torch.stack([c_min, c_max]) - if cfg.model.positional_encoding: - pos_volume_closest = calculate_normal_positional_encoding( - volume_coordinates, sdf_node_closest_point, cell_length=[dx, dy, dz] - ) - pos_volume_center_of_mass = calculate_normal_positional_encoding( - volume_coordinates, center_of_mass, cell_length=[dx, dy, dz] - ) - else: - pos_volume_closest = volume_coordinates - sdf_node_closest_point - pos_volume_center_of_mass = volume_coordinates - center_of_mass - - volume_coordinates = normalize(volume_coordinates, c_max, c_min) - grid = normalize(grid, c_max, c_min) - vol_grid_max_min = np.asarray([c_min, c_max]) + pos_volume_closest = volume_coordinates - sdf_node_closest_point + pos_volume_center_of_mass = volume_coordinates - center_of_mass_normalized else: volume_coordinates = None @@ -659,7 +700,8 @@ def main(cfg: DictConfig): # print(f"Processed sdf and normalized") - geom_centers = np.float32(stl_vertices) + geom_centers = stl_vertices + # print(f"Geom centers max: {np.amax(geom_centers, axis=0)}, min: {np.amin(geom_centers, axis=0)}") if model_type == "combined": # Add the parameters to the dictionary @@ -684,35 +726,27 @@ def main(cfg: DictConfig): "surface_fields": surface_fields, "volume_min_max": vol_grid_max_min, "surface_min_max": surf_grid_max_min, - "length_scale": np.array(length_scale, dtype=np.float32), - "global_params_values": np.expand_dims( - np.array(global_params_values, dtype=np.float32), -1 - ), - "global_params_reference": np.expand_dims( - np.array(global_params_reference, dtype=np.float32), -1 - ), + "length_scale": length_scale, + "global_params_values": torch.unsqueeze(global_params_values, -1), + "global_params_reference": torch.unsqueeze(global_params_reference, -1), } elif model_type == "surface": data_dict = { - "pos_surface_center_of_mass": np.float32(pos_surface_center_of_mass), - "geometry_coordinates": np.float32(geom_centers), - "surf_grid": np.float32(surf_grid), - "sdf_surf_grid": np.float32(sdf_surf_grid), - "surface_mesh_centers": np.float32(surface_coordinates), - "surface_mesh_neighbors": np.float32(surface_neighbors), - "surface_normals": np.float32(surface_normals), - "surface_neighbors_normals": np.float32(surface_neighbors_normals), - "surface_areas": np.float32(surface_sizes), - "surface_neighbors_areas": np.float32(surface_neighbors_sizes), - "surface_fields": np.float32(surface_fields), - "surface_min_max": np.float32(surf_grid_max_min), - "length_scale": np.array(length_scale, dtype=np.float32), - "global_params_values": np.expand_dims( - np.array(global_params_values, dtype=np.float32), -1 - ), - "global_params_reference": np.expand_dims( - np.array(global_params_reference, dtype=np.float32), -1 - ), + "pos_surface_center_of_mass": pos_surface_center_of_mass, + "geometry_coordinates": geom_centers, + "surf_grid": surf_grid, + "sdf_surf_grid": sdf_surf_grid, + "surface_mesh_centers": surface_coordinates, + "surface_mesh_neighbors": surface_neighbors, + "surface_normals": surface_normals, + "surface_neighbors_normals": surface_neighbors_normals, + "surface_areas": surface_sizes, + "surface_neighbors_areas": surface_neighbors_sizes, + "surface_fields": surface_fields, + "surface_min_max": surf_grid_max_min, + "length_scale": length_scale, + "global_params_values": torch.unsqueeze(global_params_values, -1), + "global_params_reference": torch.unsqueeze(global_params_reference, -1), } elif model_type == "volume": data_dict = { @@ -728,66 +762,65 @@ def main(cfg: DictConfig): "volume_mesh_centers": volume_coordinates, "volume_min_max": vol_grid_max_min, "surface_min_max": surf_grid_max_min, - "length_scale": np.array(length_scale, dtype=np.float32), - "global_params_values": np.expand_dims( - np.array(global_params_values, dtype=np.float32), -1 - ), - "global_params_reference": np.expand_dims( - np.array(global_params_reference, dtype=np.float32), -1 - ), + "length_scale": length_scale, + "global_params_values": torch.unsqueeze(global_params_values, -1), + "global_params_reference": torch.unsqueeze(global_params_reference, -1), } - data_dict = { - key: torch.from_numpy(np.expand_dims(np.float32(value), 0)) - for key, value in data_dict.items() - } + data_dict = {key: torch.unsqueeze(value, 0) for key, value in data_dict.items()} prediction_vol, prediction_surf = test_step( data_dict, model, dist.device, cfg, vol_factors, surf_factors ) if prediction_surf is not None: - surface_sizes = np.expand_dims(surface_sizes, -1) + surface_sizes = torch.unsqueeze(surface_sizes, -1) - pres_x_pred = np.sum( + pres_x_pred = torch.sum( prediction_surf[0, :, 0] * surface_normals[:, 0] * surface_sizes[:, 0] ) - shear_x_pred = np.sum(prediction_surf[0, :, 1] * surface_sizes[:, 0]) + shear_x_pred = torch.sum(prediction_surf[0, :, 1] * surface_sizes[:, 0]) - pres_x_true = np.sum( + pres_x_true = torch.sum( surface_fields[:, 0] * surface_normals[:, 0] * surface_sizes[:, 0] ) - shear_x_true = np.sum(surface_fields[:, 1] * surface_sizes[:, 0]) + shear_x_true = torch.sum(surface_fields[:, 1] * surface_sizes[:, 0]) - force_x_pred = np.sum( + force_x_pred = torch.sum( prediction_surf[0, :, 0] * surface_normals[:, 0] * surface_sizes[:, 0] - prediction_surf[0, :, 1] * surface_sizes[:, 0] ) - force_x_true = np.sum( + force_x_true = torch.sum( surface_fields[:, 0] * surface_normals[:, 0] * surface_sizes[:, 0] - surface_fields[:, 1] * surface_sizes[:, 0] ) - force_y_pred = np.sum( + force_y_pred = torch.sum( prediction_surf[0, :, 0] * surface_normals[:, 1] * surface_sizes[:, 0] - prediction_surf[0, :, 2] * surface_sizes[:, 0] ) - force_y_true = np.sum( + force_y_true = torch.sum( surface_fields[:, 0] * surface_normals[:, 1] * surface_sizes[:, 0] - surface_fields[:, 2] * surface_sizes[:, 0] ) - force_z_pred = np.sum( + force_z_pred = torch.sum( prediction_surf[0, :, 0] * surface_normals[:, 2] * surface_sizes[:, 0] - prediction_surf[0, :, 3] * surface_sizes[:, 0] ) - force_z_true = np.sum( + force_z_true = torch.sum( surface_fields[:, 0] * surface_normals[:, 2] * surface_sizes[:, 0] - surface_fields[:, 3] * surface_sizes[:, 0] ) - print("Drag=", dirname, force_x_pred, force_x_true) - print("Lift=", dirname, force_z_pred, force_z_true) - print("Side=", dirname, force_y_pred, force_y_true) + print( + "Drag=", dirname, force_x_pred.cpu().numpy(), force_x_true.cpu().numpy() + ) + print( + "Lift=", dirname, force_z_pred.cpu().numpy(), force_z_true.cpu().numpy() + ) + print( + "Side=", dirname, force_y_pred.cpu().numpy(), force_y_true.cpu().numpy() + ) aero_forces_all.append( [ dirname, @@ -800,14 +833,18 @@ def main(cfg: DictConfig): ] ) - l2_gt = np.mean(np.square(surface_fields), (0)) - l2_error = np.mean(np.square(prediction_surf[0] - surface_fields), (0)) - l2_surface_all.append(np.sqrt(l2_error / l2_gt)) + l2_gt = torch.mean(torch.square(surface_fields), (0)) + l2_error = torch.mean( + torch.square(prediction_surf[0] - surface_fields), (0) + ) + l2_surface_all.append( + np.sqrt(l2_error.cpu().numpy()) / np.sqrt(l2_gt.cpu().numpy()) + ) print( "Surface L-2 norm:", dirname, - np.sqrt(l2_error) / np.sqrt(l2_gt), + np.sqrt(l2_error.cpu().numpy()) / np.sqrt(l2_gt.cpu().numpy()), ) if prediction_vol is not None: @@ -816,7 +853,7 @@ def main(cfg: DictConfig): c_min = vol_grid_max_min[0] c_max = vol_grid_max_min[1] volume_coordinates = unnormalize(volume_coordinates, c_max, c_min) - ids_in_bbox = np.where( + ids_in_bbox = torch.where( (volume_coordinates[:, 0] < c_min[0]) | (volume_coordinates[:, 0] > c_max[0]) | (volume_coordinates[:, 1] < c_min[1]) @@ -826,36 +863,49 @@ def main(cfg: DictConfig): ) target_vol[ids_in_bbox] = 0.0 prediction_vol[ids_in_bbox] = 0.0 - l2_gt = np.mean(np.square(target_vol), (0)) - l2_error = np.mean(np.square(prediction_vol - target_vol), (0)) + l2_gt = torch.mean(torch.square(target_vol), (0)) + l2_error = torch.mean(torch.square(prediction_vol - target_vol), (0)) print( "Volume L-2 norm:", dirname, - np.sqrt(l2_error) / np.sqrt(l2_gt), + np.sqrt(l2_error.cpu().numpy()) / np.sqrt(l2_gt.cpu().numpy()), + ) + l2_volume_all.append( + np.sqrt(l2_error.cpu().numpy()) / np.sqrt(l2_gt.cpu().numpy()) ) - l2_volume_all.append(np.sqrt(l2_error) / np.sqrt(l2_gt)) + # import pdb; pdb.set_trace() if prediction_surf is not None: - surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 0:1]) + surfParam_vtk = numpy_support.numpy_to_vtk( + prediction_surf[0, :, 0:1].cpu().numpy() + ) surfParam_vtk.SetName(f"{surface_variable_names[0]}Pred") celldata_all.GetCellData().AddArray(surfParam_vtk) - surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 1:]) + surfParam_vtk = numpy_support.numpy_to_vtk( + prediction_surf[0, :, 1:].cpu().numpy() + ) surfParam_vtk.SetName(f"{surface_variable_names[1]}Pred") celldata_all.GetCellData().AddArray(surfParam_vtk) write_to_vtp(celldata_all, vtp_pred_save_path) if prediction_vol is not None: - volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 0:3]) + volParam_vtk = numpy_support.numpy_to_vtk( + prediction_vol[:, 0:3].cpu().numpy() + ) volParam_vtk.SetName(f"{volume_variable_names[0]}Pred") polydata_vol.GetPointData().AddArray(volParam_vtk) - volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 3:4]) + volParam_vtk = numpy_support.numpy_to_vtk( + prediction_vol[:, 3:4].cpu().numpy() + ) volParam_vtk.SetName(f"{volume_variable_names[1]}Pred") polydata_vol.GetPointData().AddArray(volParam_vtk) - volParam_vtk = numpy_support.numpy_to_vtk(prediction_vol[:, 4:5]) + volParam_vtk = numpy_support.numpy_to_vtk( + prediction_vol[:, 4:5].cpu().numpy() + ) volParam_vtk.SetName(f"{volume_variable_names[2]}Pred") polydata_vol.GetPointData().AddArray(volParam_vtk) diff --git a/examples/cfd/external_aerodynamics/domino/src/train.py b/examples/cfd/external_aerodynamics/domino/src/train.py index 96e30b58e7..55731696d2 100644 --- a/examples/cfd/external_aerodynamics/domino/src/train.py +++ b/examples/cfd/external_aerodynamics/domino/src/train.py @@ -30,18 +30,21 @@ import time import os import re -import torch -import torchinfo - from typing import Literal, Any +from tabulate import tabulate import apex import numpy as np import hydra from hydra.utils import to_absolute_path from omegaconf import DictConfig, OmegaConf + +# This will set up the cupy-ecosystem and pytorch to share memory pools +from physicsnemo.utils.memory import unified_gpu_memory + +import torchinfo import torch.distributed as dist -from torch.cuda.amp import GradScaler, autocast +from torch.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -56,16 +59,18 @@ from physicsnemo.datapipes.cae.domino_datapipe import ( DoMINODataPipe, - compute_scaling_factors, create_domino_dataset, ) from physicsnemo.models.domino.model import DoMINO from physicsnemo.utils.domino.utils import * +from utils import ScalingFactors, get_keys_to_read, coordinate_distributed_environment + # This is included for GPU memory tracking: from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo import time + # Initialize NVML nvmlInit() @@ -73,530 +78,8 @@ from physicsnemo.utils.profiling import profile, Profiler -# Profiler().enable("line_profiler") -# Profiler().initialize() - - -def compute_physics_loss( - output: torch.Tensor, - target: torch.Tensor, - mask: torch.Tensor, - loss_type: Literal["mse", "rmse"], - dims: tuple[int, ...] | None, - first_deriv: torch.nn.Module, - eqn: Any, - bounding_box: torch.Tensor, - vol_factors: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute physics-based loss terms for Navier-Stokes equations. - - Args: - output: Model output containing (output, coords_neighbors, output_neighbors, neighbors_list) - target: Ground truth values - mask: Mask for valid values - loss_type: Type of loss to calculate ("mse" or "rmse") - dims: Dimensions for loss calculation - first_deriv: First derivative calculator - eqn: Equations - bounding_box: Bounding box for normalization - vol_factors: Volume factors for normalization - - Returns: - Tuple of (data_loss, continuity_loss, momentum_x_loss, momentum_y_loss, momentum_z_loss) - """ - # Physics loss enabled - output, coords_neighbors, output_neighbors, neighbors_list = output - batch_size = output.shape[1] - fields, num_neighbors = output_neighbors.shape[3], output_neighbors.shape[2] - coords_total = coords_neighbors[0, :] - output_total = output_neighbors[0, :] - output_total_unnormalized = unnormalize( - output_total, vol_factors[0], vol_factors[1] - ) - coords_total_unnormalized = unnormalize( - coords_total, bounding_box[0], bounding_box[1] - ) - - # compute first order gradients on all the nodes from the neighbors_list - grad_list = {} - for parent_id, neighbor_ids in neighbors_list.items(): - neighbor_ids_tensor = torch.tensor(neighbor_ids).to( - output_total_unnormalized.device - ) - du = ( - output_total_unnormalized[:, [parent_id]] - - output_total_unnormalized[:, neighbor_ids_tensor] - ) - dv = ( - coords_total_unnormalized[:, [parent_id]] - - coords_total_unnormalized[:, neighbor_ids_tensor] - ) - grads = first_deriv.forward( - coords=None, connectivity_tensor=None, y=None, du=du, dv=dv - ) - grad = torch.cat(grads, dim=1) - grad_list[parent_id] = grad - - # compute second order gradients on only the center node - neighbor_ids_tensor = torch.tensor(neighbors_list[0]).to( - output_total_unnormalized.device - ) - grad_neighbors_center = torch.stack([v for v in grad_list.values()], dim=1) - grad_neighbors_center = grad_neighbors_center.reshape( - batch_size, len(neighbors_list[0]) + 1, -1 - ) - - du = grad_neighbors_center[:, [0]] - grad_neighbors_center[:, neighbor_ids_tensor] - dv = ( - coords_total_unnormalized[:, [0]] - - coords_total_unnormalized[:, neighbor_ids_tensor] - ) - - # second order gradients - ggrads_center = first_deriv.forward( - coords=None, connectivity_tensor=None, y=None, du=du, dv=dv - ) - ggrad_center = torch.cat(ggrads_center, dim=1) - grad_neighbors_center = grad_neighbors_center.reshape( - batch_size, len(neighbors_list[0]) + 1, 3, -1 - ) - - # Get the outputs on the original nodes - fields_center_unnormalized = output_total_unnormalized[:, 0, :] - grad_center = grad_neighbors_center[:, 0, :, :] - grad_grad_uvw_center = ggrad_center[:, :, :9] - - nu = 1.507 * 1e-5 - - dict_mapping = { - "u": fields_center_unnormalized[:, [0]], - "v": fields_center_unnormalized[:, [1]], - "w": fields_center_unnormalized[:, [2]], - "p": fields_center_unnormalized[:, [3]], - "nu": nu + fields_center_unnormalized[:, [4]], - "u__x": grad_center[:, 0, [0]], - "u__y": grad_center[:, 1, [0]], - "u__z": grad_center[:, 2, [0]], - "v__x": grad_center[:, 0, [1]], - "v__y": grad_center[:, 1, [1]], - "v__z": grad_center[:, 2, [1]], - "w__x": grad_center[:, 0, [2]], - "w__y": grad_center[:, 1, [2]], - "w__z": grad_center[:, 2, [2]], - "p__x": grad_center[:, 0, [3]], - "p__y": grad_center[:, 1, [3]], - "p__z": grad_center[:, 2, [3]], - "nu__x": grad_center[:, 0, [4]], - "nu__y": grad_center[:, 1, [4]], - "nu__z": grad_center[:, 2, [4]], - "u__x__x": grad_grad_uvw_center[:, 0, [0]], - "u__x__y": grad_grad_uvw_center[:, 1, [0]], - "u__x__z": grad_grad_uvw_center[:, 2, [0]], - "u__y__x": grad_grad_uvw_center[:, 1, [0]], # same as __x__y - "u__y__y": grad_grad_uvw_center[:, 1, [1]], - "u__y__z": grad_grad_uvw_center[:, 2, [1]], - "u__z__x": grad_grad_uvw_center[:, 2, [0]], # same as __x__z - "u__z__y": grad_grad_uvw_center[:, 2, [1]], # same as __y__z - "u__z__z": grad_grad_uvw_center[:, 2, [2]], - "v__x__x": grad_grad_uvw_center[:, 0, [3]], - "v__x__y": grad_grad_uvw_center[:, 1, [3]], - "v__x__z": grad_grad_uvw_center[:, 2, [3]], - "v__y__x": grad_grad_uvw_center[:, 1, [3]], # same as __x__y - "v__y__y": grad_grad_uvw_center[:, 1, [4]], - "v__y__z": grad_grad_uvw_center[:, 2, [4]], - "v__z__x": grad_grad_uvw_center[:, 2, [3]], # same as __x__z - "v__z__y": grad_grad_uvw_center[:, 2, [4]], # same as __y__z - "v__z__z": grad_grad_uvw_center[:, 2, [5]], - "w__x__x": grad_grad_uvw_center[:, 0, [6]], - "w__x__y": grad_grad_uvw_center[:, 1, [6]], - "w__x__z": grad_grad_uvw_center[:, 2, [6]], - "w__y__x": grad_grad_uvw_center[:, 1, [6]], # same as __x__y - "w__y__y": grad_grad_uvw_center[:, 1, [7]], - "w__y__z": grad_grad_uvw_center[:, 2, [7]], - "w__z__x": grad_grad_uvw_center[:, 2, [6]], # same as __x__z - "w__z__y": grad_grad_uvw_center[:, 2, [7]], # same as __y__z - "w__z__z": grad_grad_uvw_center[:, 2, [8]], - } - continuity = eqn["continuity"].evaluate(dict_mapping)["continuity"] - momentum_x = eqn["momentum_x"].evaluate(dict_mapping)["momentum_x"] - momentum_y = eqn["momentum_y"].evaluate(dict_mapping)["momentum_y"] - momentum_z = eqn["momentum_z"].evaluate(dict_mapping)["momentum_z"] - - # Compute the weights for the equation residuals - weight_continuity = torch.sigmoid(0.5 * (torch.abs(continuity) - 10)) - weight_momentum_x = torch.sigmoid(0.5 * (torch.abs(momentum_x) - 10)) - weight_momentum_y = torch.sigmoid(0.5 * (torch.abs(momentum_y) - 10)) - weight_momentum_z = torch.sigmoid(0.5 * (torch.abs(momentum_z) - 10)) - - weighted_continuity = weight_continuity * torch.abs(continuity) - weighted_momentum_x = weight_momentum_x * torch.abs(momentum_x) - weighted_momentum_y = weight_momentum_y * torch.abs(momentum_y) - weighted_momentum_z = weight_momentum_z * torch.abs(momentum_z) - - # Compute data loss - num = torch.sum(mask * (output - target) ** 2.0, dims) - if loss_type == "rmse": - denom = torch.sum(mask * target**2.0, dims) - else: - denom = torch.sum(mask) - - del coords_total, output_total - torch.cuda.empty_cache() - - return ( - torch.mean(num / denom), - torch.mean(torch.abs(weighted_continuity)), - torch.mean(torch.abs(weighted_momentum_x)), - torch.mean(torch.abs(weighted_momentum_y)), - torch.mean(torch.abs(weighted_momentum_z)), - ) - - -def loss_fn( - output: torch.Tensor, - target: torch.Tensor, - loss_type: Literal["mse", "rmse"], - padded_value: float = -10, -) -> torch.Tensor: - """Calculate mean squared error or root mean squared error with masking for padded values. - - Args: - output: Predicted values from the model - target: Ground truth values - loss_type: Type of loss to calculate ("mse" or "rmse") - padded_value: Value used for padding in the tensor - - Returns: - Calculated loss as a scalar tensor - """ - mask = abs(target - padded_value) > 1e-3 - - if loss_type == "rmse": - dims = (0, 1) - else: - dims = None - - num = torch.sum(mask * (output - target) ** 2.0, dims) - if loss_type == "rmse": - denom = torch.sum(mask * target**2.0, dims) - loss = torch.mean(torch.sqrt(num / denom)) - elif loss_type == "mse": - denom = torch.sum(mask) - loss = torch.mean(num / denom) - else: - raise ValueError(f"Invalid loss type: {loss_type}") - return loss - - -def loss_fn_with_physics( - output: torch.Tensor, - target: torch.Tensor, - loss_type: Literal["mse", "rmse"], - padded_value: float = -10, - first_deriv: torch.nn.Module = None, - eqn: Any = None, - bounding_box: torch.Tensor = None, - vol_factors: torch.Tensor = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Calculate loss with physics-based terms for appropriate equations. - - Args: - output: Predicted values from the model (with neighbor data when physics enabled) - target: Ground truth values - loss_type: Type of loss to calculate ("mse" or "rmse") - padded_value: Value used for padding in the tensor - first_deriv: First derivative calculator - eqn: Equations - bounding_box: Bounding box for normalization - vol_factors: Volume factors for normalization - - Returns: - Tuple of (data_loss, continuity_loss, momentum_x_loss, momentum_y_loss, momentum_z_loss) - """ - mask = abs(target - padded_value) > 1e-3 - - if loss_type == "rmse": - dims = (0, 1) - else: - dims = None - - # Call the physics loss computation function - return compute_physics_loss( - output=output, - target=target, - mask=mask, - loss_type=loss_type, - dims=dims, - first_deriv=first_deriv, - eqn=eqn, - bounding_box=bounding_box, - vol_factors=vol_factors, - ) - - -def loss_fn_surface( - output: torch.Tensor, target: torch.Tensor, loss_type: Literal["mse", "rmse"] -) -> torch.Tensor: - """Calculate loss for surface data by handling scalar and vector components separately. - - Args: - output: Predicted surface values from the model - target: Ground truth surface values - loss_type: Type of loss to calculate ("mse" or "rmse") - - Returns: - Combined scalar and vector loss as a scalar tensor - """ - # Separate the scalar and vector components: - output_scalar, output_vector = torch.split(output, [1, 3], dim=2) - target_scalar, target_vector = torch.split(target, [1, 3], dim=2) - - numerator = torch.mean((output_scalar - target_scalar) ** 2.0) - vector_diff_sq = torch.mean((target_vector - output_vector) ** 2.0, (0, 1)) - if loss_type == "mse": - masked_loss_pres = numerator - masked_loss_ws = torch.sum(vector_diff_sq) - else: - denom = torch.mean((target_scalar) ** 2.0) - masked_loss_pres = numerator / denom - - # Compute the mean diff**2 of the vector component, leave the last dimension: - masked_loss_ws_num = vector_diff_sq - masked_loss_ws_denom = torch.mean((target_vector) ** 2.0, (0, 1)) - masked_loss_ws = torch.sum(masked_loss_ws_num / masked_loss_ws_denom) - - loss = masked_loss_pres + masked_loss_ws - - return loss / 4.0 - - -def loss_fn_area( - output: torch.Tensor, - target: torch.Tensor, - normals: torch.Tensor, - area: torch.Tensor, - area_scaling_factor: float, - loss_type: Literal["mse", "rmse"], -) -> torch.Tensor: - """Calculate area-weighted loss for surface data considering normal vectors. - - Args: - output: Predicted surface values from the model - target: Ground truth surface values - normals: Normal vectors for the surface - area: Area values for surface elements - area_scaling_factor: Scaling factor for area weighting - loss_type: Type of loss to calculate ("mse" or "rmse") - - Returns: - Area-weighted loss as a scalar tensor - """ - area = area * area_scaling_factor - area_scale_factor = area - - # Separate the scalar and vector components. - target_scalar, target_vector = torch.split( - target * area_scale_factor, [1, 3], dim=2 - ) - output_scalar, output_vector = torch.split( - output * area_scale_factor, [1, 3], dim=2 - ) - - # Apply the normals to the scalar components (only [:,:,0]): - normals, _ = torch.split(normals, [1, normals.shape[-1] - 1], dim=2) - target_scalar = target_scalar * normals - output_scalar = output_scalar * normals - - # Compute the mean diff**2 of the scalar component: - masked_loss_pres = torch.mean(((output_scalar - target_scalar) ** 2.0), dim=(0, 1)) - if loss_type == "rmse": - masked_loss_pres /= torch.mean(target_scalar**2.0, dim=(0, 1)) - - # Compute the mean diff**2 of the vector component, leave the last dimension: - masked_loss_ws = torch.mean((target_vector - output_vector) ** 2.0, (0, 1)) - - if loss_type == "rmse": - masked_loss_ws /= torch.mean((target_vector) ** 2.0, (0, 1)) - - # Combine the scalar and vector components: - loss = 0.25 * (masked_loss_pres + torch.sum(masked_loss_ws)) - - return loss - - -def integral_loss_fn( - output, target, area, normals, stream_velocity=None, padded_value=-10 -): - drag_loss = drag_loss_fn( - output, target, area, normals, stream_velocity=stream_velocity, padded_value=-10 - ) - lift_loss = lift_loss_fn( - output, target, area, normals, stream_velocity=stream_velocity, padded_value=-10 - ) - return lift_loss + drag_loss - - -def lift_loss_fn(output, target, area, normals, stream_velocity=None, padded_value=-10): - vel_inlet = stream_velocity # Get this from the dataset - mask = abs(target - padded_value) > 1e-3 - - output_true = target * mask * area * (vel_inlet) ** 2.0 - output_pred = output * mask * area * (vel_inlet) ** 2.0 - - normals = torch.select(normals, 2, 2) - # output_true_0 = output_true[:, :, 0] - output_true_0 = output_true.select(2, 0) - output_pred_0 = output_pred.select(2, 0) - - pres_true = output_true_0 * normals - pres_pred = output_pred_0 * normals - - wz_true = output_true[:, :, -1] - wz_pred = output_pred[:, :, -1] - - masked_pred = torch.mean(pres_pred + wz_pred, (1)) - masked_truth = torch.mean(pres_true + wz_true, (1)) - - loss = (masked_pred - masked_truth) ** 2.0 - loss = torch.mean(loss) - return loss - - -def drag_loss_fn(output, target, area, normals, stream_velocity=None, padded_value=-10): - vel_inlet = stream_velocity # Get this from the dataset - mask = abs(target - padded_value) > 1e-3 - output_true = target * mask * area * (vel_inlet) ** 2.0 - output_pred = output * mask * area * (vel_inlet) ** 2.0 - - pres_true = output_true[:, :, 0] * normals[:, :, 0] - pres_pred = output_pred[:, :, 0] * normals[:, :, 0] - - wx_true = output_true[:, :, 1] - wx_pred = output_pred[:, :, 1] - - masked_pred = torch.mean(pres_pred + wx_pred, (1)) - masked_truth = torch.mean(pres_true + wx_true, (1)) - - loss = (masked_pred - masked_truth) ** 2.0 - loss = torch.mean(loss) - return loss - - -def compute_loss_dict( - prediction_vol: torch.Tensor, - prediction_surf: torch.Tensor, - batch_inputs: dict, - loss_fn_type: dict, - integral_scaling_factor: float, - surf_loss_scaling: float, - vol_loss_scaling: float, - first_deriv: torch.nn.Module | None = None, - eqn: Any = None, - bounding_box: torch.Tensor | None = None, - vol_factors: torch.Tensor | None = None, - add_physics_loss: bool = False, -) -> tuple[torch.Tensor, dict]: - """ - Compute the loss terms in a single function call. - - Computes: - - Volume loss if prediction_vol is not None - - Surface loss if prediction_surf is not None - - Integral loss if prediction_surf is not None - - Total loss as a weighted sum of the above - - Returns: - - Total loss as a scalar tensor - - Dictionary of loss terms (for logging, etc) - """ - nvtx.range_push("Loss Calculation") - total_loss_terms = [] - loss_dict = {} - - if prediction_vol is not None: - target_vol = batch_inputs["volume_fields"] - - if add_physics_loss: - loss_vol = loss_fn_with_physics( - prediction_vol, - target_vol, - loss_fn_type.loss_type, - padded_value=-10, - first_deriv=first_deriv, - eqn=eqn, - bounding_box=bounding_box, - vol_factors=vol_factors, - ) - loss_dict["loss_vol"] = loss_vol[0] - loss_dict["loss_continuity"] = loss_vol[1] - loss_dict["loss_momentum_x"] = loss_vol[2] - loss_dict["loss_momentum_y"] = loss_vol[3] - loss_dict["loss_momentum_z"] = loss_vol[4] - total_loss_terms.append(loss_vol[0]) - total_loss_terms.append(loss_vol[1]) - total_loss_terms.append(loss_vol[2]) - total_loss_terms.append(loss_vol[3]) - total_loss_terms.append(loss_vol[4]) - else: - loss_vol = loss_fn( - prediction_vol, - target_vol, - loss_fn_type.loss_type, - padded_value=-10, - ) - loss_dict["loss_vol"] = loss_vol - total_loss_terms.append(loss_vol) - - if prediction_surf is not None: - target_surf = batch_inputs["surface_fields"] - surface_areas = batch_inputs["surface_areas"] - surface_areas = torch.unsqueeze(surface_areas, -1) - surface_normals = batch_inputs["surface_normals"] - - # Needs to be taken from the dataset - stream_velocity = batch_inputs["global_params_values"][:, 0, :] - - loss_surf = loss_fn_surface( - prediction_surf, - target_surf, - loss_fn_type.loss_type, - ) - - loss_surf_area = loss_fn_area( - prediction_surf, - target_surf, - surface_normals, - surface_areas, - area_scaling_factor=loss_fn_type.area_weighing_factor, - loss_type=loss_fn_type.loss_type, - ) - - if loss_fn_type.loss_type == "mse": - loss_surf = loss_surf * surf_loss_scaling - loss_surf_area = loss_surf_area * surf_loss_scaling - - total_loss_terms.append(loss_surf) - loss_dict["loss_surf"] = loss_surf - total_loss_terms.append(loss_surf_area) - loss_dict["loss_surf_area"] = loss_surf_area - loss_integral = ( - integral_loss_fn( - prediction_surf, - target_surf, - surface_areas, - surface_normals, - stream_velocity, - padded_value=-10, - ) - ) * integral_scaling_factor - loss_dict["loss_integral"] = loss_integral - total_loss_terms.append(loss_integral) - - total_loss = sum(total_loss_terms) - loss_dict["total_loss"] = total_loss - nvtx.range_pop() - - return total_loss, loss_dict +from loss import compute_loss_dict +from utils import get_num_vars, load_scaling_factors, compute_l2, all_reduce_dict def validation_step( @@ -604,6 +87,8 @@ def validation_step( model, device, logger, + tb_writer, + epoch_index, use_sdf_basis=False, use_surface_normals=False, integral_scaling_factor=1.0, @@ -615,13 +100,17 @@ def validation_step( bounding_box: torch.Tensor | None = None, vol_factors: torch.Tensor | None = None, add_physics_loss=False, + autocast_enabled=None, ): + dm = DistributedManager() running_vloss = 0.0 with torch.no_grad(): + metrics = None + for i_batch, sample_batched in enumerate(dataloader): sampled_batched = dict_to_device(sample_batched, device) - with autocast(enabled=True): + with autocast("cuda", enabled=autocast_enabled, cache_enabled=False): if add_physics_loss: prediction_vol, prediction_surf = model( sampled_batched, return_volume_neighbors=True @@ -645,8 +134,37 @@ def validation_step( ) running_vloss += loss.item() + local_metrics = compute_l2( + prediction_surf, prediction_vol, sampled_batched, dataloader + ) + if metrics is None: + metrics = local_metrics + else: + metrics = { + key: metrics[key] + local_metrics[key] for key in metrics.keys() + } avg_vloss = running_vloss / (i_batch + 1) + metrics = {key: metrics[key] / (i_batch + 1) for key in metrics.keys()} + + metrics = all_reduce_dict(metrics, dm) + + if dm.rank == 0: + logger.info( + f" Device {device}, batch: {i_batch + 1}, VAL loss norm: {loss.detach().item():.5f}" + ) + tb_x = epoch_index + for key in metrics.keys(): + tb_writer.add_scalar(f"L2 Metrics/val/{key}", metrics[key], tb_x) + + metrics_table = tabulate( + [[k, v] for k, v in metrics.items()], + headers=["Metric", "Average Value"], + tablefmt="pretty", + ) + logger.info( + f"\nEpoch {epoch_index} VALIDATION Average Metrics:\n{metrics_table}\n" + ) return avg_vloss @@ -670,9 +188,13 @@ def train_epoch( eqn: Any = None, bounding_box: torch.Tensor | None = None, vol_factors: torch.Tensor | None = None, + surf_factors: torch.Tensor | None = None, add_physics_loss=False, + autocast_enabled=None, + grad_clip_enabled=None, + grad_max_norm=None, ): - dist = DistributedManager() + dm = DistributedManager() running_loss = 0.0 last_loss = 0.0 @@ -680,106 +202,177 @@ def train_epoch( gpu_start_info = nvmlDeviceGetMemoryInfo(gpu_handle) start_time = time.perf_counter() - for i_batch, sample_batched in enumerate(dataloader): - sampled_batched = dict_to_device(sample_batched, device) - - if add_physics_loss: - autocast_enabled = False - else: - autocast_enabled = True - with autocast(enabled=autocast_enabled): - with nvtx.range("Model Forward Pass"): - if add_physics_loss: - prediction_vol, prediction_surf = model( - sampled_batched, return_volume_neighbors=True - ) - else: - prediction_vol, prediction_surf = model(sampled_batched) + with Profiler(): + io_start_time = time.perf_counter() + metrics = None + for i_batch, sampled_batched in enumerate(dataloader): + io_end_time = time.perf_counter() + if add_physics_loss: + autocast_enabled = False + + with autocast("cuda", enabled=autocast_enabled, cache_enabled=False): + with nvtx.range("Model Forward Pass"): + if add_physics_loss: + prediction_vol, prediction_surf = model( + sampled_batched, return_volume_neighbors=True + ) + else: + prediction_vol, prediction_surf = model(sampled_batched) - loss, loss_dict = compute_loss_dict( - prediction_vol, - prediction_surf, - sampled_batched, - loss_fn_type, - integral_scaling_factor, - surf_loss_scaling, - vol_loss_scaling, - first_deriv, - eqn, - bounding_box, - vol_factors, - add_physics_loss, - ) - - loss = loss / loss_interval - scaler.scale(loss).backward() + loss, loss_dict = compute_loss_dict( + prediction_vol, + prediction_surf, + sampled_batched, + loss_fn_type, + integral_scaling_factor, + surf_loss_scaling, + vol_loss_scaling, + first_deriv, + eqn, + bounding_box, + vol_factors, + add_physics_loss, + ) - if ((i_batch + 1) % loss_interval == 0) or (i_batch + 1 == len(dataloader)): - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() + # Compute metrics: + if isinstance(prediction_vol, tuple): + # This is if return_neighbors is on for volume: + prediction_vol = prediction_vol[0] - # Gather data and report - running_loss += loss.item() - elapsed_time = time.perf_counter() - start_time - start_time = time.perf_counter() - gpu_end_info = nvmlDeviceGetMemoryInfo(gpu_handle) - gpu_memory_used = gpu_end_info.used / (1024**3) - gpu_memory_delta = (gpu_end_info.used - gpu_start_info.used) / (1024**3) - - logging_string = f"Device {device}, batch processed: {i_batch + 1}\n" - # Format the loss dict into a string: - loss_string = ( - " " - + "\t".join([f"{key.replace('loss_', ''):<10}" for key in loss_dict.keys()]) - + "\n" - ) - loss_string += ( - " " + f"\t".join([f"{l.item():<10.3e}" for l in loss_dict.values()]) + "\n" - ) + local_metrics = compute_l2( + prediction_surf, prediction_vol, sampled_batched, dataloader + ) + if metrics is None: + metrics = local_metrics + else: + # Sum the running total: + metrics = { + key: metrics[key] + local_metrics[key] for key in metrics.keys() + } + + loss = loss / loss_interval + scaler.scale(loss).backward() + + if ((i_batch + 1) % loss_interval == 0) or (i_batch + 1 == len(dataloader)): + if grad_clip_enabled: + # Unscales the gradients of optimizer's assigned params in-place. + scaler.unscale_(optimizer) + + # Since the gradients of optimizer's assigned params are unscaled, clips as usual. + torch.nn.utils.clip_grad_norm_(model.parameters(), grad_max_norm) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + # Gather data and report + running_loss += loss.detach().item() + elapsed_time = time.perf_counter() - start_time + io_time = io_end_time - io_start_time + start_time = time.perf_counter() + gpu_end_info = nvmlDeviceGetMemoryInfo(gpu_handle) + gpu_memory_used = gpu_end_info.used / (1024**3) + gpu_memory_delta = (gpu_end_info.used - gpu_start_info.used) / (1024**3) + + logging_string = f"Device {device}, batch processed: {i_batch + 1}\n" + # Format the loss dict into a string: + loss_string = ( + " " + + "\t".join( + [f"{key.replace('loss_', ''):<10}" for key in loss_dict.keys()] + ) + + "\n" + ) + loss_string += ( + " " + + f"\t".join( + [f"{l.detach().item():<10.3e}" for l in loss_dict.values()] + ) + + "\n" + ) - logging_string += loss_string - logging_string += f" GPU memory used: {gpu_memory_used:.3f} Gb\n" - logging_string += f" GPU memory delta: {gpu_memory_delta:.3f} Gb\n" - logging_string += f" Time taken: {elapsed_time:.2f} seconds\n" - logger.info(logging_string) - gpu_start_info = nvmlDeviceGetMemoryInfo(gpu_handle) + logging_string += loss_string + logging_string += f" GPU memory used: {gpu_memory_used:.3f} Gb (delta: {gpu_memory_delta:.3f})\n" + logging_string += f" Timings: (IO: {io_time:.2f}, Model: {elapsed_time - io_time:.2f}, Total: {elapsed_time:.2f})s\n" + logger.info(logging_string) + gpu_start_info = nvmlDeviceGetMemoryInfo(gpu_handle) + io_start_time = time.perf_counter() last_loss = running_loss / (i_batch + 1) # loss per batch - if dist.rank == 0: + # Normalize metrics: + metrics = {key: metrics[key] / (i_batch + 1) for key in metrics.keys()} + # reduce metrics across batch: + metrics = all_reduce_dict(metrics, dm) + if dm.rank == 0: logger.info( - f" Device {device}, batch: {i_batch + 1}, loss norm: {loss.item():.5f}" + f" Device {device}, batch: {i_batch + 1}, loss norm: {loss.detach().item():.5f}" ) tb_x = epoch_index * len(dataloader) + i_batch + 1 tb_writer.add_scalar("Loss/train", last_loss, tb_x) + for key in metrics.keys(): + tb_writer.add_scalar(f"L2 Metrics/train/{key}", metrics[key], epoch_index) + + metrics_table = tabulate( + [[k, v] for k, v in metrics.items()], + headers=["Metric", "Average Value"], + tablefmt="pretty", + ) + logger.info(f"\nEpoch {epoch_index} Average Metrics:\n{metrics_table}\n") return last_loss @hydra.main(version_base="1.3", config_path="conf", config_name="config") def main(cfg: DictConfig) -> None: + ###################################################### # initialize distributed manager + ###################################################### DistributedManager.initialize() dist = DistributedManager() + # DoMINO supports domain parallel training. This function helps coordinate + # how to set that up, if needed. + domain_mesh, data_mesh, placements = coordinate_distributed_environment(cfg) + + ################################ # Initialize NVML + ################################ nvmlInit() - gpu_handle = nvmlDeviceGetHandleByIndex(dist.device.index) - compute_scaling_factors( - cfg=cfg, - input_path=cfg.data.input_dir, - use_cache=cfg.data_processor.use_cache, - ) - model_type = cfg.model.model_type + ###################################################### + # Initialize logger + ###################################################### logger = PythonLogger("Train") logger = RankZeroLoggingWrapper(logger, dist) logger.info(f"Config summary:\n{OmegaConf.to_yaml(cfg, sort_keys=True)}") - # Get physics imports conditionally + ###################################################### + # Get scaling factors - precompute them if this fails! + ###################################################### + vol_factors, surf_factors = load_scaling_factors(cfg) + + ###################################################### + # Configure the model + ###################################################### + model_type = cfg.model.model_type + num_vol_vars, num_surf_vars, num_global_features = get_num_vars(cfg, model_type) + + if model_type == "combined" or model_type == "surface": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + else: + surface_variable_names = [] + + if model_type == "combined" or model_type == "volume": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + else: + volume_variable_names = [] + + ###################################################### + # Configure physics loss + # Unless enabled, these are null-ops + ###################################################### add_physics_loss = getattr(cfg.train, "add_physics_loss", False) if add_physics_loss: @@ -789,56 +382,15 @@ def main(cfg: DictConfig) -> None: else: PDE = FirstDeriv = IncompressibleNavierStokes = None - num_vol_vars = 0 - volume_variable_names = [] - if model_type == "volume" or model_type == "combined": - volume_variable_names = list(cfg.variables.volume.solution.keys()) - for j in volume_variable_names: - if cfg.variables.volume.solution[j] == "vector": - num_vol_vars += 3 - else: - num_vol_vars += 1 - else: - num_vol_vars = None - - num_surf_vars = 0 - surface_variable_names = [] - if model_type == "surface" or model_type == "combined": - surface_variable_names = list(cfg.variables.surface.solution.keys()) - num_surf_vars = 0 - for j in surface_variable_names: - if cfg.variables.surface.solution[j] == "vector": - num_surf_vars += 3 - else: - num_surf_vars += 1 - else: - num_surf_vars = None - - num_global_features = 0 - global_params_names = list(cfg.variables.global_parameters.keys()) - for param in global_params_names: - if cfg.variables.global_parameters[param].type == "vector": - num_global_features += len(cfg.variables.global_parameters[param].reference) - elif cfg.variables.global_parameters[param].type == "scalar": - num_global_features += 1 - else: - raise ValueError(f"Unknown global parameter type") - - vol_save_path = os.path.join( - "outputs", cfg.project.name, "volume_scaling_factors.npy" - ) - surf_save_path = os.path.join( - "outputs", cfg.project.name, "surface_scaling_factors.npy" - ) - if os.path.exists(vol_save_path): - vol_factors = np.load(vol_save_path) - vol_factors_tensor = ( - torch.from_numpy(vol_factors).to(dist.device) if add_physics_loss else None - ) - else: - vol_factors = None - vol_factors_tensor = None + # Initialize physics components conditionally + first_deriv = None + eqn = None + if add_physics_loss: + first_deriv = FirstDeriv(dim=3, direct_input=True) + eqn = IncompressibleNavierStokes(rho=1.226, nu="nu", dim=3, time=False) + eqn = eqn.make_nodes(return_as_dict=True) + # The bounding box is used in calculating the physics loss: bounding_box = None if add_physics_loss: bounding_box = cfg.data.bounding_box @@ -846,57 +398,74 @@ def main(cfg: DictConfig) -> None: torch.from_numpy( np.stack([bounding_box["max"], bounding_box["min"]], axis=0) ) - .to(vol_factors_tensor.dtype) + .to(vol_factors.dtype) .to(dist.device) ) - if os.path.exists(surf_save_path): - surf_factors = np.load(surf_save_path) - else: - surf_factors = None + ###################################################### + # Configure the dataset + ###################################################### - train_dataset = create_domino_dataset( + # This helper function is to determine which keys to read from the data + # (and which to use default values for, if they aren't present - like + # air_density, for example) + keys_to_read, keys_to_read_if_available = get_keys_to_read( + cfg, model_type, get_ground_truth=True + ) + + # The dataset actually works in two pieces + # The core dataset just reads data from disk, and puts it on the GPU if needed. + # The data processesing pipeline will preprocess that data and prepare it for the model. + # Obviously, you need both, so this function will return the datapipeline in + # a way that can be iterated over. + # + # To properly shuffle the data, we use a distributed sampler too. + # It's configured properly for optional domain parallelism, and you have + # to make sure to call set_epoch below. + + train_dataloader = create_domino_dataset( cfg, phase="train", - volume_variable_names=volume_variable_names, - surface_variable_names=surface_variable_names, + keys_to_read=keys_to_read, + keys_to_read_if_available=keys_to_read_if_available, vol_factors=vol_factors, surf_factors=surf_factors, + device_mesh=domain_mesh, + placements=placements, + normalize_coordinates=cfg.data.normalize_coordinates, + sample_in_bbox=cfg.data.sample_in_bbox, + sampling=cfg.data.sampling, + ) + train_sampler = DistributedSampler( + train_dataloader, + num_replicas=data_mesh.size(), + rank=data_mesh.get_local_rank(), + **cfg.train.sampler, ) - val_dataset = create_domino_dataset( + + val_dataloader = create_domino_dataset( cfg, phase="val", - volume_variable_names=volume_variable_names, - surface_variable_names=surface_variable_names, + keys_to_read=keys_to_read, + keys_to_read_if_available=keys_to_read_if_available, vol_factors=vol_factors, surf_factors=surf_factors, + device_mesh=domain_mesh, + placements=placements, + normalize_coordinates=cfg.data.normalize_coordinates, + sample_in_bbox=cfg.data.sample_in_bbox, + sampling=cfg.data.sampling, ) - - train_sampler = DistributedSampler( - train_dataset, - num_replicas=dist.world_size, - rank=dist.rank, - **cfg.train.sampler, - ) - val_sampler = DistributedSampler( - val_dataset, - num_replicas=dist.world_size, - rank=dist.rank, + val_dataloader, + num_replicas=data_mesh.size(), + rank=data_mesh.get_local_rank(), **cfg.val.sampler, ) - train_dataloader = DataLoader( - train_dataset, - sampler=train_sampler, - **cfg.train.dataloader, - ) - val_dataloader = DataLoader( - val_dataset, - sampler=val_sampler, - **cfg.val.dataloader, - ) - + ###################################################### + # Configure the model + ###################################################### model = DoMINO( input_features=3, output_features_vol=num_vol_vars, @@ -904,7 +473,6 @@ def main(cfg: DictConfig) -> None: global_features=num_global_features, model_parameters=cfg.model, ).to(dist.device) - model = torch.compile(model, disable=True) # TODO make this configurable # Print model summary (structure and parmeter count). logger.info(f"Model summary:\n{torchinfo.summary(model, verbose=0, depth=2)}\n") @@ -920,23 +488,45 @@ def main(cfg: DictConfig) -> None: static_graph=True, ) - # optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=0.001) - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, milestones=[50, 100, 200, 250, 300, 350, 400, 450], gamma=0.5 - ) + ###################################################### + # Initialize optimzer and gradient scaler + ###################################################### - # Initialize physics components conditionally - first_deriv = None - eqn = None - if add_physics_loss: - first_deriv = FirstDeriv(dim=3, direct_input=True) - eqn = IncompressibleNavierStokes(rho=1.226, nu="nu", dim=3, time=False) - eqn = eqn.make_nodes(return_as_dict=True) + optimizer_class = None + if cfg.train.optimizer.name == "Adam": + optimizer_class = torch.optim.Adam + elif cfg.train.optimizer.name == "AdamW": + optimizer_class = torch.optim.AdamW + else: + raise ValueError(f"Unsupported optimizer: {cfg.train.optimizer.name}") + optimizer = optimizer_class( + model.parameters(), + lr=cfg.train.optimizer.lr, + weight_decay=cfg.train.optimizer.weight_decay, + ) + if cfg.train.lr_scheduler.name == "MultiStepLR": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=cfg.train.lr_scheduler.milestones, + gamma=cfg.train.lr_scheduler.gamma, + ) + elif cfg.train.lr_scheduler.name == "CosineAnnealingLR": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=cfg.train.lr_scheduler.T_max, + eta_min=cfg.train.lr_scheduler.eta_min, + ) + else: + raise ValueError(f"Unsupported scheduler: {cfg.train.lr_scheduler.name}") # Initialize the scaler for mixed precision scaler = GradScaler() + ###################################################### + # Initialize output tools + ###################################################### + + # Tensorboard Writer to track training. writer = SummaryWriter(os.path.join(cfg.output, "tensorboard")) epoch_number = 0 @@ -952,6 +542,9 @@ def main(cfg: DictConfig) -> None: if dist.world_size > 1: torch.distributed.barrier() + ###################################################### + # Load checkpoint if available + ###################################################### init_epoch = load_checkpoint( to_absolute_path(cfg.resume_dir), models=model, @@ -977,6 +570,10 @@ def main(cfg: DictConfig) -> None: initial_integral_factor_orig = cfg.model.integral_loss_scaling_factor + ###################################################### + # Begin Training loop over epochs + ###################################################### + for epoch in range(init_epoch, cfg.train.epochs): start_time = time.perf_counter() logger.info(f"Device {dist.device}, epoch {epoch_number}:") @@ -986,8 +583,11 @@ def main(cfg: DictConfig) -> None: "Physics loss enabled - mixed precision (autocast) will be disabled as physics loss computation is not supported with mixed precision" ) + # This controls what indices to use for each epoch. train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) + train_dataloader.dataset.set_indices(list(train_sampler)) + val_dataloader.dataset.set_indices(list(val_sampler)) initial_integral_factor = initial_integral_factor_orig @@ -1015,8 +615,11 @@ def main(cfg: DictConfig) -> None: first_deriv=first_deriv, eqn=eqn, bounding_box=bounding_box, - vol_factors=vol_factors_tensor, + vol_factors=vol_factors, add_physics_loss=add_physics_loss, + autocast_enabled=cfg.train.amp.enabled, + grad_clip_enabled=cfg.train.amp.clip_grad, + grad_max_norm=cfg.train.amp.grad_max_norm, ) epoch_end_time = time.perf_counter() logger.info( @@ -1030,6 +633,8 @@ def main(cfg: DictConfig) -> None: model=model, device=dist.device, logger=logger, + tb_writer=writer, + epoch_index=epoch, use_sdf_basis=cfg.model.use_sdf_in_basis_func, use_surface_normals=cfg.model.use_surface_normals, integral_scaling_factor=initial_integral_factor, @@ -1039,8 +644,9 @@ def main(cfg: DictConfig) -> None: first_deriv=first_deriv, eqn=eqn, bounding_box=bounding_box, - vol_factors=vol_factors_tensor, + vol_factors=vol_factors, add_physics_loss=add_physics_loss, + autocast_enabled=cfg.train.amp.enabled, ) scheduler.step() diff --git a/examples/cfd/external_aerodynamics/domino/src/utils.py b/examples/cfd/external_aerodynamics/domino/src/utils.py new file mode 100644 index 0000000000..9c144fa0c3 --- /dev/null +++ b/examples/cfd/external_aerodynamics/domino/src/utils.py @@ -0,0 +1,467 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dataclasses import dataclass +from typing import Dict, Optional, Any +import numpy as np +import torch +import torch.distributed as dist +import pickle +from pathlib import Path +from typing import Literal +from omegaconf import DictConfig +from physicsnemo.distributed import DistributedManager + +from torch.distributed.tensor.placement_types import ( + Shard, + Replicate, +) + + +def get_num_vars(cfg: dict, model_type: Literal["volume", "surface", "combined"]): + """Calculate the number of variables for volume, surface, and global features. + + This function analyzes the configuration to determine how many variables are needed + for different mesh data types based on the model type. Vector variables contribute + 3 components (x, y, z) while scalar variables contribute 1 component each. + + Args: + cfg: Configuration object containing variable definitions for volume, surface, + and global parameters with their types (scalar/vector). + model_type (str): Type of model - can be "volume", "surface", or "combined". + Determines which variable types are included in the count. + + Returns: + tuple: A 3-tuple containing: + - num_vol_vars (int or None): Number of volume variables. None if model_type + is not "volume" or "combined". + - num_surf_vars (int or None): Number of surface variables. None if model_type + is not "surface" or "combined". + - num_global_features (int): Number of global parameter features. + """ + num_vol_vars = 0 + volume_variable_names = [] + if model_type == "volume" or model_type == "combined": + volume_variable_names = list(cfg.variables.volume.solution.keys()) + for j in volume_variable_names: + if cfg.variables.volume.solution[j] == "vector": + num_vol_vars += 3 + else: + num_vol_vars += 1 + else: + num_vol_vars = None + + num_surf_vars = 0 + surface_variable_names = [] + if model_type == "surface" or model_type == "combined": + surface_variable_names = list(cfg.variables.surface.solution.keys()) + num_surf_vars = 0 + for j in surface_variable_names: + if cfg.variables.surface.solution[j] == "vector": + num_surf_vars += 3 + else: + num_surf_vars += 1 + else: + num_surf_vars = None + + num_global_features = 0 + global_params_names = list(cfg.variables.global_parameters.keys()) + for param in global_params_names: + if cfg.variables.global_parameters[param].type == "vector": + num_global_features += len(cfg.variables.global_parameters[param].reference) + elif cfg.variables.global_parameters[param].type == "scalar": + num_global_features += 1 + else: + raise ValueError(f"Unknown global parameter type") + + return num_vol_vars, num_surf_vars, num_global_features + + +def get_keys_to_read( + cfg: dict, + model_type: Literal["volume", "surface", "combined"], + get_ground_truth: bool = True, +): + """ + This function helps configure the keys to read from the dataset. + + And, if some global parameter values are provided in the config, + they are also read here and passed to the dataset. + + """ + + # Always read these keys: + keys_to_read = ["stl_coordinates", "stl_centers", "stl_faces", "stl_areas"] + + # If these keys are in the config, use them, else provide defaults in + # case they aren't in the dataset: + cfg_params_vec = [] + for key in cfg.variables.global_parameters: + if cfg.variables.global_parameters[key].type == "vector": + cfg_params_vec.extend(cfg.variables.global_parameters[key].reference) + else: + cfg_params_vec.append(cfg.variables.global_parameters[key].reference) + keys_to_read_if_available = { + "global_params_values": torch.tensor(cfg_params_vec).reshape(-1, 1), + "global_params_reference": torch.tensor(cfg_params_vec).reshape(-1, 1), + } + + # Volume keys: + volume_keys = [ + "volume_mesh_centers", + ] + if get_ground_truth: + volume_keys.append("volume_fields") + + # Surface keys: + surface_keys = [ + "surface_mesh_centers", + "surface_normals", + "surface_areas", + ] + if get_ground_truth: + surface_keys.append("surface_fields") + + if model_type == "volume" or model_type == "combined": + keys_to_read.extend(volume_keys) + if model_type == "surface" or model_type == "combined": + keys_to_read.extend(surface_keys) + + return keys_to_read, keys_to_read_if_available + + +def coordinate_distributed_environment(cfg: DictConfig): + """ + Initialize the distributed env for DoMINO. This is actually always a 2D Mesh: + one dimension is the data-parallel dimension (DDP), and the other is the + domain dimension. + + For the training scripts, we need to know the rank, size of each dimension, + and return the domain_mesh and placements for the loader. + + Args: + cfg: Configuration object containing the domain parallelism configuration. + + Returns: + domain_mesh: torch.distributed.DeviceMesh: The domain mesh for the domain parallel dimension. + data_mesh: torch.distributed.DeviceMesh: The data mesh for the data parallel dimension. + placements: dict[str, torch.distributed.tensor.Placement]: The placements for the data set + """ + + if not DistributedManager.is_initialized(): + DistributedManager.initialize() + dist = DistributedManager() + + # Default to no domain parallelism: + domain_size = cfg.get("domain_parallelism", {}).get("domain_size", 1) + + # Initialize the device mesh: + mesh = dist.initialize_mesh( + mesh_shape=(-1, domain_size), mesh_dim_names=("ddp", "domain") + ) + domain_mesh = mesh["domain"] + data_mesh = mesh["ddp"] + + if domain_size > 1: + # Define the default placements for each tensor that might show up in + # the data. Note that we'll define placements for all keys, even if + # they aren't actually used. + + # Note that placements are defined for pre-batched data, no batch index! + + grid_like_placement = [ + Shard(0), + ] + point_like_placement = [ + Shard(0), + ] + replicate_placement = [ + Replicate(), + ] + placements = { + "stl_coordinates": point_like_placement, + "stl_centers": point_like_placement, + "stl_faces": point_like_placement, + "stl_areas": point_like_placement, + "surface_fields": point_like_placement, + "volume_mesh_centers": point_like_placement, + "volume_fields": point_like_placement, + "surface_mesh_centers": point_like_placement, + "surface_normals": point_like_placement, + "surface_areas": point_like_placement, + } + else: + domain_mesh = None + placements = None + + return domain_mesh, data_mesh, placements + + +@dataclass +class ScalingFactors: + """ + Data structure for storing scaling factors computed for DoMINO datasets. + + This class provides a clean, easily serializable format for storing + mean, std, min, and max values for different array keys in the dataset. + Uses numpy arrays for easy serialization and cross-platform compatibility. + + Attributes: + mean: Dictionary mapping keys to mean numpy arrays + std: Dictionary mapping keys to standard deviation numpy arrays + min_val: Dictionary mapping keys to minimum value numpy arrays + max_val: Dictionary mapping keys to maximum value numpy arrays + field_keys: List of field keys for which statistics were computed + """ + + mean: Dict[str, np.ndarray] + std: Dict[str, np.ndarray] + min_val: Dict[str, np.ndarray] + max_val: Dict[str, np.ndarray] + field_keys: list[str] + + def to_torch( + self, device: Optional[torch.device] = None + ) -> Dict[str, Dict[str, torch.Tensor]]: + """Convert numpy arrays to torch tensors for use in training/inference.""" + device = device or torch.device("cpu") + + return { + "mean": {k: torch.from_numpy(v).to(device) for k, v in self.mean.items()}, + "std": {k: torch.from_numpy(v).to(device) for k, v in self.std.items()}, + "min_val": { + k: torch.from_numpy(v).to(device) for k, v in self.min_val.items() + }, + "max_val": { + k: torch.from_numpy(v).to(device) for k, v in self.max_val.items() + }, + } + + def save(self, filepath: str | Path) -> None: + """Save scaling factors to pickle file.""" + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + with open(filepath, "wb") as f: + pickle.dump(self, f) + + @classmethod + def load(cls, filepath: str | Path) -> "ScalingFactors": + """Load scaling factors from pickle file.""" + with open(filepath, "rb") as f: + factors = pickle.load(f) + return factors + + def get_field_shapes(self) -> Dict[str, tuple]: + """Get the shape of each field's statistics.""" + return {key: self.mean[key].shape for key in self.field_keys} + + def summary(self) -> str: + """Generate a human-readable summary of the scaling factors.""" + summary = ["Scaling Factors Summary:"] + summary.append(f"Field Keys: {self.field_keys}") + + for key in self.field_keys: + mean_val = self.mean[key] + std_val = self.std[key] + min_val = self.min_val[key] + max_val = self.max_val[key] + + summary.append(f"\n{key}:") + summary.append(f" Shape: {mean_val.shape}") + summary.append(f" Mean: {mean_val}") + summary.append(f" Std: {std_val}") + summary.append(f" Min: {min_val}") + summary.append(f" Max: {max_val}") + + return "\n".join(summary) + + +def load_scaling_factors( + cfg: DictConfig, logger=None +) -> tuple[torch.Tensor, torch.Tensor]: + """Load scaling factors from the configuration.""" + pickle_path = os.path.join(cfg.data.scaling_factors) + + try: + scaling_factors = ScalingFactors.load(pickle_path) + if logger is not None: + logger.info(f"Scaling factors loaded from: {pickle_path}") + except FileNotFoundError: + raise FileNotFoundError( + f"Scaling factors not found at: {pickle_path}; please run compute_statistics.py to compute them." + ) + + if cfg.model.normalization == "min_max_scaling": + vol_factors = np.asarray( + [ + scaling_factors.max_val["volume_fields"], + scaling_factors.min_val["volume_fields"], + ] + ) + surf_factors = np.asarray( + [ + scaling_factors.max_val["surface_fields"], + scaling_factors.min_val["surface_fields"], + ] + ) + elif cfg.model.normalization == "mean_std_scaling": + vol_factors = np.asarray( + [ + scaling_factors.mean["volume_fields"], + scaling_factors.std["volume_fields"], + ] + ) + surf_factors = np.asarray( + [ + scaling_factors.mean["surface_fields"], + scaling_factors.std["surface_fields"], + ] + ) + else: + raise ValueError(f"Invalid normalization mode: {cfg.model.normalization}") + + vol_factors_tensor = torch.from_numpy(vol_factors) + surf_factors_tensor = torch.from_numpy(surf_factors) + + dm = DistributedManager() + vol_factors_tensor = vol_factors_tensor.to(dm.device, dtype=torch.float32) + surf_factors_tensor = surf_factors_tensor.to(dm.device, dtype=torch.float32) + + return vol_factors_tensor, surf_factors_tensor + + +def compute_l2( + pred_surface: torch.Tensor | None, + pred_volume: torch.Tensor | None, + batch, + dataloader, +) -> dict[str, torch.Tensor]: + """ + Compute the L2 norm between prediction and target. + + Requires the dataloader to unscale back to original values + """ + + l2_dict = {} + + if pred_surface is not None: + _, target_surface = dataloader.unscale_model_outputs( + surface_fields=batch["surface_fields"] + ) + _, pred_surface = dataloader.unscale_model_outputs(surface_fields=pred_surface) + l2_surface = metrics_fn_surface(pred_surface, target_surface) + l2_dict.update(l2_surface) + if pred_volume is not None: + target_volume, _ = dataloader.unscale_model_outputs( + volume_fields=batch["volume_fields"] + ) + pred_volume, _ = dataloader.unscale_model_outputs(volume_fields=pred_volume) + l2_volume = metrics_fn_volume(pred_volume, target_volume) + l2_dict.update(l2_volume) + + return l2_dict + + +def metrics_fn_surface( + pred: torch.Tensor, + target: torch.Tensor, +) -> dict[str, torch.Tensor]: + """ + Computes L2 surface metrics between prediction and target. + + Args: + pred: Predicted values (normalized). + target: Target values (normalized). + + Returns: + Dictionary of L2 surface metrics for pressure and shear components. + """ + + l2_num = (pred - target) ** 2 + l2_num = torch.sum(l2_num, dim=1) + l2_num = torch.sqrt(l2_num) + + l2_denom = target**2 + l2_denom = torch.sum(l2_denom, dim=1) + l2_denom = torch.sqrt(l2_denom) + + l2 = l2_num / l2_denom + + metrics = { + "l2_surf_pressure": torch.mean(l2[:, 0]), + "l2_shear_x": torch.mean(l2[:, 1]), + "l2_shear_y": torch.mean(l2[:, 2]), + "l2_shear_z": torch.mean(l2[:, 3]), + } + + return metrics + + +def metrics_fn_volume( + pred: torch.Tensor, + target: torch.Tensor, +) -> dict[str, torch.Tensor]: + """ + Computes L2 volume metrics between prediction and target. + """ + l2_num = (pred - target) ** 2 + l2_num = torch.sum(l2_num, dim=1) + l2_num = torch.sqrt(l2_num) + + l2_denom = target**2 + l2_denom = torch.sum(l2_denom, dim=1) + l2_denom = torch.sqrt(l2_denom) + + l2 = l2_num / l2_denom + + metrics = { + "l2_vol_pressure": torch.mean(l2[:, 3]), + "l2_velocity_x": torch.mean(l2[:, 0]), + "l2_velocity_y": torch.mean(l2[:, 1]), + "l2_velocity_z": torch.mean(l2[:, 2]), + "l2_nut": torch.mean(l2[:, 4]), + } + + return metrics + + +def all_reduce_dict( + metrics: dict[str, torch.Tensor], dm: DistributedManager +) -> dict[str, torch.Tensor]: + """ + Reduces a dictionary of metrics across all distributed processes. + + Args: + metrics: Dictionary of metric names to torch.Tensor values. + dm: DistributedManager instance for distributed context. + + Returns: + Dictionary of reduced metrics. + """ + # TODO - update this to use domains and not the full world + + if dm.world_size == 1: + return metrics + + for key, value in metrics.items(): + dist.all_reduce(value) + value = value / dm.world_size + metrics[key] = value + + return metrics diff --git a/physicsnemo/datapipes/cae/cae_dataset.py b/physicsnemo/datapipes/cae/cae_dataset.py new file mode 100644 index 0000000000..8a2dfdfc5c --- /dev/null +++ b/physicsnemo/datapipes/cae/cae_dataset.py @@ -0,0 +1,1275 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import time +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import torch +import torch.distributed as dist +import zarr +from torch.distributed.tensor import Replicate, Shard + +try: + import tensorstore as ts + + TENSORSTORE_AVAILABLE = True +except ImportError: + TENSORSTORE_AVAILABLE = False + +try: + import pyvista as pv + + PV_AVAILABLE = True +except ImportError: + PV_AVAILABLE = False + +from physicsnemo.distributed import ShardTensor, ShardTensorSpec +from physicsnemo.distributed.utils import compute_split_shapes + +# Abstractions: +# - want to read npy/npz/.zarr/.stl/.vtp files +# - Need to share next level abstractions +# - Domain parallel dataloading is supported: output will be ShardTensor instead. +# - need to be able to configure preprocessing +# - CPU -> GPU transfer happens here, needs to be isolated in it's own stream +# - Output of dataloader should be torch.Tensor objects. + + +""" +This datapipe handles reading files from Zarr and piping into torch.Tensor objects. + +It's expected that the files are organized as groups, with each .zarr +file representing one training example. To improve IO performance, the files +should be chunked for each array. The reader takes a list of keys in the +group to read, and will not read keys that are not specified. The exception +is if _no_ keys are passed, in which case _all_ keys will be read. +""" + + +class BackendReader(ABC): + """ + Abstract base class for backend readers. + """ + + def __init__( + self, + keys_to_read: list[str] | None, + keys_to_read_if_available: dict[str, torch.Tensor] | None, + ) -> None: + """ + Initialize the backend reader. + """ + self.keys_to_read = keys_to_read + self.keys_to_read_if_available = keys_to_read_if_available + + self.volume_sampling_size = None + + self.is_volumetric = any(["volume" in key for key in self.keys_to_read]) + + @abstractmethod + def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: + """ + Read a file and return a dictionary of tensors. + """ + pass + + @abstractmethod + def read_file_sharded( + self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh + ) -> tuple[dict[str, torch.Tensor], dict[str, dict]]: + """ + Read a file and return a dictionary of tensors ready to convert to ShardTensors. + + NOTE: this function does not actually convert torch tensors to ShardTensors. + It's possible that the conversion, in some cases, can be a collective function. + Due to the async nature of the loader, we don't rely on any ordering of + collectives and defer them to the last possible minute. + + Additionally, these functions return CPU tensors and we don't actually + define shard tensors on cpu. + + So, the dataset itself will convert a local tensor + shard info to shard tensor + after the cpu-> gpu movement. + """ + pass + + def fill_optional_keys( + self, data: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """ + Fill missing keys with the keys from the keys_to_read_if_available dictionary. + """ + for key in self.keys_to_read_if_available: + if key not in data.keys(): + data[key] = self.keys_to_read_if_available[key] + return data + + def _get_slice_boundaries( + self, array_shape: tuple[int], this_rank: int, n_splits: int, split_dim: int = 0 + ) -> tuple[int, int, tuple | None]: + """ + For an array, determine the slice boundaries for parallel reading. + + Args: + array_shape: The total shape of the target array. + this_rank: The rank of the distributed process. + n_splits: The size of the distributed process. + split_dim: The dimension to split, default is 0. + + Returns: + The slice boundaries for parallel reading. + """ + # Determine what slice this rank should read + + sections = compute_split_shapes(array_shape[split_dim], n_splits) + + global_chunk_start = sum(sections[:this_rank]) + global_chunk_stop = global_chunk_start + sections[this_rank] + + chunk_sizes = tuple( + array_shape[:split_dim] + (section,) + array_shape[split_dim + 1 :] + for section in sections + ) + + return global_chunk_start, global_chunk_stop, chunk_sizes + + def set_volume_sampling_size(self, volume_sampling_size: int): + """ + Set the volume sampling size. When set, the readers will + assume the volumetric data is shuffled on disk and read only + contiguous chunks of the data up to the sampling size. + + + Args: + volume_sampling_size: The total size of the volume sampling. + + """ + self.volume_sampling_size = volume_sampling_size + + def select_random_sections_from_slice( + self, + slice_start: int, + slice_stop: int, + n_points: int, + ) -> slice: + """ + + select the contiguous chunks of the volume data to read. + + Args: + n_volume_points: The number of points to sample from the volume. + + Returns: + A tuple of the start and stop indices of the contiguous chunks. + """ + + if slice_stop - slice_start < n_points: + raise ValueError( + f"Slice size {slice_stop - slice_start} is less than the number of points {n_points}" + ) + + # Choose a random start point that will fit the entire n_points region: + start = np.random.randint(slice_start, slice_stop - n_points) + return slice(start, start + n_points) + + +class NpyFileReader(BackendReader): + """ + Reader for numpy files. + """ + + def __init__( + self, + keys_to_read: list[str] | None, + keys_to_read_if_available: dict[str, torch.Tensor] | None, + ) -> None: + super().__init__(keys_to_read, keys_to_read_if_available) + + def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: + """ + Read a file and return a dictionary of tensors. + """ + data = np.load(filename, allow_pickle=True).item() + + missing_keys = set(self.keys_to_read) - set(data.keys()) + + if len(missing_keys) > 0: + raise ValueError(f"Keys {missing_keys} not found in file {filename}") + + data = {key: torch.from_numpy(data[key]) for key in self.keys_to_read} + + return self.fill_optional_keys(data) + + def read_file_sharded( + self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh + ) -> dict[str, ShardTensor]: + pass + + def set_volume_sampling_size(self, volume_sampling_size: int): + """ + This is not supported for npy files. + """ + raise NotImplementedError( + "volume sampling directly from disk is not supported for npy files." + ) + + +class NpzFileReader(BackendReader): + """ + Reader for npz files. + """ + + def __init__( + self, + keys_to_read: list[str] | None, + keys_to_read_if_available: dict[str, torch.Tensor] | None, + ) -> None: + super().__init__(keys_to_read, keys_to_read_if_available) + + def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: + """ + Read a file and return a dictionary of tensors. + """ + in_data = np.load(filename) + + keys_found = set(in_data.keys()) + keys_missing = set(self.keys_to_read) - keys_found + if len(keys_missing) > 0: + raise ValueError(f"Keys {keys_missing} not found in file {filename}") + + # Make sure to select the slice outside of the loop. + if self.is_volumetric: + if self.volume_sampling_size is not None: + volume_slice = self.select_random_sections_from_slice( + 0, + in_data["volume_mesh_centers"].shape[0], + self.volume_sampling_size, + ) + else: + volume_slice = slice(0, in_data["volume_mesh_centers"].shape[0]) + + # This is a slower basic way to do this, to be improved: + data = {} + for key in self.keys_to_read: + if "volume" not in key: + data[key] = torch.from_numpy(in_data[key][:]) + else: + data[key] = torch.from_numpy(in_data[key][volume_slice]) + + # data = {key: torch.from_numpy(in_data[key][:]) for key in self.keys_to_read} + + return self.fill_optional_keys(data) + + def read_file_sharded( + self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh + ) -> dict[str, ShardTensor]: + pass + + def set_volume_sampling_size(self, volume_sampling_size: int): + """ + This is not supported for npz files. + """ + raise NotImplementedError( + "volume sampling directly from disk is not supported for npz files." + ) + + +class ZarrFileReader(BackendReader): + """ + Reader for zarr files. + """ + + def __init__( + self, + keys_to_read: list[str] | None, + keys_to_read_if_available: dict[str, torch.Tensor] | None, + ) -> None: + super().__init__(keys_to_read, keys_to_read_if_available) + + def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: + """ + Read a file and return a dictionary of tensors. + """ + group = zarr.open_group(filename, mode="r") + + missing_keys = set(self.keys_to_read) - set(group.keys()) + + if len(missing_keys) > 0: + raise ValueError(f"Keys {missing_keys} not found in file {filename}") + + # Make sure to select the slice outside of the loop. + if self.is_volumetric: + if self.volume_sampling_size is not None: + volume_slice = self.select_random_sections_from_slice( + 0, + group["volume_mesh_centers"].shape[0], + self.volume_sampling_size, + ) + else: + volume_slice = slice(0, group["volume_mesh_centers"].shape[0]) + + # This is a slower basic way to do this, to be improved: + data = {} + for key in self.keys_to_read: + if "volume" not in key: + data[key] = torch.from_numpy(group[key][:]) + else: + data[key] = torch.from_numpy(group[key][volume_slice]) + + return self.fill_optional_keys(data) + + def read_file_sharded( + self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh + ) -> tuple[dict[str, torch.Tensor], dict[str, dict]]: + """ + Read a file and return a dictionary of tensors. + """ + + # We need the coordinates of this GPU: + this_rank = device_mesh.get_local_rank() + domain_size = dist.get_world_size(group=device_mesh.get_group()) + + group = zarr.open_group(filename, mode="r") + + missing_keys = set(self.keys_to_read) - set(group.keys()) + + if len(missing_keys) > 0: + raise ValueError(f"Keys {missing_keys} not found in file {filename}") + + data = {} + specs = {} + for key in self.keys_to_read: + # Open the array in zarr without reading it and get info: + zarr_array = group[key] + array_shape = zarr_array.shape + if array_shape == (): + # Read scalars from every rank and use replicate sharding + raw_data = torch.from_numpy(zarr_array[:]) + placement = [ + Replicate(), + ] + chunk_sizes = None + else: + target_dim = 0 + if array_shape[target_dim] < domain_size: + # If the array is smaller than the number of ranks, + # again read and use replicate sharding: + raw_data = torch.from_numpy(zarr_array[:]) + placement = [ + Replicate(), + ] + chunk_sizes = None + else: + # Read partially from the data and use Shard(target_dim) sharding + chunk_start, chunk_stop, chunk_sizes = self._get_slice_boundaries( + zarr_array.shape, this_rank, domain_size + ) + raw_data = torch.from_numpy(zarr_array[chunk_start:chunk_stop]) + placement = [ + Shard(target_dim), + ] + + # Turn chunk sizes into a dict over mesh dim 0: + chunk_sizes = {0: chunk_sizes} + + # + data[key] = raw_data + specs[key] = (placement, chunk_sizes) + + # Patch in the optional keys: + data = self.fill_optional_keys(data) + for key in data.keys(): + if key not in specs: + specs[key] = ( + [ + Replicate(), + ], + {}, + ) + + return data, specs + + +if PV_AVAILABLE: + + class VTKFileReader(BackendReader): + """ + Reader for vtk files. + """ + + def __init__( + self, + keys_to_read: list[str] | None, + keys_to_read_if_available: dict[str, torch.Tensor] | None, + ) -> None: + super().__init__(keys_to_read, keys_to_read_if_available) + + self.stl_file_keys = [ + "stl_coordinates", + "stl_centers", + "stl_faces", + "stl_areas", + ] + self.vtp_file_keys = [ + "surface_mesh_centers", + "surface_normals", + "surface_mesh_sizes", + "CpMeanTrim", + "pMeanTrim", + "wallShearStressMeanTrim", + ] + self.vtu_file_keys = [ + "volume_mesh_centers", + "volume_fields", + ] + + self.exclude_patterns = [ + "single_solid", + ] + + def get_file_name(self, dir_name: pathlib.Path, extension: str) -> pathlib.Path: + """ + Get the file name for a given directory and extension. + """ + # >>> matches = [p for p in list(dir_name.iterdir()) if p.suffix == ".stl" and not any(pattern in p.name for pattern in exclude_patterns)] + matches = [ + p + for p in dir_name.iterdir() + if p.suffix == extension + and not any(pattern in p.name for pattern in self.exclude_patterns) + ] + if len(matches) == 0: + raise FileNotFoundError(f"No {extension} files found in {dir_name}") + fname = matches[0] + return dir_name / fname + + def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: + """ + Read a set of files and return a dictionary of tensors. + """ + + # This reader attempts to only read what's necessary, and not more. + # So, the functions that do the reading are each "one file" functions + # and we open them for processing only when necessary. + + return_data = {} + + # Note that this reader is, already, running in a background thread. + # It may or may not help to further thread these calls. + if any(key in self.stl_file_keys for key in self.keys_to_read): + stl_path = self.get_file_name(filename, ".stl") + stl_data = self.read_data_from_stl(stl_path) + return_data.update(stl_data) + if any(key in self.vtp_file_keys for key in self.keys_to_read): + vtp_path = self.get_file_name(filename, ".vtp") + vtp_data = self.read_data_from_vtp(vtp_path) + return_data.update(vtp_data) + if any(key in self.vtu_file_keys for key in self.keys_to_read): + raise NotImplementedError("VTU files are not supported yet.") + + return self.fill_optional_keys(return_data) + + def read_file_sharded( + self, filename: pathlib.Path, parallel_rank: int, parallel_size: int + ) -> tuple[dict[str, torch.Tensor], dict[str, ShardTensorSpec]]: + """ + Read a file and return a dictionary of tensors. + """ + raise NotImplementedError("Not implemented yet.") + + def read_data_from_stl( + self, + stl_path: str, + ) -> dict: + """ + Reads surface mesh data from an STL file and prepares a batch dictionary for inference. + + Args: + stl_path (str): Path to the STL file. + + Returns: + dict: Batch dictionary with mesh faces and coordinates as torch tensors. + """ + + mesh = pv.read(stl_path) + + batch = {} + + faces = mesh.faces.reshape(-1, 4) + faces = faces[:, 1:] + + batch["stl_faces"] = faces.flatten() + + batch["stl_coordinates"] = mesh.points + batch["surface_normals"] = mesh.cell_normals + + batch = {k: torch.from_numpy(v) for k, v in batch.items()} + + return batch + + def read_data_from_vtp(self, vtp_path: str) -> dict: + """ + Read vtp file from a file + """ + + raise NotImplementedError("Not implemented yet.") + + def set_volume_sampling_size(self, volume_sampling_size: int): + """ + This is not supported for vtk files. + """ + raise NotImplementedError( + "volume sampling directly from disk is not supported for vtk files." + ) + + +if TENSORSTORE_AVAILABLE: + + class TensorStoreZarrReader(BackendReader): + """ + Reader for tensorstore zarr files. + """ + + def __init__( + self, + keys_to_read: list[str] | None, + keys_to_read_if_available: dict[str, torch.Tensor] | None, + cache_bytes_limit: int = 10_000_000, + data_copy_concurrency: int = 72, + file_io_concurrency: int = 72, + ) -> None: + super().__init__(keys_to_read, keys_to_read_if_available) + + self.spec_template = { + "driver": "auto", + "kvstore": { + "driver": "file", + "path": None, + }, + } + + self.context = ts.Context( + { + "cache_pool": {"total_bytes_limit": cache_bytes_limit}, + "data_copy_concurrency": {"limit": data_copy_concurrency}, + "file_io_concurrency": {"limit": file_io_concurrency}, + } + ) + + def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: + """ + Read a file and return a dictionary of tensors. + """ + + # Trigger an async open of each data item: + read_futures = {} + for key in self.keys_to_read: + spec = self.spec_template.copy() + spec["kvstore"]["path"] = str(filename) + "/" + str(key) + + read_futures[key] = ts.open( + spec, create=False, open=True, context=self.context + ) + + # Wait for all the opens to conclude: + read_futures = { + key: read_futures[key].result() for key in read_futures.keys() + } + + # Make sure to select the slice outside of the loop. + # We need + if self.is_volumetric: + if self.volume_sampling_size is not None: + volume_slice = self.select_random_sections_from_slice( + 0, + read_futures["volume_mesh_centers"].shape[0], + self.volume_sampling_size, + ) + else: + volume_slice = slice( + 0, read_futures["volume_mesh_centers"].shape[0] + ) + + # Trigger an async read of each data item: + # (Each item will be a numpy ndarray after this:) + tensor_futures = {} + for key in self.keys_to_read: + if "volume" not in key: + tensor_futures[key] = read_futures[key].read() + # For the volume data, read the slice: + else: + tensor_futures[key] = read_futures[key][volume_slice].read() + + # Convert them to torch tensors: + # (make sure to block for the result) + data = { + key: torch.as_tensor(tensor_futures[key].result(), dtype=torch.float32) + for key in self.keys_to_read + } + + return self.fill_optional_keys(data) + + def read_file_sharded( + self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh + ) -> tuple[dict[str, torch.Tensor], dict[str, dict]]: + """ + Read a file and return a dictionary of tensors. + """ + + # We need the coordinates of this GPU: + this_rank = device_mesh.get_local_rank() + domain_size = dist.get_world_size(group=device_mesh.get_group()) + + # This pulls a list of store objects in tensorstore: + stores = {} + for key in self.keys_to_read: + spec = self.spec_template.copy() + spec["kvstore"]["path"] = str(filename) + "/" + str(key) + + stores[key] = ts.open( + spec, create=False, open=True, context=self.context + ) + + stores = {key: stores[key].result() for key in stores.keys()} + + data = {} + specs = {} + for key in self.keys_to_read: + # Open the array in zarr without reading it and get info: + store = stores[key] + array_shape = store.shape + if array_shape == (): + # Read scalars from every rank and use replicate sharding + _slice = np.s_[:] + # raw_data = torch.from_numpy(store[:]) + placement = [ + Replicate(), + ] + chunk_sizes = None + else: + target_dim = 0 + if array_shape[target_dim] < domain_size: + # If the array is smaller than the number of ranks, + # again read and use replicate sharding: + _slice = np.s_[:] + # raw_data = torch.from_numpy(store[:]) + placement = [ + Replicate(), + ] + chunk_sizes = None + else: + # Read partially from the data and use Shard(target_dim) sharding + chunk_start, chunk_stop, chunk_sizes = ( + self._get_slice_boundaries( + store.shape, this_rank, domain_size + ) + ) + _slice = np.s_[chunk_start:chunk_stop] + # raw_data = torch.from_numpy(zarr_array[chunk_start:chunk_stop]) + placement = [ + Shard(target_dim), + ] + + # Turn chunk sizes into a dict over mesh dim 0: + chunk_sizes = {0: chunk_sizes} + + # Trigger the reads as async: + data[key] = store[_slice].read() + specs[key] = (placement, chunk_sizes) + + # Finally, await the full data read: + for key in self.keys_to_read: + data[key] = torch.as_tensor(data[key].result()) + + # Patch in the optional keys: + data = self.fill_optional_keys(data) + for key in data.keys(): + if key not in specs: + specs[key] = ( + [ + Replicate(), + ], + {}, + ) + + return data, specs + +else: + + class TensorStoreZarrReader(BackendReader): + """ + Null reader for tensorstore zarr files. + """ + + def __init__( + self, + keys_to_read: list[str] | None, + keys_to_read_if_available: dict[str, torch.Tensor] | None, + ) -> None: + # Raise an exception on construction if we get here: + raise NotImplementedError( + "TensorStoreZarrReader is not available without tensorstore. `pip install tensorstore`." + ) + + +def is_vtk_directory(file: pathlib.Path) -> bool: + """ + Check if a file is a vtk directory. + """ + return file.is_dir() and all( + [f.suffix in [".vtp", ".stl", ".vtu", ".vtk", ".csv"] for f in file.iterdir()] + ) + + +class CAEDataset: + """ + Dataset reader for DrivaerML and similar datasets. In general, this + dataset supports reading dictionary-like data, and returning a + dictionary of torch.Tensor objects. + + When constructed, the user must pass a directory of data examples. + The dataset will inspect the folder, identify all children, and decide: + - If every file is a directory ending in .zarr, the zarr reader is used. + - If every file is .npy, the .npy reader is used. + - If every file is .npz, the .npz reader is used. + - If every file is a directory without an extension, it's assumed to be .stl/.vtp/.vtu + + The user can optionally force one path with a parameter. + + The flow of this dataset is: + - Load data from file, using a thread. + - Each individual file reading tool may or may not have it's own threading + or multi processing enabled. That's up to it. This just does async + loading. + - Data should come out of the readers in dict{str : torch.Tensor} format + - The data is transferred from CPU to GPU in a separate stream. + + Users can call __getitem__(i), which will trigger the pipeline, + or they can call `preload(i)`, which will start the pipeline for index `i`. + Subsequent calls to `__getitem__(i)` should be faster since the IO is in + progress or complete. + + Using the `__iter__` functionality will automatically enable preloading. + + """ + + def __init__( + self, + data_dir: str | pathlib.Path, + keys_to_read: list[str] | None, + keys_to_read_if_available: dict[str, torch.Tensor] | None, + output_device: torch.device, + preload_depth: int = 2, + pin_memory: bool = False, + device_mesh: torch.distributed.DeviceMesh | None = None, + placements: dict[str, torch.distributed.tensor.Placement] | None = None, + consumer_stream: torch.cuda.Stream | None = None, + ) -> None: + if isinstance(data_dir, str): + data_dir = pathlib.Path(data_dir) + + # Verify the data directory exists: + if not data_dir.exists(): + raise FileNotFoundError(f"Data directory {data_dir} does not exist") + + # Verify the data directory is a directory: + if not data_dir.is_dir(): + raise NotADirectoryError(f"Data directory {data_dir} is not a directory") + + self._keys_to_read = keys_to_read + + # Make sure the optional keys are on the right device: + self._keys_to_read_if_available = { + k: v.to(output_device) for k, v in keys_to_read_if_available.items() + } + + self.file_reader, self._filenames = self._infer_file_type_and_filenames( + data_dir + ) + + self.pin_memory = pin_memory + + # Check the file names; some can be read well in parallel, while others + # are not parallelizable. + + self._length = len(self._filenames) + + self.output_device = output_device + if output_device.type == "cuda": + self._data_loader_stream = torch.cuda.Stream() + else: + self._data_loader_stream = None + + self.device_mesh = device_mesh + self.placements = placements + # This tracks global tensor info + # so we can convert to ShardTensor at the right time. + self.shard_spec = {} + + if self.device_mesh is not None: + if self.device_mesh.ndim != 1: + raise ValueError("Device mesh must be one dimensional") + + # This is thread storage for data preloading: + self._preload_queue = {} + self._transfer_events = {} + self.preload_depth = preload_depth + self.preload_executor = ThreadPoolExecutor(max_workers=max(1, preload_depth)) + + if consumer_stream is None and self.output_device.type == "cuda": + consumer_stream = torch.cuda.current_stream() + + self.consumer_stream = consumer_stream + + def set_indices(self, indices: list[int]): + """ + Set the indices for the dataset for this epoch. + """ + + # TODO - this needs to block while anything is in the preprocess queue. + + self.indices = indices + + def idx_to_index(self, idx): + if hasattr(self, "indices"): + return self.indices[idx] + + return idx + + def _infer_file_type_and_filenames( + self, data_dir: pathlib.Path + ) -> tuple[str, list[str]]: + """ + Infer the file type and filenames from the data directory. + """ + + # We validated the directory exists and is a directory already. + + # List the files: + files = list(data_dir.iterdir()) + + # Initialize the file reader object + # Note that for some of these, they could be functions + # But others benefit from having a state, so we use classes: + + if all(file.suffix == ".npy" for file in files): + file_reader = NpyFileReader( + self._keys_to_read, self._keys_to_read_if_available + ) + return file_reader, files + elif all(file.suffix == ".npz" for file in files): + file_reader = NpzFileReader( + self._keys_to_read, self._keys_to_read_if_available + ) + return file_reader, files + elif all(file.suffix == ".zarr" and file.is_dir() for file in files): + if TENSORSTORE_AVAILABLE: + file_reader = TensorStoreZarrReader( + self._keys_to_read, self._keys_to_read_if_available + ) + else: + file_reader = ZarrFileReader( + self._keys_to_read, self._keys_to_read_if_available + ) + return file_reader, files + elif all(is_vtk_directory(file) for file in files): + file_reader = VTKFileReader( + self._keys_to_read, self._keys_to_read_if_available + ) + return file_reader, files + # Each "file" here is a directory of .vtp, stl, etc. + else: + # TODO - support folders of stl, vtp, vtu. + raise ValueError(f"Unsupported file type: {files[0]}") + + def _move_to_gpu( + self, data: dict[str, torch.Tensor], idx: int + ) -> dict[str, torch.Tensor]: + """Convert numpy arrays to torch tensors and move to GPU if available. + + Args: + data: Dictionary of key to torch tensor. + + Returns: + Dictionary of key to torch tensor on GPU if available. + """ + + if self.output_device.type != "cuda": + return data + + result = {} + + with torch.cuda.stream(self._data_loader_stream): + for key in data.keys(): + if data[key].device == self.output_device: + result[key] = data[key] + continue + if self.pin_memory: + result[key] = ( + data[key].pin_memory().to(self.output_device, non_blocking=True) + ) + else: + result[key] = data[key].to(self.output_device, non_blocking=True) + # Move to GPU if available + # result[key] = data[key].to(self.output_device, non_blocking=True) + result[key].record_stream(self.consumer_stream) + + # Mark the consumer stream: + transfer_event = torch.cuda.Event() + transfer_event.record(self._data_loader_stream) + self._transfer_events[idx] = transfer_event + + return result + + def _convert_to_shard_tensors( + self, + tensors: dict[str, torch.Tensor], + filename: str, + ) -> dict[str, ShardTensor]: + """Convert tensors to ShardTensor objects for distributed training. + + Args: + tensors: Dictionary of key to torch tensor. + + Returns: + Dictionary of key to torch tensor or ShardTensor. + """ + + if self.device_mesh is None: + return tensors + + spec_dict = self.shard_spec.pop(filename) + result = {} + for key in tensors.keys(): + placement, chunk_sizes = spec_dict[key] + + result[key] = ShardTensor.from_local( + local_tensor=tensors[key], + device_mesh=self.device_mesh, + placements=placement, + sharding_shapes=chunk_sizes, + ) + + return result + + def preload(self, idx: int) -> None: + """ + Asynchronously preload the data for the given index (up to CPU, not GPU). + Only one preload operation is supported at a time. + + Args: + idx: Index of the sample to preload. + """ + if idx in self._preload_queue: + # Skip items that are already in the queue + return + + def _preload_worker(): + data = self._read_file(self._filenames[idx]) + if "stl_faces" in data: + data["stl_faces"] = data["stl_faces"].to(torch.int32) + # Convert to torch tensors + return self._move_to_gpu(data, idx) + + self._preload_queue[idx] = self.preload_executor.submit(_preload_worker) + + def get_preloaded(self, idx: int) -> dict[str, torch.Tensor] | None: + """ + Retrieve the preloaded data (blocking if not ready). + + Returns: + (idx, data) tuple where data is a dictionary of key to numpy array or torch tensor. + + Raises: + RuntimeError: If no preload is in progress. + Exception: If preload failed. + """ + + if idx not in self._preload_queue: + return None + + result = self._preload_queue[ + idx + ].result() # This will block until the result is ready + self._preload_queue.pop(idx) # Clear the future after getting the result + + return result + + def __iter__(self): + # When starting the iterator method, start loading the data + # at idx = 0, idx = 1 + # Start preprocessing at idx = 0, when the load completes + + self.i = 0 + + N = len(self.indices) if hasattr(self, "indices") else len(self) + for i in range(self.preload_depth): + # Trigger the dataset to start loading index 0: + if N > i + 1: + self.preload(self.idx_to_index(self.i + i)) + + return self + + def __next__(self): + N = len(self.indices) if hasattr(self, "indices") else len(self._filenames) + + # Iteration bounds are based on the counter, not the random-access index + if self.i >= N: + self.i = 0 + raise StopIteration + + # This is the file random access index + target_index = self.idx_to_index(self.i) + + # Before returning, put the next two target indexes into the queue: + for preload_i in range(self.preload_depth): + next_iteration_index = self.i + preload_i + 1 + if N > next_iteration_index: + preload_idx = self.idx_to_index(next_iteration_index) + self.preload(preload_idx) + + # Send up the random-access data: + data = self.__getitem__(target_index) + + self.i += 1 + + return data + + def __len__(self): + return len(self._filenames) + + def _read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]: + """ + Read a file and return a dictionary of tensors. + """ + if self.device_mesh is not None: + tensor_dict, spec_dict = self.file_reader.read_file_sharded( + filename, self.device_mesh + ) + self.shard_spec[filename] = spec_dict + return tensor_dict + else: + return self.file_reader.read_file(filename) + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor | ShardTensor]: + """ + Get a data sample. + + Flow is: + - Read data, or get preloaded data if this idx is preloaded. + - Move data to GPU, if needed. + - Preloading data will move to GPU if it can. + - If domain parallelism is enabled, convert to ShardTensors. + - Return + + Args: + idx: Index of the sample to retrieve + + Returns: + Dictionary containing tensors/ShardTensors for the requested data + """ + + if idx >= len(self._filenames): + raise IndexError( + f"Index {idx} out of range for dataset of size {len(self._filenames)}" + ) + + # Attempt to get preloaded data: + data = self.get_preloaded(idx) + if data is None: + # Read data from zarr file + data = self._read_file(self._filenames[idx]) + data = self._move_to_gpu(data, idx) + + # This blocks until the preprocessing has transferred to GPU + if idx in self._transfer_events: + self.consumer_stream.wait_event(self._transfer_events[idx]) + self._transfer_events.pop(idx) + + # Convert to ShardTensors if using domain parallelism + if self.device_mesh is not None: + data = self._convert_to_shard_tensors(data, self._filenames[idx]) + + return data + + def set_volume_sampling_size(self, volume_sampling_size: int): + """ + Set the volume sampling size. When set, the readers will + assume the volumetric data is shuffled on disk and read only + contiguous chunks of the data up to the sampling size. + + Args: + volume_sampling_size: The total size of the volume sampling. + """ + self.file_reader.set_volume_sampling_size(volume_sampling_size) + + def close(self): + """ + Explicitly close the dataset and cleanup resources, including the ThreadPoolExecutor. + """ + if hasattr(self, "preload_executor") and self.preload_executor is not None: + self.preload_executor.shutdown(wait=True) + self.preload_executor = None + + def __del__(self): + """ + Cleanup resources when the dataset is destroyed. + """ + self.close() + + +def compute_mean_std_min_max( + dataset: CAEDataset, field_keys: list[str], max_samples: int = 20 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the mean, standard deviation, minimum, and maximum for a specified field + across all samples in a dataset. + + Uses a numerically stable online algorithm for mean and variance. + + Args: + dataset (CAEDataset): The dataset to process. + field_key (str): The key for the field to normalize. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + mean, std, min, max tensors for the field. + """ + N = {} + mean = {} + M2 = {} # Sum of squares of differences from the current mean + min_val = {} + max_val = {} + + # Read the first data item to get the shapes: + example_data = dataset[0] + + # Create placeholders for the accumulators: + for key in field_keys: + N[key] = torch.zeros(1, dtype=torch.int64, device=example_data[key].device) + mean[key] = torch.zeros( + example_data[key].shape[-1], + device=example_data[key].device, + dtype=torch.float64, + ) + M2[key] = torch.zeros( + example_data[key].shape[-1], + device=example_data[key].device, + dtype=torch.float64, + ) + min_val[key] = torch.full( + (example_data[key].shape[-1],), + float("inf"), + device=example_data[key].device, + ) + max_val[key] = torch.full( + (example_data[key].shape[-1],), + float("-inf"), + device=example_data[key].device, + ) + + global_start = time.perf_counter() + start = time.perf_counter() + data_list = np.arange(len(dataset)) + np.random.shuffle(data_list) + for i, j in enumerate(data_list): + data = dataset[j] + if i >= max_samples: + break + + for field_key in field_keys: + field_data = data[field_key] + + # Compute batch statistics + batch_mean = field_data.mean(axis=(0)) + batch_M2 = ((field_data - batch_mean) ** 2).sum(axis=(0)) + batch_n = field_data.shape[0] + + # Update running mean and M2 (Welford's algorithm) + delta = batch_mean - mean[field_key] + N[field_key] += batch_n # batch_n should also be torch.int64 + mean[field_key] = mean[field_key] + delta * (batch_n / N[field_key]) + M2[field_key] = ( + M2[field_key] + + batch_M2 + + delta**2 * (batch_n * N[field_key]) / N[field_key] + ) + + end = time.perf_counter() + iteration_time = end - start + print( + f"on iteration {i} of {max_samples}, time: {iteration_time:.2f} seconds for file: {j}" + ) + start = time.perf_counter() + + var = {} + std = {} + for field_key in field_keys: + var[field_key] = M2[field_key] / ( + N[field_key].item() - 1 + ) # Convert N to Python int for division + std[field_key] = torch.sqrt(var[field_key]) + + start = time.perf_counter() + for i, j in enumerate(data_list): + data = dataset[j] + if i >= max_samples: + break + + for field_key in field_keys: + field_data = data[field_key] + + batch_n = field_data.shape[0] + + # # Update min/max + + mean_sample = mean[field_key] + std_sample = std[field_key] + mask = torch.ones_like(field_data, dtype=torch.bool) + for v in range(field_data.shape[-1]): + outliers = (field_data[:, v] < mean_sample[v] - 9.0 * std_sample[v]) | ( + field_data[:, v] > mean_sample[v] + 9.0 * std_sample[v] + ) + mask[:, v] = ~outliers + + batch_min = [] + batch_max = [] + for v in range(field_data.shape[-1]): + batch_min.append(field_data[mask[:, v], v].min()) + batch_max.append(field_data[mask[:, v], v].max()) + + batch_min = torch.stack(batch_min) + batch_max = torch.stack(batch_max) + + min_val[field_key] = torch.minimum(min_val[field_key], batch_min) + max_val[field_key] = torch.maximum(max_val[field_key], batch_max) + + end = time.perf_counter() + iteration_time = end - start + print( + f"on iteration {i} of {max_samples}, time: {iteration_time:.2f} seconds for file: {j}" + ) + start = time.perf_counter() + + global_end = time.perf_counter() + global_time = global_end - global_start + + print(f"Total time: {global_time:.2f} seconds for {max_samples} samples") + + return mean, std, min_val, max_val diff --git a/physicsnemo/datapipes/cae/domino_datapipe.py b/physicsnemo/datapipes/cae/domino_datapipe.py index 0a3ec9e38b..491ab5a199 100644 --- a/physicsnemo/datapipes/cae/domino_datapipe.py +++ b/physicsnemo/datapipes/cae/domino_datapipe.py @@ -17,7 +17,7 @@ """ This code provides the datapipe for reading the processed npy files, generating multi-res grids, calculating signed distance fields, -positional encodings, sampling random points in the volume and on surface, +sampling random points in the volume and on surface, normalizing fields and returning the output tensors as a dictionary. This datapipe also non-dimensionalizes the fields, so the order in which the variables should @@ -26,102 +26,46 @@ variable names, domain resolution, sampling size etc. are configurable in config.yaml. """ -import os -import time -from concurrent.futures import ThreadPoolExecutor -from contextlib import nullcontext from dataclasses import dataclass from pathlib import Path -from typing import Literal, Optional, Protocol, Sequence, Union +from typing import Iterable, Literal, Optional, Protocol, Sequence, Union -import cuml -import cupy as cp import numpy as np import torch import torch.cuda.nvtx as nvtx -import zarr from omegaconf import DictConfig -from scipy.spatial import KDTree -from torch import Tensor -from torch.utils.data import Dataset, default_collate +from torch.distributed.tensor.placement_types import Replicate +from torch.utils.data import Dataset +from physicsnemo.datapipes.cae.cae_dataset import ( + CAEDataset, + compute_mean_std_min_max, +) from physicsnemo.distributed import DistributedManager +from physicsnemo.distributed.shard_tensor import ShardTensor, scatter_tensor from physicsnemo.utils.domino.utils import ( - ArrayType, - area_weighted_shuffle_array, calculate_center_of_mass, - calculate_normal_positional_encoding, create_grid, get_filenames, - mean_std_sampling, normalize, pad, - # sample_array, shuffle_array, - solution_weighted_shuffle_array, standardize, + unnormalize, + unstandardize, ) +from physicsnemo.utils.neighbors import knn from physicsnemo.utils.profiling import profile from physicsnemo.utils.sdf import signed_distance_field -""" -These functions, below, are to handle the SDF calculation which only -accepts torch tensors. The entire pipeline is moving to torch, so -these aren't necessary after that. -""" - - -def _convert_array_to_torch(array: cp.ndarray | np.ndarray) -> torch.Tensor: - """ - TEMPORARY function to convert cupy and numpy arrays to torch tensors. - """ - if isinstance(array, cp.ndarray): - return torch.utils.dlpack.from_dlpack(array) - elif isinstance(array, np.ndarray): - return torch.from_numpy(array) - else: - raise ValueError(f"Unsupported array type: {type(array)}") - - -def _convert_torch_to_array(array: torch.Tensor, provider) -> cp.ndarray | np.ndarray: - """ - TEMPORARY function to convert torch tensors to cupy arrays. - """ - return provider.from_dlpack(array) - - -def domino_collate_fn(batch): - """ - This function is a custom collation function to move cupy data to torch tensors on the device. - - For things that aren't cupy arrays, fall back to torch.data.default_convert. Data, here, - is a dictionary of numpy arrays or cupy arrays. - - """ - - def convert(obj): - if isinstance(obj, cp.ndarray): - return torch.utils.dlpack.from_dlpack(obj.toDlpack()) - elif isinstance(obj, list): - return [convert(x) for x in obj] - elif isinstance(obj, tuple): - return tuple(convert(x) for x in obj) - elif isinstance(obj, dict): - return {k: convert(v) for k, v in obj.items()} - else: - return obj - - batch = [convert(sample) for sample in batch] - return default_collate(batch) - class BoundingBox(Protocol): """ Type definition for the required format of bounding box dimensions. """ - min: ArrayType - max: ArrayType + min: Sequence + max: Sequence @dataclass @@ -134,8 +78,6 @@ class DoMINODataConfig: surface_variables: (Surface specific) Names of surface variables. surface_points_sample: (Surface specific) Number of surface points to sample per batch. num_surface_neighbors: (Surface specific) Number of surface neighbors to consider for nearest neighbors approach. - resample_surfaces: (Surface specific) Whether to resample the surface before kdtree/knn. Not available if caching. - resampling_points: (Surface specific) Number of points to resample the surface to. surface_sampling_algorithm: (Surface specific) Algorithm to use for surface sampling ("area_weighted" or "random"). surface_factors: (Surface specific) Non-dimensionalization factors for surface variables. If set, and scaling_type is: @@ -145,6 +87,9 @@ class DoMINODataConfig: attributes that are arraylike. volume_variables: (Volume specific) Names of volume variables. volume_points_sample: (Volume specific) Number of volume points to sample per batch. + volume_sample_from_disk: (Volume specific) If the volume data is in a shuffled state on disk, + read contiguous chunks of the data rather than the entire volume data. This greatly + accelerates IO in bandwidth limited systems or when the volumetric data is very large. volume_factors: (Volume specific) Non-dimensionalization factors for volume variables scaling. If set, and scaling_type is: - min_max_scaling -> rescale volume_fields to the min/max set here @@ -168,10 +113,6 @@ class DoMINODataConfig: - volume.points_sample geom_points_sample: Number of STL points sampled per batch. Independent of volume.points_sample and surface.points_sample. - positional_encoding: Whether to use positional encoding. Affects the calculation of: - - pos_volume_closest - - pos_volume_center_of_mass - - pos_surface_centter_of_mass scaling_type: Scaling type for volume variables. If used, will rescale the volume_fields and surface fields outputs. Requires volume.factor and surface.factor to be set. @@ -186,15 +127,13 @@ class DoMINODataConfig: You might choose gpu_preprocessing=True and gpu_output=False if caching. """ - data_path: Path + data_path: Path | None phase: Literal["train", "val", "test"] # Surface-specific variables: surface_variables: Optional[Sequence] = ("pMean", "wallShearStress") surface_points_sample: int = 1024 num_surface_neighbors: int = 11 - resample_surfaces: bool = False - resampling_points: int = 1_000_000 surface_sampling_algorithm: str = Literal["area_weighted", "random"] surface_factors: Optional[Sequence] = None bounding_box_dims_surf: Optional[Union[BoundingBox, Sequence]] = None @@ -202,15 +141,15 @@ class DoMINODataConfig: # Volume specific variables: volume_variables: Optional[Sequence] = ("UMean", "pMean") volume_points_sample: int = 1024 + volume_sample_from_disk: bool = False volume_factors: Optional[Sequence] = None bounding_box_dims: Optional[Union[BoundingBox, Sequence]] = None - grid_resolution: Union[Sequence, ArrayType] = (256, 96, 64) + grid_resolution: Sequence = (256, 96, 64) normalize_coordinates: bool = False sample_in_bbox: bool = False sampling: bool = False geom_points_sample: int = 300000 - positional_encoding: bool = False scaling_type: Optional[Literal["min_max_scaling", "mean_std_scaling"]] = None compute_scaling_factors: bool = False caching: bool = False @@ -219,16 +158,17 @@ class DoMINODataConfig: gpu_output: bool = True def __post_init__(self): - # Ensure data_path is a Path object: - if isinstance(self.data_path, str): - self.data_path = Path(self.data_path) - self.data_path = self.data_path.expanduser() + if self.data_path is not None: + # Ensure data_path is a Path object: + if isinstance(self.data_path, str): + self.data_path = Path(self.data_path) + self.data_path = self.data_path.expanduser() - if not self.data_path.exists(): - raise ValueError(f"Path {self.data_path} does not exist") + if not self.data_path.exists(): + raise ValueError(f"Path {self.data_path} does not exist") - if not self.data_path.is_dir(): - raise ValueError(f"Path {self.data_path} is not a directory") + if not self.data_path.is_dir(): + raise ValueError(f"Path {self.data_path} is not a directory") # Object if caching settings are impossible: if self.caching: @@ -236,8 +176,6 @@ def __post_init__(self): raise ValueError("Sampling should be False for caching") if self.compute_scaling_factors: raise ValueError("Compute scaling factors should be False for caching") - if self.resample_surfaces: - raise ValueError("Resample surface should be False for caching") if self.phase not in [ "train", @@ -258,559 +196,308 @@ def __post_init__(self): ##### TODO -# - put model type in config or leave in __init__ -# - check the bounding box protocol works +# - The SDF normalization here is based on using a normalized mesh and +# a normalized coordinate. The alternate method is to normalize to the min/max of the grid. class DoMINODataPipe(Dataset): """ Datapipe for DoMINO + Leverages a dataset for the actual reading of the data, and this + object is responsible for preprocessing the data. + """ def __init__( self, input_path, model_type: Literal["surface", "volume", "combined"], + pin_memory: bool = False, **data_config_overrides, ): # Perform config packaging and validation self.config = DoMINODataConfig(data_path=input_path, **data_config_overrides) + # Set up the distributed manager: if not DistributedManager.is_initialized(): DistributedManager.initialize() dist = DistributedManager() - if self.config.gpu_preprocessing or self.config.gpu_output: - # Make sure we move data to the right device: - target_device = dist.device.index - self.device_context = cp.cuda.Device(target_device) - self.device_context.use() - else: - self.device_context = nullcontext() - self.device = dist.device - - if self.config.deterministic: - np.random.seed(42) - cp.random.seed(42) - else: - np.random.seed(seed=int(time.time())) - cp.random.seed(seed=int(time.time())) + # Set devices for the preprocessing and IO target + self.preproc_device = ( + dist.device if self.config.gpu_preprocessing else torch.device("cpu") + ) + # The cae_dataset will automatically target this device + # In an async transfer. + self.output_device = ( + dist.device if self.config.gpu_output else torch.device("cpu") + ) + # Model type determines whether we process surface, volume, or both. self.model_type = model_type - self.filenames = get_filenames(self.config.data_path, exclude_dirs=True) - total_files = len(self.filenames) - - self.indices = np.array(range(total_files)) - - # Why shuffle the indices here if only using random access below? - - np.random.shuffle(self.indices) - - # Determine the array provider based on what device - # will do preprocessing: - self.array_provider = cp if self.config.gpu_preprocessing else np # Update the arrays for bounding boxes: - if hasattr(self.config.bounding_box_dims, "max") and hasattr( self.config.bounding_box_dims, "min" ): self.config.bounding_box_dims = [ - self.array_provider.asarray(self.config.bounding_box_dims.max).astype( - "float32" + torch.tensor( + self.config.bounding_box_dims.max, + device=self.preproc_device, + dtype=torch.float32, ), - self.array_provider.asarray(self.config.bounding_box_dims.min).astype( - "float32" + torch.tensor( + self.config.bounding_box_dims.min, + device=self.preproc_device, + dtype=torch.float32, ), ] + self.default_volume_grid = create_grid( + self.config.bounding_box_dims[0], + self.config.bounding_box_dims[1], + self.config.grid_resolution, + ) + + # And, do the surface bounding box if supplied: if hasattr(self.config.bounding_box_dims_surf, "max") and hasattr( self.config.bounding_box_dims_surf, "min" ): self.config.bounding_box_dims_surf = [ - self.array_provider.asarray( - self.config.bounding_box_dims_surf.max - ).astype("float32"), - self.array_provider.asarray( - self.config.bounding_box_dims_surf.min - ).astype("float32"), + torch.tensor( + self.config.bounding_box_dims_surf.max, + device=self.preproc_device, + dtype=torch.float32, + ), + torch.tensor( + self.config.bounding_box_dims_surf.min, + device=self.preproc_device, + dtype=torch.float32, + ), ] - # Used if threaded data is enabled: - self.max_workers = 24 - # Create a single thread pool for the class - self.executor = ThreadPoolExecutor(max_workers=self.max_workers) - - # Define here the keys to read for each __getitem__ call - - # Always read these keys - self.keys_to_read = ["stl_coordinates", "stl_centers", "stl_faces", "stl_areas"] - with self.device_context: - xp = self.array_provider - self.keys_to_read_if_available = { - "global_params_values": xp.asarray([30.0, 1.226]), - "global_params_reference": xp.asarray([30.0, 1.226]), - } - self.volume_keys = ["volume_mesh_centers", "volume_fields"] - self.surface_keys = [ - "surface_mesh_centers", - "surface_normals", - "surface_areas", - "surface_fields", - ] - - if self.model_type == "volume" or self.model_type == "combined": - self.keys_to_read.extend(self.volume_keys) - if self.model_type == "surface" or self.model_type == "combined": - self.keys_to_read.extend(self.surface_keys) - - def __del__(self): - # Clean up the executor when the instance is being destroyed - if hasattr(self, "executor"): - self.executor.shutdown() - - @profile - def read_data_zarr(self, filepath): - # def create_pinned_streaming_space(shape, dtype): - # # TODO - this function could boost performance a little, but - # # the pinned memory pool seems too small. - # if self.array_provider == cp: - # nbytes = np.prod(shape) * dtype.itemsize - # ptr = cp.cuda.alloc_pinned_memory(nbytes) - # arr = np.frombuffer(ptr, dtype) - # return arr.reshape(shape) - # else: - # return np.empty(shape, dtype=dtype) - - def read_chunk_into_array(ram_array, fs_zarr_array, slice): - ram_array[slice] = fs_zarr_array[slice] - - @profile - def chunked_aligned_read(zarr_group, key, futures): - zarr_array = zarr_group[key] - - shape = zarr_array.shape - chunk_size = zarr_array.chunks[0] - - # Pre-allocate the full result array - result_shape = zarr_array.shape - result_dtype = zarr_array.dtype - - result = np.empty(result_shape, dtype=result_dtype) - - for start in range(0, shape[0], chunk_size): - end = min(start + chunk_size, shape[0]) - read_slice = np.s_[start:end] - futures.append( - self.executor.submit( - read_chunk_into_array, result, zarr_array, read_slice - ) - ) - - return result + self.default_surface_grid = create_grid( + self.config.bounding_box_dims_surf[0], + self.config.bounding_box_dims_surf[1], + self.config.grid_resolution, + ) - with zarr.open_group(filepath, mode="r") as z: - data = {} - futures = [] - if "volume_fields" in z.keys(): - data["volume_fields"] = chunked_aligned_read( - z, "volume_fields", futures + # Ensure the volume and surface scaling factors are torch tensors + # and on the right device: + if self.config.volume_factors is not None: + if not isinstance(self.config.volume_factors, torch.Tensor): + self.config.volume_factors = torch.from_numpy( + self.config.volume_factors ) - if "volume_mesh_centers" in z.keys(): - data["volume_mesh_centers"] = chunked_aligned_read( - z, "volume_mesh_centers", futures + self.config.volume_factors = self.config.volume_factors.to( + self.preproc_device, dtype=torch.float32 + ) + if self.config.surface_factors is not None: + if not isinstance(self.config.surface_factors, torch.Tensor): + self.config.surface_factors = torch.from_numpy( + self.config.surface_factors ) + self.config.surface_factors = self.config.surface_factors.to( + self.preproc_device, dtype=torch.float32 + ) - for key in self.keys_to_read: - if z[key].shape == (): - data[key] = z[key] - elif key in ["volume_fields", "volume_mesh_centers"]: - continue - else: - data[key] = np.empty(z[key].shape, dtype=z[key].dtype) - slice = np.s_[:] - futures.append( - self.executor.submit( - read_chunk_into_array, data[key], z[key], slice - ) - ) - - # Now wait for all the futures to complete - for future in futures: - result = future.result() - if isinstance(result, tuple) and len(result) == 2: - key, value = result - data[key] = value - - # Move big data to GPU - for key in data.keys(): - data[key] = self.array_provider.asarray(data[key]) - - # Optional, maybe-present keys - for key in self.keys_to_read_if_available: - if key not in data.keys(): - data[key] = self.keys_to_read_if_available[key] - - return data - - @profile - def read_data_npy(self, filepath): - with open(filepath, "rb") as f: - data = np.load(f, allow_pickle=True).item() + self.dataset = None - for key in self.keys_to_read_if_available: - if key not in data.keys(): - data[key] = self.keys_to_read_if_available[key] + def compute_stl_scaling_and_surface_grids( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the min and max for the defining mesh. - if "filename" in data.keys(): - data.pop("filename", None) + If the user supplies a bounding box, we use that. Otherwise, + it raises an error. - if not (isinstance(data["stl_coordinates"], np.ndarray)): - data["stl_coordinates"] = np.asarray(data["stl_coordinates"]) + The returned min/max and grid are used for surface data. + """ - # Maybe move to GPU: - with self.device_context: - for key in data.keys(): - if data[key] is not None: - data[key] = self.array_provider.asarray(data[key]) - return data + # Check the bounding box is not unit length - @profile - def read_data_npz( - self, - filepath, - max_workers=None, - ): - if max_workers is not None: - self.max_workers = max_workers - - def load_one(key): - with np.load(filepath) as data: - return key, data[key] - - def check_optional_keys(): - with np.load(filepath) as data: - optional_results = {} - for key in self.keys_to_read_if_available: - if key in data.keys(): - optional_results[key] = data[key] - else: - optional_results[key] = self.keys_to_read_if_available[key] - with self.device_context: - optional_results = { - key: self.array_provider.asarray(value) - for key, value in optional_results.items() - } - return optional_results - - # Use the class-level executor instead of creating a new one - results = dict(self.executor.map(load_one, self.keys_to_read)) - - # Move the results to the GPU: - with self.device_context: - for key in results.keys(): - results[key] = self.array_provider.asarray(results[key]) - - # Check the optional ones: - optional_results = check_optional_keys() - results.update(optional_results) - - return results + if self.config.bounding_box_dims_surf is not None: + s_max = self.config.bounding_box_dims_surf[0] + s_min = self.config.bounding_box_dims_surf[1] + surf_grid = self.default_surface_grid + else: + raise ValueError("Bounding box dimensions are not set in config") - def __len__(self): - return len(self.indices) + return s_min, s_max, surf_grid - @profile - def preprocess_combined(self, data_dict): - # Pull these out and force to fp32: - with self.device_context: - global_params_values = data_dict["global_params_values"].astype( - self.array_provider.float32 - ) - global_params_reference = data_dict["global_params_reference"].astype( - self.array_provider.float32 - ) + def compute_volume_scaling_and_grids( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the min and max and grid for volume data. - # Pull these pieces out of the data_dict for manipulation - stl_vertices = data_dict["stl_coordinates"] - stl_centers = data_dict["stl_centers"] - mesh_indices_flattened = data_dict["stl_faces"] - stl_sizes = data_dict["stl_areas"] - idx = np.where(stl_sizes > 0.0) - stl_sizes = stl_sizes[idx] - stl_centers = stl_centers[idx] + If the user supplies a bounding box, we use that. Otherwise, + it raises an error. - xp = self.array_provider + """ - # Make sure the mesh_indices_flattened is an integer array: - if mesh_indices_flattened.dtype != xp.int32: - mesh_indices_flattened = mesh_indices_flattened.astype(xp.int32) + # Determine the volume min / max locations + if self.config.bounding_box_dims is not None: + c_max = self.config.bounding_box_dims[0] + c_min = self.config.bounding_box_dims[1] + volume_grid = self.default_volume_grid + else: + raise ValueError("Bounding box dimensions are not set in config") - center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes) + return c_min, c_max, volume_grid - if self.config.bounding_box_dims_surf is None: - s_max = xp.amax(stl_vertices, 0) - s_min = xp.amin(stl_vertices, 0) - else: - s_max = xp.asarray(self.config.bounding_box_dims_surf[0]) - s_min = xp.asarray(self.config.bounding_box_dims_surf[1]) - - # SDF calculation on the grid using WARP - if not self.config.compute_scaling_factors: - nx, ny, nz = self.config.grid_resolution - surf_grid = create_grid(s_max, s_min, [nx, ny, nz]) - surf_grid_reshaped = surf_grid.reshape(nx * ny * nz, 3) - - sdf_surf_grid, _ = signed_distance_field( - _convert_array_to_torch(stl_vertices), - _convert_array_to_torch(mesh_indices_flattened), - _convert_array_to_torch(surf_grid_reshaped), - use_sign_winding_number=True, - ) - sdf_surf_grid = sdf_surf_grid.reshape(nx, ny, nz) - sdf_surf_grid = _convert_torch_to_array(sdf_surf_grid, self.array_provider) + @profile + def downsample_geometry( + self, + stl_vertices, + ) -> torch.Tensor: + """ + Downsample the geometry to the desired number of points. - else: - surf_grid = None - sdf_surf_grid = None + Args: + stl_vertices: The vertices of the surface. + """ if self.config.sampling: - # nvtx.range_push("Geometry Sampling") geometry_points = self.config.geom_points_sample + geometry_coordinates_sampled, idx_geometry = shuffle_array( stl_vertices, geometry_points ) if geometry_coordinates_sampled.shape[0] < geometry_points: - geometry_coordinates_sampled = pad( - geometry_coordinates_sampled, geometry_points, pad_value=-100.0 + raise ValueError( + "Surface mesh has fewer points than requested sample size" ) geom_centers = geometry_coordinates_sampled - # nvtx.range_pop() else: geom_centers = stl_vertices - # geom_centers = self.array_provider.float32(geom_centers) - - surf_grid_max_min = xp.stack([s_min, s_max]) - - return_dict = { - "surf_grid": surf_grid, - "sdf_surf_grid": sdf_surf_grid, - "surface_min_max": surf_grid_max_min, - "global_params_values": xp.expand_dims( - xp.array(global_params_values, dtype=xp.float32), -1 - ), - "global_params_reference": xp.expand_dims( - xp.array(global_params_reference, dtype=xp.float32), -1 - ), - "geometry_coordinates": geom_centers, - } - - return ( - return_dict, - s_min, - s_max, - mesh_indices_flattened, - stl_vertices, - center_of_mass, - ) + return geom_centers - @profile - def preprocess_surface(self, data_dict, core_dict, center_of_mass, s_min, s_max): + def process_surface( + self, + s_min: torch.Tensor, + s_max: torch.Tensor, + c_min: torch.Tensor, + c_max: torch.Tensor, + *, # Forcing the rest by keyword only since it's a long list ... + center_of_mass: torch.Tensor, + surf_grid: torch.Tensor, + surface_coordinates: torch.Tensor, + surface_normals: torch.Tensor, + surface_sizes: torch.Tensor, + stl_vertices: torch.Tensor, + stl_indices: torch.Tensor, + surface_fields: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: nx, ny, nz = self.config.grid_resolution return_dict = {} - surface_coordinates = data_dict["surface_mesh_centers"] - surface_normals = data_dict["surface_normals"] - surface_sizes = data_dict["surface_areas"] - surface_fields = data_dict["surface_fields"] - idx = np.where(surface_sizes > 0) + ######################################################################## + # Remove any sizes <= 0: + ######################################################################## + idx = surface_sizes > 0 surface_sizes = surface_sizes[idx] - surface_fields = surface_fields[idx] surface_normals = surface_normals[idx] surface_coordinates = surface_coordinates[idx] + if surface_fields is not None: + surface_fields = surface_fields[idx] + + ######################################################################## + # Reject surface points outside of the Bounding Box + # NOTE - this is using the VOLUME bounding box! + ######################################################################## + if self.config.sample_in_bbox: + ids_min = surface_coordinates[:] > c_min + ids_max = surface_coordinates[:] < c_max + + ids_in_bbox = ids_min & ids_max + ids_in_bbox = ids_in_bbox.all(dim=-1) + + surface_coordinates = surface_coordinates[ids_in_bbox] + surface_normals = surface_normals[ids_in_bbox] + surface_sizes = surface_sizes[ids_in_bbox] + if surface_fields is not None: + surface_fields = surface_fields[ids_in_bbox] + + ######################################################################## + # Perform Down sampling of the surface fields. + # Note that we snapshot the full surface coordinates for + # use in the kNN in the next step. + ######################################################################## - xp = self.array_provider + full_surface_coordinates = surface_coordinates + full_surface_normals = surface_normals + full_surface_sizes = surface_sizes - if self.config.resample_surfaces: - if self.config.resampling_points > surface_coordinates.shape[0]: - resampling_points = surface_coordinates.shape[0] + if self.config.sampling: + # Perform the down sampling: + if self.config.surface_sampling_algorithm == "area_weighted": + weights = surface_sizes else: - resampling_points = self.config.resampling_points + weights = None - surface_coordinates, idx_s = shuffle_array( - surface_coordinates, resampling_points + surface_coordinates_sampled, idx_surface = shuffle_array( + surface_coordinates, + self.config.surface_points_sample, + weights=weights, ) - surface_normals = surface_normals[idx_s] - surface_sizes = surface_sizes[idx_s] - surface_fields = surface_fields[idx_s] - - if not self.config.compute_scaling_factors: - c_max = self.config.bounding_box_dims[0] - c_min = self.config.bounding_box_dims[1] - - if self.config.sample_in_bbox: - # TODO - clean this up with vectorization? - # TODO - the xp.where is likely a useless op. Need to check. - ids_in_bbox = xp.where( - (surface_coordinates[:, 0] > c_min[0]) - & (surface_coordinates[:, 0] < c_max[0]) - & (surface_coordinates[:, 1] > c_min[1]) - & (surface_coordinates[:, 1] < c_max[1]) - & (surface_coordinates[:, 2] > c_min[2]) - & (surface_coordinates[:, 2] < c_max[2]) - ) - surface_coordinates = surface_coordinates[ids_in_bbox] - surface_normals = surface_normals[ids_in_bbox] - surface_sizes = surface_sizes[ids_in_bbox] - surface_fields = surface_fields[ids_in_bbox] - # Compute the positional encoding before sampling - if self.config.positional_encoding: - dx, dy, dz = ( - (s_max[0] - s_min[0]) / nx, - (s_max[1] - s_min[1]) / ny, - (s_max[2] - s_min[2]) / nz, - ) - pos_normals_com_surface = calculate_normal_positional_encoding( - surface_coordinates, center_of_mass, cell_dimensions=[dx, dy, dz] - ) - else: - pos_normals_com_surface = surface_coordinates - xp.asarray( - center_of_mass + if surface_coordinates_sampled.shape[0] < self.config.surface_points_sample: + raise ValueError( + "Surface mesh has fewer points than requested sample size" ) - # Fit the kNN (or KDTree, if CPU) on ALL points: - if self.config.num_surface_neighbors > 1: - if self.array_provider == cp: - knn = cuml.neighbors.NearestNeighbors( - n_neighbors=self.config.num_surface_neighbors, - algorithm="rbc", - ) - knn.fit(surface_coordinates) - else: - # Under the hood this is instantiating a KDTree. - # aka here knn is a type, not a class, technically. - interp_func = KDTree(surface_coordinates) - - if self.config.sampling: - # Perform the down sampling: - if self.config.surface_sampling_algorithm == "area_weighted": - ( - surface_coordinates_sampled, - idx_surface, - ) = area_weighted_shuffle_array( - surface_coordinates, - self.config.surface_points_sample, - surface_sizes, - ) - elif self.config.surface_sampling_algorithm == "solution_weighted": - ( - surface_coordinates_sampled, - idx_surface, - ) = solution_weighted_shuffle_array( - surface_coordinates, - self.config.surface_points_sample, - surface_fields[:, 0], - scaling_factor=0.5, - ) - else: - surface_coordinates_sampled, idx_surface = shuffle_array( - surface_coordinates, self.config.surface_points_sample - ) - - if ( - surface_coordinates_sampled.shape[0] - < self.config.surface_points_sample - ): - surface_coordinates_sampled = pad( - surface_coordinates_sampled, - self.config.surface_points_sample, - pad_value=-10.0, - ) - - # Select out the sampled points for non-neighbor arrays: + # Select out the sampled points for non-neighbor arrays: + if surface_fields is not None: surface_fields = surface_fields[idx_surface] - pos_normals_com_surface = pos_normals_com_surface[idx_surface] - - # Now, perform the kNN on the sampled points: - if self.config.num_surface_neighbors > 1: - if self.array_provider == cp: - ii = knn.kneighbors( - surface_coordinates_sampled, return_distance=False - ) - else: - _, ii = interp_func.query( - surface_coordinates_sampled, - k=self.config.num_surface_neighbors, - ) - - # Pull out the neighbor elements. Note that ii is the index into the original - # points - but only exists for the sampled points - # In other words, a point from `surface_coordinates_sampled` has neighbors - # from the full `surface_coordinates` array. - surface_neighbors = surface_coordinates[ii][:, 1:] - surface_neighbors_normals = surface_normals[ii][:, 1:] - surface_neighbors_sizes = surface_sizes[ii][:, 1:] - else: - surface_neighbors = surface_coordinates - surface_neighbors_normals = surface_normals - surface_neighbors_sizes = surface_sizes - - # We could index into these above the knn step too; they aren't dependent on that. - surface_normals = surface_normals[idx_surface] - surface_sizes = surface_sizes[idx_surface] - - # Update the coordinates to the sampled points: - surface_coordinates = surface_coordinates_sampled - - else: - # We are *not* sampling, kNN on ALL points: - if self.array_provider == cp: - ii = knn.kneighbors(surface_coordinates, return_distance=False) - else: - _, ii = interp_func.query( - surface_coordinates, - k=self.config.num_surface_neighbors, - ) - - # Construct the neighbors arrays: - surface_neighbors = surface_coordinates[ii][:, 1:] - surface_neighbors_normals = surface_normals[ii][:, 1:] - surface_neighbors_sizes = surface_sizes[ii][:, 1:] - - # Have to normalize neighbors after the kNN and sampling - if self.config.normalize_coordinates: - core_dict["surf_grid"] = normalize(core_dict["surf_grid"], s_max, s_min) - surface_coordinates = normalize(surface_coordinates, s_max, s_min) - surface_neighbors = normalize(surface_neighbors, s_max, s_min) - - if self.config.scaling_type is not None: - if self.config.surface_factors is not None: - if self.config.scaling_type == "mean_std_scaling": - surf_mean = self.config.surface_factors[0] - surf_std = self.config.surface_factors[1] - # TODO - Are these array calls needed? - surface_fields = standardize( - surface_fields, xp.asarray(surf_mean), xp.asarray(surf_std) - ) - elif self.config.scaling_type == "min_max_scaling": - surf_min = self.config.surface_factors[1] - surf_max = self.config.surface_factors[0] - # TODO - Are these array calls needed? - surface_fields = normalize( - surface_fields, xp.asarray(surf_max), xp.asarray(surf_min) - ) + # Subsample the normals and sizes: + surface_normals = surface_normals[idx_surface] + surface_sizes = surface_sizes[idx_surface] + # Update the coordinates to the sampled points: + surface_coordinates = surface_coordinates_sampled + + ######################################################################## + # Perform a kNN on the surface to find the neighbor information + ######################################################################## + if self.config.num_surface_neighbors > 1: + # Perform the kNN: + neighbor_indices, neighbor_distances = knn( + points=full_surface_coordinates, + queries=surface_coordinates, + k=self.config.num_surface_neighbors, + ) + # print(f"Full surface coordinates shape: {full_surface_coordinates.shape}") + # Pull out the neighbor elements. + # Note that `neighbor_indices` is the index into the original, + # full sized tensors (full_surface_coordinates, etc). + surface_neighbors = full_surface_coordinates[neighbor_indices][:, 1:] + surface_neighbors_normals = full_surface_normals[neighbor_indices][:, 1:] + surface_neighbors_sizes = full_surface_sizes[neighbor_indices][:, 1:] else: - surface_sizes = None - surface_normals = None - surface_neighbors = None - surface_neighbors_normals = None - surface_neighbors_sizes = None - pos_normals_com_surface = None + surface_neighbors = surface_coordinates + surface_neighbors_normals = surface_normals + surface_neighbors_sizes = surface_sizes + + # Better to normalize everything after the kNN and sampling + if self.config.normalize_coordinates: + surface_coordinates = normalize(surface_coordinates, s_max, s_min) + surface_neighbors = normalize(surface_neighbors, s_max, s_min) + center_of_mass = normalize(center_of_mass, s_max, s_min) + + pos_normals_com_surface = surface_coordinates - center_of_mass + + ######################################################################## + # Apply scaling to the targets, if desired: + ######################################################################## + if self.config.scaling_type is not None and surface_fields is not None: + surface_fields = self.scale_model_targets( + surface_fields, self.config.surface_factors + ) return_dict.update( { @@ -821,477 +508,469 @@ def preprocess_surface(self, data_dict, core_dict, center_of_mass, s_min, s_max) "surface_neighbors_normals": surface_neighbors_normals, "surface_areas": surface_sizes, "surface_neighbors_areas": surface_neighbors_sizes, - "surface_fields": surface_fields, } ) + if surface_fields is not None: + return_dict["surface_fields"] = surface_fields return return_dict - @profile - def preprocess_volume( + def process_volume( self, - data_dict, - core_dict, - s_min, - s_max, - mesh_indices_flattened, - stl_vertices, - center_of_mass, - ): - return_dict = {} - - nx, ny, nz = self.config.grid_resolution + c_min: torch.Tensor, + c_max: torch.Tensor, + volume_coordinates: torch.Tensor, + volume_grid: torch.Tensor, + center_of_mass: torch.Tensor, + stl_vertices: torch.Tensor, + stl_indices: torch.Tensor, + volume_fields: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + """ + Preprocess the volume data. - xp = self.array_provider + First, if configured, we reject points not in the volume bounding box. - # # Temporary: convert to cupy here: - volume_coordinates = data_dict["volume_mesh_centers"] - volume_fields = data_dict["volume_fields"] + Next, if sampling is enabled, we sample the volume points and apply that + sampling to the ground truth too, if it's present. - if not self.config.compute_scaling_factors: - if self.config.bounding_box_dims is None: - c_max = s_max + (s_max - s_min) / 2 - c_min = s_min - (s_max - s_min) / 2 - c_min[2] = s_min[2] - else: - c_max = xp.asarray(self.config.bounding_box_dims[0]) - c_min = xp.asarray(self.config.bounding_box_dims[1]) - - if self.config.sample_in_bbox: - # TODO - xp.where can probably be removed. - ids_in_bbox = self.array_provider.where( - (volume_coordinates[:, 0] > c_min[0]) - & (volume_coordinates[:, 0] < c_max[0]) - & (volume_coordinates[:, 1] > c_min[1]) - & (volume_coordinates[:, 1] < c_max[1]) - & (volume_coordinates[:, 2] > c_min[2]) - & (volume_coordinates[:, 2] < c_max[2]) - ) - volume_coordinates = volume_coordinates[ids_in_bbox] + """ + ######################################################################## + # Reject points outside the volumetric BBox + ######################################################################## + if self.config.sample_in_bbox: + # Remove points in the volume that are outside + # of the bbox area. + min_check = volume_coordinates[:] > c_min + max_check = volume_coordinates[:] < c_max + + ids_in_bbox = min_check & max_check + ids_in_bbox = ids_in_bbox.all(dim=1) + + volume_coordinates = volume_coordinates[ids_in_bbox] + if volume_fields is not None: volume_fields = volume_fields[ids_in_bbox] - dx, dy, dz = ( - (c_max[0] - c_min[0]) / nx, - (c_max[1] - c_min[1]) / ny, - (c_max[2] - c_min[2]) / nz, - ) + ######################################################################## + # Apply sampling to the volume coordinates and fields + ######################################################################## - # Generate a grid of specified resolution to map the bounding box - # The grid is used for capturing structured geometry features and SDF representation of geometry - grid = create_grid(c_max, c_min, [nx, ny, nz]) - grid_reshaped = grid.reshape(nx * ny * nz, 3) - - # SDF calculation on the grid using WARP - sdf_grid, _ = signed_distance_field( - _convert_array_to_torch(stl_vertices), - _convert_array_to_torch(mesh_indices_flattened), - _convert_array_to_torch(grid_reshaped), - use_sign_winding_number=True, + # If the volume data has been sampled from disk, directly, then + # still apply sampling. We over-pull from disk deliberately. + if self.config.sampling: + # Generate a series of idx to sample the volume + # without replacement + volume_coordinates_sampled, idx_volume = shuffle_array( + volume_coordinates, self.config.volume_points_sample ) - sdf_grid = sdf_grid.reshape((nx, ny, nz)) - sdf_grid = _convert_torch_to_array(sdf_grid, self.array_provider) - - if self.config.sampling: - volume_coordinates_sampled, idx_volume = shuffle_array( - volume_coordinates, self.config.volume_points_sample + volume_coordinates_sampled = volume_coordinates[idx_volume] + # In case too few points are in the sampled data (because the + # inputs were too few), pad the outputs: + if volume_coordinates_sampled.shape[0] < self.config.volume_points_sample: + raise ValueError( + "Volume mesh has fewer points than requested sample size" ) - if ( - volume_coordinates_sampled.shape[0] - < self.config.volume_points_sample - ): - volume_coordinates_sampled = pad( - volume_coordinates_sampled, - self.config.volume_points_sample, - pad_value=-10.0, - ) + + # Apply the same sampling to the targets, too: + if volume_fields is not None: volume_fields = volume_fields[idx_volume] - volume_coordinates = volume_coordinates_sampled - sdf_nodes, sdf_node_closest_point = signed_distance_field( - _convert_array_to_torch(stl_vertices), - _convert_array_to_torch(mesh_indices_flattened), - _convert_array_to_torch(volume_coordinates), - use_sign_winding_number=True, + volume_coordinates = volume_coordinates_sampled + + ######################################################################## + # Apply normalization to the coordinates, if desired: + ######################################################################## + if self.config.normalize_coordinates: + volume_coordinates = normalize(volume_coordinates, c_max, c_min) + grid = normalize(volume_grid, c_max, c_min) + normed_vertices = normalize(stl_vertices, c_max, c_min) + center_of_mass = normalize(center_of_mass, c_max, c_min) + else: + grid = volume_grid + normed_vertices = stl_vertices + center_of_mass = center_of_mass + + ######################################################################## + # Apply scaling to the targets, if desired: + ######################################################################## + if self.config.scaling_type is not None and volume_fields is not None: + volume_fields = self.scale_model_targets( + volume_fields, self.config.volume_factors ) - sdf_nodes = _convert_torch_to_array(sdf_nodes, self.array_provider) - sdf_node_closest_point = _convert_torch_to_array( - sdf_node_closest_point, self.array_provider + + ######################################################################## + # Compute Signed Distance Function for volumetric quantities + # Note - the SDF happens here, after volume data processing finishes, + # because we need to use the (maybe) normalized volume coordinates and grid + ######################################################################## + + # SDF calculation on the volume grid using WARP + sdf_grid, _ = signed_distance_field( + normed_vertices, + stl_indices, + grid, + use_sign_winding_number=True, + ) + + # Get the SDF of all the selected volume coordinates, + # And keep the closest point to each one. + sdf_nodes, sdf_node_closest_point = signed_distance_field( + normed_vertices, + stl_indices, + volume_coordinates, + use_sign_winding_number=True, + ) + sdf_nodes = sdf_nodes.reshape((-1, 1)) + + # Use the closest point from the mesh to compute the volume encodings: + pos_normals_closest_vol, pos_normals_com_vol = self.calculate_volume_encoding( + volume_coordinates, sdf_node_closest_point, center_of_mass + ) + + return_dict = { + "volume_mesh_centers": volume_coordinates, + "sdf_nodes": sdf_nodes, + "grid": grid, + "sdf_grid": sdf_grid, + "pos_volume_closest": pos_normals_closest_vol, + "pos_volume_center_of_mass": pos_normals_com_vol, + } + + if volume_fields is not None: + return_dict["volume_fields"] = volume_fields + + return return_dict + + def calculate_volume_encoding( + self, + volume_coordinates: torch.Tensor, + sdf_node_closest_point: torch.Tensor, + center_of_mass: torch.Tensor, + ): + pos_normals_closest_vol = volume_coordinates - sdf_node_closest_point + pos_normals_com_vol = volume_coordinates - center_of_mass + + return pos_normals_closest_vol, pos_normals_com_vol + + @torch.no_grad() + def process_data(self, data_dict): + # Validate that all required keys are present in data_dict + required_keys = [ + "global_params_values", + "global_params_reference", + "stl_coordinates", + "stl_faces", + "stl_centers", + "stl_areas", + ] + missing_keys = [key for key in required_keys if key not in data_dict] + if missing_keys: + raise ValueError( + f"Missing required keys in data_dict: {missing_keys}. " + f"Required keys are: {required_keys}" ) - # TODO - is this needed? - sdf_nodes = xp.asarray(sdf_nodes) - sdf_node_closest_point = xp.asarray(sdf_node_closest_point) + # Start building the preprocessed return dict: + return_dict = { + "global_params_values": data_dict["global_params_values"], + "global_params_reference": data_dict["global_params_reference"], + } - sdf_nodes = sdf_nodes.reshape((-1, 1)) + ######################################################################## + # Process the core STL information + ######################################################################## - if self.config.positional_encoding: - pos_normals_closest_vol = calculate_normal_positional_encoding( - volume_coordinates, - sdf_node_closest_point, - cell_dimensions=[dx, dy, dz], - ) - pos_normals_com_vol = calculate_normal_positional_encoding( - volume_coordinates, center_of_mass, cell_dimensions=[dx, dy, dz] - ) - else: - pos_normals_closest_vol = volume_coordinates - sdf_node_closest_point - pos_normals_com_vol = volume_coordinates - center_of_mass + # This function gets information about the surface scale, + # and decides what the surface grid will be: - if self.config.normalize_coordinates: - volume_coordinates = normalize(volume_coordinates, c_max, c_min) - grid = normalize(grid, c_max, c_min) - - if self.config.scaling_type is not None: - if self.config.volume_factors is not None: - if self.config.scaling_type == "mean_std_scaling": - vol_mean = self.config.volume_factors[0] - vol_std = self.config.volume_factors[1] - volume_fields = standardize(volume_fields, vol_mean, vol_std) - elif self.config.scaling_type == "min_max_scaling": - vol_min = xp.asarray(self.config.volume_factors[1]) - vol_max = xp.asarray(self.config.volume_factors[0]) - volume_fields = normalize(volume_fields, vol_max, vol_min) - - vol_grid_max_min = xp.stack([c_min, c_max]) + stl_coordinates = data_dict["stl_coordinates"] + + s_min, s_max, surf_grid = self.compute_stl_scaling_and_surface_grids() + if isinstance(stl_coordinates, ShardTensor): + mesh = stl_coordinates._spec.mesh + # Then, replicate the bounding box along the mesh if present. + s_max = scatter_tensor( + s_max, + 0, + mesh=mesh, + placements=[ + Replicate(), + ], + global_shape=s_max.shape, + dtype=s_max.dtype, + requires_grad=False, + ) + s_min = scatter_tensor( + s_min, + 0, + mesh=mesh, + placements=[ + Replicate(), + ], + global_shape=s_min.shape, + dtype=s_min.dtype, + requires_grad=False, + ) + surf_grid = scatter_tensor( + surf_grid, + 0, + mesh=mesh, + placements=[ + Replicate(), + ], + global_shape=surf_grid.shape, + dtype=surf_grid.dtype, + requires_grad=False, + ) + + # We always need to calculate the SDF on the surface grid: + # This is for the SDF Later: + if self.config.normalize_coordinates: + normed_vertices = normalize(data_dict["stl_coordinates"], s_max, s_min) + surf_grid = normalize(surf_grid, s_max, s_min) else: - pos_normals_closest_vol = None - pos_normals_com_vol = None - sdf_nodes = None - sdf_grid = None - grid = None - vol_grid_max_min = None + normed_vertices = data_dict["stl_coordinates"] - return_dict.update( - { - "pos_volume_closest": pos_normals_closest_vol, - "pos_volume_center_of_mass": pos_normals_com_vol, - "grid": grid, - "sdf_grid": sdf_grid, - "sdf_nodes": sdf_nodes, - "volume_fields": volume_fields, - "volume_mesh_centers": volume_coordinates, - "volume_min_max": vol_grid_max_min, - } + # For SDF calculations, make sure the mesh_indices_flattened is an integer array: + mesh_indices_flattened = data_dict["stl_faces"].to(torch.int32) + + # Compute signed distance function for the surface grid: + sdf_surf_grid, _ = signed_distance_field( + mesh_vertices=normed_vertices, + mesh_indices=mesh_indices_flattened, + input_points=surf_grid, + use_sign_winding_number=True, ) + return_dict["sdf_surf_grid"] = sdf_surf_grid + return_dict["surf_grid"] = surf_grid - return return_dict + # Store this only if normalization is active: + if self.config.normalize_coordinates: + return_dict["surface_min_max"] = torch.stack([s_min, s_max]) - @profile - def preprocess_data(self, data_dict): - ( - return_dict, - s_min, - s_max, - mesh_indices_flattened, - stl_vertices, - center_of_mass, - ) = self.preprocess_combined(data_dict) + # This is a center of mass computation for the stl surface, + # using the size of each mesh point as weight. + center_of_mass = calculate_center_of_mass( + data_dict["stl_centers"], data_dict["stl_areas"] + ) - if self.model_type == "volume" or self.model_type == "combined": - volume_dict = self.preprocess_volume( - data_dict, - return_dict, + # This will apply downsampling if needed to the geometry coordinates + geom_centers = self.downsample_geometry( + stl_vertices=data_dict["stl_coordinates"], + ) + return_dict["geometry_coordinates"] = geom_centers + + ######################################################################## + # Determine the volumetric bounds of the data: + ######################################################################## + # Compute the min/max for volume an the unnomralized grid: + c_min, c_max, volume_grid = self.compute_volume_scaling_and_grids() + + ######################################################################## + # Process the surface data + ######################################################################## + if self.model_type == "surface" or self.model_type == "combined": + surface_fields_raw = ( + data_dict["surface_fields"] if "surface_fields" in data_dict else None + ) + surface_dict = self.process_surface( s_min, s_max, - mesh_indices_flattened, - stl_vertices, - center_of_mass, + c_min, + c_max, + center_of_mass=center_of_mass, + surf_grid=surf_grid, + surface_coordinates=data_dict["surface_mesh_centers"], + surface_normals=data_dict["surface_normals"], + surface_sizes=data_dict["surface_areas"], + stl_vertices=data_dict["stl_coordinates"], + stl_indices=mesh_indices_flattened, + surface_fields=surface_fields_raw, ) - return_dict.update(volume_dict) + return_dict.update(surface_dict) - if self.model_type == "surface" or self.model_type == "combined": - surface_dict = self.preprocess_surface( - data_dict, return_dict, center_of_mass, s_min, s_max + ######################################################################## + # Process the volume data + ######################################################################## + # For volume data, we store this only if normalizing coordinates: + if self.model_type == "volume" or self.model_type == "combined": + if self.config.normalize_coordinates: + return_dict["volume_min_max"] = torch.stack([c_min, c_max]) + + if self.model_type == "volume" or self.model_type == "combined": + volume_fields_raw = ( + data_dict["volume_fields"] if "volume_fields" in data_dict else None ) - return_dict.update(surface_dict) + volume_dict = self.process_volume( + c_min, + c_max, + volume_coordinates=data_dict["volume_mesh_centers"], + volume_grid=volume_grid, + center_of_mass=center_of_mass, + stl_vertices=data_dict["stl_coordinates"], + stl_indices=mesh_indices_flattened, + volume_fields=volume_fields_raw, + ) + + return_dict.update(volume_dict) return return_dict - @profile + def scale_model_targets( + self, fields: torch.Tensor, factors: torch.Tensor + ) -> torch.Tensor: + """ + Scale the model targets based on the configured scaling factors. + """ + if self.config.scaling_type == "mean_std_scaling": + field_mean = factors[0] + field_std = factors[1] + return standardize(fields, field_mean, field_std) + elif self.config.scaling_type == "min_max_scaling": + field_min = factors[1] + field_max = factors[0] + return normalize(fields, field_max, field_min) + + def unscale_model_outputs( + self, + volume_fields: torch.Tensor | None = None, + surface_fields: torch.Tensor | None = None, + ): + """ + Unscale the model outputs based on the configured scaling factors. + + The unscaling is included here to make it a consistent interface regardless + of the scaling factors and type used. + + """ + + if volume_fields is not None: + if self.config.scaling_type == "mean_std_scaling": + vol_mean = self.config.volume_factors[0] + vol_std = self.config.volume_factors[1] + volume_fields = unstandardize(volume_fields, vol_mean, vol_std) + elif self.config.scaling_type == "min_max_scaling": + vol_min = self.config.volume_factors[1] + vol_max = self.config.volume_factors[0] + volume_fields = unnormalize(volume_fields, vol_max, vol_min) + if surface_fields is not None: + if self.config.scaling_type == "mean_std_scaling": + surf_mean = self.config.surface_factors[0] + surf_std = self.config.surface_factors[1] + surface_fields = unstandardize(surface_fields, surf_mean, surf_std) + elif self.config.scaling_type == "min_max_scaling": + surf_min = self.config.surface_factors[1] + surf_max = self.config.surface_factors[0] + surface_fields = unnormalize(surface_fields, surf_max, surf_min) + + return volume_fields, surface_fields + + def set_dataset(self, dataset: Iterable) -> None: + """ + Pass a dataset to the datapipe to enable iterating over both in one pass. + """ + self.dataset = dataset + + if self.config.volume_sample_from_disk: + # We deliberately double the data to read compared to the sampling size: + self.dataset.set_volume_sampling_size( + 100 * self.config.volume_points_sample + ) + + def __len__(self): + if self.dataset is not None: + return len(self.dataset) + else: + return 0 + def __getitem__(self, idx): """ Function for fetching and processing a single file's data. Domino, in general, expects one example per file and the files are relatively large due to the mesh size. + + Requires the user to have set a dataset via `set_dataset`. """ + if self.dataset is None: + raise ValueError("Dataset is not present") - if self.config.deterministic: - self.array_provider.random.seed(idx) - # But also always set numpy: - np.random.seed(idx) + # Get the data from the dataset. + # Under the hood, this may be fetching preloaded data. + data_dict = self.dataset[idx] - index = self.indices[idx] - cfd_filename = self.filenames[index] + return self.__call__(data_dict) - # Get all of the data: - filepath = self.config.data_path / cfd_filename + def __call__(self, data_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Process the incoming data dictionary. + - Processes the data + - moves it to GPU + - adds a batch dimension - if filepath.suffix == ".zarr": - data_dict = self.read_data_zarr(filepath) - elif filepath.suffix == ".npz": - data_dict = self.read_data_npz(filepath) - elif filepath.suffix == ".npy": - data_dict = self.read_data_npy(filepath) - else: - raise ValueError(f"Unsupported file extension: {filepath.suffix}") - - return_dict = self.preprocess_data(data_dict) - - # return only pytorch tensor objects. - # If returning on CPU (but processed on GPU), convert below. - # This assumes we keep the data on the device it's on. - for key, value in return_dict.items(): - if isinstance(value, np.ndarray): - return_dict[key] = torch.from_numpy(value) - elif isinstance(value, cp.ndarray): - return_dict[key] = torch.utils.dlpack.from_dlpack(value.toDlpack()) - - if self.config.gpu_output: - # Make sure this is all on the GPU. - # Everything here should be a torch tensor now. - for key, value in return_dict.items(): - if isinstance(value, torch.Tensor) and not value.is_cuda: - return_dict[key] = value.pin_memory().to(self.device) - else: - # Make sure everything is on the CPU. - for key, value in return_dict.items(): - if isinstance(value, torch.Tensor) and value.is_cuda: - return_dict[key] = value.cpu() + Args: + data_dict: Dictionary containing the data to process as torch.Tensors. - return return_dict + Returns: + Dictionary containing the processed data as torch.Tensors. + """ + data_dict = self.process_data(data_dict) -@profile -def compute_scaling_factors(cfg: DictConfig, input_path: str, use_cache: bool) -> None: - model_type = cfg.model.model_type - max_scaling_factor_files = 20 - - if model_type == "volume" or model_type == "combined": - vol_save_path = os.path.join(cfg.project_dir, "volume_scaling_factors.npy") - if not os.path.exists(vol_save_path): - print("Computing volume scaling factors") - volume_variable_names = list(cfg.variables.volume.solution.keys()) - - fm_dict = DoMINODataPipe( - input_path, - phase="train", - grid_resolution=cfg.model.interp_res, - volume_variables=volume_variable_names, - surface_variables=None, - normalize_coordinates=True, - sampling=False, - sample_in_bbox=True, - volume_points_sample=cfg.model.volume_points_sample, - geom_points_sample=cfg.model.geom_points_sample, - positional_encoding=cfg.model.positional_encoding, - model_type=cfg.model.model_type, - bounding_box_dims=cfg.data.bounding_box, - bounding_box_dims_surf=cfg.data.bounding_box_surface, - compute_scaling_factors=True, - gpu_preprocessing=True, - gpu_output=True, - ) + # If the data is not on the target device, put it there: + for key, value in data_dict.items(): + if value.device != self.output_device: + data_dict[key] = value.to(self.output_device) + + # Add a batch dimension to the data_dict + data_dict = {k: v.unsqueeze(0) for k, v in data_dict.items()} - # Calculate mean - if cfg.model.normalization == "mean_std_scaling": - for j in range(len(fm_dict)): - print("On iteration {j}") - d_dict = fm_dict[j] - vol_fields = d_dict["volume_fields"] - - if vol_fields is not None: - if j == 0: - vol_fields_sum = np.mean(vol_fields, 0) - else: - vol_fields_sum += np.mean(vol_fields, 0) - else: - vol_fields_sum = 0.0 - - vol_fields_mean = vol_fields_sum / len(fm_dict) - - for j in range(len(fm_dict)): - print("On iteration {j} again") - d_dict = fm_dict[j] - vol_fields = d_dict["volume_fields"] - - if vol_fields is not None: - if j == 0: - vol_fields_sum_square = np.mean( - (vol_fields - vol_fields_mean) ** 2.0, 0 - ) - else: - vol_fields_sum_square += np.mean( - (vol_fields - vol_fields_mean) ** 2.0, 0 - ) - else: - vol_fields_sum_square = 0.0 - - vol_fields_std = np.sqrt(vol_fields_sum_square / len(fm_dict)) - - vol_scaling_factors = [vol_fields_mean, vol_fields_std] - - if cfg.model.normalization == "min_max_scaling": - for j in range(len(fm_dict)): - print(f"Min max scaling on iteration {j}") - d_dict = fm_dict[j] - vol_fields = d_dict["volume_fields"] - - if vol_fields.device.type == "cuda": - xp = cp - vol_fields = vol_fields.cuda() - vol_fields = cp.from_dlpack(vol_fields) - else: - xp = np - vol_fields = vol_fields.cpu().numpy() - - if vol_fields is not None: - vol_mean = xp.mean(vol_fields, 0) - vol_std = xp.std(vol_fields, 0) - vol_idx = mean_std_sampling( - vol_fields, vol_mean, vol_std, tolerance=12.0 - ) - vol_fields_sampled = xp.delete(vol_fields, vol_idx, axis=0) - if j == 0: - vol_fields_max = xp.amax(vol_fields_sampled, 0) - vol_fields_min = xp.amin(vol_fields_sampled, 0) - else: - vol_fields_max1 = xp.amax(vol_fields_sampled, 0) - vol_fields_min1 = xp.amin(vol_fields_sampled, 0) - - for k in range(vol_fields.shape[-1]): - if vol_fields_max1[k] > vol_fields_max[k]: - vol_fields_max[k] = vol_fields_max1[k] - - if vol_fields_min1[k] < vol_fields_min[k]: - vol_fields_min[k] = vol_fields_min1[k] - else: - vol_fields_max = 0.0 - vol_fields_min = 0.0 - - if j > max_scaling_factor_files: - break - vol_scaling_factors = [vol_fields_max, vol_fields_min] - - for i, item in enumerate(vol_scaling_factors): - if isinstance(item, cp.ndarray): - vol_scaling_factors[i] = item.get() - - np.save(vol_save_path, vol_scaling_factors) - - if model_type == "surface" or model_type == "combined": - surf_save_path = os.path.join(cfg.project_dir, "surface_scaling_factors.npy") - - if not os.path.exists(surf_save_path): - print("Computing surface scaling factors") - volume_variable_names = list(cfg.variables.volume.solution.keys()) - surface_variable_names = list(cfg.variables.surface.solution.keys()) - - fm_dict = DoMINODataPipe( - input_path, - phase="train", - grid_resolution=cfg.model.interp_res, - volume_variables=None, - surface_variables=surface_variable_names, - normalize_coordinates=True, - sampling=False, - sample_in_bbox=True, - volume_points_sample=cfg.model.volume_points_sample, - geom_points_sample=cfg.model.geom_points_sample, - positional_encoding=cfg.model.positional_encoding, - model_type=cfg.model.model_type, - bounding_box_dims=cfg.data.bounding_box, - bounding_box_dims_surf=cfg.data.bounding_box_surface, - compute_scaling_factors=True, + return data_dict + + def __iter__(self): + if self.dataset is None: + raise ValueError( + "Dataset is not present, can not use the datapipe as an iterator." ) - # Calculate mean - if cfg.model.normalization == "mean_std_scaling": - for j in range(len(fm_dict)): - print(f"Mean std scaling on iteration {j}") - d_dict = fm_dict[j] - surf_fields = d_dict["surface_fields"].cpu().numpy() - - if surf_fields is not None: - if j == 0: - surf_fields_sum = np.mean(surf_fields, 0) - else: - surf_fields_sum += np.mean(surf_fields, 0) - else: - surf_fields_sum = 0.0 - - surf_fields_mean = surf_fields_sum / len(fm_dict) - - for j in range(len(fm_dict)): - print(f"Mean std scaling on iteration {j} again") - d_dict = fm_dict[j] - surf_fields = d_dict["surface_fields"] - - if surf_fields is not None: - if j == 0: - surf_fields_sum_square = np.mean( - (surf_fields - surf_fields_mean) ** 2.0, 0 - ) - else: - surf_fields_sum_square += np.mean( - (surf_fields - surf_fields_mean) ** 2.0, 0 - ) - else: - surf_fields_sum_square = 0.0 - - surf_fields_std = np.sqrt(surf_fields_sum_square / len(fm_dict)) - - surf_scaling_factors = [surf_fields_mean, surf_fields_std] - - if cfg.model.normalization == "min_max_scaling": - for j in range(len(fm_dict)): - print(f"Min max scaling on iteration {j}") - d_dict = fm_dict[j] - surf_fields = d_dict["surface_fields"] - if surf_fields.device.type == "cuda": - xp = cp - surf_fields = surf_fields.cuda() - surf_fields = cp.from_dlpack(surf_fields) - else: - xp = np - surf_fields = surf_fields.cpu().numpy() - - if surf_fields is not None: - surf_mean = xp.mean(surf_fields, 0) - surf_std = xp.std(surf_fields, 0) - surf_idx = mean_std_sampling( - surf_fields, surf_mean, surf_std, tolerance=12.0 - ) - surf_fields_sampled = xp.delete(surf_fields, surf_idx, axis=0) - if j == 0: - surf_fields_max = xp.amax(surf_fields_sampled, 0) - surf_fields_min = xp.amin(surf_fields_sampled, 0) - else: - surf_fields_max1 = xp.amax(surf_fields_sampled, 0) - surf_fields_min1 = xp.amin(surf_fields_sampled, 0) - - for k in range(surf_fields.shape[-1]): - if surf_fields_max1[k] > surf_fields_max[k]: - surf_fields_max[k] = surf_fields_max1[k] - - if surf_fields_min1[k] < surf_fields_min[k]: - surf_fields_min[k] = surf_fields_min1[k] - else: - surf_fields_max = 0.0 - surf_fields_min = 0.0 - - if j > max_scaling_factor_files: - break - - surf_scaling_factors = [surf_fields_max, surf_fields_min] - - for i, item in enumerate(surf_scaling_factors): - if isinstance(item, cp.ndarray): - surf_scaling_factors[i] = item.get() - - np.save(surf_save_path, surf_scaling_factors) + for i, batch in enumerate(self.dataset): + yield self.__call__(batch) + + +def compute_scaling_factors( + cfg: DictConfig, + input_path: str, + target_keys: list[str], + max_samples=20, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Using the dataset at the path, compute the mean, std, min, and max of the target keys. + + Args: + cfg: Hydra configuration object containing all parameters + input_path: Path to the dataset to load. + target_keys: List of keys to compute the mean, std, min, and max of. + use_cache: (deprecated) This argument has no effect. + """ + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + dataset = CAEDataset( + data_dir=input_path, + keys_to_read=target_keys, + keys_to_read_if_available={}, + output_device=device, + ) + + mean, std, min_val, max_val = compute_mean_std_min_max( + dataset, + field_keys=target_keys, + max_samples=max_samples, + ) + + return mean, std, min_val, max_val class CachedDoMINODataset(Dataset): @@ -1362,7 +1041,8 @@ def __getitem__(self, idx): filepath = self.data_path / cfd_filename result = np.load(filepath, allow_pickle=True).item() result = { - k: v.numpy() if isinstance(v, Tensor) else v for k, v in result.items() + k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v + for k, v in result.items() } nvtx.range_pop() @@ -1394,10 +1074,10 @@ def __getitem__(self, idx): # Sample surface points if present if "surface_mesh_centers" in result and self.surface_points: if self.surface_sampling_algorithm == "area_weighted": - coords_sampled, idx_surface = area_weighted_shuffle_array( - result["surface_mesh_centers"], - self.surface_points, - result["surface_areas"], + coords_sampled, idx_surface = shuffle_array( + points=result["surface_mesh_centers"], + n_points=self.surface_points, + weights=result["surface_areas"], ) else: coords_sampled, idx_surface = shuffle_array( @@ -1444,12 +1124,28 @@ def __getitem__(self, idx): def create_domino_dataset( - cfg, phase, volume_variable_names, surface_variable_names, vol_factors, surf_factors + cfg: DictConfig, + phase: Literal["train", "val", "test"], + keys_to_read: list[str], + keys_to_read_if_available: dict[str, torch.Tensor], + vol_factors: list[float], + surf_factors: list[float], + normalize_coordinates: bool = True, + sample_in_bbox: bool = True, + sampling: bool = True, + device_mesh: torch.distributed.DeviceMesh | None = None, + placements: dict[str, torch.distributed.tensor.Placement] | None = None, ): + model_type = cfg.model.model_type if phase == "train": input_path = cfg.data.input_dir + dataloader_cfg = cfg.train.dataloader elif phase == "val": input_path = cfg.data.input_dir_val + dataloader_cfg = cfg.val.dataloader + elif phase == "test": + input_path = cfg.eval.test_path + dataloader_cfg = None else: raise ValueError(f"Invalid phase {phase}") @@ -1457,7 +1153,7 @@ def create_domino_dataset( return CachedDoMINODataset( input_path, phase=phase, - sampling=True, + sampling=sampling, volume_points_sample=cfg.model.volume_points_sample, surface_points_sample=cfg.model.surface_points_sample, geom_points_sample=cfg.model.geom_points_sample, @@ -1465,6 +1161,15 @@ def create_domino_dataset( surface_sampling_algorithm=cfg.model.surface_sampling_algorithm, ) else: + # The dataset path works in two pieces: + # There is a core "dataset" which is loading data and moving to GPU + # And there is the preprocess step, here. + + # Optionally, and for backwards compatibility, the preprocess + # object can accept a dataset which will enable it as an iterator. + # The iteration function will loop over the dataset, preprocess the + # output, and return it. + overrides = {} if hasattr(cfg.data, "gpu_preprocessing"): overrides["gpu_preprocessing"] = cfg.data.gpu_preprocessing @@ -1472,32 +1177,60 @@ def create_domino_dataset( if hasattr(cfg.data, "gpu_output"): overrides["gpu_output"] = cfg.data.gpu_output - return DoMINODataPipe( + dm = DistributedManager() + + if cfg.data.gpu_preprocessing: + device = dm.device + consumer_stream = torch.cuda.default_stream() + else: + device = torch.device("cpu") + consumer_stream = None + + if dataloader_cfg is not None: + preload_depth = dataloader_cfg.preload_depth + pin_memory = dataloader_cfg.pin_memory + else: + preload_depth = 1 + pin_memory = False + + dataset = CAEDataset( + data_dir=input_path, + keys_to_read=keys_to_read, + keys_to_read_if_available=keys_to_read_if_available, + output_device=device, + preload_depth=preload_depth, + pin_memory=pin_memory, + device_mesh=device_mesh, + placements=placements, + consumer_stream=consumer_stream, + ) + + datapipe = DoMINODataPipe( input_path, phase=phase, grid_resolution=cfg.model.interp_res, - volume_variables=volume_variable_names, - surface_variables=surface_variable_names, - normalize_coordinates=True, - sampling=True, - sample_in_bbox=True, + normalize_coordinates=normalize_coordinates, + sampling=sampling, + sample_in_bbox=sample_in_bbox, volume_points_sample=cfg.model.volume_points_sample, surface_points_sample=cfg.model.surface_points_sample, geom_points_sample=cfg.model.geom_points_sample, - positional_encoding=cfg.model.positional_encoding, volume_factors=vol_factors, surface_factors=surf_factors, scaling_type=cfg.model.normalization, - model_type=cfg.model.model_type, + model_type=model_type, bounding_box_dims=cfg.data.bounding_box, bounding_box_dims_surf=cfg.data.bounding_box_surface, + volume_sample_from_disk=cfg.data.volume_sample_from_disk, num_surface_neighbors=cfg.model.num_neighbors_surface, - resample_surfaces=cfg.model.resampling_surface_mesh.resample, - resampling_points=cfg.model.resampling_surface_mesh.points, surface_sampling_algorithm=cfg.model.surface_sampling_algorithm, **overrides, ) + datapipe.set_dataset(dataset) + + return datapipe + if __name__ == "__main__": fm_data = DoMINODataPipe( diff --git a/physicsnemo/datapipes/cae/domino_sharded_datapipe.py b/physicsnemo/datapipes/cae/domino_sharded_datapipe.py deleted file mode 100644 index fe2b0d5fcf..0000000000 --- a/physicsnemo/datapipes/cae/domino_sharded_datapipe.py +++ /dev/null @@ -1,176 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import asdict - -import torch - -from physicsnemo.utils.version_check import check_module_requirements - -from .domino_datapipe import DoMINODataPipe - -# Prevent importing this module if the minimum version of pytorch is not met. -check_module_requirements("physicsnemo.distributed.shard_tensor") - -from torch.distributed.tensor.placement_types import ( # noqa: E402 - Replicate, - Shard, -) - -from physicsnemo.distributed.shard_tensor import ShardTensor # noqa: E402 - - -class ShardedDoMINODataPipe(DoMINODataPipe): - """ - An extension of the DoMINODataPipe for domain parallel training. - - How this works: - 1. the preprocessing is done in cupy or numpy in the base class, which we - want to keep. - 2. Dataloading is done on one file per idx in __getitem__. For sharded data, - we want to load one file per mesh and shard or replicate the data as needed. - 3. The sharding can be either on the grid or the point clouds. We shard the grids - after loading point data, so data loading only worries about the point clouds. - 4. For numpy files (.npz, .npy), each rank reads the whole file and takes only - the data it needs, in the end. Because data loading is the bulk of the time, - this preprocesses everything independently and then shards. - 5. For Zarr files, each rank can read slices of the data independently. So - infer the chunk size, based on the number of ranks in the mesh and sharding, - and then read the right slice. - 6. For some of the pipeline, we need the full data. So it gets gathered locally. - 7. After preprocessing, the data is chunked into appropriate shards and sent out. - 8. This file provides a wrapper function for the collate function (like a decorator) - that will turn appropriate cupy into tensors and then into shard tensors. - - """ - - def __init__( - self, - input_path, - model_type, - domain_mesh, - shard_point_cloud, - shard_grid, - **config_overrides, - ): - # if 'gpu_output' not in config_overrides: - config_overrides["gpu_output"] = True - - # First, initialize the super class. - super().__init__( - input_path, - model_type, - **config_overrides, - ) - - self.domain_mesh = domain_mesh - - self.shard_point_cloud = shard_point_cloud - self.shard_grid = shard_grid - - # These are keys that are point-like - self.point_cloud_keys = [ - "volume_fields", - "pos_volume_closest", - "pos_volume_center_of_mass", - "pos_surface_center_of_mass", - "geometry_coordinates", - "surface_mesh_centers", - "surface_mesh_neighbors", - "sdf_nodes", - "surface_normals", - "surface_neighbors_normals", - "surface_areas", - "surface_neighbors_areas", - "volume_mesh_centers", - "surface_fields", - ] - - # These keys are grid-like - self.grid_keys = [ - "grid", - "surf_grid", - "sdf_grid", - "sdf_surf_grid", - ] - - # These keys are scalar-like and should never be sharded - self.scalar_keys = [ - "global_params_values", - "global_params_reference", - "surface_min_max", - "volume_min_max", - "length_scale", - ] - - def __getitem__(self, idx): - single_dict = super().__getitem__(idx) - - # Here, we're assuming that the data is already replicated. - # Turn all the pieces of the dict into ShardTensors with that placement. - default_placement = [ - Replicate(), - ] - for key, value in single_dict.items(): - if isinstance(value, torch.Tensor): - single_dict[key] = ShardTensor.from_local( - value, self.domain_mesh, default_placement - ) - - # # Now, shard the data. - sharding = [ - Shard(0), - ] - if self.shard_point_cloud: - for key in self.point_cloud_keys: - if key in single_dict: - single_dict[key] = single_dict[key].redistribute( - placements=sharding - ) - - if self.shard_grid: - for key in self.grid_keys: - if key in single_dict: - single_dict[key] = single_dict[key].redistribute( - placements=sharding - ) - - return single_dict - - -def create_sharded_domino_dataset( - base_dataset, - domain_mesh, - shard_point_cloud, - shard_grid, -): - # Pull off the data path, model type, and config_dict: - data_path = base_dataset.config.data_path - model_type = base_dataset.model_type - config_dict = asdict(base_dataset.config) - - # Make sure the input path is not included in the config_dict: - config_dict.pop("data_path") - - # Use the configuration of the base dataset to create a sharded dataset: - return ShardedDoMINODataPipe( - input_path=data_path, - model_type=model_type, - domain_mesh=domain_mesh, - shard_point_cloud=shard_point_cloud, - shard_grid=shard_grid, - **config_dict, - ) diff --git a/physicsnemo/models/domino/encodings.py b/physicsnemo/models/domino/encodings.py new file mode 100644 index 0000000000..7b27eeb134 --- /dev/null +++ b/physicsnemo/models/domino/encodings.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code contains the DoMINO model architecture. +The DoMINO class contains an architecture to model both surface and +volume quantities together as well as separately (controlled using +the config.yaml file) +""" + +import torch +import torch.nn as nn +from einops import rearrange + +from physicsnemo.models.layers import BQWarp + +from .mlps import LocalPointConv + + +class LocalGeometryEncoding(nn.Module): + """ + A local geometry encoding module. + + This will apply a ball query to the input features, mapping the point cloud + to the volume mesh, and then apply a local point convolution to the output. + + Args: + radius: The radius of the ball query. + neighbors_in_radius: The number of neighbors in the radius of the ball query. + total_neighbors_in_radius: The total number of neighbors in the radius of the ball query. + base_layer: The number of neurons in the hidden layer of the MLP. + activation: The activation function to use in the MLP. + grid_resolution: The resolution of the grid. + """ + + def __init__( + self, + radius: float, + neighbors_in_radius: int, + total_neighbors_in_radius: int, + base_layer: int, + activation: nn.Module, + grid_resolution: tuple[int, int, int], + ): + super().__init__() + self.bq_warp = BQWarp( + radius=radius, + neighbors_in_radius=neighbors_in_radius, + ) + self.local_point_conv = LocalPointConv( + input_features=total_neighbors_in_radius, + base_layer=base_layer, + output_features=neighbors_in_radius, + activation=activation, + ) + self.grid_resolution = grid_resolution + + def forward( + self, + encoding_g: torch.Tensor, + volume_mesh_centers: torch.Tensor, + p_grid: torch.Tensor, + ) -> torch.Tensor: + batch_size = volume_mesh_centers.shape[0] + nx, ny, nz = self.grid_resolution + + p_grid = torch.reshape(p_grid, (batch_size, nx * ny * nz, 3)) + mapping, outputs = self.bq_warp( + volume_mesh_centers, p_grid, reverse_mapping=False + ) + + mapping = mapping.type(torch.int64) + mask = mapping != 0 + + encoding_g_inner = [] + for j in range(encoding_g.shape[1]): + geo_encoding = rearrange(encoding_g[:, j], "b nx ny nz -> b 1 (nx ny nz)") + + geo_encoding_sampled = torch.index_select( + geo_encoding, 2, mapping.flatten() + ) + geo_encoding_sampled = torch.reshape(geo_encoding_sampled, mask.shape) + geo_encoding_sampled = geo_encoding_sampled * mask + + encoding_g_inner.append(geo_encoding_sampled) + encoding_g_inner = torch.cat(encoding_g_inner, dim=2) + encoding_g_inner = self.local_point_conv(encoding_g_inner) + + return encoding_g_inner + + +class MultiGeometryEncoding(nn.Module): + """ + Module to apply multiple local geometry encodings + + This will stack several local geometry encodings together, and concatenate the results. + + Args: + radii: The list of radii of the local geometry encodings. + neighbors_in_radius: The list of number of neighbors in the radius of the local geometry encodings. + geo_encoding_type: The type of geometry encoding to use. Can be "both", "stl", or "sdf". + base_layer: The number of neurons in the hidden layer of the MLP. + activation: The activation function to use in the MLP. + grid_resolution: The resolution of the grid. + """ + + def __init__( + self, + radii: list[float], + neighbors_in_radius: list[int], + geo_encoding_type: str, + n_upstream_radii: int, + base_layer: int, + activation: nn.Module, + grid_resolution: tuple[int, int, int], + ): + super().__init__() + + self.local_geo_encodings = nn.ModuleList( + [ + LocalGeometryEncoding( + radius=r, + neighbors_in_radius=n, + total_neighbors_in_radius=self.calculate_total_neighbors_in_radius( + geo_encoding_type, n, n_upstream_radii + ), + base_layer=base_layer, + activation=activation, + grid_resolution=grid_resolution, + ) + for r, n in zip(radii, neighbors_in_radius) + ] + ) + + def calculate_total_neighbors_in_radius( + self, geo_encoding_type: str, neighbors_in_radius: int, n_upstream_radii: int + ) -> int: + if geo_encoding_type == "both": + total_neighbors_in_radius = neighbors_in_radius * (n_upstream_radii + 1) + elif geo_encoding_type == "stl": + total_neighbors_in_radius = neighbors_in_radius * (n_upstream_radii) + elif geo_encoding_type == "sdf": + total_neighbors_in_radius = neighbors_in_radius + + return total_neighbors_in_radius + + def forward( + self, + encoding_g: torch.Tensor, + volume_mesh_centers: torch.Tensor, + p_grid: torch.Tensor, + ) -> torch.Tensor: + return torch.cat( + [ + local_geo_encoding(encoding_g, volume_mesh_centers, p_grid) + for local_geo_encoding in self.local_geo_encodings + ], + dim=-1, + ) diff --git a/physicsnemo/models/domino/geometry_rep.py b/physicsnemo/models/domino/geometry_rep.py new file mode 100644 index 0000000000..eee192e600 --- /dev/null +++ b/physicsnemo/models/domino/geometry_rep.py @@ -0,0 +1,499 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from physicsnemo.models.layers import BQWarp, Mlp, fourier_encode, get_activation +from physicsnemo.models.unet import UNet + +# from .encodings import fourier_encode + + +def scale_sdf(sdf: torch.Tensor, scaling_factor: float = 0.04) -> torch.Tensor: + """ + Scale a signed distance function (SDF) to emphasize surface regions. + + This function applies a non-linear scaling to the SDF values that compresses + the range while preserving the sign, effectively giving more weight to points + near surfaces where abs(SDF) is small. + + Args: + sdf: Tensor containing signed distance function values + + Returns: + Tensor with scaled SDF values in range [-1, 1] + """ + return sdf / (scaling_factor + torch.abs(sdf)) + + +class GeoConvOut(nn.Module): + """ + Geometry layer to project STL geometry data onto regular grids. + """ + + def __init__( + self, + input_features: int, + neighbors_in_radius: int, + model_parameters, + grid_resolution=None, + ): + """ + Initialize the GeoConvOut layer. + + Args: + input_features: Number of input feature dimensions + neighbors_in_radius: Number of neighbors in radius + model_parameters: Configuration parameters for the model + grid_resolution: Resolution of the output grid [nx, ny, nz] + """ + super().__init__() + if grid_resolution is None: + grid_resolution = [256, 96, 64] + base_neurons = model_parameters.base_neurons + self.fourier_features = model_parameters.fourier_features + self.num_modes = model_parameters.num_modes + + if self.fourier_features: + input_features_calculated = ( + input_features * (1 + 2 * self.num_modes) * neighbors_in_radius + ) + else: + input_features_calculated = input_features * neighbors_in_radius + + self.mlp = Mlp( + in_features=input_features_calculated, + hidden_features=[base_neurons, base_neurons // 2], + out_features=model_parameters.base_neurons_in, + act_layer=get_activation(model_parameters.activation), + drop=0.0, + ) + + self.grid_resolution = grid_resolution + + self.activation = get_activation(model_parameters.activation) + + self.neighbors_in_radius = neighbors_in_radius + + if self.fourier_features: + self.register_buffer( + "freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes)) + ) + + def forward( + self, + x: torch.Tensor, + grid: torch.Tensor, + radius: float = 0.025, + neighbors_in_radius: int = 10, + ) -> torch.Tensor: + """ + Process and project geometric features onto a 3D grid. + + Args: + x: Input tensor containing coordinates of the neighboring points + (batch_size, nx*ny*nz, n_points, 3) + grid: Input tensor represented as a grid of shape + (batch_size, nx, ny, nz, 3) + + Returns: + Processed geometry features of shape (batch_size, base_neurons_in, nx, ny, nz) + """ + + nx, ny, nz = ( + self.grid_resolution[0], + self.grid_resolution[1], + self.grid_resolution[2], + ) + grid = grid.reshape(1, nx * ny * nz, 3, 1) + + x = rearrange( + x, "b x y z -> b x (y z)", x=nx * ny * nz, y=self.neighbors_in_radius, z=3 + ) + if self.fourier_features: + facets = torch.cat((x, fourier_encode(x, self.freqs)), axis=-1) + else: + facets = x + + x = F.tanh(self.mlp(facets)) + + x = rearrange(x, "b (x y z) c -> b c x y z", x=nx, y=ny, z=nz) + + return x + + +class GeoProcessor(nn.Module): + """Geometry processing layer using CNNs""" + + def __init__(self, input_filters: int, output_filters: int, model_parameters): + """ + Initialize the GeoProcessor network. + + Args: + input_filters: Number of input channels + model_parameters: Configuration parameters for the model + """ + super().__init__() + base_filters = model_parameters.base_filters + self.conv1 = nn.Conv3d( + input_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv2 = nn.Conv3d( + base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv3 = nn.Conv3d( + 2 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv3_1 = nn.Conv3d( + 4 * base_filters, 4 * base_filters, kernel_size=3, padding="same" + ) + self.conv4 = nn.Conv3d( + 4 * base_filters, 2 * base_filters, kernel_size=3, padding="same" + ) + self.conv5 = nn.Conv3d( + 4 * base_filters, base_filters, kernel_size=3, padding="same" + ) + self.conv6 = nn.Conv3d( + 2 * base_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv7 = nn.Conv3d( + 2 * input_filters, input_filters, kernel_size=3, padding="same" + ) + self.conv8 = nn.Conv3d( + input_filters, output_filters, kernel_size=3, padding="same" + ) + self.avg_pool = torch.nn.AvgPool3d((2, 2, 2)) + self.max_pool = nn.MaxPool3d(2) + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + self.activation = get_activation(model_parameters.activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Process geometry information through the 3D CNN network. + + The network follows an encoder-decoder architecture with skip connections: + 1. Downsampling path (encoder) with three levels of max pooling + 2. Processing loop in the bottleneck + 3. Upsampling path (decoder) with skip connections from the encoder + + Args: + x: Input tensor containing grid-represented geometry of shape + (batch_size, input_filters, nx, ny, nz) + + Returns: + Processed geometry features of shape (batch_size, 1, nx, ny, nz) + """ + # Encoder + x0 = x + x = self.conv1(x) + x = self.activation(x) + x = self.max_pool(x) + + x1 = x + x = self.conv2(x) + x = self.activation(x) + x = self.max_pool(x) + + x2 = x + x = self.conv3(x) + x = self.activation(x) + x = self.max_pool(x) + + # Processor loop + x = self.activation(self.conv3_1(x)) + + # Decoder + x = self.conv4(x) + x = self.activation(x) + x = self.upsample(x) + x = torch.cat((x, x2), dim=1) + + x = self.conv5(x) + x = self.activation(x) + x = self.upsample(x) + x = torch.cat((x, x1), dim=1) + + x = self.conv6(x) + x = self.activation(x) + x = self.upsample(x) + x = torch.cat((x, x0), dim=1) + + x = self.activation(self.conv7(x)) + x = self.conv8(x) + + return x + + +class GeometryRep(nn.Module): + """ + Geometry representation module that processes STL geometry data. + + This module constructs a multiscale representation of geometry by: + 1. Computing multi-scale geometry encoding for local and global context + 2. Processing signed distance field (SDF) data for surface information + + The combined encoding enables the model to reason about both local and global + geometric properties. + """ + + def __init__( + self, + input_features: int, + radii: Sequence[float], + neighbors_in_radius, + hops=1, + sdf_scaling_factor: Sequence[float] = [0.04], + model_parameters=None, + # activation_conv: nn.Module, + # activation_processor: nn.Module, + ): + """ + Initialize the GeometryRep module. + + Args: + input_features: Number of input feature dimensions + model_parameters: Configuration parameters for the model + """ + super().__init__() + geometry_rep = model_parameters.geometry_rep + self.geo_encoding_type = model_parameters.geometry_encoding_type + self.cross_attention = geometry_rep.geo_processor.cross_attention + self.self_attention = geometry_rep.geo_processor.self_attention + self.activation_conv = get_activation(geometry_rep.geo_conv.activation) + self.activation_processor = geometry_rep.geo_processor.activation + self.sdf_scaling_factor = sdf_scaling_factor + + self.bq_warp = nn.ModuleList() + self.geo_processors = nn.ModuleList() + for j in range(len(radii)): + self.bq_warp.append( + BQWarp( + radius=radii[j], + neighbors_in_radius=neighbors_in_radius[j], + ) + ) + if geometry_rep.geo_processor.processor_type == "unet": + h = geometry_rep.geo_processor.base_filters + if self.self_attention: + normalization_in_unet = "layernorm" + else: + normalization_in_unet = None + self.geo_processors.append( + UNet( + in_channels=geometry_rep.geo_conv.base_neurons_in, + out_channels=geometry_rep.geo_conv.base_neurons_out, + model_depth=3, + feature_map_channels=[ + h, + 2 * h, + 4 * h, + ], + num_conv_blocks=1, + kernel_size=3, + stride=1, + conv_activation=self.activation_processor, + padding=1, + padding_mode="zeros", + pooling_type="MaxPool3d", + pool_size=2, + normalization=normalization_in_unet, + use_attn_gate=self.self_attention, + attn_decoder_feature_maps=[4 * h, 2 * h], + attn_feature_map_channels=[2 * h, h], + attn_intermediate_channels=4 * h, + gradient_checkpointing=True, + ) + ) + elif geometry_rep.geo_processor.processor_type == "conv": + self.geo_processors.append( + nn.Sequential( + GeoProcessor( + input_filters=geometry_rep.geo_conv.base_neurons_in, + output_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ), + ) + ) + else: + raise ValueError("Invalid prompt. Specify unet or conv ...") + + self.geo_conv_out = nn.ModuleList() + self.geo_processor_out = nn.ModuleList() + for u in range(len(radii)): + self.geo_conv_out.append( + GeoConvOut( + input_features=input_features, + neighbors_in_radius=neighbors_in_radius[u], + model_parameters=geometry_rep.geo_conv, + grid_resolution=model_parameters.interp_res, + ) + ) + self.geo_processor_out.append( + nn.Conv3d( + geometry_rep.geo_conv.base_neurons_out, + 1, + kernel_size=3, + padding="same", + ) + ) + + if geometry_rep.geo_processor.processor_type == "unet": + h = geometry_rep.geo_processor.base_filters + if self.self_attention: + normalization_in_unet = "layernorm" + else: + normalization_in_unet = None + + self.geo_processor_sdf = UNet( + in_channels=5 + len(self.sdf_scaling_factor), + out_channels=geometry_rep.geo_conv.base_neurons_out, + model_depth=3, + feature_map_channels=[ + h, + 2 * h, + 4 * h, + ], + num_conv_blocks=1, + kernel_size=3, + stride=1, + conv_activation=self.activation_processor, + padding=1, + padding_mode="zeros", + pooling_type="MaxPool3d", + pool_size=2, + normalization=normalization_in_unet, + use_attn_gate=self.self_attention, + attn_decoder_feature_maps=[4 * h, 2 * h], + attn_feature_map_channels=[2 * h, h], + attn_intermediate_channels=4 * h, + gradient_checkpointing=True, + ) + elif geometry_rep.geo_processor.processor_type == "conv": + self.geo_processor_sdf = nn.Sequential( + GeoProcessor( + input_filters=5 + len(self.sdf_scaling_factor), + output_filters=geometry_rep.geo_conv.base_neurons_out, + model_parameters=geometry_rep.geo_processor, + ), + ) + else: + raise ValueError("Invalid prompt. Specify unet or conv ...") + self.radii = radii + self.neighbors_in_radius = neighbors_in_radius + self.hops = hops + + self.geo_processor_sdf_out = nn.Conv3d( + geometry_rep.geo_conv.base_neurons_out, 1, kernel_size=3, padding="same" + ) + + if self.cross_attention: + self.combined_unet = UNet( + in_channels=1 + len(radii), + out_channels=1 + len(radii), + model_depth=3, + feature_map_channels=[ + h, + 2 * h, + 4 * h, + ], + num_conv_blocks=1, + kernel_size=3, + stride=1, + conv_activation=self.activation_processor, + padding=1, + padding_mode="zeros", + pooling_type="MaxPool3d", + pool_size=2, + normalization="layernorm", + use_attn_gate=True, + attn_decoder_feature_maps=[4 * h, 2 * h], + attn_feature_map_channels=[2 * h, h], + attn_intermediate_channels=4 * h, + gradient_checkpointing=True, + ) + + def forward( + self, x: torch.Tensor, p_grid: torch.Tensor, sdf: torch.Tensor + ) -> torch.Tensor: + """ + Process geometry data to create a comprehensive representation. + + This method combines short-range, long-range, and SDF-based geometry + encodings to create a rich representation of the geometry. + + Args: + x: Input tensor containing geometric point data + p_grid: Grid points for sampling + sdf: Signed distance field tensor + + Returns: + Comprehensive geometry encoding that concatenates short-range, + SDF-based, and long-range features + """ + if self.geo_encoding_type == "both" or self.geo_encoding_type == "stl": + # Calculate multi-scale geoemtry dependency + x_encoding = [] + for j in range(len(self.radii)): + mapping, k_short = self.bq_warp[j](x, p_grid) + x_encoding_inter = self.geo_conv_out[j](k_short, p_grid) + # Propagate information in the geometry enclosed BBox + for _ in range(self.hops): + dx = self.geo_processors[j](x_encoding_inter) / self.hops + x_encoding_inter = x_encoding_inter + dx + x_encoding_inter = self.geo_processor_out[j](x_encoding_inter) + x_encoding.append(x_encoding_inter) + x_encoding = torch.cat(x_encoding, dim=1) + + if self.geo_encoding_type == "both" or self.geo_encoding_type == "sdf": + # Expand SDF + sdf = torch.unsqueeze(sdf, 1) + # Binary sdf + binary_sdf = torch.where(sdf >= 0, 0.0, 1.0) + # Gradients of SDF + sdf_x, sdf_y, sdf_z = torch.gradient(sdf, dim=[2, 3, 4]) + + scaled_sdf = [] + # Scaled sdf to emphasize near surface + for s in range(len(self.sdf_scaling_factor)): + s_sdf = scale_sdf(sdf, self.sdf_scaling_factor[s]) + scaled_sdf.append(s_sdf) + + scaled_sdf = torch.cat(scaled_sdf, dim=1) + + # Process SDF and its computed features + sdf = torch.cat((sdf, scaled_sdf, binary_sdf, sdf_x, sdf_y, sdf_z), 1) + + sdf_encoding = self.geo_processor_sdf(sdf) + sdf_encoding = self.geo_processor_sdf_out(sdf_encoding) + + if self.geo_encoding_type == "both": + # Geometry encoding comprised of short-range, long-range and SDF features + encoding_g = torch.cat((x_encoding, sdf_encoding), 1) + elif self.geo_encoding_type == "sdf": + encoding_g = sdf_encoding + elif self.geo_encoding_type == "stl": + encoding_g = x_encoding + + if self.cross_attention: + encoding_g = self.combined_unet(encoding_g) + + return encoding_g diff --git a/physicsnemo/models/domino/mlps.py b/physicsnemo/models/domino/mlps.py new file mode 100644 index 0000000000..f074fa7735 --- /dev/null +++ b/physicsnemo/models/domino/mlps.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains specific MLPs for the DoMINO model. + +The main feature here is we've locked in the number of layers. +""" + +import torch.nn as nn + +from physicsnemo.models.layers import Mlp + + +class AggregationModel(Mlp): + """ + Neural network module to aggregate local geometry encoding with basis functions. + + This module combines basis function representations with geometry encodings + to predict the final output quantities. It serves as the final prediction layer + that integrates all available information sources. + + It is implemented as a straightforward MLP with 5 total layers. + + """ + + def __init__( + self, + input_features: int, + output_features: int, + base_layer: int, + activation: nn.Module, + ): + hidden_features = [base_layer, base_layer, base_layer, base_layer] + + super().__init__( + in_features=input_features, + hidden_features=hidden_features, + out_features=output_features, + act_layer=activation, + drop=0.0, + ) + + +class LocalPointConv(Mlp): + """Layer for local geometry point kernel + + This is a straight forward MLP, with exactly two layers. + """ + + def __init__( + self, + input_features: int, + base_layer: int, + output_features: int, + activation: nn.Module, + ): + super().__init__( + in_features=input_features, + hidden_features=base_layer, + out_features=output_features, + act_layer=activation, + drop=0.0, + ) diff --git a/physicsnemo/models/domino/model.py b/physicsnemo/models/domino/model.py index c95f971e97..9f46947f2b 100644 --- a/physicsnemo/models/domino/model.py +++ b/physicsnemo/models/domino/model.py @@ -21,857 +21,18 @@ the config.yaml file) """ -import math -from collections import defaultdict -from typing import Callable, Literal, Sequence - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange - -from physicsnemo.models.unet import UNet -from physicsnemo.utils.neighbors import radius_search -from physicsnemo.utils.profiling import profile - - -def get_activation(activation: Literal["relu", "gelu"]) -> Callable: - """ - Return a PyTorch activation function corresponding to the given name. - """ - if activation == "relu": - return F.relu - elif activation == "gelu": - return F.gelu - else: - raise ValueError(f"Activation function {activation} not found") - - -def fourier_encode(coords, num_freqs): - """Function to caluculate fourier features""" - # Create a range of frequencies - freqs = torch.exp(torch.linspace(0, math.pi, num_freqs, device=coords.device)) - # Generate sine and cosine features - features = [torch.sin(coords * f) for f in freqs] + [ - torch.cos(coords * f) for f in freqs - ] - ret = torch.cat(features, dim=-1) - return ret - - -def fourier_encode_vectorized(coords, freqs): - """Vectorized Fourier feature encoding""" - D = coords.shape[-1] - F = freqs.shape[0] - - freqs = freqs[None, None, :, None] # reshape to [*, F, 1] for broadcasting - - coords = coords.unsqueeze(-2) # [*, 1, D] - scaled = (coords * freqs).reshape(*coords.shape[:-2], D * F) # [*, D, F] - features = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1) # [*, D, 2F] - - return features.reshape(*coords.shape[:-2], D * 2 * F) # [*, D * 2F] - - -def calculate_pos_encoding(nx, d=8): - """Function to caluculate positional encoding""" - vec = [] - for k in range(int(d / 2)): - vec.append(torch.sin(nx / 10000 ** (2 * (k) / d))) - vec.append(torch.cos(nx / 10000 ** (2 * (k) / d))) - return vec - - -def scale_sdf(sdf: torch.Tensor) -> torch.Tensor: - """ - Scale a signed distance function (SDF) to emphasize surface regions. - - This function applies a non-linear scaling to the SDF values that compresses - the range while preserving the sign, effectively giving more weight to points - near surfaces where abs(SDF) is small. - - Args: - sdf: Tensor containing signed distance function values - - Returns: - Tensor with scaled SDF values in range [-1, 1] - """ - return sdf / (0.4 + torch.abs(sdf)) - - -class BQWarp(nn.Module): - """ - Warp-based ball-query layer for finding neighboring points within a specified radius. - - This layer uses an accelerated ball query implementation to efficiently find points - within a specified radius of query points. - """ - - def __init__( - self, - grid_resolution=None, - radius: float = 0.25, - neighbors_in_radius: int = 10, - ): - """ - Initialize the BQWarp layer. - - Args: - grid_resolution: Resolution of the grid in each dimension [nx, ny, nz] - radius: Radius for ball query operation - neighbors_in_radius: Maximum number of neighbors to return within radius - """ - super().__init__() - if grid_resolution is None: - grid_resolution = [256, 96, 64] - - self.radius = radius - self.neighbors_in_radius = neighbors_in_radius - self.grid_resolution = grid_resolution - - def forward( - self, x: torch.Tensor, p_grid: torch.Tensor, reverse_mapping: bool = True - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Performs ball query operation to find neighboring points and their features. - - This method uses the Warp-accelerated ball query implementation to find points - within a specified radius. It can operate in two modes: - - Forward mapping: Find points from x that are near p_grid points (reverse_mapping=False) - - Reverse mapping: Find points from p_grid that are near x points (reverse_mapping=True) - - Args: - x: Tensor of shape (batch_size, num_points, 3+features) containing point coordinates - and their features - p_grid: Tensor of shape (batch_size, grid_x, grid_y, grid_z, 3) containing grid point - coordinates - reverse_mapping: Boolean flag to control the direction of the mapping: - - True: Find p_grid points near x points - - False: Find x points near p_grid points - - Returns: - tuple containing: - - mapping: Tensor containing indices of neighboring points - - outputs: Tensor containing coordinates of the neighboring points - """ - batch_size = x.shape[0] - nx, ny, nz = self.grid_resolution - - p_grid = torch.reshape(p_grid, (batch_size, nx * ny * nz, 3)) - - if reverse_mapping: - mapping, outputs = radius_search( - x[0], - p_grid[0], - self.radius, - self.neighbors_in_radius, - return_points=True, - ) - mapping = mapping.unsqueeze(0) - outputs = outputs.unsqueeze(0) - else: - mapping, outputs = radius_search( - p_grid[0], - x[0], - self.radius, - self.neighbors_in_radius, - return_points=True, - ) - mapping = mapping.unsqueeze(0) - outputs = outputs.unsqueeze(0) - - return mapping, outputs - - -class GeoConvOut(nn.Module): - """ - Geometry layer to project STL geometry data onto regular grids. - """ - - def __init__( - self, - input_features: int, - model_parameters, - grid_resolution=None, - ): - """ - Initialize the GeoConvOut layer. - - Args: - input_features: Number of input feature dimensions - model_parameters: Configuration parameters for the model - grid_resolution: Resolution of the output grid [nx, ny, nz] - """ - super().__init__() - if grid_resolution is None: - grid_resolution = [256, 96, 64] - base_neurons = model_parameters.base_neurons - self.fourier_features = model_parameters.fourier_features - self.num_modes = model_parameters.num_modes - - if self.fourier_features: - input_features_calculated = input_features * (1 + 2 * self.num_modes) - else: - input_features_calculated = input_features - - self.fc1 = nn.Linear(input_features_calculated, base_neurons) - self.fc2 = nn.Linear(base_neurons, base_neurons // 2) - self.fc3 = nn.Linear(base_neurons // 2, model_parameters.base_neurons_in) - - self.grid_resolution = grid_resolution - - self.activation = get_activation(model_parameters.activation) - - if self.fourier_features: - self.register_buffer( - "freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes)) - ) - - def forward( - self, - x: torch.Tensor, - grid: torch.Tensor, - radius: float = 0.025, - neighbors_in_radius: int = 10, - ) -> torch.Tensor: - """ - Process and project geometric features onto a 3D grid. - - Args: - x: Input tensor containing coordinates of the neighboring points - (batch_size, nx*ny*nz, 3, n_points) - grid: Input tensor represented as a grid of shape - (batch_size, nx, ny, nz, 3) - - Returns: - Processed geometry features of shape (batch_size, base_neurons_in, nx, ny, nz) - """ - - nx, ny, nz = ( - self.grid_resolution[0], - self.grid_resolution[1], - self.grid_resolution[2], - ) - grid = grid.reshape(1, nx * ny * nz, 3, 1) - x_transposed = torch.transpose(x, 2, 3) - dist_weights = 1.0 / (1e-6 + (x_transposed - grid) ** 2.0) - dist_weights = torch.transpose(dist_weights, 2, 3) - - # x = torch.sum(x * dist_weights, 2) / torch.sum(dist_weights, 2) - # x = torch.sum(x, 2) - mask = abs(x - 0) > 1e-6 - if self.fourier_features: - facets = torch.cat((x, fourier_encode_vectorized(x, self.freqs)), axis=-1) - else: - facets = x - x = self.activation(self.fc1(facets)) - x = self.activation(self.fc2(x)) - x = F.tanh(self.fc3(x)) - - mask = mask[:, :, :, 0:1].expand( - mask.shape[0], mask.shape[1], mask.shape[2], x.shape[-1] - ) - - x = torch.sum(x * mask, 2) - x = rearrange(x, "b (x y z) c -> b c x y z", x=nx, y=ny, z=nz) - return x - - -class GeoProcessor(nn.Module): - """Geometry processing layer using CNNs""" - - def __init__(self, input_filters: int, output_filters: int, model_parameters): - """ - Initialize the GeoProcessor network. - - Args: - input_filters: Number of input channels - model_parameters: Configuration parameters for the model - """ - super().__init__() - base_filters = model_parameters.base_filters - self.conv1 = nn.Conv3d( - input_filters, base_filters, kernel_size=3, padding="same" - ) - self.conv2 = nn.Conv3d( - base_filters, 2 * base_filters, kernel_size=3, padding="same" - ) - self.conv3 = nn.Conv3d( - 2 * base_filters, 4 * base_filters, kernel_size=3, padding="same" - ) - self.conv3_1 = nn.Conv3d( - 4 * base_filters, 4 * base_filters, kernel_size=3, padding="same" - ) - self.conv4 = nn.Conv3d( - 4 * base_filters, 2 * base_filters, kernel_size=3, padding="same" - ) - self.conv5 = nn.Conv3d( - 4 * base_filters, base_filters, kernel_size=3, padding="same" - ) - self.conv6 = nn.Conv3d( - 2 * base_filters, input_filters, kernel_size=3, padding="same" - ) - self.conv7 = nn.Conv3d( - 2 * input_filters, input_filters, kernel_size=3, padding="same" - ) - self.conv8 = nn.Conv3d( - input_filters, output_filters, kernel_size=3, padding="same" - ) - self.avg_pool = torch.nn.AvgPool3d((2, 2, 2)) - self.max_pool = nn.MaxPool3d(2) - self.upsample = nn.Upsample(scale_factor=2, mode="nearest") - self.activation = get_activation(model_parameters.activation) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Process geometry information through the 3D CNN network. - - The network follows an encoder-decoder architecture with skip connections: - 1. Downsampling path (encoder) with three levels of max pooling - 2. Processing loop in the bottleneck - 3. Upsampling path (decoder) with skip connections from the encoder - - Args: - x: Input tensor containing grid-represented geometry of shape - (batch_size, input_filters, nx, ny, nz) - - Returns: - Processed geometry features of shape (batch_size, 1, nx, ny, nz) - """ - # Encoder - x0 = x - x = self.conv1(x) - x = self.activation(x) - x = self.max_pool(x) - - x1 = x - x = self.conv2(x) - x = self.activation(x) - x = self.max_pool(x) - - x2 = x - x = self.conv3(x) - x = self.activation(x) - x = self.max_pool(x) - - # Processor loop - x = self.activation(self.conv3_1(x)) - - # Decoder - x = self.conv4(x) - x = self.activation(x) - x = self.upsample(x) - x = torch.cat((x, x2), dim=1) - - x = self.conv5(x) - x = self.activation(x) - x = self.upsample(x) - x = torch.cat((x, x1), dim=1) - - x = self.conv6(x) - x = self.activation(x) - x = self.upsample(x) - x = torch.cat((x, x0), dim=1) - - x = self.activation(self.conv7(x)) - x = self.conv8(x) - - return x - - -class GeometryRep(nn.Module): - """ - Geometry representation module that processes STL geometry data. - - This module constructs a multiscale representation of geometry by: - 1. Computing multi-scale geometry encoding for local and global context - 2. Processing signed distance field (SDF) data for surface information - - The combined encoding enables the model to reason about both local and global - geometric properties. - """ - - def __init__( - self, - input_features: int, - radii: Sequence[float], - neighbors_in_radius, - hops=1, - model_parameters=None, - ): - """ - Initialize the GeometryRep module. - - Args: - input_features: Number of input feature dimensions - model_parameters: Configuration parameters for the model - """ - super().__init__() - geometry_rep = model_parameters.geometry_rep - self.geo_encoding_type = model_parameters.geometry_encoding_type - self.cross_attention = geometry_rep.geo_processor.cross_attention - self.self_attention = geometry_rep.geo_processor.self_attention - self.activation_conv = get_activation(geometry_rep.geo_conv.activation) - self.activation_processor = geometry_rep.geo_processor.activation - - self.bq_warp = nn.ModuleList() - self.geo_processors = nn.ModuleList() - for j in range(len(radii)): - self.bq_warp.append( - BQWarp( - grid_resolution=model_parameters.interp_res, - radius=radii[j], - neighbors_in_radius=neighbors_in_radius[j], - ) - ) - if geometry_rep.geo_processor.processor_type == "unet": - h = geometry_rep.geo_processor.base_filters - if self.self_attention: - normalization_in_unet = "layernorm" - else: - normalization_in_unet = None - self.geo_processors.append( - UNet( - in_channels=geometry_rep.geo_conv.base_neurons_in, - out_channels=geometry_rep.geo_conv.base_neurons_out, - model_depth=3, - feature_map_channels=[ - h, - 2 * h, - 4 * h, - ], - num_conv_blocks=1, - kernel_size=3, - stride=1, - conv_activation=self.activation_processor, - padding=1, - padding_mode="zeros", - pooling_type="MaxPool3d", - pool_size=2, - normalization=normalization_in_unet, - use_attn_gate=self.self_attention, - attn_decoder_feature_maps=[4 * h, 2 * h], - attn_feature_map_channels=[2 * h, h], - attn_intermediate_channels=4 * h, - gradient_checkpointing=True, - ) - ) - elif geometry_rep.geo_processor.processor_type == "conv": - self.geo_processors.append( - nn.Sequential( - GeoProcessor( - input_filters=geometry_rep.geo_conv.base_neurons_in, - output_filters=geometry_rep.geo_conv.base_neurons_out, - model_parameters=geometry_rep.geo_processor, - ), - GeoProcessor( - input_filters=geometry_rep.geo_conv.base_neurons_in, - output_filters=geometry_rep.geo_conv.base_neurons_out, - model_parameters=geometry_rep.geo_processor, - ), - ) - ) - else: - raise ValueError("Invalid prompt. Specify unet or conv ...") - - self.geo_conv_out = nn.ModuleList() - self.geo_processor_out = nn.ModuleList() - for _ in range(len(radii)): - self.geo_conv_out.append( - GeoConvOut( - input_features=input_features, - model_parameters=geometry_rep.geo_conv, - grid_resolution=model_parameters.interp_res, - ) - ) - self.geo_processor_out.append( - nn.Conv3d( - geometry_rep.geo_conv.base_neurons_out, - 1, - kernel_size=3, - padding="same", - ) - ) - - if geometry_rep.geo_processor.processor_type == "unet": - h = geometry_rep.geo_processor.base_filters - if self.self_attention: - normalization_in_unet = "layernorm" - else: - normalization_in_unet = None - self.geo_processor_sdf = UNet( - in_channels=6, - out_channels=geometry_rep.geo_conv.base_neurons_out, - model_depth=3, - feature_map_channels=[ - h, - 2 * h, - 4 * h, - ], - num_conv_blocks=1, - kernel_size=3, - stride=1, - conv_activation=self.activation_processor, - padding=1, - padding_mode="zeros", - pooling_type="MaxPool3d", - pool_size=2, - normalization=normalization_in_unet, - use_attn_gate=self.self_attention, - attn_decoder_feature_maps=[4 * h, 2 * h], - attn_feature_map_channels=[2 * h, h], - attn_intermediate_channels=4 * h, - gradient_checkpointing=True, - ) - elif geometry_rep.geo_processor.processor_type == "conv": - self.geo_processor_sdf = nn.Sequential( - GeoProcessor( - input_filters=6, - output_filters=geometry_rep.geo_conv.base_neurons_out, - model_parameters=geometry_rep.geo_processor, - ), - GeoProcessor( - input_filters=geometry_rep.geo_conv.base_neurons_out, - output_filters=geometry_rep.geo_conv.base_neurons_out, - model_parameters=geometry_rep.geo_processor, - ), - ) - else: - raise ValueError("Invalid prompt. Specify unet or conv ...") - self.radii = radii - self.hops = hops - - self.geo_processor_sdf_out = nn.Conv3d( - geometry_rep.geo_conv.base_neurons_out, 1, kernel_size=3, padding="same" - ) - - if self.cross_attention: - self.combined_unet = UNet( - in_channels=1 + len(radii), - out_channels=1 + len(radii), - model_depth=3, - feature_map_channels=[ - h, - 2 * h, - 4 * h, - ], - num_conv_blocks=1, - kernel_size=3, - stride=1, - conv_activation=self.activation_processor, - padding=1, - padding_mode="zeros", - pooling_type="MaxPool3d", - pool_size=2, - normalization="layernorm", - use_attn_gate=True, - attn_decoder_feature_maps=[4 * h, 2 * h], - attn_feature_map_channels=[2 * h, h], - attn_intermediate_channels=4 * h, - gradient_checkpointing=True, - ) - - def forward( - self, x: torch.Tensor, p_grid: torch.Tensor, sdf: torch.Tensor - ) -> torch.Tensor: - """ - Process geometry data to create a comprehensive representation. - - This method combines short-range, long-range, and SDF-based geometry - encodings to create a rich representation of the geometry. - - Args: - x: Input tensor containing geometric point data - p_grid: Grid points for sampling - sdf: Signed distance field tensor - - Returns: - Comprehensive geometry encoding that concatenates short-range, - SDF-based, and long-range features - """ - if self.geo_encoding_type == "both" or self.geo_encoding_type == "stl": - # Calculate multi-scale geoemtry dependency - x_encoding = [] - for j in range(len(self.radii)): - mapping, k_short = self.bq_warp[j](x, p_grid) - x_encoding_inter = self.geo_conv_out[j](k_short, p_grid) - # Propagate information in the geometry enclosed BBox - for _ in range(self.hops): - dx = self.geo_processors[j](x_encoding_inter) / self.hops - x_encoding_inter = x_encoding_inter + dx - x_encoding_inter = self.geo_processor_out[j](x_encoding_inter) - x_encoding.append(x_encoding_inter) - x_encoding = torch.cat(x_encoding, dim=1) - - if self.geo_encoding_type == "both" or self.geo_encoding_type == "sdf": - # Expand SDF - sdf = torch.unsqueeze(sdf, 1) - # Scaled sdf to emphasize near surface - scaled_sdf = scale_sdf(sdf) - # Binary sdf - binary_sdf = torch.where(sdf >= 0, 0.0, 1.0) - # Gradients of SDF - sdf_x, sdf_y, sdf_z = torch.gradient(sdf, dim=[2, 3, 4]) - - # Process SDF and its computed features - sdf = torch.cat((sdf, scaled_sdf, binary_sdf, sdf_x, sdf_y, sdf_z), 1) - sdf_encoding = self.geo_processor_sdf(sdf) - sdf_encoding = self.geo_processor_sdf_out(sdf_encoding) - - if self.geo_encoding_type == "both": - # Geometry encoding comprised of short-range, long-range and SDF features - encoding_g = torch.cat((x_encoding, sdf_encoding), 1) - elif self.geo_encoding_type == "sdf": - encoding_g = sdf_encoding - elif self.geo_encoding_type == "stl": - encoding_g = x_encoding - - if self.cross_attention: - encoding_g = self.combined_unet(encoding_g) - - return encoding_g - - -class NNBasisFunctions(nn.Module): - """Basis function layer for point clouds""" - - def __init__(self, input_features: int, model_parameters=None): - super(NNBasisFunctions, self).__init__() - base_layer = model_parameters.base_layer - self.fourier_features = model_parameters.fourier_features - self.num_modes = model_parameters.num_modes - - if self.fourier_features: - input_features_calculated = ( - input_features + input_features * self.num_modes * 2 - ) - else: - input_features_calculated = input_features - - self.fc1 = nn.Linear(input_features_calculated, base_layer) - self.fc2 = nn.Linear(base_layer, int(base_layer)) - self.fc3 = nn.Linear(int(base_layer), int(base_layer)) - - self.activation = get_activation(model_parameters.activation) - - if self.fourier_features: - self.register_buffer( - "freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes)) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Transform point features into a basis function representation. - - Args: - x: Input tensor containing point features - - Returns: - Tensor containing basis function coefficients - """ - if self.fourier_features: - facets = torch.cat((x, fourier_encode_vectorized(x, self.freqs)), dim=-1) - else: - facets = x - facets = self.activation(self.fc1(facets)) - facets = self.activation(self.fc2(facets)) - facets = self.fc3(facets) - - return facets - - -class ParameterModel(nn.Module): - """ - Neural network module to encode simulation parameters. - - This module encodes physical global parameters into a learned - latent representation that can be incorporated into the - model'sprediction process. - """ - - def __init__(self, input_features: int, model_parameters=None): - """ - Initialize the parameter encoding network. - - Args: - input_features: Number of input parameters to encode - model_parameters: Configuration parameters for the model - """ - super(ParameterModel, self).__init__() - self.fourier_features = model_parameters.fourier_features - self.num_modes = model_parameters.num_modes - - if self.fourier_features: - input_features_calculated = ( - input_features + input_features * self.num_modes * 2 - ) - self.register_buffer( - "freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes)) - ) - else: - input_features_calculated = input_features - - base_layer = model_parameters.base_layer - self.fc1 = nn.Linear(input_features_calculated, base_layer) - self.fc2 = nn.Linear(base_layer, int(base_layer)) - self.fc3 = nn.Linear(int(base_layer), int(base_layer)) - - self.activation = get_activation(model_parameters.activation) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Encode physical parameters into a latent representation. - - Args: - x: Input tensor containing physical parameters (e.g., inlet velocity, air density) - - Returns: - Tensor containing encoded parameter representation - """ - if self.fourier_features: - params = torch.cat((x, fourier_encode_vectorized(x, self.freqs)), dim=-1) - else: - params = x - params = self.activation(self.fc1(params)) - params = self.activation(self.fc2(params)) - params = self.fc3(params) - - return params - - -class AggregationModel(nn.Module): - """ - Neural network module to aggregate local geometry encoding with basis functions. - - This module combines basis function representations with geometry encodings - to predict the final output quantities. It serves as the final prediction layer - that integrates all available information sources. - """ - - def __init__( - self, - input_features: int, - output_features: int, - model_parameters=None, - new_change: bool = True, - ): - """ - Initialize the aggregation model. - - Args: - input_features: Number of input feature dimensions - output_features: Number of output feature dimensions - model_parameters: Configuration parameters for the model - new_change: Flag to enable newer implementation (default: True) - """ - super(AggregationModel, self).__init__() - self.input_features = input_features - self.output_features = output_features - self.new_change = new_change - base_layer = model_parameters.base_layer - self.fc1 = nn.Linear(self.input_features, base_layer) - self.fc2 = nn.Linear(base_layer, int(base_layer)) - self.fc3 = nn.Linear(int(base_layer), int(base_layer)) - self.fc4 = nn.Linear(int(base_layer), int(base_layer)) - self.fc5 = nn.Linear(int(base_layer), self.output_features) - - self.activation = get_activation(model_parameters.activation) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Process the combined input features to predict output quantities. - - This method applies a series of fully connected layers to the input, - which typically contains a combination of basis functions, geometry - encodings, and potentially parameter encodings. - - Args: - x: Input tensor containing combined features - - Returns: - Tensor containing predicted output quantities - """ - out = self.activation(self.fc1(x)) - out = self.activation(self.fc2(out)) - out = self.activation(self.fc3(out)) - out = self.activation(self.fc4(out)) - - out = self.fc5(out) - - return out - - -class LocalPointConv(nn.Module): - """Layer for local geometry point kernel""" - - def __init__( - self, - input_features, - base_layer, - output_features, - model_parameters=None, - ): - super(LocalPointConv, self).__init__() - self.input_features = input_features - self.output_features = output_features - self.fc1 = nn.Linear(self.input_features, base_layer) - self.fc2 = nn.Linear(base_layer, self.output_features) - self.activation = get_activation(model_parameters.activation) - - def forward(self, x): - out = self.activation(self.fc1(x)) - out = self.fc2(out) - - return out - - -class PositionEncoder(nn.Module): - """Positional encoding of point clouds""" - - def __init__(self, input_features: int, model_parameters=None): - super().__init__() - base_layer = model_parameters.base_neurons - self.fourier_features = model_parameters.fourier_features - self.num_modes = model_parameters.num_modes - - if self.fourier_features: - input_features_calculated = ( - input_features + input_features * self.num_modes * 2 - ) - else: - input_features_calculated = input_features - - self.fc1 = nn.Linear(input_features_calculated, base_layer) - self.fc2 = nn.Linear(base_layer, int(base_layer)) - self.fc3 = nn.Linear(int(base_layer), int(base_layer)) - - self.activation = get_activation(model_parameters.activation) - - if self.fourier_features: - self.register_buffer( - "freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes)) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Transform point features into a basis function representation. - - Args: - x: Input tensor containing point features - - Returns: - Tensor containing position encoder - """ - if self.fourier_features: - facets = torch.cat((x, fourier_encode_vectorized(x, self.freqs)), axis=-1) - else: - facets = x - facets = self.activation(self.fc1(facets)) - facets = self.activation(self.fc2(facets)) - facets = self.fc3(facets) +import torch +import torch.nn as nn - return facets +from physicsnemo.models.layers import FourierMLP, get_activation +from physicsnemo.models.unet import UNet +from .encodings import ( + MultiGeometryEncoding, +) +from .geometry_rep import GeometryRep, scale_sdf +from .mlps import AggregationModel +from .solutions import SolutionCalculatorSurface, SolutionCalculatorVolume # @dataclass # class MetaData(ModelMetaData): @@ -1010,7 +171,7 @@ def __init__( ValueError: If both output_features_vol and output_features_surf are None """ super().__init__() - self.input_features = input_features + self.output_features_vol = output_features_vol self.output_features_surf = output_features_surf self.num_sample_points_surface = model_parameters.num_neighbors_surface @@ -1106,23 +267,6 @@ def __init__( self.encode_parameters = model_parameters.encode_parameters self.geo_encoding_type = model_parameters.geometry_encoding_type - if hasattr(model_parameters, "num_volume_neighbors"): - self.num_volume_neighbors = model_parameters.num_volume_neighbors - else: - self.num_volume_neighbors = 50 - - if hasattr(model_parameters, "return_volume_neighbors"): - self.return_volume_neighbors = model_parameters.return_volume_neighbors - if ( - self.return_volume_neighbors - and self.solution_calculation_mode == "one-loop" - ): - print( - "'one-loop' solution_calculation mode not supported when return_volume_neighbors is set to true" - ) - print("Overwriting the solution_calculation mode to 'two-loop'") - self.solution_calculation_mode = "two-loop" - if self.use_surface_normals: if not self.use_surface_area: input_features_surface = input_features + 3 @@ -1134,9 +278,12 @@ def __init__( if self.encode_parameters: # Defining the parameter model base_layer_p = model_parameters.parameter_model.base_layer - self.parameter_model = ParameterModel( + self.parameter_model = FourierMLP( input_features=self.global_features, - model_parameters=model_parameters.parameter_model, + fourier_features=model_parameters.parameter_model.fourier_features, + num_modes=model_parameters.parameter_model.num_modes, + base_layer=model_parameters.parameter_model.base_layer, + activation=get_activation(model_parameters.parameter_model.activation), ) else: base_layer_p = 0 @@ -1146,6 +293,7 @@ def __init__( radii=model_parameters.geometry_rep.geo_conv.volume_radii, neighbors_in_radius=model_parameters.geometry_rep.geo_conv.volume_neighbors_in_radius, hops=model_parameters.geometry_rep.geo_conv.volume_hops, + sdf_scaling_factor=model_parameters.geometry_rep.geo_processor.volume_sdf_scaling_factor, model_parameters=model_parameters, ) @@ -1154,13 +302,7 @@ def __init__( radii=model_parameters.geometry_rep.geo_conv.surface_radii, neighbors_in_radius=model_parameters.geometry_rep.geo_conv.surface_neighbors_in_radius, hops=model_parameters.geometry_rep.geo_conv.surface_hops, - model_parameters=model_parameters, - ) - - self.geo_rep_surface1 = GeometryRep( - input_features=input_features, - radii=model_parameters.geometry_rep.geo_conv.volume_radii, - neighbors_in_radius=model_parameters.geometry_rep.geo_conv.volume_neighbors_in_radius, + sdf_scaling_factor=model_parameters.geometry_rep.geo_processor.surface_sdf_scaling_factor, model_parameters=model_parameters, ) @@ -1172,9 +314,15 @@ def __init__( self.num_variables_surf ): # Have the same basis function for each variable self.nn_basis_surf.append( - NNBasisFunctions( + FourierMLP( input_features=input_features_surface, - model_parameters=model_parameters.nn_basis_functions, + base_layer=model_parameters.nn_basis_functions.base_layer, + fourier_features=model_parameters.nn_basis_functions.fourier_features, + num_modes=model_parameters.nn_basis_functions.num_modes, + activation=get_activation( + model_parameters.nn_basis_functions.activation + ), + # model_parameters=model_parameters.nn_basis_functions, ) ) @@ -1184,9 +332,15 @@ def __init__( self.num_variables_vol ): # Have the same basis function for each variable self.nn_basis_vol.append( - NNBasisFunctions( + FourierMLP( input_features=input_features, - model_parameters=model_parameters.nn_basis_functions, + base_layer=model_parameters.nn_basis_functions.base_layer, + fourier_features=model_parameters.nn_basis_functions.fourier_features, + num_modes=model_parameters.nn_basis_functions.num_modes, + activation=get_activation( + model_parameters.nn_basis_functions.activation + ), + # model_parameters=model_parameters.nn_basis_functions, ) ) @@ -1194,117 +348,62 @@ def __init__( position_encoder_base_neurons = model_parameters.position_encoder.base_neurons self.activation = get_activation(model_parameters.activation) self.use_sdf_in_basis_func = model_parameters.use_sdf_in_basis_func + self.sdf_scaling_factor = ( + model_parameters.geometry_rep.geo_processor.volume_sdf_scaling_factor + ) if self.output_features_vol is not None: - if model_parameters.positional_encoding: - inp_pos_vol = 25 if model_parameters.use_sdf_in_basis_func else 12 - else: - inp_pos_vol = 7 if model_parameters.use_sdf_in_basis_func else 3 - - self.fc_p_vol = PositionEncoder( - inp_pos_vol, model_parameters.position_encoder - ) - - if self.output_features_surf is not None: - if model_parameters.positional_encoding: - inp_pos_surf = 12 - else: - inp_pos_surf = 3 - - self.fc_p_surf = PositionEncoder( - inp_pos_surf, model_parameters.position_encoder + inp_pos_vol = ( + 7 + len(self.sdf_scaling_factor) + if model_parameters.use_sdf_in_basis_func + else 3 ) - # BQ for surface - self.surface_neighbors_in_radius = ( - model_parameters.geometry_local.surface_neighbors_in_radius - ) - self.surface_radius = model_parameters.geometry_local.surface_radii - self.surface_bq_warp = nn.ModuleList() - self.surface_local_point_conv = nn.ModuleList() - - for ct in range(len(self.surface_radius)): - if self.geo_encoding_type == "both": - total_neighbors_in_radius = self.surface_neighbors_in_radius[ct] * ( - len(model_parameters.geometry_rep.geo_conv.surface_radii) + 1 - ) - elif self.geo_encoding_type == "stl": - total_neighbors_in_radius = self.surface_neighbors_in_radius[ct] * ( - len(model_parameters.geometry_rep.geo_conv.surface_radii) - ) - elif self.geo_encoding_type == "sdf": - total_neighbors_in_radius = self.surface_neighbors_in_radius[ct] - - self.surface_bq_warp.append( - BQWarp( - grid_resolution=model_parameters.interp_res, - radius=self.surface_radius[ct], - neighbors_in_radius=self.surface_neighbors_in_radius[ct], - ) - ) - self.surface_local_point_conv.append( - LocalPointConv( - input_features=total_neighbors_in_radius, - base_layer=512, - output_features=self.surface_neighbors_in_radius[ct], - model_parameters=model_parameters.local_point_conv, - ) + self.fc_p_vol = FourierMLP( + input_features=inp_pos_vol, + fourier_features=model_parameters.position_encoder.fourier_features, + num_modes=model_parameters.position_encoder.num_modes, + base_layer=model_parameters.position_encoder.base_neurons, + activation=get_activation(model_parameters.position_encoder.activation), ) - # BQ for volume - self.volume_neighbors_in_radius = ( - model_parameters.geometry_local.volume_neighbors_in_radius + if self.output_features_surf is not None: + inp_pos_surf = 3 + + self.fc_p_surf = FourierMLP( + input_features=inp_pos_surf, + fourier_features=model_parameters.position_encoder.fourier_features, + num_modes=model_parameters.position_encoder.num_modes, + base_layer=model_parameters.position_encoder.base_neurons, + activation=get_activation(model_parameters.position_encoder.activation), + ) + + # Create a set of local geometry encodings for the surface data: + self.surface_local_geo_encodings = MultiGeometryEncoding( + radii=model_parameters.geometry_local.surface_radii, + neighbors_in_radius=model_parameters.geometry_local.surface_neighbors_in_radius, + geo_encoding_type=self.geo_encoding_type, + n_upstream_radii=len(model_parameters.geometry_rep.geo_conv.surface_radii), + base_layer=512, + activation=get_activation(model_parameters.local_point_conv.activation), + grid_resolution=self.grid_resolution, ) - self.volume_radius = model_parameters.geometry_local.volume_radii - self.volume_bq_warp = nn.ModuleList() - self.volume_local_point_conv = nn.ModuleList() - - for ct in range(len(self.volume_radius)): - if self.geo_encoding_type == "both": - total_neighbors_in_radius = self.volume_neighbors_in_radius[ct] * ( - len(model_parameters.geometry_rep.geo_conv.volume_radii) + 1 - ) - elif self.geo_encoding_type == "stl": - total_neighbors_in_radius = self.volume_neighbors_in_radius[ct] * ( - len(model_parameters.geometry_rep.geo_conv.volume_radii) - ) - elif self.geo_encoding_type == "sdf": - total_neighbors_in_radius = self.volume_neighbors_in_radius[ct] - - self.volume_bq_warp.append( - BQWarp( - grid_resolution=model_parameters.interp_res, - radius=self.volume_radius[ct], - neighbors_in_radius=self.volume_neighbors_in_radius[ct], - ) - ) - self.volume_local_point_conv.append( - LocalPointConv( - input_features=total_neighbors_in_radius, - base_layer=512, - output_features=self.volume_neighbors_in_radius[ct], - model_parameters=model_parameters.local_point_conv, - ) - ) - # Transmitting surface to volume - self.surf_to_vol_conv1 = nn.Conv3d( - len(model_parameters.geometry_rep.geo_conv.volume_radii) + 1, - 16, - kernel_size=3, - padding="same", - ) - self.surf_to_vol_conv2 = nn.Conv3d( - 16, - len(model_parameters.geometry_rep.geo_conv.volume_radii) + 1, - kernel_size=3, - padding="same", + # Create a set of local geometry encodings for the surface data: + self.volume_local_geo_encodings = MultiGeometryEncoding( + radii=model_parameters.geometry_local.volume_radii, + neighbors_in_radius=model_parameters.geometry_local.volume_neighbors_in_radius, + geo_encoding_type=self.geo_encoding_type, + n_upstream_radii=len(model_parameters.geometry_rep.geo_conv.volume_radii), + base_layer=512, + activation=get_activation(model_parameters.local_point_conv.activation), + grid_resolution=self.grid_resolution, ) # Aggregation model if self.output_features_surf is not None: # Surface base_layer_geo_surf = 0 - for j in self.surface_neighbors_in_radius: + for j in model_parameters.geometry_local.surface_neighbors_in_radius: base_layer_geo_surf += j self.agg_model_surf = nn.ModuleList() @@ -1316,14 +415,30 @@ def __init__( + base_layer_geo_surf + base_layer_p, output_features=1, - model_parameters=model_parameters.aggregation_model, + base_layer=model_parameters.aggregation_model.base_layer, + activation=get_activation( + model_parameters.aggregation_model.activation + ), ) ) + self.solution_calculator_surf = SolutionCalculatorSurface( + num_variables=self.num_variables_surf, + num_sample_points=self.num_sample_points_surface, + use_surface_normals=self.use_surface_normals, + use_surface_area=self.use_surface_area, + encode_parameters=self.encode_parameters, + parameter_model=self.parameter_model + if self.encode_parameters + else None, + aggregation_model=self.agg_model_surf, + nn_basis=self.nn_basis_surf, + ) + if self.output_features_vol is not None: # Volume base_layer_geo_vol = 0 - for j in self.volume_neighbors_in_radius: + for j in model_parameters.geometry_local.volume_neighbors_in_radius: base_layer_geo_vol += j self.agg_model_vol = nn.ModuleList() @@ -1335,539 +450,31 @@ def __init__( + base_layer_geo_vol + base_layer_p, output_features=1, - model_parameters=model_parameters.aggregation_model, - ) - ) - - def position_encoder( - self, - encoding_node: torch.Tensor, - eval_mode: Literal["surface", "volume"] = "volume", - ) -> torch.Tensor: - """ - Compute positional encoding for input points. - - Args: - encoding_node: Tensor containing node position information - eval_mode: Mode of evaluation, either "volume" or "surface" - - Returns: - Tensor containing positional encoding features - """ - if eval_mode == "volume": - x = self.fc_p_vol(encoding_node) - elif eval_mode == "surface": - x = self.fc_p_surf(encoding_node) - else: - raise ValueError( - f"`eval_mode` must be 'surface' or 'volume', got {eval_mode=}" - ) - return x - - def geo_encoding_local( - self, encoding_g, volume_mesh_centers, p_grid, mode="volume" - ): - """Function to calculate local geometry encoding from global encoding""" - - if mode == "volume": - radius = self.volume_radius - bq_warp = self.volume_bq_warp - point_conv = self.volume_local_point_conv - elif mode == "surface": - radius = self.surface_radius - bq_warp = self.surface_bq_warp - point_conv = self.surface_local_point_conv - - batch_size = volume_mesh_centers.shape[0] - nx, ny, nz = ( - self.grid_resolution[0], - self.grid_resolution[1], - self.grid_resolution[2], - ) - - encoding_outer = [] - for p in range(len(radius)): - p_grid = torch.reshape(p_grid, (batch_size, nx * ny * nz, 3)) - mapping, outputs = bq_warp[p]( - volume_mesh_centers, p_grid, reverse_mapping=False - ) - mapping = mapping.type(torch.int64) - mask = mapping != 0 - - encoding_g_inner = [] - for j in range(encoding_g.shape[1]): - geo_encoding = rearrange( - encoding_g[:, j], "b nx ny nz -> b 1 (nx ny nz)" - ) - - geo_encoding_sampled = torch.index_select( - geo_encoding, 2, mapping.flatten() - ) - geo_encoding_sampled = torch.reshape(geo_encoding_sampled, mask.shape) - geo_encoding_sampled = geo_encoding_sampled * mask - - encoding_g_inner.append(geo_encoding_sampled) - encoding_g_inner = torch.cat(encoding_g_inner, dim=2) - encoding_g_inner = point_conv[p](encoding_g_inner) - - encoding_outer.append(encoding_g_inner) - - encoding_g = torch.cat(encoding_outer, dim=-1) - - return encoding_g - - def calculate_solution_with_neighbors( - self, - surface_mesh_centers, - encoding_g, - encoding_node, - surface_mesh_neighbors, - surface_normals, - surface_neighbors_normals, - surface_areas, - surface_neighbors_areas, - global_params_values, - global_params_reference, - num_sample_points=7, - ): - """Function to approximate solution given the neighborhood information""" - num_variables = self.num_variables_surf - nn_basis = self.nn_basis_surf - agg_model = self.agg_model_surf - - if self.encode_parameters: - processed_parameters = [] - for k in range(global_params_values.shape[1]): - param = torch.unsqueeze(global_params_values[:, k, :], 1) - ref = torch.unsqueeze(global_params_reference[:, k, :], 1) - param = param.expand( - param.shape[0], - surface_mesh_centers.shape[1], - param.shape[2], - ) - param = param / ref - processed_parameters.append(param) - processed_parameters = torch.cat(processed_parameters, axis=-1) - param_encoding = self.parameter_model(processed_parameters) - - if self.use_surface_normals: - if not self.use_surface_area: - surface_mesh_centers = torch.cat( - (surface_mesh_centers, surface_normals), - dim=-1, - ) - if num_sample_points > 1: - surface_mesh_neighbors = torch.cat( - ( - surface_mesh_neighbors, - surface_neighbors_normals, - ), - dim=-1, - ) - - else: - surface_mesh_centers = torch.cat( - ( - surface_mesh_centers, - surface_normals, - torch.log(surface_areas) / 10, - ), - dim=-1, - ) - if num_sample_points > 1: - surface_mesh_neighbors = torch.cat( - ( - surface_mesh_neighbors, - surface_neighbors_normals, - torch.log(surface_neighbors_areas) / 10, - ), - dim=-1, - ) - - if self.solution_calculation_mode == "one-loop": - encoding_list = [ - encoding_node.unsqueeze(2).expand(-1, -1, num_sample_points, -1), - encoding_g.unsqueeze(2).expand(-1, -1, num_sample_points, -1), - ] - - for f in range(num_variables): - one_loop_centers_expanded = surface_mesh_centers.unsqueeze(2) - - one_loop_noise = one_loop_centers_expanded - ( - surface_mesh_neighbors + 1e-6 - ) - one_loop_noise = torch.norm(one_loop_noise, dim=-1, keepdim=True) - - # Doing it this way prevents the intermediate one_loop_basis_f from being stored in memory for the rest of the function. - agg_output = agg_model[f]( - torch.cat( - ( - nn_basis[f]( - torch.cat( - ( - one_loop_centers_expanded, - surface_mesh_neighbors + 1e-6, - ), - dim=2, - ) - ), - *encoding_list, + base_layer=model_parameters.aggregation_model.base_layer, + activation=get_activation( + model_parameters.aggregation_model.activation ), - dim=-1, - ) - ) - - one_loop_output_center, one_loop_output_neighbor = torch.split( - agg_output, [1, num_sample_points - 1], dim=2 - ) - one_loop_output_neighbor = one_loop_output_neighbor * ( - 1.0 / one_loop_noise - ) - - one_loop_output_center = one_loop_output_center.squeeze(2) - one_loop_output_neighbor = one_loop_output_neighbor.sum(2) - one_loop_dist_sum = torch.sum(1.0 / one_loop_noise, dim=2) - - # Stop here - if num_sample_points > 1: - one_loop_output_res = ( - 0.5 * one_loop_output_center - + 0.5 * one_loop_output_neighbor / one_loop_dist_sum - ) - else: - one_loop_output_res = one_loop_output_center - if f == 0: - one_loop_output_all = one_loop_output_res - else: - one_loop_output_all = torch.cat( - (one_loop_output_all, one_loop_output_res), dim=-1 - ) - - return one_loop_output_all - - if self.solution_calculation_mode == "two-loop": - for f in range(num_variables): - for p in range(num_sample_points): - if p == 0: - volume_m_c = surface_mesh_centers - else: - volume_m_c = surface_mesh_neighbors[:, :, p - 1] + 1e-6 - noise = surface_mesh_centers - volume_m_c - dist = torch.norm(noise, dim=-1, keepdim=True) - - basis_f = nn_basis[f](volume_m_c) - output = torch.cat((basis_f, encoding_node, encoding_g), dim=-1) - if self.encode_parameters: - output = torch.cat((output, param_encoding), dim=-1) - if p == 0: - output_center = agg_model[f](output) - else: - if p == 1: - output_neighbor = agg_model[f](output) * (1.0 / dist) - dist_sum = 1.0 / dist - else: - output_neighbor += agg_model[f](output) * (1.0 / dist) - dist_sum += 1.0 / dist - if num_sample_points > 1: - output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum - else: - output_res = output_center - if f == 0: - output_all = output_res - else: - output_all = torch.cat((output_all, output_res), dim=-1) - - return output_all - - def sample_sphere(self, center, r, num_points): - """Uniformly sample points in a 3D sphere around the center. - - This method generates random points within a sphere of radius r centered - at each point in the input tensor. The sampling is uniform in volume, - meaning points are more likely to be sampled in the outer regions of the sphere. - - Args: - center: Tensor of shape (batch_size, num_points, 3) containing center coordinates - r: Radius of the sphere for sampling - num_points: Number of points to sample per center - - Returns: - Tensor of shape (batch_size, num_points, num_samples, 3) containing - the sampled points around each center - """ - # Adjust the center points to the final shape: - unsqueezed_center = center.unsqueeze(2).expand(-1, -1, num_points, -1) - - # Generate directions like the centers: - directions = torch.randn_like(unsqueezed_center) - directions = directions / torch.norm(directions, dim=-1, keepdim=True) - - # Generate radii like the centers: - radii = r * torch.pow(torch.rand_like(unsqueezed_center), 1 / 3) - - output = unsqueezed_center + directions * radii - return output - - def sample_sphere_shell(self, center, r_inner, r_outer, num_points): - """Uniformly sample points in a 3D spherical shell around a center. - - This method generates random points within a spherical shell (annulus) - between inner radius r_inner and outer radius r_outer centered at each - point in the input tensor. The sampling is uniform in volume within the shell. - - Args: - center: Tensor of shape (batch_size, num_points, 3) containing center coordinates - r_inner: Inner radius of the spherical shell - r_outer: Outer radius of the spherical shell - num_points: Number of points to sample per center - - Returns: - Tensor of shape (batch_size, num_points, num_samples, 3) containing - the sampled points within the spherical shell around each center - """ - # directions = torch.randn( - # size=(center.shape[0], center.shape[1], num_points, center.shape[2]), - # device=center.device, - # ) - # directions = directions / torch.norm(directions, dim=-1, keepdim=True) - - unsqueezed_center = center.unsqueeze(2).expand(-1, -1, num_points, -1) - - # Generate directions like the centers: - directions = torch.randn_like(unsqueezed_center) - directions = directions / torch.norm(directions, dim=-1, keepdim=True) - - radii = ( - torch.rand_like(unsqueezed_center) * (r_outer**3 - r_inner**3) + r_inner**3 - ) - radii = torch.pow(radii, 1 / 3) - - output = unsqueezed_center + directions * radii - - return output - - def calculate_solution( - self, - volume_mesh_centers, - encoding_g, - encoding_node, - global_params_values, - global_params_reference, - eval_mode, - num_sample_points=20, - noise_intensity=50, - return_volume_neighbors=False, - ): - """Function to approximate solution sampling the neighborhood information""" - if eval_mode == "volume": - num_variables = self.num_variables_vol - nn_basis = self.nn_basis_vol - agg_model = self.agg_model_vol - elif eval_mode == "surface": - num_variables = self.num_variables_surf - nn_basis = self.nn_basis_surf - agg_model = self.agg_model_surf - - if self.encode_parameters: - processed_parameters = [] - for k in range(global_params_values.shape[1]): - param = torch.unsqueeze(global_params_values[:, k, :], 1) - ref = torch.unsqueeze(global_params_reference[:, k, :], 1) - param = param.expand( - param.shape[0], - volume_mesh_centers.shape[1], - param.shape[2], - ) - param = param / ref - processed_parameters.append(param) - processed_parameters = torch.cat(processed_parameters, axis=-1) - param_encoding = self.parameter_model(processed_parameters) - - if self.solution_calculation_mode == "one-loop": - # Stretch these out to num_sample_points - one_loop_encoding_node = encoding_node.unsqueeze(0).expand( - num_sample_points, -1, -1, -1 - ) - one_loop_encoding_g = encoding_g.unsqueeze(0).expand( - num_sample_points, -1, -1, -1 - ) - - if self.encode_parameters: - one_loop_other_terms = ( - one_loop_encoding_node, - one_loop_encoding_g, - param_encoding, - ) - else: - one_loop_other_terms = (one_loop_encoding_node, one_loop_encoding_g) - - for f in range(num_variables): - one_loop_volume_mesh_centers_expanded = volume_mesh_centers.unsqueeze( - 0 - ).expand(num_sample_points, -1, -1, -1) - # Bulk_random_noise has shape (num_sample_points, batch_size, num_points, 3) - one_loop_bulk_random_noise = torch.rand_like( - one_loop_volume_mesh_centers_expanded - ) - - one_loop_bulk_random_noise = 2 * (one_loop_bulk_random_noise - 0.5) - one_loop_bulk_random_noise = ( - one_loop_bulk_random_noise / noise_intensity - ) - one_loop_bulk_dist = torch.norm( - one_loop_bulk_random_noise, dim=-1, keepdim=True - ) - - _, one_loop_bulk_dist = torch.split( - one_loop_bulk_dist, [1, num_sample_points - 1], dim=0 - ) - - # Set the first sample point to 0.0: - one_loop_bulk_random_noise[0] = torch.zeros_like( - one_loop_bulk_random_noise[0] - ) - - # Add the noise to the expanded volume_mesh_centers: - one_loop_volume_m_c = volume_mesh_centers + one_loop_bulk_random_noise - # If this looks overly complicated - it is. - # But, this makes sure that the memory used to store the output of both nn_basis[f] - # as well as the output of torch.cat can be deallocated immediately. - # Apply the aggregation model and distance scaling: - one_loop_output = agg_model[f]( - torch.cat( - (nn_basis[f](one_loop_volume_m_c), *one_loop_other_terms), - dim=-1, - ) - ) - - # select off the first, unperturbed term: - one_loop_output_center, one_loop_output_neighbor = torch.split( - one_loop_output, [1, num_sample_points - 1], dim=0 - ) - - # Scale the neighbor terms by the distance: - one_loop_output_neighbor = one_loop_output_neighbor / one_loop_bulk_dist - - one_loop_dist_sum = torch.sum(1.0 / one_loop_bulk_dist, dim=0) - - # Adjust shapes: - one_loop_output_center = one_loop_output_center.squeeze(1) - one_loop_output_neighbor = one_loop_output_neighbor.sum(0) - - # Compare: - if num_sample_points > 1: - one_loop_output_res = ( - 0.5 * one_loop_output_center - + 0.5 * one_loop_output_neighbor / one_loop_dist_sum ) - else: - one_loop_output_res = one_loop_output_center - if f == 0: - one_loop_output_all = one_loop_output_res - else: - one_loop_output_all = torch.cat( - (one_loop_output_all, one_loop_output_res), dim=-1 - ) - - return one_loop_output_all - - if self.solution_calculation_mode == "two-loop": - volume_m_c_perturbed = [volume_mesh_centers.unsqueeze(2)] - - if return_volume_neighbors: - num_hop1 = num_sample_points - num_hop2 = ( - num_sample_points // 2 if num_sample_points != 1 else 1 - ) # This is per 1 hop node - neighbors = defaultdict(list) - - volume_m_c_hop1 = self.sample_sphere( - volume_mesh_centers, 1 / noise_intensity, num_hop1 ) - # 1 hop neighbors - for i in range(num_hop1): - idx = len(volume_m_c_perturbed) - volume_m_c_perturbed.append(volume_m_c_hop1[:, :, i : i + 1, :]) - neighbors[0].append(idx) - - # 2 hop neighbors - for i in range(num_hop1): - parent_idx = ( - i + 1 - ) # Skipping the first point, which is the original - parent_point = volume_m_c_perturbed[parent_idx] - - children = self.sample_sphere_shell( - parent_point.squeeze(2), - 1 / noise_intensity, - 2 / noise_intensity, - num_hop2, - ) - - for c in range(num_hop2): - idx = len(volume_m_c_perturbed) - volume_m_c_perturbed.append(children[:, :, c : c + 1, :]) - neighbors[parent_idx].append(idx) - - volume_m_c_perturbed = torch.cat(volume_m_c_perturbed, dim=2) - neighbors = dict(neighbors) - field_neighbors = {i: [] for i in range(num_variables)} + if hasattr(model_parameters, "return_volume_neighbors"): + return_volume_neighbors = model_parameters.return_volume_neighbors else: - volume_m_c_sample = self.sample_sphere( - volume_mesh_centers, 1 / noise_intensity, num_sample_points - ) - for i in range(num_sample_points): - volume_m_c_perturbed.append(volume_m_c_sample[:, :, i : i + 1, :]) - - volume_m_c_perturbed = torch.cat(volume_m_c_perturbed, dim=2) - - for f in range(num_variables): - for p in range(volume_m_c_perturbed.shape[2]): - volume_m_c = volume_m_c_perturbed[:, :, p, :] - if p != 0: - dist = torch.norm( - volume_m_c - volume_mesh_centers, dim=-1, keepdim=True - ) - basis_f = nn_basis[f](volume_m_c) - output = torch.cat((basis_f, encoding_node, encoding_g), dim=-1) - if self.encode_parameters: - output = torch.cat((output, param_encoding), dim=-1) - if p == 0: - output_center = agg_model[f](output) - else: - if p == 1: - output_neighbor = agg_model[f](output) * (1.0 / dist) - dist_sum = 1.0 / dist - else: - output_neighbor += agg_model[f](output) * (1.0 / dist) - dist_sum += 1.0 / dist - if return_volume_neighbors: - field_neighbors[f].append(agg_model[f](output)) - - if return_volume_neighbors: - field_neighbors[f] = torch.stack(field_neighbors[f], dim=2) + return_volume_neighbors = False - if num_sample_points > 1: - output_res = ( - 0.5 * output_center + 0.5 * output_neighbor / dist_sum - ) # This only applies to the main point, and not the preturbed points - else: - output_res = output_center - if f == 0: - output_all = output_res - else: - output_all = torch.cat((output_all, output_res), axis=-1) - - if return_volume_neighbors: - field_neighbors = torch.cat( - [field_neighbors[i] for i in range(num_variables)], dim=3 - ) - return output_all, volume_m_c_perturbed, field_neighbors, neighbors - else: - return output_all + self.solution_calculator_vol = SolutionCalculatorVolume( + num_variables=self.num_variables_vol, + num_sample_points=self.num_sample_points_volume, + noise_intensity=50, + return_volume_neighbors=return_volume_neighbors, + encode_parameters=self.encode_parameters, + parameter_model=self.parameter_model + if self.encode_parameters + else None, + aggregation_model=self.agg_model_vol, + nn_basis=self.nn_basis_vol, + ) - @profile - def forward(self, data_dict, return_volume_neighbors=False): + def forward(self, data_dict): # Loading STL inputs, bounding box grids, precomputed SDF and scaling factors # STL nodes @@ -1876,9 +483,6 @@ def forward(self, data_dict, return_volume_neighbors=False): # Bounding box grid s_grid = data_dict["surf_grid"] sdf_surf_grid = data_dict["sdf_surf_grid"] - # Scaling factors - surf_max = data_dict["surface_min_max"][:, 1] - surf_min = data_dict["surface_min_max"][:, 0] # Parameters global_params_values = data_dict["global_params_values"] @@ -1890,37 +494,61 @@ def forward(self, data_dict, return_volume_neighbors=False): p_grid = data_dict["grid"] sdf_grid = data_dict["sdf_grid"] # Scaling factors - vol_max = data_dict["volume_min_max"][:, 1] - vol_min = data_dict["volume_min_max"][:, 0] + if "volume_min_max" in data_dict.keys(): + vol_max = data_dict["volume_min_max"][:, 1] + vol_min = data_dict["volume_min_max"][:, 0] - # Normalize based on computational domain - geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + # Normalize based on computational domain + geo_centers_vol = ( + 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1 + ) + else: + geo_centers_vol = geo_centers encoding_g_vol = self.geo_rep_volume(geo_centers_vol, p_grid, sdf_grid) # SDF on volume mesh nodes sdf_nodes = data_dict["sdf_nodes"] + # scaled_sdf_nodes = [] + # for i in range(len(self.sdf_scaling_factor)): + # scaled_sdf_nodes.append(scale_sdf(sdf_nodes, self.sdf_scaling_factor[i])) + scaled_sdf_nodes = [ + scale_sdf(sdf_nodes, scaling) for scaling in self.sdf_scaling_factor + ] + scaled_sdf_nodes = torch.cat(scaled_sdf_nodes, dim=-1) + # Positional encoding based on closest point on surface to a volume node pos_volume_closest = data_dict["pos_volume_closest"] # Positional encoding based on center of mass of geometry to volume node pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"] if self.use_sdf_in_basis_func: encoding_node_vol = torch.cat( - (sdf_nodes, pos_volume_closest, pos_volume_center_of_mass), dim=-1 + ( + sdf_nodes, + scaled_sdf_nodes, + pos_volume_closest, + pos_volume_center_of_mass, + ), + dim=-1, ) else: encoding_node_vol = pos_volume_center_of_mass # Calculate positional encoding on volume nodes - encoding_node_vol = self.position_encoder( - encoding_node_vol, eval_mode="volume" - ) + encoding_node_vol = self.fc_p_vol(encoding_node_vol) if self.output_features_surf is not None: # Represent geometry on bounding box - geo_centers_surf = ( - 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 - ) + # Scaling factors + if "surface_min_max" in data_dict.keys(): + surf_max = data_dict["surface_min_max"][:, 1] + surf_min = data_dict["surface_min_max"][:, 0] + geo_centers_surf = ( + 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1 + ) + else: + geo_centers_surf = geo_centers + encoding_g_surf = self.geo_rep_surface( geo_centers_surf, s_grid, sdf_surf_grid ) @@ -1930,9 +558,7 @@ def forward(self, data_dict, return_volume_neighbors=False): encoding_node_surf = pos_surface_center_of_mass # Calculate positional encoding on surface centers - encoding_node_surf = self.position_encoder( - encoding_node_surf, eval_mode="surface" - ) + encoding_node_surf = self.fc_p_surf(encoding_node_surf) if ( self.output_features_surf is not None @@ -1947,20 +573,19 @@ def forward(self, data_dict, return_volume_neighbors=False): # Calculate local geometry encoding for volume # Sampled points on volume volume_mesh_centers = data_dict["volume_mesh_centers"] - encoding_g_vol = self.geo_encoding_local( - 0.5 * encoding_g_vol, volume_mesh_centers, p_grid, mode="volume" + encoding_g_vol = self.volume_local_geo_encodings( + 0.5 * encoding_g_vol, + volume_mesh_centers, + p_grid, ) # Approximate solution on volume node - output_vol = self.calculate_solution( + output_vol = self.solution_calculator_vol( volume_mesh_centers, encoding_g_vol, encoding_node_vol, global_params_values, global_params_reference, - eval_mode="volume", - num_sample_points=self.num_sample_points_volume, - return_volume_neighbors=return_volume_neighbors, ) else: @@ -1979,12 +604,12 @@ def forward(self, data_dict, return_volume_neighbors=False): surface_areas = torch.unsqueeze(surface_areas, -1) surface_neighbors_areas = torch.unsqueeze(surface_neighbors_areas, -1) # Calculate local geometry encoding for surface - encoding_g_surf = self.geo_encoding_local( - 0.5 * encoding_g_surf, surface_mesh_centers, s_grid, mode="surface" + encoding_g_surf = self.surface_local_geo_encodings( + 0.5 * encoding_g_surf, surface_mesh_centers, s_grid ) # Approximate solution on surface cell center - output_surf = self.calculate_solution_with_neighbors( + output_surf = self.solution_calculator_surf( surface_mesh_centers, encoding_g_surf, encoding_node_surf, @@ -1995,7 +620,6 @@ def forward(self, data_dict, return_volume_neighbors=False): surface_neighbors_areas, global_params_values, global_params_reference, - num_sample_points=self.num_sample_points_surface, ) else: output_surf = None diff --git a/physicsnemo/models/domino/solutions.py b/physicsnemo/models/domino/solutions.py new file mode 100644 index 0000000000..23a7e36f39 --- /dev/null +++ b/physicsnemo/models/domino/solutions.py @@ -0,0 +1,368 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This code contains the DoMINO model architecture. +The DoMINO class contains an architecture to model both surface and +volume quantities together as well as separately (controlled using +the config.yaml file) +""" + +from collections import defaultdict + +import torch +import torch.nn as nn + + +def apply_parameter_encoding( + mesh_centers: torch.Tensor, + global_params_values: torch.Tensor, + global_params_reference: torch.Tensor, +) -> torch.Tensor: + processed_parameters = [] + for k in range(global_params_values.shape[1]): + param = torch.unsqueeze(global_params_values[:, k, :], 1) + ref = torch.unsqueeze(global_params_reference[:, k, :], 1) + param = param.expand( + param.shape[0], + mesh_centers.shape[1], + param.shape[2], + ) + param = param / ref + processed_parameters.append(param) + processed_parameters = torch.cat(processed_parameters, axis=-1) + + return processed_parameters + + +def sample_sphere(center, r, num_points): + """Uniformly sample points in a 3D sphere around the center. + + This method generates random points within a sphere of radius r centered + at each point in the input tensor. The sampling is uniform in volume, + meaning points are more likely to be sampled in the outer regions of the sphere. + + Args: + center: Tensor of shape (batch_size, num_points, 3) containing center coordinates + r: Radius of the sphere for sampling + num_points: Number of points to sample per center + + Returns: + Tensor of shape (batch_size, num_points, num_samples, 3) containing + the sampled points around each center + """ + # Adjust the center points to the final shape: + unsqueezed_center = center.unsqueeze(2).expand(-1, -1, num_points, -1) + + # Generate directions like the centers: + directions = torch.randn_like(unsqueezed_center) + directions = directions / torch.norm(directions, dim=-1, keepdim=True) + + # Generate radii like the centers: + radii = r * torch.pow(torch.rand_like(unsqueezed_center), 1 / 3) + + output = unsqueezed_center + directions * radii + return output + + +def sample_sphere_shell(center, r_inner, r_outer, num_points): + """Uniformly sample points in a 3D spherical shell around a center. + + This method generates random points within a spherical shell (annulus) + between inner radius r_inner and outer radius r_outer centered at each + point in the input tensor. The sampling is uniform in volume within the shell. + + Args: + center: Tensor of shape (batch_size, num_points, 3) containing center coordinates + r_inner: Inner radius of the spherical shell + r_outer: Outer radius of the spherical shell + num_points: Number of points to sample per center + + Returns: + Tensor of shape (batch_size, num_points, num_samples, 3) containing + the sampled points within the spherical shell around each center + """ + + unsqueezed_center = center.unsqueeze(2).expand(-1, -1, num_points, -1) + + # Generate directions like the centers: + directions = torch.randn_like(unsqueezed_center) + directions = directions / torch.norm(directions, dim=-1, keepdim=True) + + radii = torch.rand_like(unsqueezed_center) * (r_outer**3 - r_inner**3) + r_inner**3 + radii = torch.pow(radii, 1 / 3) + + output = unsqueezed_center + directions * radii + + return output + + +class SolutionCalculatorVolume(nn.Module): + """ + Module to calculate the output solution of the DoMINO Model for volume data. + """ + + def __init__( + self, + num_variables: int, + num_sample_points: int, + noise_intensity: float, + encode_parameters: bool, + return_volume_neighbors: bool, + parameter_model: nn.Module | None, + aggregation_model: nn.ModuleList, + nn_basis: nn.ModuleList, + ): + super().__init__() + + self.num_variables = num_variables + self.num_sample_points = num_sample_points + self.noise_intensity = noise_intensity + self.encode_parameters = encode_parameters + self.return_volume_neighbors = return_volume_neighbors + self.parameter_model = parameter_model + self.aggregation_model = aggregation_model + self.nn_basis = nn_basis + + if self.encode_parameters: + if self.parameter_model is None: + raise ValueError( + "Parameter model is required when encode_parameters is True" + ) + + def forward( + self, + volume_mesh_centers: torch.Tensor, + encoding_g: torch.Tensor, + encoding_node: torch.Tensor, + global_params_values: torch.Tensor, + global_params_reference: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: + """ + Forward pass of the SolutionCalculator module. + """ + if self.encode_parameters: + param_encoding = apply_parameter_encoding( + volume_mesh_centers, global_params_values, global_params_reference + ) + param_encoding = self.parameter_model(param_encoding) + + volume_m_c_perturbed = [volume_mesh_centers.unsqueeze(2)] + + if self.return_volume_neighbors: + num_hop1 = self.num_sample_points + num_hop2 = ( + self.num_sample_points // 2 if self.num_sample_points != 1 else 1 + ) # This is per 1 hop node + neighbors = defaultdict(list) + + volume_m_c_hop1 = sample_sphere( + volume_mesh_centers, 1 / self.noise_intensity, num_hop1 + ) + # 1 hop neighbors + for i in range(num_hop1): + idx = len(volume_m_c_perturbed) + volume_m_c_perturbed.append(volume_m_c_hop1[:, :, i : i + 1, :]) + neighbors[0].append(idx) + + # 2 hop neighbors + for i in range(num_hop1): + parent_idx = i + 1 # Skipping the first point, which is the original + parent_point = volume_m_c_perturbed[parent_idx] + + children = sample_sphere_shell( + parent_point.squeeze(2), + 1 / self.noise_intensity, + 2 / self.noise_intensity, + num_hop2, + ) + + for c in range(num_hop2): + idx = len(volume_m_c_perturbed) + volume_m_c_perturbed.append(children[:, :, c : c + 1, :]) + neighbors[parent_idx].append(idx) + + volume_m_c_perturbed = torch.cat(volume_m_c_perturbed, dim=2) + neighbors = dict(neighbors) + field_neighbors = {i: [] for i in range(self.num_variables)} + else: + volume_m_c_sample = sample_sphere( + volume_mesh_centers, 1 / self.noise_intensity, self.num_sample_points + ) + for i in range(self.num_sample_points): + volume_m_c_perturbed.append(volume_m_c_sample[:, :, i : i + 1, :]) + + volume_m_c_perturbed = torch.cat(volume_m_c_perturbed, dim=2) + + for f in range(self.num_variables): + for p in range(volume_m_c_perturbed.shape[2]): + volume_m_c = volume_m_c_perturbed[:, :, p, :] + if p != 0: + dist = torch.norm( + volume_m_c - volume_mesh_centers, dim=-1, keepdim=True + ) + basis_f = self.nn_basis[f](volume_m_c) + output = torch.cat((basis_f, encoding_node, encoding_g), dim=-1) + if self.encode_parameters: + output = torch.cat((output, param_encoding), dim=-1) + if p == 0: + output_center = self.aggregation_model[f](output) + else: + if p == 1: + output_neighbor = self.aggregation_model[f](output) * ( + 1.0 / dist + ) + dist_sum = 1.0 / dist + else: + output_neighbor += self.aggregation_model[f](output) * ( + 1.0 / dist + ) + dist_sum += 1.0 / dist + if self.return_volume_neighbors: + field_neighbors[f].append(self.aggregation_model[f](output)) + + if self.return_volume_neighbors: + field_neighbors[f] = torch.stack(field_neighbors[f], dim=2) + + if self.num_sample_points > 1: + output_res = ( + 0.5 * output_center + 0.5 * output_neighbor / dist_sum + ) # This only applies to the main point, and not the preturbed points + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = torch.cat((output_all, output_res), axis=-1) + + if self.return_volume_neighbors: + field_neighbors = torch.cat( + [field_neighbors[i] for i in range(self.num_variables)], dim=3 + ) + return output_all, volume_m_c_perturbed, field_neighbors, neighbors + else: + return output_all + + +class SolutionCalculatorSurface(nn.Module): + """ + Module to calculate the output solution of the DoMINO Model for surface data. + """ + + def __init__( + self, + num_variables: int, + num_sample_points: int, + encode_parameters: bool, + use_surface_normals: bool, + use_surface_area: bool, + parameter_model: nn.Module | None, + aggregation_model: nn.ModuleList, + nn_basis: nn.ModuleList, + ): + super().__init__() + self.num_variables = num_variables + self.num_sample_points = num_sample_points + self.encode_parameters = encode_parameters + self.use_surface_normals = use_surface_normals + self.use_surface_area = use_surface_area + self.parameter_model = parameter_model + self.aggregation_model = aggregation_model + self.nn_basis = nn_basis + + if self.encode_parameters: + if self.parameter_model is None: + raise ValueError( + "Parameter model is required when encode_parameters is True" + ) + + def forward( + self, + surface_mesh_centers: torch.Tensor, + encoding_g: torch.Tensor, + encoding_node: torch.Tensor, + surface_mesh_neighbors: torch.Tensor, + surface_normals: torch.Tensor, + surface_neighbors_normals: torch.Tensor, + surface_areas: torch.Tensor, + surface_neighbors_areas: torch.Tensor, + global_params_values: torch.Tensor, + global_params_reference: torch.Tensor, + ) -> torch.Tensor: + """Function to approximate solution given the neighborhood information""" + + if self.encode_parameters: + param_encoding = apply_parameter_encoding( + surface_mesh_centers, global_params_values, global_params_reference + ) + param_encoding = self.parameter_model(param_encoding) + + centers_inputs = [ + surface_mesh_centers, + ] + neighbors_inputs = [ + surface_mesh_neighbors, + ] + + if self.use_surface_normals: + centers_inputs.append(surface_normals) + if self.num_sample_points > 1: + neighbors_inputs.append(surface_neighbors_normals) + + if self.use_surface_area: + centers_inputs.append(torch.log(surface_areas) / 10) + if self.num_sample_points > 1: + neighbors_inputs.append(torch.log(surface_neighbors_areas) / 10) + + surface_mesh_centers = torch.cat(centers_inputs, dim=-1) + surface_mesh_neighbors = torch.cat(neighbors_inputs, dim=-1) + + for f in range(self.num_variables): + for p in range(self.num_sample_points): + if p == 0: + volume_m_c = surface_mesh_centers + else: + volume_m_c = surface_mesh_neighbors[:, :, p - 1] + 1e-6 + noise = surface_mesh_centers - volume_m_c + dist = torch.norm(noise, dim=-1, keepdim=True) + + basis_f = self.nn_basis[f](volume_m_c) + output = torch.cat((basis_f, encoding_node, encoding_g), dim=-1) + if self.encode_parameters: + output = torch.cat((output, param_encoding), dim=-1) + if p == 0: + output_center = self.aggregation_model[f](output) + else: + if p == 1: + output_neighbor = self.aggregation_model[f](output) * ( + 1.0 / dist + ) + dist_sum = 1.0 / dist + else: + output_neighbor += self.aggregation_model[f](output) * ( + 1.0 / dist + ) + dist_sum += 1.0 / dist + if self.num_sample_points > 1: + output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum + else: + output_res = output_center + if f == 0: + output_all = output_res + else: + output_all = torch.cat((output_all, output_res), dim=-1) + + return output_all diff --git a/physicsnemo/models/layers/__init__.py b/physicsnemo/models/layers/__init__.py index 627fa4f07f..cfebf5e38d 100644 --- a/physicsnemo/models/layers/__init__.py +++ b/physicsnemo/models/layers/__init__.py @@ -22,9 +22,16 @@ Stan, get_activation, ) +from .ball_query import BQWarp from .conv_layers import ConvBlock, CubeEmbedding from .dgm_layers import DGMLayer -from .fourier_layers import FourierFilter, FourierLayer, GaborFilter +from .fourier_layers import ( + FourierFilter, + FourierLayer, + FourierMLP, + GaborFilter, + fourier_encode, +) from .fully_connected_layers import ( Conv1dFCLayer, Conv2dFCLayer, diff --git a/physicsnemo/models/layers/ball_query.py b/physicsnemo/models/layers/ball_query.py index ee3e1538a9..77416bd57a 100644 --- a/physicsnemo/models/layers/ball_query.py +++ b/physicsnemo/models/layers/ball_query.py @@ -14,504 +14,107 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +""" +This layer is a compilable, ball-query operation. -import torch -import warp as wp -from torch.overrides import handle_torch_function, has_torch_function +By default, it will project a grid of points to a 1D set of points. +It does not support batch size > 1. +""" -@wp.kernel -def ball_query( - points1: wp.array(dtype=wp.vec3), - points2: wp.array(dtype=wp.vec3), - grid: wp.uint64, - k: wp.int32, - radius: wp.float32, - mapping: wp.array3d(dtype=wp.int32), - num_neighbors: wp.array2d(dtype=wp.int32), -): - """ - Performs ball query operation to find neighboring points within a specified radius. +import torch +import torch.nn as nn +from einops import rearrange - For each point in points1, finds up to k neighboring points from points2 that are - within the specified radius. Uses a hash grid for efficient spatial queries. +from physicsnemo.utils.neighbors import radius_search - Note that the neighbors found are not strictly guaranteed to be the closest k neighbors, - in the event that more than k neighbors are found within the radius. - Args: - points1: Array of query points - points2: Array of points to search - grid: Pre-computed hash grid for accelerated spatial queries - k: Maximum number of neighbors to find for each query point - radius: Maximum search radius for finding neighbors - mapping: Output array to store indices of neighboring points. Should be instantiated as zeros(1, len(points1), k) - num_neighbors: Output array to store the number of neighbors found for each query point. Should be instantiated as zeros(1, len(points1)) +class BQWarp(nn.Module): """ - tid = wp.tid() - - # Get position from points1 - pos = points1[tid] - - # particle contact - neighbors = wp.hash_grid_query(id=grid, point=pos, max_dist=radius) - - # Keep track of the number of neighbors found - neighbors_found = wp.int32(0) - - # loop through neighbors to compute density - for index in neighbors: - # Check if outside the radius - pos2 = points2[index] - if wp.length(pos - pos2) > radius: - continue - - # Add neighbor to the list - mapping[0, tid, neighbors_found] = index - - # Increment the number of neighbors found - neighbors_found += 1 - - # Break if we have found enough neighbors - if neighbors_found == k: - num_neighbors[0, tid] = k - break - - # Set the number of neighbors - num_neighbors[0, tid] = neighbors_found - - -@wp.kernel -def sparse_ball_query( - points2: wp.array(dtype=wp.vec3), - mapping: wp.array3d(dtype=wp.int32), - num_neighbors: wp.array2d(dtype=wp.int32), - outputs: wp.array4d(dtype=wp.float32), -): - tid = wp.tid() - - # Get number of neighbors - k = num_neighbors[0, tid] - - # Loop through neighbors - for _k in range(k): - # Get point2 index - index = mapping[0, tid, _k] - - # Get position from points2 - pos = points2[index] - - # Set the output - outputs[0, tid, _k, 0] = pos[0] - outputs[0, tid, _k, 1] = pos[1] - outputs[0, tid, _k, 2] = pos[2] - + Warp-based ball-query layer for finding neighboring points within a specified radius. -def _ball_query_forward_primitive_( - points1: torch.Tensor, - points2: torch.Tensor, - k: int, - radius: float, - hash_grid: wp.HashGrid, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Create output tensors: - mapping = torch.zeros( - (1, points1.shape[0], k), - dtype=torch.int32, - device=points1.device, - requires_grad=False, - ) - num_neighbors = torch.zeros( - (1, points1.shape[0]), - dtype=torch.int32, - device=points1.device, - requires_grad=False, - ) - outputs = torch.zeros( - (1, points1.shape[0], k, 3), - dtype=torch.float32, - device=points1.device, - requires_grad=(points1.requires_grad or points2.requires_grad), - ) + This layer uses an accelerated ball query implementation to efficiently find points + within a specified radius of query points. - # Convert from torch to warp - points1 = wp.from_torch(points1, dtype=wp.vec3, requires_grad=points1.requires_grad) - points2 = wp.from_torch(points2, dtype=wp.vec3, requires_grad=points2.requires_grad) - - wp_mapping = wp.from_torch(mapping, dtype=wp.int32, requires_grad=False) - wp_num_neighbors = wp.from_torch(num_neighbors, dtype=wp.int32, requires_grad=False) - wp_outputs = wp.from_torch( - outputs, - dtype=wp.float32, - requires_grad=(points1.requires_grad or points2.requires_grad), - ) - - # Build the grid - hash_grid.build(points2, radius) - - # Run the kernel to get mapping - wp.launch( - ball_query, - inputs=[ - points1, - points2, - hash_grid.id, - k, - radius, - ], - outputs=[ - wp_mapping, - wp_num_neighbors, - ], - dim=[points1.shape[0]], - ) - - # Run the kernel to get outputs - wp.launch( - sparse_ball_query, - inputs=[ - points2, - wp_mapping, - wp_num_neighbors, - ], - outputs=[ - wp_outputs, - ], - dim=[points1.shape[0]], - ) - - return mapping, num_neighbors, outputs - - -def _ball_query_backward_primitive_( - points1, - points2, - mapping, - num_neighbors, - outputs, - grad_mapping, - grad_num_neighbors, - grad_outputs, -) -> Tuple[torch.Tensor, torch.Tensor]: - p2_grad = torch.zeros_like(points2) - - # Run the kernel in adjoint mode - wp.launch( - sparse_ball_query, - inputs=[ - wp.from_torch(points2, dtype=wp.vec3, requires_grad=points2.requires_grad), - wp.from_torch(mapping, dtype=wp.int32, requires_grad=False), - wp.from_torch(num_neighbors, dtype=wp.int32, requires_grad=False), - ], - outputs=[ - wp.from_torch(outputs, dtype=wp.float32, requires_grad=False), - ], - adj_inputs=[ - wp.from_torch(p2_grad, dtype=wp.vec3, requires_grad=points2.requires_grad), - wp.from_torch( - grad_mapping, dtype=wp.int32, requires_grad=mapping.requires_grad - ), - wp.from_torch( - grad_num_neighbors, - dtype=wp.int32, - requires_grad=num_neighbors.requires_grad, - ), - ], - adj_outputs=[ - wp.from_torch(grad_outputs, dtype=wp.float32), - ], - dim=[points1.shape[0]], - adjoint=True, - ) - - return p2_grad - - -class BallQuery(torch.autograd.Function): + Only supports batch size 1. """ - Warp based Ball Query. - - Note: only differentiable with respect to points1 and points2. - """ - - @staticmethod - def forward( - ctx, - points1: torch.Tensor, - points2: torch.Tensor, - k: int, - radius: float, - hash_grid: wp.HashGrid, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Only works for batch size 1 - if points1.shape[0] != 1: - raise AssertionError("Ball Query only works for batch size 1") - - # CJA - 5/15/25 - This was added recently, but it looks like I also - # addressed it. The primitive functions below handle device selection - # via compute-follows-data: they will allocate new tensors on the device - # where points1 currently resides (forward) and points2 resides (backward). - # there isn't checking that the devices match, but it will crash if they do not. - # try: - # device = str(wp.get_device()) - # except Exception: - # device = "cuda" - - ctx.k = k - ctx.radius = radius - - # Make grid - ctx.hash_grid = hash_grid - # Apply the primitive. Note the batch index is removed. - mapping, num_neighbors, outputs = _ball_query_forward_primitive_( - points1[0], - points2[0], - k, - radius, - hash_grid, - ) - ctx.save_for_backward(points1, points2, mapping, num_neighbors, outputs) - - return mapping, num_neighbors, outputs - - @staticmethod - def backward(ctx, grad_mapping, grad_num_neighbors, grad_outputs): - points1, points2, mapping, num_neighbors, outputs = ctx.saved_tensors - # Apply the primitive - p2_grad = _ball_query_backward_primitive_( - points1[0], - points2[0], - mapping, - num_neighbors, - outputs, - grad_mapping, - grad_num_neighbors, - grad_outputs, - ) - p2_grad = p2_grad.unsqueeze(0) - - # Return the gradients - return ( - torch.zeros_like(points1), - p2_grad, - None, - None, - None, - ) - - -def ball_query_layer( - points1: torch.Tensor, - points2: torch.Tensor, - k: int, - radius: float, - hash_grid: wp.HashGrid, -): - """ - Wrapper for BallQuery.apply to support a functional interface. - """ - if has_torch_function((points1, points2)): - return handle_torch_function( - ball_query_layer, (points1, points2), points1, points2, k, radius, hash_grid - ) - return BallQuery.apply(points1, points2, k, radius, hash_grid) - - -class BallQueryLayer(torch.nn.Module): - """ - Torch layer for differentiable and accelerated Ball Query - operation using Warp. - Args: - k (int): Number of neighbors. - radius (float): Radius of influence. - grid_size (int): Resolution of the hash grid. (Assumed to be uniform in all dimensions.) - """ + def __init__( + self, + radius: float = 0.25, + neighbors_in_radius: int | None = 10, + ): + """ + Initialize the BQWarp layer. - def __init__(self, k: int, radius: float, grid_size: int = 32): + Args: + radius: Radius for ball query operation + neighbors_in_radius: Maximum number of neighbors to return within radius. If None, all neighbors will be returned. + """ super().__init__() - wp.init() - self.k = k + self.radius = radius - self.hash_grid = wp.HashGrid(grid_size, grid_size, grid_size) + self.neighbors_in_radius = neighbors_in_radius def forward( - self, points1: torch.Tensor, points2: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, x: torch.Tensor, p_grid: torch.Tensor, reverse_mapping: bool = True + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Performs ball query operation to find neighboring points within a specified radius. + Performs ball query operation to find neighboring points and their features. - For each point in points1, finds up to k neighboring points from points2 that are - within the specified radius. Uses a hash grid for efficient spatial queries. + This method uses the Warp-accelerated ball query implementation to find points + within a specified radius. It can operate in two modes: + - Forward mapping: Find points from x that are near p_grid points (reverse_mapping=False) + - Reverse mapping: Find points from p_grid that are near x points (reverse_mapping=True) Args: - points1: Tensor of shape (batch_size, num_points1, 3) containing query points - points2: Tensor of shape (batch_size, num_points2, 3) containing points to search + x: Tensor of shape (batch_size, num_points, 3+features) containing point coordinates + and their features + p_grid: Tensor of shape (batch_size, grid_x, grid_y, grid_z, 3) containing grid point + coordinates + reverse_mapping: Boolean flag to control the direction of the mapping: + - True: Find p_grid points near x points + - False: Find x points near p_grid points Returns: tuple containing: - mapping: Tensor containing indices of neighboring points - - num_neighbors: Tensor containing the number of neighbors found for each query point - - outputs: Tensor containing features or coordinates of the neighboring points + - outputs: Tensor containing coordinates of the neighboring points """ - return ball_query_layer( - points1, - points2, - self.k, - self.radius, - self.hash_grid, - ) - - -if __name__ == "__main__": - # Make function for saving point clouds - import pyvista as pv - - from physicsnemo.utils.neighbors import radius_search - - radius_search = torch.compile(radius_search) - - torch.random.manual_seed(0) - torch.cuda.manual_seed(0) - - def save_point_cloud(points, name): - cloud = pv.PolyData(points.detach().cpu().numpy()) - cloud.save(name) - - # Check forward pass - # Initialize tensors - n = 1 # number of point clouds - p1 = 1600_000 # 100000 # number of points in point cloud 1 - d = 3 # dimension of the points - p2 = 1600_000 # 100000 # number of points in point cloud 2 - points1 = torch.rand(n, p1, d, device="cuda", requires_grad=False) - - points2 = torch.rand(n, p2, d, device="cuda", requires_grad=True) - k = 256 # maximum number of neighbors - radius = 0.1 - - # Make ball query layer - layer = BallQueryLayer(k, radius) - - # Make ball query - - for i in range(5): - mapping, num_neighbors, outputs = layer( - points1, - points2, - ) - indices, points = radius_search( - points=points2[0], - queries=points1[0], - radius=radius, - max_points=k, - return_dists=False, - return_points=True, - ) - - # sorted_bq_indices = torch.sort(mapping[0][0]).values - # sorted_rs_indices = torch.sort(indices[0]).values - - # print(sorted_bq_indices - sorted_rs_indices) - # print(sorted_bq_indices) - # print(sorted_rs_indices) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - for i in range(25): - if i == 5: - start_event.record() - mapping, num_neighbors, outputs = layer( - points1, - points2, - ) - end_event.record() - torch.cuda.synchronize() - print( - f"Ball Query Time taken: {start_event.elapsed_time(end_event) / 20} ms per iteration" - ) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - for i in range(25): - if i == 5: - torch.cuda.synchronize() - start_event.record() - indices, points = radius_search( - points=points2[0], - queries=points1[0], - radius=radius, - max_points=k, - return_dists=False, - return_points=True, - ) - end_event.record() - torch.cuda.synchronize() - print( - f"Radius Search Time taken: {start_event.elapsed_time(end_event) / 20} ms per iteration" - ) - - # Optimize the background points to move to the query points - optimizer = torch.optim.SGD([points2], 0.00) - - # Test optimization - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - target = points1.unsqueeze(2).clone().detach() - for i in range(25): - if i == 5: - start_event.record() - optimizer.zero_grad() - # mapping, num_neighbors, outputs = layer(points1, points2, lengths1, lengths2) - mapping, num_neighbors, outputs = layer(points1, points2) - # print(mapping[0][3]) - # print(torch.where(mapping == 1)) - loss = (points1.unsqueeze(2) - outputs).pow(2).sum() - loss.backward() - # print(f"ball query Points1 grad: {points1.grad}") - optimizer.step() - optimizer.zero_grad() - - end_event.record() - torch.cuda.synchronize() - print( - f"Ball Query + backwards Time taken: {start_event.elapsed_time(end_event) / 20} ms per iteration" - ) - - # Optimize the background points to move to the query points - optimizer = torch.optim.SGD( - [points2], 0.00 - ) # Setting the LR to 0.0 ensures the same gradients each time - - # Test optimization - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - for i in range(25): - if i == 5: - start_event.record() - optimizer.zero_grad() - # mapping, num_neighbors, outputs = layer(points1, points2, lengths1, lengths2) - indexes, points = radius_search( - points=points2[0], - queries=points1[0], - radius=radius, - max_points=k, - return_dists=False, - return_points=True, - ) - loss = (target - points).pow(2).sum() - loss.backward() - optimizer.step() - # print(f"radius search Points1 grad: {points1.grad}") - optimizer.zero_grad() - end_event.record() - torch.cuda.synchronize() - print( - f"radius search + backwards Time taken: {start_event.elapsed_time(end_event) / 20} ms per iteration" - ) + if x.shape[0] != 1 or p_grid.shape[0] != 1: + raise ValueError("BQWarp only supports batch size 1") + + if p_grid.shape[-1] != x.shape[-1] or x.shape[-1] != 3: + raise ValueError("The last dimension of p_grid and x must be 3") + + if p_grid.ndim != 3: + if p_grid.ndim == 4: + p_grid = rearrange(p_grid, "b nx ny c -> b (nx ny) c") + elif p_grid.ndim == 5: + p_grid = rearrange(p_grid, "b nx ny nz c -> b (nx ny nz) c") + else: + raise ValueError("p_grid must be 3D, 4D, 5D only") + + if reverse_mapping: + mapping, outputs = radius_search( + x[0], + p_grid[0], + self.radius, + self.neighbors_in_radius, + return_points=True, + ) + mapping = mapping.unsqueeze(0) + outputs = outputs.unsqueeze(0) + else: + mapping, outputs = radius_search( + p_grid[0], + x[0], + self.radius, + self.neighbors_in_radius, + return_points=True, + ) + mapping = mapping.unsqueeze(0) + outputs = outputs.unsqueeze(0) + + return mapping, outputs diff --git a/physicsnemo/models/layers/fourier_layers.py b/physicsnemo/models/layers/fourier_layers.py index 35cb4d81a1..ba7db24a68 100644 --- a/physicsnemo/models/layers/fourier_layers.py +++ b/physicsnemo/models/layers/fourier_layers.py @@ -21,6 +21,86 @@ import torch.nn as nn from torch import Tensor +from .mlp_layers import Mlp + + +def fourier_encode(coords: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Vectorized Fourier feature encoding + + Args: + coords: Tensor containing coordinates, of shape (batch_size, D) + freqs: Tensor containing frequencies, of shape (F,) (num frequencies) + + Returns: + Tensor containing Fourier features, of shape (batch_size, D * 2 * F) + """ + + D = coords.shape[-1] + F = freqs.shape[0] + + freqs = freqs[None, None, :, None] # reshape to [*, F, 1] for broadcasting + + coords = coords.unsqueeze(-2) # [*, 1, D] + scaled = (coords * freqs).reshape(*coords.shape[:-2], D * F) # [*, D, F] + features = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1) # [*, D, 2F] + + return features.reshape(*coords.shape[:-2], D * 2 * F) # [*, D * 2F] + + +class FourierMLP(nn.Module): + """ + This is an MLP that will, optionally, fourier encode the input features. + + The encoded features are concatenated to the original inputs, and then + processed with an MLP. + + Args: + input_features: The number of input features to the MLP. + base_layer: The number of neurons in the hidden layer of the MLP. + fourier_features: Whether to fourier encode the input features. + num_modes: The number of modes to use for the fourier encoding. + activation: The activation function to use in the MLP. + + """ + + def __init__( + self, + input_features: int, + base_layer: int, + fourier_features: bool, + num_modes: int, + activation: nn.Module | str, + ): + super().__init__() + self.fourier_features = fourier_features + + # self.num_modes = model_parameters.num_modes + + if self.fourier_features: + input_features_calculated = input_features + input_features * num_modes * 2 + self.register_buffer( + "freqs", torch.exp(torch.linspace(0, math.pi, num_modes)) + ) + else: + input_features_calculated = input_features + + self.mlp = Mlp( + in_features=input_features_calculated, + hidden_features=[ + base_layer, + base_layer, + ], + out_features=base_layer, + act_layer=activation, + drop=0.0, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.fourier_features: + x = torch.cat((x, fourier_encode(x, self.freqs)), dim=-1) + + return self.mlp(x) + class FourierLayer(nn.Module): """Fourier layer used in the Fourier feature network""" diff --git a/physicsnemo/models/layers/mlp_layers.py b/physicsnemo/models/layers/mlp_layers.py index 8e9a18858b..5c8c3348a3 100644 --- a/physicsnemo/models/layers/mlp_layers.py +++ b/physicsnemo/models/layers/mlp_layers.py @@ -17,28 +17,58 @@ import torch from torch import nn +from .activations import get_activation + class Mlp(nn.Module): def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.0, + in_features: int, + hidden_features: int | list[int] | None = None, + out_features: int | None = None, + act_layer: nn.Module | str = nn.GELU, + drop: float = 0.0, ): super().__init__() out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) + if isinstance(hidden_features, int): + hidden_features = [ + hidden_features, + ] + elif hidden_features is None: + hidden_features = [ + in_features, + ] + + # If the activation is a string, get it. + # If it's a type, instantiate it. + # If it's a module, leave it be. + if isinstance(act_layer, str): + act_layer = get_activation(act_layer) + elif isinstance(act_layer, nn.Module): + pass + else: + act_layer = act_layer() + if not isinstance(act_layer, nn.Module): + raise ValueError( + f"Activation layer must be a string or a module, got {type(act_layer)}" + ) + + layers = [] + input_dim = in_features + for hidden_dim in hidden_features: + layers.append(nn.Linear(input_dim, hidden_dim)) + layers.append(act_layer) + if drop != 0: + layers.append(nn.Dropout(drop)) + input_dim = hidden_dim + + # Add the last layers: + layers.append(nn.Linear(input_dim, out_features)) + if drop != 0: + layers.append(nn.Dropout(drop)) + + self.layers = nn.Sequential(*layers) def forward(self, x: torch.Tensor): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x + return self.layers(x) diff --git a/physicsnemo/utils/domino/__init__.py b/physicsnemo/utils/domino/__init__.py new file mode 100644 index 0000000000..b2f171d4ac --- /dev/null +++ b/physicsnemo/utils/domino/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/physicsnemo/utils/domino/utils.py b/physicsnemo/utils/domino/utils.py index 12ebf7cba3..7f67f36e6c 100644 --- a/physicsnemo/utils/domino/utils.py +++ b/physicsnemo/utils/domino/utils.py @@ -19,93 +19,57 @@ This module provides essential utilities for computational fluid dynamics data processing, mesh manipulation, field normalization, and geometric computations. It supports both -CPU (NumPy) and GPU (CuPy) operations with automatic fallbacks. +torch.Tensor operations on either CPU or GPU. """ from pathlib import Path from typing import Any, Sequence -import numpy as np -import vtk -from scipy.spatial import KDTree -from vtk import vtkDataSetTriangleFilter -from vtk.util import numpy_support +import torch -from physicsnemo.utils.profiling import profile +from physicsnemo.utils.neighbors import knn -# Type alias for arrays that can be either NumPy or CuPy -try: - import cupy as cp - - ArrayType = np.ndarray | cp.ndarray -except ImportError: - ArrayType = np.ndarray - - -def array_type(array: ArrayType) -> "type[np] | type[cp]": - """Determine the array module (NumPy or CuPy) for the given array. - - This function enables array-agnostic code by returning the appropriate - array module that can be used for operations on the input array. - - Args: - array: Input array that can be either NumPy or CuPy array. - - Returns: - The array module (numpy or cupy) corresponding to the input array type. - - Examples: - >>> import numpy as np - >>> arr = np.array([1, 2, 3]) - >>> xp = array_type(arr) - >>> result = xp.sum(arr) # Uses numpy.sum - """ - try: - import cupy as cp - - return cp.get_array_module(array) - except ImportError: - return np - - -def calculate_center_of_mass(centers: ArrayType, sizes: ArrayType) -> ArrayType: +def calculate_center_of_mass( + centers: torch.Tensor, sizes: torch.Tensor +) -> torch.Tensor: """Calculate the center of mass for a collection of elements. Computes the volume-weighted centroid of mesh elements, commonly used in computational fluid dynamics for mesh analysis and load balancing. Args: - centers: Array of shape (n_elements, 3) containing the centroid + centers: torch.Tensor of shape (n_elements, 3) containing the centroid coordinates of each element. - sizes: Array of shape (n_elements,) containing the volume + sizes: torch.Tensor of shape (n_elements,) containing the volume or area of each element used as weights. Returns: - Array of shape (1, 3) containing the x, y, z coordinates of the center of mass. + torch.Tensor of shape (1, 3) containing the x, y, z coordinates of the center of mass. Raises: ValueError: If centers and sizes have incompatible shapes. Examples: - >>> import numpy as np - >>> centers = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) - >>> sizes = np.array([1.0, 2.0, 3.0]) + >>> import torch + >>> centers = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) + >>> sizes = torch.tensor([1.0, 2.0, 3.0]) >>> com = calculate_center_of_mass(centers, sizes) - >>> np.allclose(com, [[4.0/3.0, 4.0/3.0, 4.0/3.0]]) + >>> torch.allclose(com, torch.tensor([[4.0/3.0, 4.0/3.0, 4.0/3.0]])) True """ - xp = array_type(centers) - total_weighted_position = xp.einsum("i,ij->j", sizes, centers) - total_size = xp.sum(sizes) + total_weighted_position = torch.einsum("i,ij->j", sizes, centers) + total_size = torch.sum(sizes) return total_weighted_position[None, ...] / total_size def normalize( - field: ArrayType, max_val: ArrayType | None = None, min_val: ArrayType | None = None -) -> ArrayType: + field: torch.Tensor, + max_val: float | torch.Tensor | None = None, + min_val: float | torch.Tensor | None = None, +) -> torch.Tensor: """Normalize field values to the range [-1, 1]. Applies min-max normalization to scale field values to a symmetric range @@ -113,7 +77,7 @@ def normalize( ensure numerical stability and faster convergence. Args: - field: Input field array to be normalized. + field: Input field tensor to be normalized. max_val: Maximum values for normalization, can be scalar or array. If None, computed from the field data. min_val: Minimum values for normalization, can be scalar or array. @@ -126,30 +90,31 @@ def normalize( ZeroDivisionError: If max_val equals min_val (zero range). Examples: - >>> import numpy as np - >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> import torch + >>> field = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) >>> normalized = normalize(field, 5.0, 1.0) - >>> np.allclose(normalized, [-1.0, -0.5, 0.0, 0.5, 1.0]) + >>> torch.allclose(normalized, torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0])) True >>> # Auto-compute min/max >>> normalized_auto = normalize(field) - >>> np.allclose(normalized_auto, [-1.0, -0.5, 0.0, 0.5, 1.0]) + >>> torch.allclose(normalized_auto, torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0])) True """ - xp = array_type(field) if max_val is None: - max_val = xp.max(field, axis=0, keepdims=True) + max_val, _ = field.max(axis=0, keepdim=True) if min_val is None: - min_val = xp.min(field, axis=0, keepdims=True) + min_val, _ = field.min(axis=0, keepdim=True) field_range = max_val - min_val return 2.0 * (field - min_val) / field_range - 1.0 def unnormalize( - normalized_field: ArrayType, max_val: ArrayType, min_val: ArrayType -) -> ArrayType: + normalized_field: torch.Tensor, + max_val: float | torch.Tensor, + min_val: float | torch.Tensor, +) -> torch.Tensor: """Reverse the normalization process to recover original field values. Transforms normalized values from the range [-1, 1] back to their original @@ -164,10 +129,12 @@ def unnormalize( Field values restored to their original physical range. Examples: - >>> import numpy as np - >>> normalized = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) - >>> original = unnormalize(normalized, 5.0, 1.0) - >>> np.allclose(original, [1.0, 2.0, 3.0, 4.0, 5.0]) + >>> import torch + >>> normalized = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0]) + >>> max_val = torch.tensor(5.0) + >>> min_val = torch.tensor(1.0) + >>> original = unnormalize(normalized, max_val, min_val) + >>> torch.allclose(original, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])) True """ field_range = max_val - min_val @@ -175,8 +142,10 @@ def unnormalize( def standardize( - field: ArrayType, mean: ArrayType | None = None, std: ArrayType | None = None -) -> ArrayType: + field: torch.Tensor, + mean: float | torch.Tensor | None = None, + std: float | torch.Tensor | None = None, +) -> torch.Tensor: """Standardize field values to have zero mean and unit variance. Applies z-score normalization to center the data around zero with @@ -184,7 +153,7 @@ def standardize( when the data follows a normal distribution. Args: - field: Input field array to be standardized. + field: Input field tensor to be standardized. mean: Mean values for standardization. If None, computed from field data. std: Standard deviation values for standardization. If None, computed from field data. @@ -195,31 +164,34 @@ def standardize( ZeroDivisionError: If std contains zeros. Examples: - >>> import numpy as np - >>> field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) - >>> standardized = standardize(field, 3.0, np.sqrt(2.5)) - >>> np.allclose(standardized, [-1.265, -0.632, 0.0, 0.632, 1.265], atol=1e-3) + >>> import torch + >>> field = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> mean = torch.tensor(3.0) + >>> std = torch.sqrt(torch.tensor(2.5)) + >>> standardized = standardize(field, mean, std) + >>> torch.allclose(standardized, torch.tensor([-1.265, -0.632, 0.0, 0.632, 1.265]), atol=1e-3) True >>> # Auto-compute mean/std >>> standardized_auto = standardize(field) - >>> np.allclose(np.mean(standardized_auto), 0.0) + >>> torch.allclose(torch.mean(standardized_auto), torch.tensor(0.0)) True - >>> np.allclose(np.std(standardized_auto, ddof=0), 1.0) + >>> torch.allclose(torch.std(standardized_auto), torch.tensor(1.0)) True """ - xp = array_type(field) if mean is None: - mean = xp.mean(field, axis=0, keepdims=True) + mean = field.mean(axis=0, keepdim=True) if std is None: - std = xp.std(field, axis=0, keepdims=True) + std = field.std(axis=0, keepdim=True) return (field - mean) / std def unstandardize( - standardized_field: ArrayType, mean: ArrayType, std: ArrayType -) -> ArrayType: + standardized_field: torch.Tensor, + mean: float | torch.Tensor, + std: float | torch.Tensor, +) -> torch.Tensor: """Reverse the standardization process to recover original field values. Transforms standardized values (zero mean, unit variance) back to their @@ -234,370 +206,22 @@ def unstandardize( Field values restored to their original distribution. Examples: - >>> import numpy as np - >>> standardized = np.array([-1.265, -0.632, 0.0, 0.632, 1.265]) - >>> original = unstandardize(standardized, 3.0, np.sqrt(2.5)) - >>> np.allclose(original, [1.0, 2.0, 3.0, 4.0, 5.0], atol=1e-3) + >>> import torch + >>> standardized = torch.tensor([-1.265, -0.632, 0.0, 0.632, 1.265]) + >>> mean = torch.tensor(3.0) + >>> std = torch.sqrt(torch.tensor(2.5)) + >>> original = unstandardize(standardized, mean, std) + >>> torch.allclose(original, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]), atol=1e-3) True """ return standardized_field * std + mean -def write_to_vtp(polydata: "vtk.vtkPolyData", filename: str) -> None: - """Write VTK polydata to a VTP (VTK PolyData) file format. - - VTP files are XML-based and store polygonal data including points, polygons, - and associated field data. This format is commonly used for surface meshes - in computational fluid dynamics visualization. - - Args: - polydata: VTK polydata object containing mesh geometry and fields. - filename: Output filename with .vtp extension. Directory will be created - if it doesn't exist. - - Raises: - RuntimeError: If writing fails due to file permissions or disk space. - - """ - # Ensure output directory exists - output_path = Path(filename) - output_path.parent.mkdir(parents=True, exist_ok=True) - - writer = vtk.vtkXMLPolyDataWriter() - writer.SetFileName(str(output_path)) - writer.SetInputData(polydata) - - if not writer.Write(): - raise RuntimeError(f"Failed to write polydata to {output_path}") - - -def write_to_vtu(unstructured_grid: "vtk.vtkUnstructuredGrid", filename: str) -> None: - """Write VTK unstructured grid to a VTU (VTK Unstructured Grid) file format. - - VTU files store 3D volumetric meshes with arbitrary cell types including - tetrahedra, hexahedra, and pyramids. This format is essential for storing - finite element analysis results. - - Args: - unstructured_grid: VTK unstructured grid object containing volumetric mesh - geometry and field data. - filename: Output filename with .vtu extension. Directory will be created - if it doesn't exist. - - Raises: - RuntimeError: If writing fails due to file permissions or disk space. - - """ - # Ensure output directory exists - output_path = Path(filename) - output_path.parent.mkdir(parents=True, exist_ok=True) - - writer = vtk.vtkXMLUnstructuredGridWriter() - writer.SetFileName(str(output_path)) - writer.SetInputData(unstructured_grid) - - if not writer.Write(): - raise RuntimeError(f"Failed to write unstructured grid to {output_path}") - - -def extract_surface_triangles(tetrahedral_mesh: "vtk.vtkUnstructuredGrid") -> list[int]: - """Extract surface triangle indices from a tetrahedral mesh. - - This function identifies the boundary faces of a 3D tetrahedral mesh and - returns the vertex indices that form triangular faces on the surface. - This is essential for visualization and boundary condition application. - - Args: - tetrahedral_mesh: VTK unstructured grid containing tetrahedral elements. - - Returns: - List of vertex indices forming surface triangles. Every three consecutive - indices define one triangle. - - Raises: - NotImplementedError: If the surface contains non-triangular faces. - - """ - # Extract the surface using VTK filter - surface_filter = vtk.vtkDataSetSurfaceFilter() - surface_filter.SetInputData(tetrahedral_mesh) - surface_filter.Update() - - # Wrap with PyVista for easier manipulation - import pyvista as pv - - surface_mesh = pv.wrap(surface_filter.GetOutput()) - triangle_indices = [] - - # Process faces - PyVista stores faces as [n_vertices, v1, v2, ..., vn] - faces = surface_mesh.faces.reshape((-1, 4)) - for face in faces: - if face[0] == 3: # Triangle (3 vertices) - triangle_indices.extend([face[1], face[2], face[3]]) - else: - raise NotImplementedError( - f"Non-triangular face found with {face[0]} vertices" - ) - - return triangle_indices - - -def convert_to_tet_mesh(polydata: "vtk.vtkPolyData") -> "vtk.vtkUnstructuredGrid": - """Convert surface polydata to a tetrahedral volumetric mesh. - - This function performs tetrahedralization of a surface mesh, creating - a 3D volumetric mesh suitable for finite element analysis. The process - fills the interior of the surface with tetrahedral elements. - - Args: - polydata: VTK polydata representing a closed surface mesh. - - Returns: - VTK unstructured grid containing tetrahedral elements filling the - volume enclosed by the input surface. - - Raises: - RuntimeError: If tetrahedralization fails (e.g., non-manifold surface). - - """ - tetrahedral_filter = vtkDataSetTriangleFilter() - tetrahedral_filter.SetInputData(polydata) - tetrahedral_filter.Update() - - tetrahedral_mesh = tetrahedral_filter.GetOutput() - return tetrahedral_mesh - - -def convert_point_data_to_cell_data(input_data: "vtk.vtkDataSet") -> "vtk.vtkDataSet": - """Convert point-based field data to cell-based field data. - - This function transforms field variables defined at mesh vertices (nodes) - to values defined at cell centers. This conversion is often needed when - switching between different numerical methods or visualization requirements. - - Args: - input_data: VTK dataset with point data to be converted. - - Returns: - VTK dataset with the same geometry but field data moved from points to cells. - Values are typically averaged from the surrounding points. - - """ - point_to_cell_filter = vtk.vtkPointDataToCellData() - point_to_cell_filter.SetInputData(input_data) - point_to_cell_filter.Update() - - return point_to_cell_filter.GetOutput() - - -def get_node_to_elem(polydata: "vtk.vtkDataSet") -> "vtk.vtkDataSet": - """Convert point data to cell data for VTK dataset. - - This function transforms field variables defined at mesh vertices to - values defined at cell centers using VTK's built-in conversion filter. - - Args: - polydata: VTK dataset with point data to be converted. - - Returns: - VTK dataset with field data moved from points to cells. - - """ - point_to_cell_filter = vtk.vtkPointDataToCellData() - point_to_cell_filter.SetInputData(polydata) - point_to_cell_filter.Update() - cell_data = point_to_cell_filter.GetOutput() - return cell_data - - -def get_fields_from_cell( - cell_data: "vtk.vtkCellData", variable_names: list[str] -) -> np.ndarray: - """Extract field variables from VTK cell data. - - This function extracts multiple field variables from VTK cell data and - organizes them into a structured NumPy array. Each variable becomes a - column in the output array. - - Args: - cell_data: VTK cell data object containing field variables. - variable_names: List of variable names to extract from the cell data. - - Returns: - NumPy array of shape (n_cells, n_variables) containing the extracted - field data. Variables are ordered according to the input list. - - Raises: - ValueError: If a requested variable name is not found in the cell data. - - """ - extracted_fields = [] - for variable_name in variable_names: - variable_array = cell_data.GetArray(variable_name) - if variable_array is None: - raise ValueError(f"Variable '{variable_name}' not found in cell data") - - num_tuples = variable_array.GetNumberOfTuples() - field_values = [] - for tuple_idx in range(num_tuples): - variable_value = np.array(variable_array.GetTuple(tuple_idx)) - field_values.append(variable_value) - field_values = np.asarray(field_values) - extracted_fields.append(field_values) - - # Transpose to get shape (n_cells, n_variables) - extracted_fields = np.transpose(np.asarray(extracted_fields), (1, 0)) - return extracted_fields - - -def get_fields( - data_attributes: "vtk.vtkDataSetAttributes", variable_names: list[str] -) -> list[np.ndarray]: - """Extract multiple field variables from VTK data attributes. - - This function extracts field variables from VTK data attributes (either - point data or cell data) and returns them as a list of NumPy arrays. - It handles both point and cell data seamlessly. - - Args: - data_attributes: VTK data attributes object (point data or cell data). - variable_names: List of variable names to extract. - - Returns: - List of NumPy arrays, one for each requested variable. Each array - has shape (n_points/n_cells, n_components) where n_components - depends on the variable (1 for scalars, 3 for vectors, etc.). - - Raises: - ValueError: If a requested variable is not found in the data attributes. - - """ - extracted_fields = [] - for variable_name in variable_names: - try: - vtk_array = data_attributes.GetArray(variable_name) - except ValueError as e: - raise ValueError( - f"Failed to get array '{variable_name}' from the data attributes: {e}" - ) - - # Convert VTK array to NumPy array with proper shape - numpy_array = numpy_support.vtk_to_numpy(vtk_array).reshape( - vtk_array.GetNumberOfTuples(), vtk_array.GetNumberOfComponents() - ) - extracted_fields.append(numpy_array) - - return extracted_fields - - -def get_vertices(polydata: "vtk.vtkPolyData") -> np.ndarray: - """Extract vertex coordinates from VTK polydata object. - - This function converts VTK polydata to a NumPy array containing the 3D - coordinates of all vertices in the mesh. - - Args: - polydata: VTK polydata object containing mesh geometry. - - Returns: - NumPy array of shape (n_points, 3) containing [x, y, z] coordinates - for each vertex. - - """ - vtk_points = polydata.GetPoints() - vertices = numpy_support.vtk_to_numpy(vtk_points.GetData()) - return vertices - - -def get_volume_data( - polydata: "vtk.vtkPolyData", variable_names: list[str] -) -> tuple[np.ndarray, list[np.ndarray]]: - """Extract vertices and field data from 3D volumetric mesh. - - This function extracts both geometric information (vertex coordinates) - and field data from a 3D volumetric mesh. It's commonly used for - processing finite element analysis results. - - Args: - polydata: VTK polydata representing a 3D volumetric mesh. - variable_names: List of field variable names to extract. - - Returns: - Tuple containing: - - Vertex coordinates as NumPy array of shape (n_vertices, 3) - - List of field arrays, one per variable - - """ - vertices = get_vertices(polydata) - point_data = polydata.GetPointData() - fields = get_fields(point_data, variable_names) - - return vertices, fields - - -def get_surface_data( - polydata: "vtk.vtkPolyData", variable_names: list[str] -) -> tuple[np.ndarray, list[np.ndarray], list[tuple[int, int]]]: - """Extract surface mesh data including vertices, fields, and edge connectivity. - - This function extracts comprehensive surface mesh information including - vertex coordinates, field data at vertices, and edge connectivity information. - It's commonly used for processing CFD surface results and boundary conditions. - - Args: - polydata: VTK polydata representing a surface mesh. - variable_names: List of field variable names to extract from the mesh. - - Returns: - Tuple containing: - - Vertex coordinates as NumPy array of shape (n_vertices, 3) - - List of field arrays, one per variable - - List of edge tuples representing mesh connectivity - - Raises: - ValueError: If a requested variable is not found or polygon data is missing. - - """ - points = polydata.GetPoints() - vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) - - point_data = polydata.GetPointData() - fields = [] - for array_name in variable_names: - try: - array = point_data.GetArray(array_name) - except ValueError: - raise ValueError( - f"Failed to get array {array_name} from the unstructured grid." - ) - array_data = np.zeros( - (points.GetNumberOfPoints(), array.GetNumberOfComponents()) - ) - for j in range(points.GetNumberOfPoints()): - array.GetTuple(j, array_data[j]) - fields.append(array_data) - - polys = polydata.GetPolys() - if polys is None: - raise ValueError("Failed to get polygons from the polydata.") - polys.InitTraversal() - edges = [] - id_list = vtk.vtkIdList() - for _ in range(polys.GetNumberOfCells()): - polys.GetNextCell(id_list) - num_ids = id_list.GetNumberOfIds() - edges = [ - (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) - ] - - return vertices, fields, edges - - def calculate_normal_positional_encoding( - coordinates_a: ArrayType, - coordinates_b: ArrayType | None = None, + coordinates_a: torch.Tensor, + coordinates_b: torch.Tensor | None = None, cell_dimensions: Sequence[float] = (1.0, 1.0, 1.0), -) -> ArrayType: +) -> torch.Tensor: """Calculate sinusoidal positional encoding for 3D coordinates. This function computes transformer-style positional encodings for 3D spatial @@ -606,51 +230,50 @@ def calculate_normal_positional_encoding( unique representations for each spatial position. Args: - coordinates_a: Primary coordinates array of shape (n_points, 3). + coordinates_a: Primary coordinates tensor of shape (n_points, 3). coordinates_b: Optional secondary coordinates for computing relative positions. If provided, the encoding is computed for (coordinates_a - coordinates_b). cell_dimensions: Characteristic length scales for x, y, z dimensions used for normalization. Defaults to unit dimensions. Returns: - Array of shape (n_points, 12) containing positional encodings with + torch.Tensor of shape (n_points, 12) containing positional encodings with 4 encoding dimensions per spatial axis (x, y, z). Examples: - >>> import numpy as np - >>> coords = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + >>> import torch + >>> coords = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) >>> cell_size = [0.1, 0.1, 0.1] >>> encoding = calculate_normal_positional_encoding(coords, cell_dimensions=cell_size) >>> encoding.shape - (2, 12) + torch.Size([2, 12]) >>> # Relative positioning example - >>> coords_b = np.array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) + >>> coords_b = torch.tensor([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) >>> encoding_rel = calculate_normal_positional_encoding(coords, coords_b, cell_size) >>> encoding_rel.shape - (2, 12) + torch.Size([2, 12]) """ dx, dy, dz = cell_dimensions[0], cell_dimensions[1], cell_dimensions[2] - xp = array_type(coordinates_a) if coordinates_b is not None: normals = coordinates_a - coordinates_b - pos_x = xp.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) - pos_y = xp.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) - pos_z = xp.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) - pos_normals = xp.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) + pos_x = torch.cat(calculate_pos_encoding(normals[:, 0] / dx, d=4), dim=-1) + pos_y = torch.cat(calculate_pos_encoding(normals[:, 1] / dy, d=4), dim=-1) + pos_z = torch.cat(calculate_pos_encoding(normals[:, 2] / dz, d=4), dim=-1) + pos_normals = torch.cat((pos_x, pos_y, pos_z), dim=0).reshape(-1, 12) else: normals = coordinates_a - pos_x = xp.asarray(calculate_pos_encoding(normals[:, 0] / dx, d=4)) - pos_y = xp.asarray(calculate_pos_encoding(normals[:, 1] / dy, d=4)) - pos_z = xp.asarray(calculate_pos_encoding(normals[:, 2] / dz, d=4)) - pos_normals = xp.concatenate((pos_x, pos_y, pos_z), axis=0).reshape(-1, 12) + pos_x = torch.cat(calculate_pos_encoding(normals[:, 0] / dx, d=4), dim=-1) + pos_y = torch.cat(calculate_pos_encoding(normals[:, 1] / dy, d=4), dim=-1) + pos_z = torch.cat(calculate_pos_encoding(normals[:, 2] / dz, d=4), dim=-1) + pos_normals = torch.cat((pos_x, pos_y, pos_z), dim=0).reshape(-1, 12) return pos_normals def nd_interpolator( - coordinates: ArrayType, field: ArrayType, grid: ArrayType, k: int = 2 -) -> ArrayType: + coordinates: torch.Tensor, field: torch.Tensor, grid: torch.Tensor, k: int = 2 +) -> torch.Tensor: """Perform n-dimensional interpolation using k-nearest neighbors. This function interpolates field values from scattered points to a regular @@ -658,189 +281,256 @@ def nd_interpolator( fields on regular grids from irregular measurement points. Args: - coordinates: Array of shape (n_points, n_dims) containing source point coordinates. - field: Array of shape (n_points, n_fields) containing field values at source points. - grid: Array of shape (n_field_points, n_dims) containing target grid points for interpolation. + coordinates: torch.Tensor of shape (n_points, n_dims) containing source point coordinates. + field: torch.Tensor of shape (n_points, n_fields) containing field values at source points. + grid: torch.Tensor of shape (n_field_points, n_dims) containing target grid points for interpolation. k: Number of nearest neighbors to use for interpolation. Returns: Interpolated field values at grid points using k-nearest neighbor averaging. - Note: - This function currently uses SciPy's KDTree which only supports CPU arrays. - A future enhancement could add CuML support for GPU acceleration. Examples: - >>> import numpy as np + >>> import torch >>> # Simple 2D interpolation example - >>> coords = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) - >>> field_vals = np.array([[1.0], [2.0], [3.0], [4.0]]) - >>> grid_points = np.array([[0.5, 0.5]]) - >>> result = nd_interpolator([coords], field_vals, grid_points) + >>> coords = torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + >>> field_vals = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) + >>> grid_points = torch.tensor([[0.5, 0.5]]) + >>> result = nd_interpolator(coords, field_vals, grid_points) >>> result.shape[0] == 1 # One grid point True """ - # TODO - this function should get updated for cuml if using cupy. - kdtree = KDTree(coordinates[0]) - distances, neighbor_indices = kdtree.query(grid, k=k) + neighbor_indices, distances = knn(coordinates, grid, k=k) field_grid = field[neighbor_indices] - field_grid = np.mean(field_grid, axis=1) + field_grid = torch.mean(field_grid, dim=1) return field_grid -def pad(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: - """Pad 2D array with constant values to reach target size. +def pad(arr: torch.Tensor, n_points: int, pad_value: float = 0.0) -> torch.Tensor: + """Pad 2D tensor with constant values to reach target size. - This function extends a 2D array by adding rows filled with a constant - value. It's commonly used to standardize array sizes in batch processing + This function extends a 2D tensor by adding rows filled with a constant + value. It's commonly used to standardize tensor sizes in batch processing for machine learning applications. Args: - arr: Input array of shape (n_points, n_features) to be padded. + arr: Input tensor of shape (n_points, n_features) to be padded. n_points: Target number of points (rows) after padding. pad_value: Constant value used for padding. Defaults to 0.0. Returns: - Padded array of shape (n_points, n_features). If n_points <= arr.shape[0], - returns the original array unchanged. + Padded tensor of shape (n_points, n_features). If n_points <= arr.shape[0], + returns the original tensor unchanged. Examples: - >>> import numpy as np - >>> arr = np.array([[1.0, 2.0], [3.0, 4.0]]) + >>> import torch + >>> arr = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) >>> padded = pad(arr, 4, -1.0) >>> padded.shape - (4, 2) - >>> np.array_equal(padded[:2], arr) + torch.Size([4, 2]) + >>> torch.allclose(padded[:2], arr) True - >>> bool(np.all(padded[2:] == -1.0)) + >>> bool(torch.all(padded[2:] == -1.0)) True >>> # No padding needed >>> same = pad(arr, 2) - >>> np.array_equal(same, arr) + >>> torch.allclose(same, arr) True """ - xp = array_type(arr) + if n_points <= arr.shape[0]: return arr - arr_pad = pad_value * xp.ones( - (n_points - arr.shape[0], arr.shape[1]), dtype=xp.float32 + n_pad = n_points - arr.shape[0] + arr_padded = torch.nn.functional.pad( + arr, + ( + 0, + 0, + 0, + n_pad, + ), + mode="constant", + value=pad_value, ) - arr_padded = xp.concatenate((arr, arr_pad), axis=0) return arr_padded -def pad_inp(arr: ArrayType, n_points: int, pad_value: float = 0.0) -> ArrayType: - """Pad 3D array with constant values to reach target size. +def pad_inp(arr: torch.Tensor, n_points: int, pad_value: float = 0.0) -> torch.Tensor: + """Pad 3D tensor with constant values to reach target size. - This function extends a 3D array by adding entries along the first dimension + This function extends a 3D tensor by adding entries along the first dimension filled with a constant value. Used for standardizing 3D tensor sizes in batch processing workflows. Args: - arr: Input array of shape (n_points, height, width) to be padded. + arr: Input tensor of shape (n_points, height, width) to be padded. n_points: Target number of points along first dimension after padding. pad_value: Constant value used for padding. Defaults to 0.0. Returns: - Padded array of shape (n_points, height, width). If n_points <= arr.shape[0], - returns the original array unchanged. + Padded tensor of shape (n_points, height, width). If n_points <= arr.shape[0], + returns the original tensor unchanged. Examples: - >>> import numpy as np - >>> arr = np.array([[[1.0, 2.0]], [[3.0, 4.0]]]) + >>> import torch + >>> arr = torch.tensor([[[1.0, 2.0]], [[3.0, 4.0]]]) >>> padded = pad_inp(arr, 4, 0.0) >>> padded.shape - (4, 1, 2) - >>> np.array_equal(padded[:2], arr) + torch.Size([4, 1, 2]) + >>> torch.allclose(padded[:2], arr) True - >>> bool(np.all(padded[2:] == 0.0)) + >>> bool(torch.all(padded[2:] == 0.0)) True """ - xp = array_type(arr) if n_points <= arr.shape[0]: return arr - arr_pad = pad_value * xp.ones( - (n_points - arr.shape[0], arr.shape[1], arr.shape[2]), dtype=xp.float32 + n_pad = n_points - arr.shape[0] + arr_padded = torch.nn.functional.pad( + arr, + ( + 0, + 0, + 0, + 0, + 0, + n_pad, + ), + mode="constant", + value=pad_value, ) - arr_padded = xp.concatenate((arr, arr_pad), axis=0) return arr_padded -@profile def shuffle_array( - arr: ArrayType, + points: torch.Tensor, n_points: int, -) -> tuple[ArrayType, ArrayType]: - """Randomly sample points from array without replacement. + weights: torch.Tensor = None, +): + """ + Randomly sample points from tensor without replacement. - This function performs random sampling from the input array, selecting + This function performs random sampling from the input tensor, selecting n_points points without replacement. It's commonly used for creating training subsets and data augmentation in machine learning workflows. + Optionally, you can provide weights to use in the sampling. + + Note: the implementation with torch.multinomial is constrained to 2^24 points. + If the input is larger than that, it will be split and sampled from each chunk. + Args: - arr: Input array to sample from, shape (n_points, ...). + points: Input tensor to sample from, shape (n_points, ...). n_points: Number of points to sample. If greater than arr.shape[0], all points are returned. + weights: Optional weights for sampling. If None, uniform weights are used. Returns: Tuple containing: - - Sampled array subset + - Sampled tensor subset - Indices of the selected points Examples: - >>> import numpy as np - >>> np.random.seed(42) # For reproducible results - >>> data = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> import torch + >>> _ = torch.manual_seed(42) # For reproducible results + >>> data = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> subset, indices = shuffle_array(data, 2) >>> subset.shape - (2, 2) + torch.Size([2, 2]) >>> indices.shape - (2,) - >>> len(np.unique(indices)) == 2 # No duplicates + torch.Size([2]) + >>> len(torch.unique(indices)) == 2 # No duplicates True """ - xp = array_type(arr) - if n_points > arr.shape[0]: - # If asking too many points, truncate the ask but still shuffle. - n_points = arr.shape[0] - idx = np.random.choice(arr.shape[0], size=n_points, replace=False) - idx = xp.asarray(idx) - return arr[idx], idx + N_input_points = points.shape[0] + + if N_input_points < n_points: + return points, torch.arange(N_input_points) + + # If there are no weights, use uniform weights: + if weights is None: + weights = torch.ones(points.shape[0], device=points.device) + + # Using torch multinomial for this. + # Multinomial can't work with more than 2^24 input points. + + # So apply chunking and stich back together in that case. + # Assume each chunk gets a number proportional to it's size, + # (but make sure they add up to n_points!) + + max_chunk_size = 2**24 + + N_chunks = (N_input_points // max_chunk_size) + 1 + + # Divide the weights into these chunks + chunk_weights = torch.chunk(weights, N_chunks) + + # Determine how mant points to compute per chunk: + points_per_chunk = [ + round(n_points * c.shape[0] / N_input_points) for c in chunk_weights + ] -def shuffle_array_without_sampling(arr: ArrayType) -> tuple[ArrayType, ArrayType]: - """Shuffle array order without changing the number of elements. + gap = n_points - sum(points_per_chunk) - This function reorders all elements in the array randomly while preserving + if gap > 0: + for g in range(gap): + points_per_chunk[g] += 1 + elif gap < 0: + for g in range(-gap): + points_per_chunk[g] -= 1 + + # Create a list of indexes per chunk: + idx_chunks = [ + torch.multinomial( + w, + p, + replacement=False, + ) + for w, p in zip(chunk_weights, points_per_chunk) + ] + + # Stitch the chunks back together: + idx = torch.cat(idx_chunks) + + # Apply the selection: + points_selected = points[idx] + + return points_selected, idx + + +def shuffle_array_without_sampling( + arr: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Shuffle tensor order without changing the number of elements. + + This function reorders all elements in the tensor randomly while preserving all data points. It's useful for randomizing data order before training while maintaining the complete dataset. Args: - arr: Input array to shuffle, shape (n_points, ...). + arr: Input tensor to shuffle, shape (n_points, ...). Returns: Tuple containing: - - Shuffled array with same shape as input + - Shuffled tensor with same shape as input - Permutation indices used for shuffling Examples: - >>> import numpy as np - >>> np.random.seed(42) # For reproducible results - >>> data = np.array([[1], [2], [3], [4]]) + >>> import torch + >>> _ = torch.manual_seed(42) # For reproducible results + >>> data = torch.tensor([[1], [2], [3], [4]]) >>> shuffled, indices = shuffle_array_without_sampling(data) >>> shuffled.shape - (4, 1) + torch.Size([4, 1]) >>> indices.shape - (4,) - >>> set(indices) == set(range(4)) # All original indices present + torch.Size([4]) + >>> set(indices.tolist()) == set(range(4)) # All original indices present True """ - xp = array_type(arr) - idx = xp.arange(arr.shape[0]) - xp.random.shuffle(idx) + idx = torch.randperm(arr.shape[0]) return arr[idx], idx @@ -892,7 +582,7 @@ def get_filenames(filepath: str | Path, exclude_dirs: bool = False) -> list[str] return filenames -def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> list[ArrayType]: +def calculate_pos_encoding(nx: torch.Tensor, d: int = 8) -> list[torch.Tensor]: """Calculate sinusoidal positional encoding for transformer architectures. This function computes positional encodings using alternating sine and cosine @@ -904,12 +594,12 @@ def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> list[ArrayType]: d: Encoding dimensionality. Must be even number. Defaults to 8. Returns: - List of d arrays containing alternating sine and cosine encodings. + List of d tensors containing alternating sine and cosine encodings. Each pair (sin, cos) uses progressively lower frequencies. Examples: - >>> import numpy as np - >>> positions = np.array([0.0, 1.0, 2.0]) + >>> import torch + >>> positions = torch.tensor([0.0, 1.0, 2.0]) >>> encodings = calculate_pos_encoding(positions, d=4) >>> len(encodings) 4 @@ -917,10 +607,9 @@ def calculate_pos_encoding(nx: ArrayType, d: int = 8) -> list[ArrayType]: True """ vec = [] - xp = array_type(nx) for k in range(int(d / 2)): - vec.append(xp.sin(nx / 10000 ** (2 * k / d))) - vec.append(xp.cos(nx / 10000 ** (2 * k / d))) + vec.append(torch.sin(nx / 10000 ** (2 * k / d))) + vec.append(torch.cos(nx / 10000 ** (2 * k / d))) return vec @@ -957,8 +646,8 @@ def combine_dict(old_dict: dict[Any, Any], new_dict: dict[Any, Any]) -> dict[Any def create_grid( - max_coords: ArrayType, min_coords: ArrayType, resolution: ArrayType -) -> ArrayType: + max_coords: torch.Tensor, min_coords: torch.Tensor, resolution: torch.Tensor +) -> torch.Tensor: """Create a 3D regular grid from coordinate bounds and resolution. This function generates a regular 3D grid spanning from min_coords to @@ -971,46 +660,46 @@ def create_grid( resolution: Number of grid points [nx, ny, nz] in each dimension. Returns: - Grid array of shape (nx, ny, nz, 3) containing 3D coordinates for each + Grid tensor of shape (nx, ny, nz, 3) containing 3D coordinates for each grid point. The last dimension contains [x, y, z] coordinates. Examples: - >>> import numpy as np - >>> min_bounds = np.array([0.0, 0.0, 0.0]) - >>> max_bounds = np.array([1.0, 1.0, 1.0]) - >>> grid_res = np.array([2, 2, 2]) + >>> import torch + >>> min_bounds = torch.tensor([0.0, 0.0, 0.0]) + >>> max_bounds = torch.tensor([1.0, 1.0, 1.0]) + >>> grid_res = torch.tensor([2, 2, 2]) >>> grid = create_grid(max_bounds, min_bounds, grid_res) >>> grid.shape - (2, 2, 2, 3) - >>> np.allclose(grid[0, 0, 0], [0.0, 0.0, 0.0]) + torch.Size([2, 2, 2, 3]) + >>> torch.allclose(grid[0, 0, 0], torch.tensor([0.0, 0.0, 0.0])) True - >>> np.allclose(grid[1, 1, 1], [1.0, 1.0, 1.0]) + >>> torch.allclose(grid[1, 1, 1], torch.tensor([1.0, 1.0, 1.0])) True """ - xp = array_type(max_coords) - - dx = xp.linspace( - min_coords[0], max_coords[0], resolution[0], dtype=max_coords.dtype - ) - dy = xp.linspace( - min_coords[1], max_coords[1], resolution[1], dtype=max_coords.dtype - ) - dz = xp.linspace( - min_coords[2], max_coords[2], resolution[2], dtype=max_coords.dtype - ) + # Linspace to make evenly spaced steps along each axis: + dd = [ + torch.linspace( + min_coords[i], + max_coords[i], + resolution[i], + dtype=max_coords.dtype, + device=max_coords.device, + ) + for i in range(3) + ] - xv, yv, zv = xp.meshgrid(dx, dy, dz) - xv = xp.expand_dims(xv, -1) - yv = xp.expand_dims(yv, -1) - zv = xp.expand_dims(zv, -1) - grid = xp.concatenate((xv, yv, zv), axis=-1) - grid = xp.transpose(grid, (1, 0, 2, 3)) + # Combine them with meshgrid: + xv, yv, zv = torch.meshgrid(*dd, indexing="ij") + xv = xv.unsqueeze(-1) + yv = yv.unsqueeze(-1) + zv = zv.unsqueeze(-1) + grid = torch.concatenate((xv, yv, zv), axis=-1) return grid def mean_std_sampling( - field: ArrayType, mean: ArrayType, std: ArrayType, tolerance: float = 3.0 + field: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, tolerance: float = 3.0 ) -> list[int]: """Identify outlier points based on statistical distance from mean. @@ -1019,7 +708,7 @@ def mean_std_sampling( It's useful for data cleaning and identifying regions of interest in CFD data. Args: - field: Input field array of shape (n_points, n_components). + field: Input field tensor of shape (n_points, n_components). mean: Mean values for each field component, shape (n_components,). std: Standard deviation for each component, shape (n_components,). tolerance: Number of standard deviations to use as outlier threshold. @@ -1029,20 +718,20 @@ def mean_std_sampling( List of indices identifying outlier points that exceed the statistical threshold. Examples: - >>> import numpy as np + >>> import torch >>> # Create test data with outliers - >>> field = np.array([[1.0], [2.0], [3.0], [10.0]]) # 10.0 is outlier - >>> field_mean = np.array([2.0]) - >>> field_std = np.array([1.0]) + >>> field = torch.tensor([[1.0], [2.0], [3.0], [10.0]]) # 10.0 is outlier + >>> field_mean = torch.tensor([2.0]) + >>> field_std = torch.tensor([1.0]) >>> outliers = mean_std_sampling(field, field_mean, field_std, 2.0) >>> 3 in outliers # Index 3 (value 10.0) should be detected as outlier True """ - xp = array_type(field) + idx_all = [] for v in range(field.shape[-1]): fv = field[:, v] - idx = xp.where( + idx = torch.where( (fv > mean[v] + tolerance * std[v]) | (fv < mean[v] - tolerance * std[v]) ) if len(idx[0]) != 0: @@ -1086,16 +775,16 @@ def dict_to_device( def area_weighted_shuffle_array( - arr: ArrayType, n_points: int, area: ArrayType, area_factor: float = 1.0 -) -> tuple[ArrayType, ArrayType]: - """Perform area-weighted random sampling from array. + arr: torch.Tensor, n_points: int, area: torch.Tensor, area_factor: float = 1.0 +) -> tuple[torch.Tensor, torch.Tensor]: + """Perform area-weighted random sampling from tensor. - This function samples points from an array with probability proportional to + This function samples points from a tensor with probability proportional to their associated area weights. This is particularly useful in CFD applications where larger cells or surface elements should have higher sampling probability. Args: - arr: Input array to sample from, shape (n_points, ...). + arr: Input tensor to sample from, shape (n_points, ...). n_points: Number of points to sample. If greater than arr.shape[0], samples all available points. area: Area weights for each point, shape (n_points,). Larger values @@ -1106,64 +795,51 @@ def area_weighted_shuffle_array( Returns: Tuple containing: - - Sampled array subset weighted by area + - Sampled tensor subset weighted by area - Indices of the selected points Note: - For GPU arrays (CuPy), the sampling is performed on CPU due to memory - efficiency considerations. The Alias method could be implemented for - future GPU acceleration. + For GPU tensors, the sampling is performed on the current device. + The sampling uses torch.multinomial for efficient weighted sampling. Examples: - >>> import numpy as np - >>> np.random.seed(42) # For reproducible results - >>> mesh_data = np.array([[1.0], [2.0], [3.0], [4.0]]) - >>> cell_areas = np.array([0.1, 0.1, 0.1, 10.0]) # Last point has much larger area + >>> import torch + >>> _ = torch.manual_seed(42) # For reproducible results + >>> mesh_data = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) + >>> cell_areas = torch.tensor([0.1, 0.1, 0.1, 10.0]) # Last point has much larger area >>> subset, indices = area_weighted_shuffle_array(mesh_data, 2, cell_areas) >>> subset.shape - (2, 1) + torch.Size([2, 1]) >>> indices.shape - (2,) + torch.Size([2]) >>> # The point with large area (index 3) should likely be selected >>> len(set(indices)) <= 2 # At most 2 unique indices True >>> # Use higher area_factor for stronger bias toward large areas >>> subset_biased, _ = area_weighted_shuffle_array(mesh_data, 2, cell_areas, area_factor=2.0) """ - xp = array_type(arr) + # Calculate area-weighted probabilities sampling_probabilities = area**area_factor - sampling_probabilities /= xp.sum(sampling_probabilities) # Normalize to sum to 1 - - # Ensure we don't request more points than available - n_points = min(n_points, arr.shape[0]) - - # Create index array for all available points - point_indices = xp.arange(arr.shape[0]) - - if xp != np: - point_indices = point_indices.get() - sampling_probabilities = sampling_probabilities.get() - - selected_indices = np.random.choice( - point_indices, n_points, p=sampling_probabilities - ) - selected_indices = xp.asarray(selected_indices) + sampling_probabilities /= sampling_probabilities.sum() # Normalize to sum to 1 - return arr[selected_indices], selected_indices + return shuffle_array(arr, n_points, sampling_probabilities) def solution_weighted_shuffle_array( - arr: ArrayType, n_points: int, solution: ArrayType, scaling_factor: float = 1.0 -) -> tuple[ArrayType, ArrayType]: - """Perform solution-weighted random sampling from array. + arr: torch.Tensor, + n_points: int, + solution: torch.Tensor, + scaling_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + """Perform solution-weighted random sampling from tensor. - This function samples points from an array with probability proportional to + This function samples points from a tensor with probability proportional to their associated solution weights. This is particularly useful in CFD applications where larger cells or surface elements should have higher sampling probability. Args: - arr: Input array to sample from, shape (n_points, ...). + arr: Input tensor to sample from, shape (n_points, ...). n_points: Number of points to sample. If greater than arr.shape[0], samples all available points. solution: Solution weights for each point, shape (n_points,). Larger values @@ -1174,48 +850,100 @@ def solution_weighted_shuffle_array( Returns: Tuple containing: - - Sampled array subset weighted by solution fields + - Sampled tensor subset weighted by solution fields - Indices of the selected points Note: - For GPU arrays (CuPy), the sampling is performed on CPU due to memory - efficiency considerations. The Alias method could be implemented for - future GPU acceleration. + For GPU tensors, the sampling is performed on the current device. + The sampling uses torch.multinomial for efficient weighted sampling. Examples: - >>> import numpy as np - >>> np.random.seed(42) # For reproducible results - >>> mesh_data = np.array([[1.0], [2.0], [3.0], [4.0]]) - >>> solution = np.array([0.1, 0.1, 0.1, 10.0]) # Last point has much larger solution field + >>> import torch + >>> _ = torch.manual_seed(42) # For reproducible results + >>> mesh_data = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) + >>> solution = torch.tensor([0.1, 0.1, 0.1, 10.0]) # Last point has much larger solution field >>> subset, indices = solution_weighted_shuffle_array(mesh_data, 2, solution) >>> subset.shape - (2, 1) + torch.Size([2, 1]) >>> indices.shape - (2,) + torch.Size([2]) >>> # The point with large area (index 3) should likely be selected >>> len(set(indices)) <= 2 # At most 2 unique indices True >>> # Use higher scaling_factor for stronger bias toward large solution fields >>> subset_biased, _ = solution_weighted_shuffle_array(mesh_data, 2, solution, scaling_factor=2.0) """ - xp = array_type(arr) + # Calculate solution-weighted probabilities sampling_probabilities = solution**scaling_factor - sampling_probabilities /= xp.sum(sampling_probabilities) # Normalize to sum to 1 + sampling_probabilities /= sampling_probabilities.sum() # Normalize to sum to 1 + + return shuffle_array(arr, n_points, sampling_probabilities) + - # Ensure we don't request more points than available - n_points = min(n_points, arr.shape[0]) +def sample_points_on_mesh( + mesh_coordinates: torch.Tensor, + mesh_faces: torch.Tensor, + n_points: int, + mesh_areas: torch.Tensor | None = None, + mesh_normals: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Uniformly sample points on a mesh. + + Will use area-weighted sampling to select mesh regions, then uniform + sampling within those triangles. + """ + + # First, if we don't have the areas, compute them: + faces_reshaped = mesh_faces.reshape(-1, 3) + + if mesh_areas is None or mesh_normals is None: + # We have to do 90% of the work for both of these, + # to get either. So check at the last minute: + faces_reshaped_p0 = faces_reshaped[:, 0] + faces_reshaped_p1 = faces_reshaped[:, 1] + faces_reshaped_p2 = faces_reshaped[:, 2] + d1 = mesh_coordinates[faces_reshaped_p1] - mesh_coordinates[faces_reshaped_p0] + d2 = mesh_coordinates[faces_reshaped_p2] - mesh_coordinates[faces_reshaped_p0] + inferred_mesh_normals = torch.linalg.cross(d1, d2, dim=1) + normals_norm = torch.linalg.norm(inferred_mesh_normals, dim=1) + inferred_mesh_normals = inferred_mesh_normals / normals_norm.unsqueeze(1) + if mesh_normals is None: + mesh_normals = inferred_mesh_normals + if mesh_areas is None: + mesh_areas = 0.5 * normals_norm + + # Next, use the areas to compute a weighted sampling of the triangles: + target_triangles = torch.multinomial( + mesh_areas, + n_points, + replacement=True, + ) + + target_faces = faces_reshaped[target_triangles] - # Create index array for all available points - point_indices = xp.arange(arr.shape[0]) + # Next, generate random points within each selected triangle. + # We'll map two uniform distributions to the points in the triangles. + # See https://stackoverflow.com/questions/47410054/generate-random-locations-within-a-triangular-domain + # and the original reference https://www.cs.princeton.edu/%7Efunk/tog02.pdf + # for more information + r1 = torch.rand((n_points, 1), device=mesh_coordinates.device) + r2 = torch.rand((n_points, 1), device=mesh_coordinates.device) - if xp != np: - point_indices = point_indices.get() - sampling_probabilities = sampling_probabilities.get() + s1 = torch.sqrt(r1) - selected_indices = np.random.choice( - point_indices, n_points, p=sampling_probabilities + local_coords = torch.stack( + (1.0 - s1, (1.0 - r2) * s1, r2 * s1), + dim=1, ) - selected_indices = xp.asarray(selected_indices) - return arr[selected_indices], selected_indices + barycentric_coordinates = torch.sum( + mesh_coordinates[target_faces] * local_coords, dim=1 + ) + + # Apply the selection to the other tensors, too: + target_areas = mesh_areas[target_triangles] + target_normals = mesh_normals[target_triangles] + + return barycentric_coordinates, target_triangles, target_areas, target_normals diff --git a/physicsnemo/utils/domino/vtk_file_utils.py b/physicsnemo/utils/domino/vtk_file_utils.py new file mode 100644 index 0000000000..cdde402f8c --- /dev/null +++ b/physicsnemo/utils/domino/vtk_file_utils.py @@ -0,0 +1,380 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for data processing and training with the DoMINO model architecture. + +This module provides essential utilities for computational fluid dynamics data processing, +mesh manipulation, field normalization, and geometric computations. It supports both +CPU (NumPy) and GPU (CuPy) operations with automatic fallbacks. +""" + +from pathlib import Path + +import numpy as np +import vtk +from vtk import vtkDataSetTriangleFilter +from vtk.util import numpy_support + + +def write_to_vtp(polydata: "vtk.vtkPolyData", filename: str) -> None: + """Write VTK polydata to a VTP (VTK PolyData) file format. + + VTP files are XML-based and store polygonal data including points, polygons, + and associated field data. This format is commonly used for surface meshes + in computational fluid dynamics visualization. + + Args: + polydata: VTK polydata object containing mesh geometry and fields. + filename: Output filename with .vtp extension. Directory will be created + if it doesn't exist. + + Raises: + RuntimeError: If writing fails due to file permissions or disk space. + + """ + # Ensure output directory exists + output_path = Path(filename) + output_path.parent.mkdir(parents=True, exist_ok=True) + + writer = vtk.vtkXMLPolyDataWriter() + writer.SetFileName(str(output_path)) + writer.SetInputData(polydata) + + if not writer.Write(): + raise RuntimeError(f"Failed to write polydata to {output_path}") + + +def write_to_vtu(unstructured_grid: "vtk.vtkUnstructuredGrid", filename: str) -> None: + """Write VTK unstructured grid to a VTU (VTK Unstructured Grid) file format. + + VTU files store 3D volumetric meshes with arbitrary cell types including + tetrahedra, hexahedra, and pyramids. This format is essential for storing + finite element analysis results. + + Args: + unstructured_grid: VTK unstructured grid object containing volumetric mesh + geometry and field data. + filename: Output filename with .vtu extension. Directory will be created + if it doesn't exist. + + Raises: + RuntimeError: If writing fails due to file permissions or disk space. + + """ + # Ensure output directory exists + output_path = Path(filename) + output_path.parent.mkdir(parents=True, exist_ok=True) + + writer = vtk.vtkXMLUnstructuredGridWriter() + writer.SetFileName(str(output_path)) + writer.SetInputData(unstructured_grid) + + if not writer.Write(): + raise RuntimeError(f"Failed to write unstructured grid to {output_path}") + + +def extract_surface_triangles(tetrahedral_mesh: "vtk.vtkUnstructuredGrid") -> list[int]: + """Extract surface triangle indices from a tetrahedral mesh. + + This function identifies the boundary faces of a 3D tetrahedral mesh and + returns the vertex indices that form triangular faces on the surface. + This is essential for visualization and boundary condition application. + + Args: + tetrahedral_mesh: VTK unstructured grid containing tetrahedral elements. + + Returns: + List of vertex indices forming surface triangles. Every three consecutive + indices define one triangle. + + Raises: + NotImplementedError: If the surface contains non-triangular faces. + + """ + # Extract the surface using VTK filter + surface_filter = vtk.vtkDataSetSurfaceFilter() + surface_filter.SetInputData(tetrahedral_mesh) + surface_filter.Update() + + # Wrap with PyVista for easier manipulation + import pyvista as pv + + surface_mesh = pv.wrap(surface_filter.GetOutput()) + triangle_indices = [] + + # Process faces - PyVista stores faces as [n_vertices, v1, v2, ..., vn] + faces = surface_mesh.faces.reshape((-1, 4)) + for face in faces: + if face[0] == 3: # Triangle (3 vertices) + triangle_indices.extend([face[1], face[2], face[3]]) + else: + raise NotImplementedError( + f"Non-triangular face found with {face[0]} vertices" + ) + + return triangle_indices + + +def convert_to_tet_mesh(polydata: "vtk.vtkPolyData") -> "vtk.vtkUnstructuredGrid": + """Convert surface polydata to a tetrahedral volumetric mesh. + + This function performs tetrahedralization of a surface mesh, creating + a 3D volumetric mesh suitable for finite element analysis. The process + fills the interior of the surface with tetrahedral elements. + + Args: + polydata: VTK polydata representing a closed surface mesh. + + Returns: + VTK unstructured grid containing tetrahedral elements filling the + volume enclosed by the input surface. + + Raises: + RuntimeError: If tetrahedralization fails (e.g., non-manifold surface). + + """ + tetrahedral_filter = vtkDataSetTriangleFilter() + tetrahedral_filter.SetInputData(polydata) + tetrahedral_filter.Update() + + tetrahedral_mesh = tetrahedral_filter.GetOutput() + return tetrahedral_mesh + + +def convert_point_data_to_cell_data(input_data: "vtk.vtkDataSet") -> "vtk.vtkDataSet": + """Convert point-based field data to cell-based field data. + + This function transforms field variables defined at mesh vertices (nodes) + to values defined at cell centers. This conversion is often needed when + switching between different numerical methods or visualization requirements. + + Args: + input_data: VTK dataset with point data to be converted. + + Returns: + VTK dataset with the same geometry but field data moved from points to cells. + Values are typically averaged from the surrounding points. + + """ + point_to_cell_filter = vtk.vtkPointDataToCellData() + point_to_cell_filter.SetInputData(input_data) + point_to_cell_filter.Update() + + return point_to_cell_filter.GetOutput() + + +def get_node_to_elem(polydata: "vtk.vtkDataSet") -> "vtk.vtkDataSet": + """Convert point data to cell data for VTK dataset. + + This function transforms field variables defined at mesh vertices to + values defined at cell centers using VTK's built-in conversion filter. + + Args: + polydata: VTK dataset with point data to be converted. + + Returns: + VTK dataset with field data moved from points to cells. + + """ + point_to_cell_filter = vtk.vtkPointDataToCellData() + point_to_cell_filter.SetInputData(polydata) + point_to_cell_filter.Update() + cell_data = point_to_cell_filter.GetOutput() + return cell_data + + +def get_fields_from_cell( + cell_data: "vtk.vtkCellData", variable_names: list[str] +) -> np.ndarray: + """Extract field variables from VTK cell data. + + This function extracts multiple field variables from VTK cell data and + organizes them into a structured NumPy array. Each variable becomes a + column in the output array. + + Args: + cell_data: VTK cell data object containing field variables. + variable_names: List of variable names to extract from the cell data. + + Returns: + NumPy array of shape (n_cells, n_variables) containing the extracted + field data. Variables are ordered according to the input list. + + Raises: + ValueError: If a requested variable name is not found in the cell data. + + """ + extracted_fields = [] + for variable_name in variable_names: + variable_array = cell_data.GetArray(variable_name) + if variable_array is None: + raise ValueError(f"Variable '{variable_name}' not found in cell data") + + num_tuples = variable_array.GetNumberOfTuples() + field_values = [] + for tuple_idx in range(num_tuples): + variable_value = np.array(variable_array.GetTuple(tuple_idx)) + field_values.append(variable_value) + field_values = np.asarray(field_values) + extracted_fields.append(field_values) + + # Transpose to get shape (n_cells, n_variables) + extracted_fields = np.transpose(np.asarray(extracted_fields), (1, 0)) + return extracted_fields + + +def get_fields( + data_attributes: "vtk.vtkDataSetAttributes", variable_names: list[str] +) -> list[np.ndarray]: + """Extract multiple field variables from VTK data attributes. + + This function extracts field variables from VTK data attributes (either + point data or cell data) and returns them as a list of NumPy arrays. + It handles both point and cell data seamlessly. + + Args: + data_attributes: VTK data attributes object (point data or cell data). + variable_names: List of variable names to extract. + + Returns: + List of NumPy arrays, one for each requested variable. Each array + has shape (n_points/n_cells, n_components) where n_components + depends on the variable (1 for scalars, 3 for vectors, etc.). + + Raises: + ValueError: If a requested variable is not found in the data attributes. + + """ + extracted_fields = [] + for variable_name in variable_names: + try: + vtk_array = data_attributes.GetArray(variable_name) + except ValueError as e: + raise ValueError( + f"Failed to get array '{variable_name}' from the data attributes: {e}" + ) + + # Convert VTK array to NumPy array with proper shape + numpy_array = numpy_support.vtk_to_numpy(vtk_array).reshape( + vtk_array.GetNumberOfTuples(), vtk_array.GetNumberOfComponents() + ) + extracted_fields.append(numpy_array) + + return extracted_fields + + +def get_vertices(polydata: "vtk.vtkPolyData") -> np.ndarray: + """Extract vertex coordinates from VTK polydata object. + + This function converts VTK polydata to a NumPy array containing the 3D + coordinates of all vertices in the mesh. + + Args: + polydata: VTK polydata object containing mesh geometry. + + Returns: + NumPy array of shape (n_points, 3) containing [x, y, z] coordinates + for each vertex. + + """ + vtk_points = polydata.GetPoints() + vertices = numpy_support.vtk_to_numpy(vtk_points.GetData()) + return vertices + + +def get_volume_data( + polydata: "vtk.vtkPolyData", variable_names: list[str] +) -> tuple[np.ndarray, list[np.ndarray]]: + """Extract vertices and field data from 3D volumetric mesh. + + This function extracts both geometric information (vertex coordinates) + and field data from a 3D volumetric mesh. It's commonly used for + processing finite element analysis results. + + Args: + polydata: VTK polydata representing a 3D volumetric mesh. + variable_names: List of field variable names to extract. + + Returns: + Tuple containing: + - Vertex coordinates as NumPy array of shape (n_vertices, 3) + - List of field arrays, one per variable + + """ + vertices = get_vertices(polydata) + point_data = polydata.GetPointData() + fields = get_fields(point_data, variable_names) + + return vertices, fields + + +def get_surface_data( + polydata: "vtk.vtkPolyData", variable_names: list[str] +) -> tuple[np.ndarray, list[np.ndarray], list[tuple[int, int]]]: + """Extract surface mesh data including vertices, fields, and edge connectivity. + + This function extracts comprehensive surface mesh information including + vertex coordinates, field data at vertices, and edge connectivity information. + It's commonly used for processing CFD surface results and boundary conditions. + + Args: + polydata: VTK polydata representing a surface mesh. + variable_names: List of field variable names to extract from the mesh. + + Returns: + Tuple containing: + - Vertex coordinates as NumPy array of shape (n_vertices, 3) + - List of field arrays, one per variable + - List of edge tuples representing mesh connectivity + + Raises: + ValueError: If a requested variable is not found or polygon data is missing. + + """ + points = polydata.GetPoints() + vertices = np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]) + + point_data = polydata.GetPointData() + fields = [] + for array_name in variable_names: + try: + array = point_data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = np.zeros( + (points.GetNumberOfPoints(), array.GetNumberOfComponents()) + ) + for j in range(points.GetNumberOfPoints()): + array.GetTuple(j, array_data[j]) + fields.append(array_data) + + polys = polydata.GetPolys() + if polys is None: + raise ValueError("Failed to get polygons from the polydata.") + polys.InitTraversal() + edges = [] + id_list = vtk.vtkIdList() + for _ in range(polys.GetNumberOfCells()): + polys.GetNextCell(id_list) + num_ids = id_list.GetNumberOfIds() + edges = [ + (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) + ] + + return vertices, fields, edges diff --git a/physicsnemo/utils/memory.py b/physicsnemo/utils/memory.py new file mode 100644 index 0000000000..54ceb5061c --- /dev/null +++ b/physicsnemo/utils/memory.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +try: + import rmm + + RMM_AVAILABLE = True +except ImportError: + RMM_AVAILABLE = False + +try: + import cupy + + CUPY_AVAILABLE = True +except ImportError: + CUPY_AVAILABLE = False + +""" +Using a unifed gpu memory provider, we consolidate the pool into just a +single allocator for cupy/rapids and torch. Ideally, we add warp to this someday. + +To use this, you need to add the following to your code at or near the top +(before allocating any GPU memory): + +```python +from physicsnemo.utils.memory import unified_gpu_memory +``` + +""" + + +def srt2bool(val: str): + if isinstance(val, bool): + return val + if val.lower() in ["true", "1", "yes", "y"]: + return True + elif val.lower() in ["false", "0", "no", "n"]: + return False + else: + raise ValueError(f"Invalid boolean value: {val}") + + +DISABLE_RMM = srt2bool(os.environ.get("PHYSICSNEMO_DISABLE_RMM", False)) + + +def _setup_unified_gpu_memory(): + # Skip if RMM is disabled + if RMM_AVAILABLE and not DISABLE_RMM: + # First, determine the local rank so that we allocate on the right device. + # These are meant to be tested in the same order as DistributedManager + # We can't actually initialize it, though, since we have to unify mallocs + # before torch init. + PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD = os.environ.get( + "PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD", None + ) + if PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD is None: + for method in ["LOCAL_RANK", "OMPI_COMM_WORLD_LOCAL_RANK", "SLURM_LOCALID"]: + if os.environ.get(method) is not None: + local_rank = int(os.environ.get(method)) + break + else: + if PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD == "ENV": + local_rank = int(os.environ.get("LOCAL_RANK")) + elif PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD == "SLURM": + local_rank = int(os.environ.get("SLURM_LOCALID")) + elif PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD == "OPENMPI": + local_rank = int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")) + else: + raise ValueError( + f"Unknown initialization method: {PHYSICSNEMO_DISTRIBUTED_INITIALIZATION_METHOD}" + ) + + # Initialize RMM + rmm.reinitialize( + pool_allocator=True, devices=local_rank, initial_pool_size="1024MB" + ) + + # Set PyTorch allocator if available + from rmm.allocators.torch import rmm_torch_allocator + + if torch.cuda.is_available(): + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + # Set CuPy allocator if available + if CUPY_AVAILABLE: + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + +# This is what gets executed when someone does "from memory import unified_gpu_memory" + + +def __getattr__(name): + if name == "unified_gpu_memory": + return _setup_unified_gpu_memory() + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/physicsnemo/utils/neighbors/radius_search/_torch_impl.py b/physicsnemo/utils/neighbors/radius_search/_torch_impl.py index c6df0f9e81..2b4c3394c3 100644 --- a/physicsnemo/utils/neighbors/radius_search/_torch_impl.py +++ b/physicsnemo/utils/neighbors/radius_search/_torch_impl.py @@ -56,8 +56,6 @@ def radius_search_impl( dists = torch.empty((0,), device=dists.device, dtype=dists.dtype) else: - print(f"dists shape: {dists.shape}") - # Take the max_points lowest distances for each query closest_points = torch.topk( dists, k=min(max_points, dists.shape[0]), dim=0, largest=False diff --git a/test/datapipes/test_domino_datapipe.py b/test/datapipes/test_domino_datapipe.py index a2f5ad645a..9f665886bd 100644 --- a/test/datapipes/test_domino_datapipe.py +++ b/test/datapipes/test_domino_datapipe.py @@ -18,7 +18,7 @@ import tempfile from dataclasses import dataclass from pathlib import Path -from typing import List, Literal +from typing import List, Literal, Optional, Sequence import numpy as np import pytest @@ -27,6 +27,13 @@ from pytest_utils import import_or_fail from scipy.spatial import ConvexHull +from physicsnemo.datapipes.cae.cae_dataset import CAEDataset +from physicsnemo.datapipes.cae.domino_datapipe import ( + CachedDoMINODataset, + DoMINODataConfig, + DoMINODataPipe, +) + Tensor = torch.Tensor # DEFINING GLOBAL VARIABLES HERE @@ -91,7 +98,7 @@ def synthetic_domino_data( for i in range(n_examples): # We are generating a mesh on a random sphere. stl_points = random_sample_on_unit_sphere(N_mesh_points) - print(f"stl_points.shape: {stl_points.shape}") + # Generate the triangles with ConvexHull: hull = ConvexHull(stl_points) faces = hull.simplices # (M, 3) @@ -236,9 +243,23 @@ def bounding_boxes(): } -def create_basic_dataset(data_dir, model_type, **kwargs): +def create_basic_dataset( + data_dir, + model_type, + gpu_preprocessing: bool = False, + gpu_output: bool = False, + normalize_coordinates: bool = False, + sample_in_bbox: bool = False, + sampling: bool = False, + volume_points_sample: int = 1234, + surface_points_sample: int = 1234, + surface_sampling_algorithm: str = "random", + caching: bool = False, + scaling_type: Optional[Literal["min_max_scaling", "mean_std_scaling"]] = None, + volume_factors: Optional[Sequence] = None, + surface_factors: Optional[Sequence] = None, +): """Helper function to create a basic DoMINODataPipe with default settings.""" - from physicsnemo.datapipes.cae.domino_datapipe import DoMINODataPipe # assert model_type in ["volume", "surface", "combined"] @@ -246,41 +267,80 @@ def create_basic_dataset(data_dir, model_type, **kwargs): bounding_box = bounding_boxes() + keys_to_read = [ + "stl_coordinates", + "stl_faces", + "stl_centers", + "stl_areas", + ] + + if model_type == "volume" or model_type == "combined": + keys_to_read += [ + "volume_mesh_centers", + "volume_fields", + ] + + if model_type == "surface" or model_type == "combined": + keys_to_read += [ + "surface_mesh_centers", + "surface_areas", + "surface_normals", + "surface_fields", + ] + + keys_to_read_if_available = { + "global_params_values": torch.tensor([1.225, 10.0]), + "global_params_reference": torch.tensor([1.225, 10.0]), + } + + dataset = CAEDataset( + data_dir=input_path, + keys_to_read=keys_to_read, + keys_to_read_if_available=keys_to_read_if_available, + output_device=torch.device("cuda") + if gpu_preprocessing + else torch.device("cpu"), + preload_depth=0, + pin_memory=False, + device_mesh=None, + placements=None, + ) + default_kwargs = { "phase": "test", "grid_resolution": [64, 64, 64], - "volume_points_sample": 1234, - "surface_points_sample": 1234, - "geom_points_sample": 2345, + "volume_points_sample": volume_points_sample, + "surface_points_sample": surface_points_sample, + "geom_points_sample": 500, "num_surface_neighbors": 5, "bounding_box_dims": bounding_box["volume"], "bounding_box_dims_surf": bounding_box["surface"], - "normalize_coordinates": True, - "sampling": False, - "sample_in_bbox": False, - "positional_encoding": False, - "scaling_type": None, - "volume_factors": None, - "surface_factors": None, - "caching": False, - "compute_scaling_factors": False, - "gpu_preprocessing": True, - "gpu_output": True, + "normalize_coordinates": normalize_coordinates, + "sampling": sampling, + "sample_in_bbox": sample_in_bbox, + "scaling_type": scaling_type, + "volume_factors": volume_factors, + "surface_factors": surface_factors, + "caching": caching, + "gpu_preprocessing": gpu_preprocessing, + "gpu_output": gpu_output, + "surface_sampling_algorithm": surface_sampling_algorithm, } - default_kwargs.update(kwargs) - - return DoMINODataPipe( + pipe = DoMINODataPipe( input_path=input_path, model_type=model_type, **default_kwargs ) + pipe.set_dataset(dataset) + return pipe + def validate_sample_structure(sample, model_type, gpu_output): """Helper function to validate the structure of a dataset sample.""" assert isinstance(sample, dict) # Common keys that should always be present - expected_keys = ["geometry_coordinates", "length_scale", "surface_min_max"] + expected_keys = ["geometry_coordinates"] # Model-specific keys volume_keys = [ @@ -303,6 +363,7 @@ def validate_sample_structure(sample, model_type, gpu_output): expected_keys.extend(surface_keys) # Check that required keys are present and are torch tensors on correct device + for key in expected_keys: if key in sample: # Some keys may be None if compute_scaling_factors=True if sample[key] is not None: @@ -327,9 +388,14 @@ def test_domino_datapipe_core( """Core test for basic functionality with different device and model configurations.""" data_dir = request.getfixturevalue(data_dir) - print(f"data_dir: {data_dir}") dataset = create_basic_dataset( - data_dir, model_type, gpu_preprocessing=gpu_preprocessing, gpu_output=gpu_output + data_dir, + model_type, + gpu_preprocessing=gpu_preprocessing, + gpu_output=gpu_output, + normalize_coordinates=False, + sample_in_bbox=False, + sampling=False, ) assert len(dataset) > 0 @@ -350,23 +416,41 @@ def test_domino_datapipe_coordinate_normalization( zarr_dataset, model_type, gpu_preprocessing=True, + gpu_output=True, normalize_coordinates=normalize_coordinates, sample_in_bbox=sample_in_bbox, + sampling=False, ) sample = dataset[0] validate_sample_structure(sample, model_type, gpu_output=True) - v_coords = sample["volume_mesh_centers"] - s_coords = sample["surface_mesh_centers"] + # Check all the volume coordinates: + for volume_key in ["volume_mesh_centers"]: + coords = sample[volume_key] + check_tensor_normalization( + coords, normalize_coordinates, sample_in_bbox, is_surface=False + ) - v_min = torch.min(v_coords, dim=0).values - v_max = torch.max(v_coords, dim=0).values - s_min = torch.min(s_coords, dim=0).values - s_max = torch.max(s_coords, dim=0).values + # Check all the surface coordinates: + for surface_key in ["surface_mesh_centers", "surface_mesh_neighbors"]: + coords = sample[surface_key] + if surface_key == "surface_mesh_neighbors": + coords = coords.reshape((1, -1, 3)) + check_tensor_normalization( + coords, normalize_coordinates, sample_in_bbox, is_surface=True + ) + + +def check_tensor_normalization( + tensor, normalize_coordinates, sample_in_bbox, is_surface +): + """Check if a tensor is normalized properly.""" + + # Batch size is 1 here, but in principle this could be a loop: + t_min = torch.min(tensor[0], dim=0).values + t_max = torch.max(tensor[0], dim=0).values - print(f"{normalize_coordinates} v_coords: {v_min} to {v_max}") - print(f"{normalize_coordinates} s_coords: {s_min} to {s_max}") # If normalization is enabled, coordinates should be in [-2, 2] range if normalize_coordinates: if sample_in_bbox: @@ -374,12 +458,12 @@ def test_domino_datapipe_coordinate_normalization( # that were already inside the box should be present. # That means that all values should be between -1 and 1 - assert v_min[0] >= -1 - assert v_min[1] >= -1 - assert v_min[2] >= -1 - assert v_max[0] <= 1 - assert v_max[1] <= 1 - assert v_max[2] <= 1 + assert t_min[0] >= -1 + assert t_min[1] >= -1 + assert t_min[2] >= -1 + assert t_max[0] <= 1 + assert t_max[1] <= 1 + assert t_max[2] <= 1 else: # When normalizing the coordinates, the values of the bbox @@ -394,56 +478,248 @@ def test_domino_datapipe_coordinate_normalization( # So, field_range = (2 - -1) = 3 # new_val = 2 * (5 - -1)/ 3 - 1 = 3 - vol_x_rescale = 1 / (VOL_BBOX_XMAX - VOL_BBOX_XMIN) - vol_y_rescale = 1 / (VOL_BBOX_YMAX - VOL_BBOX_YMIN) - vol_z_rescale = 1 / (VOL_BBOX_ZMAX - VOL_BBOX_ZMIN) - - assert v_min[0] >= 2 * (DATA_XMIN - VOL_BBOX_XMIN) * vol_x_rescale - 1 - assert v_min[1] >= 2 * (DATA_YMIN - VOL_BBOX_YMIN) * vol_y_rescale - 1 - assert v_min[2] >= 2 * (DATA_ZMIN - VOL_BBOX_ZMIN) * vol_z_rescale - 1 - assert v_max[0] <= 2 * (DATA_XMAX - VOL_BBOX_XMIN) * vol_x_rescale - 1 - assert v_max[1] <= 2 * (DATA_YMAX - VOL_BBOX_YMIN) * vol_y_rescale - 1 - assert v_max[2] <= 2 * (DATA_ZMAX - VOL_BBOX_ZMIN) * vol_z_rescale - 1 - - surf_x_rescale = 1 / (SURF_BBOX_XMAX - SURF_BBOX_XMIN) - surf_y_rescale = 1 / (SURF_BBOX_YMAX - SURF_BBOX_YMIN) - surf_z_rescale = 1 / (SURF_BBOX_ZMAX - SURF_BBOX_ZMIN) - - assert s_min[0] >= 2 * (DATA_XMIN - SURF_BBOX_XMIN) * surf_x_rescale - 1 - assert s_min[1] >= 2 * (DATA_YMIN - SURF_BBOX_YMIN) * surf_y_rescale - 1 - assert s_min[2] >= 2 * (DATA_ZMIN - SURF_BBOX_ZMIN) * surf_z_rescale - 1 - assert s_max[0] <= 2 * (DATA_XMAX - SURF_BBOX_XMIN) * surf_x_rescale - 1 - assert s_max[1] <= 2 * (DATA_YMAX - SURF_BBOX_YMIN) * surf_y_rescale - 1 - assert s_max[2] <= 2 * (DATA_ZMAX - SURF_BBOX_ZMIN) * surf_z_rescale - 1 + if is_surface: + x_rescale = 1 / (SURF_BBOX_XMAX - SURF_BBOX_XMIN) + y_rescale = 1 / (SURF_BBOX_YMAX - SURF_BBOX_YMIN) + z_rescale = 1 / (SURF_BBOX_ZMAX - SURF_BBOX_ZMIN) + target_min_x = 2 * (DATA_XMIN - SURF_BBOX_XMIN) * x_rescale - 1 + target_min_y = 2 * (DATA_YMIN - SURF_BBOX_YMIN) * y_rescale - 1 + target_min_z = 2 * (DATA_ZMIN - SURF_BBOX_ZMIN) * z_rescale - 1 + target_max_x = 2 * (DATA_XMAX - SURF_BBOX_XMIN) * x_rescale - 1 + target_max_y = 2 * (DATA_YMAX - SURF_BBOX_YMIN) * y_rescale - 1 + target_max_z = 2 * (DATA_ZMAX - SURF_BBOX_ZMIN) * z_rescale - 1 + else: + x_rescale = 1 / (VOL_BBOX_XMAX - VOL_BBOX_XMIN) + y_rescale = 1 / (VOL_BBOX_YMAX - VOL_BBOX_YMIN) + z_rescale = 1 / (VOL_BBOX_ZMAX - VOL_BBOX_ZMIN) + target_min_x = 2 * (DATA_XMIN - VOL_BBOX_XMIN) * x_rescale - 1 + target_min_y = 2 * (DATA_YMIN - VOL_BBOX_YMIN) * y_rescale - 1 + target_min_z = 2 * (DATA_ZMIN - VOL_BBOX_ZMIN) * z_rescale - 1 + target_max_x = 2 * (DATA_XMAX - VOL_BBOX_XMIN) * x_rescale - 1 + target_max_y = 2 * (DATA_YMAX - VOL_BBOX_YMIN) * y_rescale - 1 + target_max_z = 2 * (DATA_ZMAX - VOL_BBOX_ZMIN) * z_rescale - 1 + + assert t_min[0] >= target_min_x + assert t_min[1] >= target_min_y + assert t_min[2] >= target_min_z + assert t_max[0] <= target_max_x + assert t_max[1] <= target_max_y + assert t_max[2] <= target_max_z else: if sample_in_bbox: # We've sampled in the bbox but NOT normalized. # So, the values should exclusively be in the BBOX ranges: - assert v_min[0] >= VOL_BBOX_XMIN - assert v_min[1] >= VOL_BBOX_YMIN - assert v_min[2] >= VOL_BBOX_ZMIN - assert v_max[0] <= VOL_BBOX_XMAX - assert v_max[1] <= VOL_BBOX_YMAX - assert v_max[2] <= VOL_BBOX_ZMAX - - assert s_min[0] >= SURF_BBOX_XMIN - assert s_min[1] >= SURF_BBOX_YMIN - assert s_min[2] >= SURF_BBOX_ZMIN - assert s_max[0] <= SURF_BBOX_XMAX - assert s_max[1] <= SURF_BBOX_YMAX - assert s_max[2] <= SURF_BBOX_ZMAX + + if is_surface: + assert t_min[0] >= SURF_BBOX_XMIN + assert t_min[1] >= SURF_BBOX_YMIN + assert t_min[2] >= SURF_BBOX_ZMIN + assert t_max[0] <= SURF_BBOX_XMAX + assert t_max[1] <= SURF_BBOX_YMAX + assert t_max[2] <= SURF_BBOX_ZMAX + else: + assert t_min[0] >= VOL_BBOX_XMIN + assert t_min[1] >= VOL_BBOX_YMIN + assert t_min[2] >= VOL_BBOX_ZMIN + assert t_max[0] <= VOL_BBOX_XMAX + assert t_max[1] <= VOL_BBOX_YMAX + assert t_max[2] <= VOL_BBOX_ZMAX else: # Not sampling, and also # Not normalizing, values should be in data range only: - assert v_min[0] >= DATA_XMIN and v_max[0] <= DATA_XMAX - assert v_min[1] >= DATA_YMIN and v_max[1] <= DATA_YMAX - assert v_min[2] >= DATA_ZMIN and v_max[2] <= DATA_ZMAX - assert s_min[0] >= DATA_XMIN and s_max[0] <= DATA_XMAX - assert s_min[1] >= DATA_YMIN and s_max[1] <= DATA_YMAX - # Surface points always should be > 0 - assert s_min[2] >= 0 and s_max[2] <= DATA_ZMAX + assert t_min[0] >= DATA_XMIN and t_max[0] <= DATA_XMAX + assert t_min[1] >= DATA_YMIN and t_max[1] <= DATA_YMAX + + if is_surface: + # Surface points always should be > 0 + assert t_min[2] >= 0 and t_max[2] <= DATA_ZMAX + else: + assert t_min[2] >= DATA_ZMIN and t_max[2] <= DATA_ZMAX + + return True + + +@pytest.mark.parametrize("model_type", ["surface"]) +@pytest.mark.parametrize("normalize_coordinates", [True, False]) +@pytest.mark.parametrize("sample_in_bbox", [True, False]) +def test_domino_datapipe_surface_normalization( + zarr_dataset, pytestconfig, model_type, normalize_coordinates, sample_in_bbox +): + """Test normalization functionality. + + This test is meant to make sure all the peripheral outputs are + normalized properly. FOcus on surface here. + + We could do them all in one test but it gets unweildy, and if there + are failures it helps nail down exactly where. + """ + cuda = torch.cuda.is_available() + + dataset = create_basic_dataset( + zarr_dataset, + model_type, + gpu_preprocessing=cuda, + gpu_output=cuda, + normalize_coordinates=normalize_coordinates, + sampling=True, + sample_in_bbox=sample_in_bbox, + ) + + # Here's a list of values to check, and the behavior we expect: + + # surf_grid - normalized by s_min, s_max + sample = dataset[0] + surf_grid = sample["surf_grid"] + + # If normalizing, surf_grid should be between -1 and 1. + # Otherwise, should be between s_min and s_max + if not normalize_coordinates: + target_min = torch.tensor([SURF_BBOX_XMIN, SURF_BBOX_YMIN, SURF_BBOX_ZMIN]) + target_max = torch.tensor([SURF_BBOX_XMAX, SURF_BBOX_YMAX, SURF_BBOX_ZMAX]) + else: + target_min = torch.tensor([-1.0, -1.0, -1.0]) + target_max = torch.tensor([1.0, 1.0, 1.0]) + + target_min = target_min.to(surf_grid.device) + target_max = target_max.to(surf_grid.device) + + # Flatten all the grid coords: + surf_grid = surf_grid.reshape((-1, 3)) + + assert torch.all(surf_grid >= target_min) + assert torch.all(surf_grid <= target_max) + + # sdf_surf_grid - should have max values less than || s_max - s_min|| + + max_norm_allowed = torch.norm(target_max - target_min) + + sdf_surf_grid = sample["sdf_surf_grid"] + assert torch.all(sdf_surf_grid <= max_norm_allowed) + # (Negative values are ok but we don't really check that.) + + # surface_min_max should only be in the dict if normaliztion is on: + if normalize_coordinates: + assert "surface_min_max" in sample + s_mm = sample["surface_min_max"] + assert s_mm.shape == (1, 2, 3) + + assert torch.allclose( + s_mm[0, 0], + torch.tensor([SURF_BBOX_XMIN, SURF_BBOX_YMIN, SURF_BBOX_ZMIN]).to( + s_mm.device + ), + ) + assert torch.allclose( + s_mm[0, 1], + torch.tensor([SURF_BBOX_XMAX, SURF_BBOX_YMAX, SURF_BBOX_ZMAX]).to( + s_mm.device + ), + ) + + else: + assert "surface_min_max" not in sample + + # For the rest of the values, checks are straightforward: + + assert torch.all(sample["surface_areas"] > 0) + assert torch.all(sample["surface_neighbors_areas"] > 0) + + # No checks implemented on the following, yet: + # - pos_surface_center_of_mass + + +@pytest.mark.parametrize("model_type", ["volume"]) +@pytest.mark.parametrize("normalize_coordinates", [True, False]) +@pytest.mark.parametrize("sample_in_bbox", [True, False]) +def test_domino_datapipe_volume_normalization( + zarr_dataset, pytestconfig, model_type, normalize_coordinates, sample_in_bbox +): + """Test normalization functionality. + + This test is meant to make sure all the peripheral outputs are + normalized properly. FOcus on volume here. + + We could do them all in one test but it gets unweildy, and if there + are failures it helps nail down exactly where. + """ + cuda = torch.cuda.is_available() + + dataset = create_basic_dataset( + zarr_dataset, + model_type, + gpu_preprocessing=cuda, + gpu_output=cuda, + normalize_coordinates=normalize_coordinates, + sampling=True, + sample_in_bbox=sample_in_bbox, + ) + + # Here's a list of values to check, and the behavior we expect: + + # grid - normalized by s_min, s_max + sample = dataset[0] + grid = sample["grid"] + + # If normalizing, surf_grid should be between -1 and 1. + # Otherwise, should be between s_min and s_max + if not normalize_coordinates: + target_min = torch.tensor([VOL_BBOX_XMIN, VOL_BBOX_YMIN, VOL_BBOX_ZMIN]) + target_max = torch.tensor([VOL_BBOX_XMAX, VOL_BBOX_YMAX, VOL_BBOX_ZMAX]) + else: + target_min = torch.tensor([-1.0, -1.0, -1.0]) + target_max = torch.tensor([1.0, 1.0, 1.0]) + + target_min = target_min.to(grid.device) + target_max = target_max.to(grid.device) + + # Flatten all the grid coords: + grid = grid.reshape((-1, 3)) + + assert torch.all(grid >= target_min) + assert torch.all(grid <= target_max) + + # sdf_grid - should have max values less than || s_max - s_min|| + + max_norm_allowed = torch.norm(target_max - target_min) + + sdf_grid = sample["sdf_grid"] + assert torch.all(sdf_grid <= max_norm_allowed) + # (Negative values are ok but we don't really check that.) + + # surface_min_max should only be in the dict if normaliztion is on: + if normalize_coordinates: + assert "volume_min_max" in sample + s_mm = sample["volume_min_max"] + assert s_mm.shape == (1, 2, 3) + + assert torch.allclose( + s_mm[0, 0], + torch.tensor([VOL_BBOX_XMIN, VOL_BBOX_YMIN, VOL_BBOX_ZMIN]).to(s_mm.device), + ) + assert torch.allclose( + s_mm[0, 1], + torch.tensor([VOL_BBOX_XMAX, VOL_BBOX_YMAX, VOL_BBOX_ZMAX]).to(s_mm.device), + ) + + else: + assert "volume_min_max" not in sample + + sdf_nodes = sample["sdf_nodes"] + pos_volume_closest_norm = torch.norm(sample["pos_volume_closest"], dim=-1).reshape( + sdf_nodes.shape + ) + assert torch.allclose(pos_volume_closest_norm, sdf_nodes) + # No checks implemented on the following, yet: + # - pos_volume_center_of_mass + + # The center of mass should be inside the mesh. So, the displacement + # from the center of mass should be exclusively larger than the sdf: + pos_volume_center_of_mass_norm = torch.norm( + sample["pos_volume_center_of_mass"], dim=-1 + ).reshape(sdf_nodes.shape) + assert torch.all(pos_volume_center_of_mass_norm > sdf_nodes) @import_or_fail(["warp", "cupy", "cuml"]) @@ -452,24 +728,30 @@ def test_domino_datapipe_coordinate_normalization( def test_domino_datapipe_sampling(zarr_dataset, model_type, sampling, pytestconfig): """Test point sampling functionality.""" sample_points = 4321 + + use_cuda = torch.cuda.is_available() + dataset = create_basic_dataset( zarr_dataset, model_type, - gpu_preprocessing=False, + gpu_preprocessing=use_cuda, + gpu_output=use_cuda, + normalize_coordinates=False, + sample_in_bbox=False, sampling=sampling, volume_points_sample=sample_points, surface_points_sample=sample_points, ) sample = dataset[0] - validate_sample_structure(sample, model_type, gpu_output=True) + validate_sample_structure(sample, model_type, gpu_output=use_cuda) if model_type in ["volume", "combined"]: for key in ["volume_mesh_centers", "volume_fields"]: if sampling: - assert sample[key].shape[0] == sample_points + assert sample[key].shape[1] == sample_points else: - assert sample[key].shape[0] == sample["volume_mesh_centers"].shape[0] + assert sample[key].shape[1] == sample["volume_mesh_centers"].shape[1] # Model-specific keys if model_type in ["surface", "combined"]: @@ -480,74 +762,60 @@ def test_domino_datapipe_sampling(zarr_dataset, model_type, sampling, pytestconf "surface_fields", ]: if sampling: - assert sample[key].shape[0] == sample_points + assert sample[key].shape[1] == sample_points else: - assert sample[key].shape[0] == sample["surface_mesh_centers"].shape[0] + assert sample[key].shape[1] == sample["surface_mesh_centers"].shape[1] for key in [ "surface_mesh_neighbors", "surface_neighbors_normals", "surface_neighbors_areas", ]: if sampling: - assert sample[key].shape[0] == sample_points - assert sample[key].shape[1] == dataset.config.num_surface_neighbors - 1 + assert sample[key].shape[1] == sample_points + assert sample[key].shape[2] == dataset.config.num_surface_neighbors - 1 else: - assert sample[key].shape[0] == sample["surface_mesh_neighbors"].shape[0] - assert sample[key].shape[1] == dataset.config.num_surface_neighbors - 1 - - -@import_or_fail(["warp", "cupy", "cuml"]) -@pytest.mark.parametrize("model_type", ["combined"]) -@pytest.mark.parametrize( - "positional_encoding", - [ - True, - ], -) -def test_domino_datapipe_positional_encoding( - zarr_dataset, model_type, positional_encoding, pytestconfig -): - """Test positional encoding functionality.""" - dataset = create_basic_dataset( - zarr_dataset, - model_type, - gpu_preprocessing=False, - positional_encoding=positional_encoding, - ) - - sample = dataset[0] - validate_sample_structure(sample, model_type, gpu_output=True) - - # Check for positional encoding keys - if positional_encoding: - pos_keys = ["pos_volume_closest", "pos_volume_center_of_mass"] - for key in pos_keys: - if key in sample: - assert sample[key] is not None + assert sample[key].shape[1] == sample["surface_mesh_neighbors"].shape[1] + assert sample[key].shape[2] == dataset.config.num_surface_neighbors - 1 @import_or_fail(["warp", "cupy", "cuml"]) -@pytest.mark.parametrize("model_type", ["volume"]) +@pytest.mark.parametrize("model_type", ["volume", "surface", "combined"]) @pytest.mark.parametrize("scaling_type", [None, "min_max_scaling", "mean_std_scaling"]) def test_domino_datapipe_scaling(zarr_dataset, model_type, scaling_type, pytestconfig): """Test field scaling functionality.""" - if scaling_type == "min_max_scaling": - volume_factors = [10.0, -10.0] # [max, min] - elif scaling_type == "mean_std_scaling": - volume_factors = [0.0, 1.0] # [mean, std] + use_cuda = torch.cuda.is_available() + + if model_type in ["volume", "combined"]: + volume_factors = torch.tensor( + [ + [10.0, -10.0, 10.0, 10.0, 10.0], + [10.0, -10.0, 10.0, 10.0, 10.0], + ] + ) else: volume_factors = None + if model_type in ["surface", "combined"]: + surface_factors = torch.tensor( + [ + [10.0, -10.0, 10.0, 10.0], + [10.0, -10.0, 10.0, 10.0], + ] + ) + else: + surface_factors = None dataset = create_basic_dataset( zarr_dataset, model_type, - gpu_preprocessing=False, + gpu_preprocessing=use_cuda, + gpu_output=use_cuda, scaling_type=scaling_type, volume_factors=volume_factors, + surface_factors=surface_factors, ) sample = dataset[0] - validate_sample_structure(sample, model_type, gpu_output=True) + validate_sample_structure(sample, model_type, gpu_output=use_cuda) # Caching tests @@ -555,24 +823,23 @@ def test_domino_datapipe_scaling(zarr_dataset, model_type, scaling_type, pytestc @pytest.mark.parametrize("model_type", ["volume"]) def test_domino_datapipe_caching_config(zarr_dataset, model_type, pytestconfig): """Test DoMINODataPipe with caching=True configuration.""" + use_cuda = torch.cuda.is_available() dataset = create_basic_dataset( zarr_dataset, model_type, - gpu_preprocessing=False, + gpu_preprocessing=use_cuda, + gpu_output=use_cuda, caching=True, sampling=False, # Required for caching - compute_scaling_factors=False, # Required for caching - resample_surfaces=False, # Required for caching ) sample = dataset[0] - validate_sample_structure(sample, model_type, gpu_output=True) + validate_sample_structure(sample, model_type, gpu_output=use_cuda) @import_or_fail(["warp", "cupy", "cuml"]) def test_cached_domino_dataset(zarr_dataset, tmp_path, pytestconfig): """Test CachedDoMINODataset functionality.""" - from physicsnemo.datapipes.cae.domino_datapipe import CachedDoMINODataset # Create some mock cached data files for i in range(3): @@ -613,31 +880,22 @@ def test_cached_domino_dataset(zarr_dataset, tmp_path, pytestconfig): def test_domino_datapipe_invalid_caching_config(zarr_dataset, pytestconfig): """Test that invalid caching configurations raise appropriate errors.""" + use_cuda = torch.cuda.is_available() # Test: caching=True with sampling=True should fail with pytest.raises(ValueError, match="Sampling should be False for caching"): - create_basic_dataset(zarr_dataset, "volume", caching=True, sampling=True) - - # Test: caching=True with compute_scaling_factors=True should fail - with pytest.raises( - ValueError, match="Compute scaling factors should be False for caching" - ): - create_basic_dataset( - zarr_dataset, "volume", caching=True, compute_scaling_factors=True - ) - - # Test: caching=True with resample_surfaces=True should fail - with pytest.raises( - ValueError, match="Resample surface should be False for caching" - ): create_basic_dataset( - zarr_dataset, "volume", caching=True, resample_surfaces=True + zarr_dataset, + "volume", + caching=True, + sampling=True, + gpu_preprocessing=use_cuda, + gpu_output=use_cuda, ) @import_or_fail(["warp", "cupy", "cuml"]) def test_domino_datapipe_invalid_phase(pytestconfig): """Test that invalid phase values raise appropriate errors.""" - from physicsnemo.datapipes.cae.domino_datapipe import DoMINODataConfig with pytest.raises(ValueError, match="phase should be one of"): DoMINODataConfig(data_path=tempfile.mkdtemp(), phase="invalid_phase") @@ -646,7 +904,6 @@ def test_domino_datapipe_invalid_phase(pytestconfig): @import_or_fail(["warp", "cupy", "cuml"]) def test_domino_datapipe_invalid_scaling_type(pytestconfig): """Test that invalid scaling_type values raise appropriate errors.""" - from physicsnemo.datapipes.cae.domino_datapipe import DoMINODataConfig with pytest.raises(ValueError, match="scaling_type should be one of"): DoMINODataConfig( @@ -659,12 +916,15 @@ def test_domino_datapipe_file_format_support(zarr_dataset, pytestconfig): """Test support for different file formats (.zarr, .npz, .npy).""" # This test assumes the data directory has files in these formats # If not available, we can mock the file reading - dataset = create_basic_dataset(zarr_dataset, "volume", gpu_preprocessing=False) + use_cuda = torch.cuda.is_available() + dataset = create_basic_dataset( + zarr_dataset, "volume", gpu_preprocessing=use_cuda, gpu_output=use_cuda + ) # Just verify we can load at least one sample assert len(dataset) > 0 sample = dataset[0] - validate_sample_structure(sample, "volume", gpu_output=True) + validate_sample_structure(sample, "volume", gpu_output=use_cuda) # Surface-specific tests (when GPU preprocessing issues are resolved) @@ -674,20 +934,17 @@ def test_domino_datapipe_surface_sampling( zarr_dataset, surface_sampling_algorithm, pytestconfig ): """Test surface sampling algorithms.""" + + gpu = torch.cuda.is_available() + dataset = create_basic_dataset( zarr_dataset, "surface", - gpu_preprocessing=False, # Avoid known GPU issues + gpu_preprocessing=gpu, + gpu_output=gpu, sampling=True, surface_sampling_algorithm=surface_sampling_algorithm, ) sample = dataset[0] validate_sample_structure(sample, "surface", gpu_output=True) - - -if __name__ == "__main__": - out_dir = synthetic_domino_data( - out_format="zarr", - ) - print(out_dir) diff --git a/test/distributed/shard_tensor/ops/test_radius_search.py b/test/distributed/shard_tensor/ops/test_radius_search.py index 0ebaf05536..7c18cd0190 100644 --- a/test/distributed/shard_tensor/ops/test_radius_search.py +++ b/test/distributed/shard_tensor/ops/test_radius_search.py @@ -31,7 +31,6 @@ import torch from physicsnemo.distributed import DistributedManager -from physicsnemo.models.domino.model import BQWarp from physicsnemo.utils.version_check import check_module_requirements try: @@ -138,6 +137,8 @@ def run_radius_search_module(model, data_dict, reverse_mapping): def test_sharded_radius_search_layer_forward( distributed_mesh, shard_points, shard_grid, reverse_mapping ): + from physicsnemo.models.layers.ball_query import BQWarp + dm = DistributedManager() device = dm.device diff --git a/test/models/data/domino_output-conv.pth b/test/models/data/domino_output-conv.pth new file mode 100644 index 0000000000..0a3b7102a4 Binary files /dev/null and b/test/models/data/domino_output-conv.pth differ diff --git a/test/models/data/domino_output-unet.pth b/test/models/data/domino_output-unet.pth new file mode 100644 index 0000000000..9ba6b36de7 Binary files /dev/null and b/test/models/data/domino_output-unet.pth differ diff --git a/test/models/data/domino_output.pth b/test/models/data/domino_output.pth deleted file mode 100644 index 432b105c9d..0000000000 Binary files a/test/models/data/domino_output.pth and /dev/null differ diff --git a/test/models/data/mlp_output.pth b/test/models/data/mlp_output.pth new file mode 100644 index 0000000000..cc2f0ea9de Binary files /dev/null and b/test/models/data/mlp_output.pth differ diff --git a/test/models/domino/__init__.py b/test/models/domino/__init__.py new file mode 100644 index 0000000000..b2f171d4ac --- /dev/null +++ b/test/models/domino/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/test/models/domino/conftest.py b/test/models/domino/conftest.py new file mode 100644 index 0000000000..9d8c8a71d5 --- /dev/null +++ b/test/models/domino/conftest.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Sequence + +import pytest + + +@pytest.fixture(scope="module") +def base_model_params(): + """Base model parameters for testing""" + + @dataclass + class model_params: + @dataclass + class geometry_rep: + @dataclass + class geo_conv: + base_neurons: int = 32 + base_neurons_in: int = 8 + base_neurons_out: int = 8 + surface_hops: int = 1 + volume_hops: int = 1 + volume_radii: Sequence = (0.1, 0.5) + volume_neighbors_in_radius: Sequence = (10, 10) + surface_radii: Sequence = (0.05,) + surface_neighbors_in_radius: Sequence = (10,) + activation: str = "relu" + fourier_features: bool = False + num_modes: int = 5 + + @dataclass + class geo_processor: + base_filters: int = 8 + activation: str = "relu" + processor_type: str = "unet" + self_attention: bool = True + cross_attention: bool = False + + base_filters: int = 8 + geo_conv = geo_conv + geo_processor = geo_processor + + @dataclass + class geometry_local: + base_layer: int = 512 + volume_neighbors_in_radius: Sequence = (128, 128) + surface_neighbors_in_radius: Sequence = (128,) + volume_radii: Sequence = (0.05, 0.1) + surface_radii: Sequence = (0.05,) + + @dataclass + class nn_basis_functions: + base_layer: int = 512 + fourier_features: bool = False + num_modes: int = 5 + activation: str = "relu" + + @dataclass + class local_point_conv: + activation: str = "relu" + + @dataclass + class aggregation_model: + base_layer: int = 512 + activation: str = "relu" + + @dataclass + class position_encoder: + base_neurons: int = 512 + activation: str = "relu" + fourier_features: bool = False + num_modes: int = 5 + + @dataclass + class parameter_model: + base_layer: int = 512 + fourier_features: bool = True + num_modes: int = 5 + activation: str = "relu" + + model_type: str = "combined" + activation: str = "relu" + interp_res: Sequence = (64, 64, 64) # Smaller for testing + use_sdf_in_basis_func: bool = True + positional_encoding: bool = False + surface_neighbors: bool = True + num_neighbors_surface: int = 7 + num_neighbors_volume: int = 7 + use_surface_normals: bool = True + use_surface_area: bool = True + encode_parameters: bool = False + combine_volume_surface: bool = False + geometry_encoding_type: str = "both" + solution_calculation_mode: str = "two-loop" + geometry_rep = geometry_rep + nn_basis_functions = nn_basis_functions + aggregation_model = aggregation_model + position_encoder = position_encoder + geometry_local = geometry_local + + return model_params diff --git a/test/models/test_domino.py b/test/models/domino/test_domino.py similarity index 59% rename from test/models/test_domino.py rename to test/models/domino/test_domino.py index 87110491d0..7e0643b92f 100644 --- a/test/models/test_domino.py +++ b/test/models/domino/test_domino.py @@ -22,9 +22,8 @@ import torch from pytest_utils import import_or_fail -# from . import common -from .common.fwdaccuracy import save_output -from .common.utils import compare_output +from ..common.fwdaccuracy import save_output +from ..common.utils import compare_output def validate_domino( @@ -44,7 +43,7 @@ def validate_domino( if file_name is None: file_name = model.meta.name + "_output.pth" file_name = ( - Path(__file__).parents[0].resolve() / Path("data") / Path(file_name.lower()) + Path(__file__).parents[1].resolve() / Path("data") / Path(file_name.lower()) ) # If file does not exist, we will create it then error # Model should then reproduce it on next pytest run @@ -60,110 +59,118 @@ def validate_domino( return compare_output(output, output_target, rtol, atol) -@import_or_fail("warp") -@pytest.mark.parametrize("device", ["cuda:0"]) -def test_domino_forward(device, pytestconfig): - """Test domino forward pass""" +@dataclass +class model_params: + @dataclass + class geometry_rep: + @dataclass + class geo_conv: + base_neurons: int = 32 + base_neurons_in: int = 1 + base_neurons_out: int = 1 + surface_hops: int = 1 + volume_hops: int = 1 + volume_radii: Sequence = (0.1, 0.5, 1.0, 2.5) + volume_neighbors_in_radius: Sequence = (32, 64, 128, 256) + surface_radii: Sequence = (0.01, 0.05, 1.0) + surface_neighbors_in_radius: Sequence = (8, 16, 128) + activation: str = "gelu" + fourier_features: bool = False + num_modes: int = 5 - from physicsnemo.models.domino.model import DoMINO + @dataclass + class geo_processor: + base_filters: int = 8 + activation: str = "gelu" + processor_type: str = "unet" + self_attention: bool = False + cross_attention: bool = False + volume_sdf_scaling_factor: Sequence = (0.04,) + surface_sdf_scaling_factor: Sequence = (0.01, 0.02, 0.04) - torch.manual_seed(0) + base_filters: int = 8 + geo_conv = geo_conv + geo_processor = geo_processor @dataclass - class model_params: - @dataclass - class geometry_rep: - @dataclass - class geo_conv: - base_neurons: int = 32 - base_neurons_in: int = 8 - base_neurons_out: int = 8 - surface_hops: int = 1 - volume_hops: int = 1 - volume_radii: Sequence = (0.1, 0.5) - volume_neighbors_in_radius: Sequence = (10, 10) - surface_radii: Sequence = (0.05,) - surface_neighbors_in_radius: Sequence = (10,) - activation: str = "relu" - fourier_features: bool = False - num_modes: int = 5 - - @dataclass - class geo_processor: - base_filters: int = 8 - activation: str = "relu" - processor_type: str = "unet" - self_attention: bool = True - cross_attention: bool = False + class geometry_local: + base_layer: int = 512 + volume_neighbors_in_radius: Sequence = (64, 128) + surface_neighbors_in_radius: Sequence = (32, 128) + volume_radii: Sequence = (0.1, 0.25) + surface_radii: Sequence = (0.05, 0.25) - base_filters: int = 8 - geo_conv = geo_conv - geo_processor = geo_processor + @dataclass + class nn_basis_functions: + base_layer: int = 512 + fourier_features: bool = True + num_modes: int = 5 + activation: str = "gelu" - @dataclass - class geometry_local: - base_layer: int = 512 - volume_neighbors_in_radius: Sequence = (128, 128) - surface_neighbors_in_radius: Sequence = (128,) - volume_radii: Sequence = (0.05, 0.1) - surface_radii: Sequence = (0.05,) + @dataclass + class local_point_conv: + activation: str = "gelu" - @dataclass - class nn_basis_functions: - base_layer: int = 512 - fourier_features: bool = False - num_modes: int = 5 - activation: str = "relu" + @dataclass + class aggregation_model: + base_layer: int = 512 + activation: str = "gelu" - @dataclass - class local_point_conv: - activation: str = "relu" + @dataclass + class position_encoder: + base_neurons: int = 512 + activation: str = "gelu" + fourier_features: bool = True + num_modes: int = 5 - @dataclass - class aggregation_model: - base_layer: int = 512 - activation: str = "relu" + @dataclass + class parameter_model: + base_layer: int = 512 + fourier_features: bool = False + num_modes: int = 5 + activation: str = "gelu" + + model_type: str = "combined" + activation: str = "gelu" + interp_res: Sequence = (128, 64, 64) + use_sdf_in_basis_func: bool = True + positional_encoding: bool = False + surface_neighbors: bool = True + num_neighbors_surface: int = 7 + num_neighbors_volume: int = 10 + use_surface_normals: bool = True + use_surface_area: bool = True + encode_parameters: bool = False + combine_volume_surface: bool = False + geometry_encoding_type: str = "both" + solution_calculation_mode: str = "two-loop" + geometry_rep = geometry_rep + nn_basis_functions = nn_basis_functions + aggregation_model = aggregation_model + position_encoder = position_encoder + geometry_local = geometry_local - @dataclass - class position_encoder: - base_neurons: int = 512 - activation: str = "relu" - fourier_features: bool = False - num_modes: int = 5 - @dataclass - class parameter_model: - base_layer: int = 512 - fourier_features: bool = True - num_modes: int = 5 - activation: str = "relu" - - model_type: str = "combined" - activation: str = "relu" - interp_res: Sequence = (128, 128, 128) - use_sdf_in_basis_func: bool = True - positional_encoding: bool = False - surface_neighbors: bool = True - num_neighbors_surface: int = 7 - num_neighbors_volume: int = 7 - use_surface_normals: bool = True - use_surface_area: bool = True - encode_parameters: bool = False - combine_volume_surface: bool = False - geometry_encoding_type: str = "both" - solution_calculation_mode: str = "two-loop" - geometry_rep = geometry_rep - nn_basis_functions = nn_basis_functions - aggregation_model = aggregation_model - position_encoder = position_encoder - geometry_local = geometry_local +@import_or_fail("warp") +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("processor_type", ["unet", "conv"]) +def test_domino_forward(device, processor_type, pytestconfig): + """Test domino forward pass""" + + from physicsnemo.models.domino.model import DoMINO + + torch.manual_seed(0) + + params = model_params() + + params.geometry_rep.geo_processor.processor_type = processor_type model = DoMINO( input_features=3, output_features_vol=4, output_features_surf=5, global_features=2, - model_parameters=model_params, + model_parameters=params, ).to(device) bsize = 1 @@ -214,5 +221,8 @@ class parameter_model: } assert validate_domino( - model, input_dict, file_name="domino_output.pth", device=device + model, + input_dict, + file_name=f"domino_output-{processor_type}.pth", + device=device, ) diff --git a/test/models/domino/test_domino_encodings.py b/test/models/domino/test_domino_encodings.py new file mode 100644 index 0000000000..6570e0a686 --- /dev/null +++ b/test/models/domino/test_domino_encodings.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import pytest +import torch + +from .utils import validate_output_shape_and_values + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("fourier_features", [True, False]) +@pytest.mark.parametrize("num_modes", [3, 5, 10]) +def test_fourier_mlp(device, fourier_features, num_modes): + """Test FourierMLP with various configurations""" + from physicsnemo.models.layers import FourierMLP + + torch.manual_seed(0) + + model = FourierMLP( + input_features=3, + base_layer=64, + fourier_features=fourier_features, + num_modes=num_modes, + activation="relu", + ).to(device) + + x = torch.randn(2, 100, 3).to(device) + output = model(x) + + validate_output_shape_and_values(output, (2, 100, 64)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_fourier_encode_vectorized(device): + """Test fourier encoding function""" + from physicsnemo.models.layers import fourier_encode + + torch.manual_seed(0) + + coords = torch.randn(4, 20, 3).to(device) + freqs = torch.exp(torch.linspace(0, math.pi, 5)).to(device) + + output = fourier_encode(coords, freqs) + + # Output should be [batch, points, D * 2 * F] = [4, 20, 3 * 2 * 5] = [4, 20, 30] + validate_output_shape_and_values(output, (4, 20, 30)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_local_geometry_encoding(device): + """Test LocalGeometryEncoding""" + from physicsnemo.models.domino.encodings import LocalGeometryEncoding + from physicsnemo.models.domino.model import get_activation + + BATCH_SIZE = 1 + + torch.manual_seed(0) + + N_ENCODING_CHANNELS = 3 + N_NEIGHBORS = 32 + N_MESH_POINTS = 50 + GRID_RESOLUTION = (32, 32, 32) + + model = LocalGeometryEncoding( + radius=0.1, + neighbors_in_radius=N_NEIGHBORS, + total_neighbors_in_radius=N_ENCODING_CHANNELS * N_NEIGHBORS, + base_layer=128, + activation=get_activation("relu"), + grid_resolution=GRID_RESOLUTION, + ).to(device) + + encoding_g = torch.randn(BATCH_SIZE, N_ENCODING_CHANNELS, *GRID_RESOLUTION).to( + device + ) + volume_mesh_centers = torch.randn(BATCH_SIZE, N_MESH_POINTS, 3).to(device) + p_grid = torch.randn(BATCH_SIZE, *GRID_RESOLUTION, 3).to(device) + + output = model(encoding_g, volume_mesh_centers, p_grid) + + validate_output_shape_and_values(output, (BATCH_SIZE, N_MESH_POINTS, 32)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("geo_encoding_type", ["both", "stl", "sdf"]) +def test_multi_geometry_encoding(device, geo_encoding_type): + """Test MultiGeometryEncoding with different encoding types""" + from physicsnemo.models.domino.encodings import MultiGeometryEncoding + from physicsnemo.models.domino.model import get_activation + + torch.manual_seed(0) + + BATCH_SIZE = 1 + N_MESH_POINTS = 50 + GRID_RESOLUTION = (32, 32, 32) + + radii = [0.05, 0.1] + neighbors_in_radius = [16, 32] + + model = MultiGeometryEncoding( + radii=radii, + neighbors_in_radius=neighbors_in_radius, + geo_encoding_type=geo_encoding_type, + base_layer=64, + n_upstream_radii=2, + activation=get_activation("relu"), + grid_resolution=GRID_RESOLUTION, + ).to(device) + + if geo_encoding_type == "both": + num_channels = len(radii) + 1 + elif geo_encoding_type == "stl": + num_channels = len(radii) + else: # sdf + num_channels = 1 + + encoding_g = torch.randn(BATCH_SIZE, num_channels, *GRID_RESOLUTION).to(device) + volume_mesh_centers = torch.randn(BATCH_SIZE, N_MESH_POINTS, 3).to(device) + p_grid = torch.randn(BATCH_SIZE, *GRID_RESOLUTION, 3).to(device) + + output = model(encoding_g, volume_mesh_centers, p_grid) + + expected_output_dim = sum(neighbors_in_radius) + + validate_output_shape_and_values( + output, (BATCH_SIZE, N_MESH_POINTS, expected_output_dim) + ) diff --git a/test/models/domino/test_domino_geometry_rep.py b/test/models/domino/test_domino_geometry_rep.py new file mode 100644 index 0000000000..940d64a9df --- /dev/null +++ b/test/models/domino/test_domino_geometry_rep.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import pytest +import torch + +from .utils import validate_output_shape_and_values + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("act", ["relu", "gelu"]) +@pytest.mark.parametrize("fourier_features", [True, False]) +def test_geo_conv_out(device, act, fourier_features): + """Test GeoConvOut layer""" + from physicsnemo.models.domino.geometry_rep import GeoConvOut + + torch.manual_seed(0) + + @dataclass + class TestParams: + base_neurons: int = 32 + base_neurons_in: int = 8 + fourier_features: bool = False + neighbors_in_radius: int = 8 + num_modes: int = 5 + activation: str = act + + params = TestParams() + params.fourier_features = fourier_features + + input_features = 3 + + grid_resolution = [32, 32, 32] + + layer = GeoConvOut( + input_features=input_features, + neighbors_in_radius=params.neighbors_in_radius, + model_parameters=params, + grid_resolution=grid_resolution, + ).to(device) + + x = torch.randn(1, np.prod(grid_resolution), params.neighbors_in_radius, 3).to( + device + ) + grid = torch.randn(1, *grid_resolution, 3).to(device) + + output = layer(x, grid) + + validate_output_shape_and_values( + output, (1, params.base_neurons_in, *grid_resolution) + ) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("act", ["relu", "gelu"]) +def test_geo_processor(device, act): + """Test GeoProcessor CNN""" + from physicsnemo.models.domino.geometry_rep import GeoProcessor + + torch.manual_seed(0) + + @dataclass + class TestParams: + base_filters: int = 8 + activation: str = act + + params = TestParams() + + processor = GeoProcessor( + input_filters=4, output_filters=2, model_parameters=params + ).to(device) + + x = torch.randn(2, 4, 16, 16, 16).to(device) + output = processor(x) + + validate_output_shape_and_values(output, (2, 2, 16, 16, 16)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("geometry_encoding_type", ["both", "stl", "sdf"]) +@pytest.mark.parametrize("processor_type", ["unet", "conv"]) +def test_geometry_rep( + device, geometry_encoding_type, processor_type, base_model_params +): + """Test GeometryRep module with different configurations""" + from physicsnemo.models.domino.geometry_rep import GeometryRep + + torch.manual_seed(0) + + # Modify params for this test + params = base_model_params() + params.geometry_encoding_type = geometry_encoding_type + params.geometry_rep.geo_processor.processor_type = processor_type + params.geometry_rep.geo_processor.self_attention = False + params.geometry_rep.geo_processor.cross_attention = False + params.interp_res = (16, 16, 16) # Smaller for faster testing + + radii = [0.1, 0.2] + neighbors_in_radius = [8, 16] + + geo_rep = GeometryRep( + input_features=3, + radii=radii, + neighbors_in_radius=neighbors_in_radius, + hops=1, + model_parameters=params, + ).to(device) + + # Test inputs + x = torch.randn(1, 20, 3).to(device) + p_grid = torch.randn(1, 16, 16, 16, 3).to(device) + sdf = torch.randn(1, 16, 16, 16).to(device) + + output = geo_rep(x, p_grid, sdf) + + # Determine expected output channels + if geometry_encoding_type == "both": + expected_channels = len(radii) + 1 # STL channels + SDF channel + elif geometry_encoding_type == "stl": + expected_channels = len(radii) + else: # sdf + expected_channels = 1 + + validate_output_shape_and_values(output, (1, expected_channels, 16, 16, 16)) diff --git a/test/models/domino/test_domino_mlps.py b/test/models/domino/test_domino_mlps.py new file mode 100644 index 0000000000..8cf9546c2b --- /dev/null +++ b/test/models/domino/test_domino_mlps.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from .utils import validate_output_shape_and_values + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("activation", ["relu", "gelu"]) +def test_aggregation_model(device, activation): + """Test AggregationModel""" + from physicsnemo.models.domino.mlps import AggregationModel + from physicsnemo.models.domino.model import get_activation + + torch.manual_seed(0) + + model = AggregationModel( + input_features=100, + output_features=1, + base_layer=64, + activation=get_activation(activation), + ).to(device) + + x = torch.randn(2, 30, 100).to(device) + output = model(x) + + validate_output_shape_and_values(output, (2, 30, 1)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("activation", ["relu", "gelu"]) +def test_local_point_conv(device, activation): + """Test LocalPointConv""" + from physicsnemo.models.domino.mlps import LocalPointConv + from physicsnemo.models.domino.model import get_activation + + torch.manual_seed(0) + + model = LocalPointConv( + input_features=50, + base_layer=128, + output_features=32, + activation=get_activation(activation), + ).to(device) + + x = torch.randn(2, 100, 50).to(device) + output = model(x) + + validate_output_shape_and_values(output, (2, 100, 32)) diff --git a/test/models/domino/test_domino_solutions.py b/test/models/domino/test_domino_solutions.py new file mode 100644 index 0000000000..36ddd0d3db --- /dev/null +++ b/test/models/domino/test_domino_solutions.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import torch.nn as nn + +from .utils import validate_output_shape_and_values + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("num_variables", [1, 3, 5]) +@pytest.mark.parametrize("num_sample_points", [1, 3, 7]) +@pytest.mark.parametrize("encode_parameters", [True, False]) +def test_solution_calculator_volume( + device, num_variables, num_sample_points, encode_parameters +): + """Test SolutionCalculatorVolume with various configurations""" + from physicsnemo.models.domino.mlps import AggregationModel + from physicsnemo.models.domino.solutions import SolutionCalculatorVolume + from physicsnemo.models.layers import FourierMLP, get_activation + + torch.manual_seed(0) + + activation = get_activation("relu") + + # Create parameter model if needed + parameter_model = ( + FourierMLP( + input_features=2, + base_layer=32, + fourier_features=True, + num_modes=3, + activation=activation, + ).to(device) + if encode_parameters + else None + ) + + # Create aggregation models + aggregation_model = nn.ModuleList( + [ + AggregationModel( + input_features=64 + 32 + 32 + (32 if encode_parameters else 0), + output_features=1, + base_layer=64, + activation=activation, + ).to(device) + for _ in range(num_variables) + ] + ) + + # Create basis functions + nn_basis = nn.ModuleList( + [ + FourierMLP( + input_features=3, + base_layer=32, + fourier_features=False, + num_modes=5, + activation=activation, + ).to(device) + for _ in range(num_variables) + ] + ) + + model = SolutionCalculatorVolume( + num_variables=num_variables, + num_sample_points=num_sample_points, + noise_intensity=50.0, + encode_parameters=encode_parameters, + return_volume_neighbors=False, + parameter_model=parameter_model, + aggregation_model=aggregation_model, + nn_basis=nn_basis, + ).to(device) + + # Test data + volume_mesh_centers = torch.randn(2, 30, 3).to(device) + encoding_g = torch.randn(2, 30, 32).to(device) + encoding_node = torch.randn(2, 30, 64).to(device) + global_params_values = torch.randn(2, 2, 1).to(device) + global_params_reference = torch.randn(2, 2, 1).to(device) + + output = model( + volume_mesh_centers, + encoding_g, + encoding_node, + global_params_values, + global_params_reference, + ) + + validate_output_shape_and_values(output, (2, 30, num_variables)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("num_variables", [1, 3, 5]) +@pytest.mark.parametrize("use_surface_normals", [True, False]) +@pytest.mark.parametrize("use_surface_area", [True, False]) +def test_solution_calculator_surface( + device, num_variables, use_surface_normals, use_surface_area +): + """Test SolutionCalculatorSurface with various configurations""" + from physicsnemo.models.domino.mlps import AggregationModel + from physicsnemo.models.domino.solutions import SolutionCalculatorSurface + from physicsnemo.models.layers import FourierMLP, get_activation + + torch.manual_seed(0) + + activation = get_activation("relu") + + # Determine input features based on surface configuration + input_features = 3 + if use_surface_normals: + input_features += 3 + if use_surface_area: + input_features += 1 + + # Create aggregation models + aggregation_model = nn.ModuleList( + [ + AggregationModel( + input_features=64 + 32 + 32, + output_features=1, + base_layer=64, + activation=activation, + ).to(device) + for _ in range(num_variables) + ] + ) + + # Create basis functions + nn_basis = nn.ModuleList( + [ + FourierMLP( + input_features=input_features, + base_layer=32, + fourier_features=False, + num_modes=5, + activation=activation, + ).to(device) + for _ in range(num_variables) + ] + ) + + model = SolutionCalculatorSurface( + num_variables=num_variables, + num_sample_points=3, + encode_parameters=False, + use_surface_normals=use_surface_normals, + use_surface_area=use_surface_area, + parameter_model=None, + aggregation_model=aggregation_model, + nn_basis=nn_basis, + ).to(device) + + # Test data + surface_mesh_centers = torch.randn(2, 30, 3).to(device) + encoding_g = torch.randn(2, 30, 32).to(device) + encoding_node = torch.randn(2, 30, 64).to(device) + surface_mesh_neighbors = torch.randn(2, 30, 2, 3).to(device) + surface_normals = torch.randn(2, 30, 3).to(device) + surface_neighbors_normals = torch.randn(2, 30, 2, 3).to(device) + surface_areas = torch.rand(2, 30, 1).to(device) + 1e-6 + surface_neighbors_areas = torch.rand(2, 30, 2, 1).to(device) + 1e-6 + global_params_values = torch.randn(2, 2, 1).to(device) + global_params_reference = torch.randn(2, 2, 1).to(device) + + output = model( + surface_mesh_centers, + encoding_g, + encoding_node, + surface_mesh_neighbors, + surface_normals, + surface_neighbors_normals, + surface_areas, + surface_neighbors_areas, + global_params_values, + global_params_reference, + ) + + validate_output_shape_and_values(output, (2, 30, num_variables)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +@pytest.mark.parametrize("r", [0.5, 1.0, 2.0]) +@pytest.mark.parametrize("num_points", [10, 50, 100]) +def test_sample_sphere(device, r, num_points): + """Test sphere sampling function""" + from physicsnemo.models.domino.solutions import sample_sphere + + torch.manual_seed(0) + + center = torch.randn(2, 30, 3).to(device) + output = sample_sphere(center, r, num_points) + + validate_output_shape_and_values(output, (2, 30, num_points, 3)) + + # Check that points are within the sphere radius + distances = torch.norm(output - center.unsqueeze(2), dim=-1) + assert (distances <= r + 1e-6).all(), "Some sampled points are outside the sphere" + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_sample_sphere_shell(device): + """Test spherical shell sampling function""" + from physicsnemo.models.domino.solutions import sample_sphere_shell + + torch.manual_seed(0) + + center = torch.randn(2, 30, 3).to(device) + r_inner, r_outer = 0.5, 1.5 + num_points = 50 + + output = sample_sphere_shell(center, r_inner, r_outer, num_points) + + validate_output_shape_and_values(output, (2, 30, num_points, 3)) + + # Check that points are within the shell + distances = torch.norm(output - center.unsqueeze(2), dim=-1) + assert (distances >= r_inner - 1e-6).all(), ( + "Some sampled points are inside inner radius" + ) + assert (distances <= r_outer + 1e-6).all(), ( + "Some sampled points are outside outer radius" + ) diff --git a/test/models/domino/utils.py b/test/models/domino/utils.py new file mode 100644 index 0000000000..6f16a2c9fb --- /dev/null +++ b/test/models/domino/utils.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def generate_test_data(bsize, nx, ny, nz, num_neigh, device): + """Generate test data for DoMINO""" + return { + "pos_volume_closest": torch.randn(bsize, 50, 3).to(device), + "pos_volume_center_of_mass": torch.randn(bsize, 50, 3).to(device), + "pos_surface_center_of_mass": torch.randn(bsize, 50, 3).to(device), + "geometry_coordinates": torch.randn(bsize, 50, 3).to(device), + "grid": torch.randn(bsize, nx, ny, nz, 3).to(device), + "surf_grid": torch.randn(bsize, nx, ny, nz, 3).to(device), + "sdf_grid": torch.randn(bsize, nx, ny, nz).to(device), + "sdf_surf_grid": torch.randn(bsize, nx, ny, nz).to(device), + "sdf_nodes": torch.randn(bsize, 50, 1).to(device), + "surface_mesh_centers": torch.randn(bsize, 50, 3).to(device), + "surface_mesh_neighbors": torch.randn(bsize, 50, num_neigh, 3).to(device), + "surface_normals": torch.randn(bsize, 50, 3).to(device), + "surface_neighbors_normals": torch.randn(bsize, 50, num_neigh, 3).to(device), + "surface_areas": torch.rand(bsize, 50).to(device) + 1e-6, + "surface_neighbors_areas": torch.rand(bsize, 50, num_neigh).to(device) + 1e-6, + "volume_mesh_centers": torch.randn(bsize, 50, 3).to(device), + "volume_min_max": torch.randn(bsize, 2, 3).to(device), + "surface_min_max": torch.randn(bsize, 2, 3).to(device), + "global_params_values": torch.randn(bsize, 2, 1).to(device), + "global_params_reference": torch.randn(bsize, 2, 1).to(device), + } + + +def validate_output_shape_and_values(output, expected_shape, check_finite=True): + """Validate output tensor shape and values""" + if output is not None: + assert output.shape == expected_shape, ( + f"Expected shape {expected_shape}, got {output.shape}" + ) + if check_finite: + assert torch.isfinite(output).all(), "Output contains non-finite values" + assert not torch.isnan(output).any(), "Output contains NaN values" diff --git a/test/models/test_mlp_layers.py b/test/models/test_mlp_layers.py new file mode 100644 index 0000000000..7a943cc51b --- /dev/null +++ b/test/models/test_mlp_layers.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from physicsnemo.models.layers import Mlp + +from .common import ( + validate_forward_accuracy, +) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_mlp_forward_accuracy(device): + torch.manual_seed(7) + target_device = torch.device(device) + + model = Mlp(in_features=10, hidden_features=20, out_features=5).to(target_device) + input_tensor = torch.randn(1, 10).to( + target_device + ) # Assuming a batch size of 1 for simplicity + model(input_tensor) + + file_name = "mlp_output.pth" + + # Tack this on for the test, since model is not a physicsnemo Module: + model.device = target_device + + assert validate_forward_accuracy( + model, + (input_tensor,), + file_name=file_name, + atol=1e-3, + ) + + +def test_mlp_activation_and_dropout(): + model = Mlp(in_features=10, hidden_features=20, out_features=5, drop=0.5) + input_tensor = torch.randn(2, 10) # Batch size of 2 + + output_tensor = model(input_tensor) + + assert output_tensor.shape == torch.Size([2, 5]) + + +def test_mlp_different_activation(): + model = Mlp( + in_features=10, hidden_features=20, out_features=7, act_layer=torch.nn.ReLU + ) + input_tensor = torch.randn(3, 10) # Batch size of 3 + + output_tensor = model(input_tensor) + assert output_tensor.shape == torch.Size([3, 7]) + + +def test_multiple_hidden_layers(): + model = Mlp(in_features=10, hidden_features=[20, 30], out_features=5) + input_tensor = torch.randn(4, 10) # Batch size of 4 + + output_tensor = model(input_tensor) + assert output_tensor.shape == torch.Size([4, 5]) diff --git a/test/utils/test_domino_utils.py b/test/utils/test_domino_utils.py index 8a0e03637b..fc10b93688 100644 --- a/test/utils/test_domino_utils.py +++ b/test/utils/test_domino_utils.py @@ -21,7 +21,10 @@ module to ensure that the documented examples work correctly. """ -import numpy as np +import math + +import pytest +import torch from physicsnemo.utils.domino.utils import ( area_weighted_shuffle_array, @@ -45,67 +48,70 @@ def test_calculate_center_of_mass(): """Test calculate_center_of_mass function with docstring example.""" - centers = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) - sizes = np.array([1.0, 2.0, 3.0]) + centers = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) + sizes = torch.tensor([1.0, 2.0, 3.0]) com = calculate_center_of_mass(centers, sizes) - expected = np.array([[4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0]]) - assert np.allclose(com, expected) + expected = torch.tensor([[4.0 / 3.0, 4.0 / 3.0, 4.0 / 3.0]]) + assert torch.allclose(com, expected) def test_normalize(): """Test normalize function with docstring examples.""" # Example 1: With explicit min/max - field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) - normalized = normalize(field, 5.0, 1.0) - expected = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) - assert np.allclose(normalized, expected) + field = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + normalized = normalize(field, max_val=5.0, min_val=1.0) + expected = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0]) + assert torch.allclose(normalized, expected) # Example 2: Auto-compute min/max normalized_auto = normalize(field) - expected_auto = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) - assert np.allclose(normalized_auto, expected_auto) + expected_auto = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0]) + assert torch.allclose(normalized_auto, expected_auto) def test_unnormalize(): """Test unnormalize function with docstring example.""" - normalized = np.array([-1.0, -0.5, 0.0, 0.5, 1.0]) + normalized = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0]) original = unnormalize(normalized, 5.0, 1.0) - expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) - assert np.allclose(original, expected) + expected = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + assert torch.allclose(original, expected) def test_standardize(): """Test standardize function with docstring examples.""" # Example 1: With explicit mean/std - field = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) - standardized = standardize(field, 3.0, np.sqrt(2.5)) - expected = np.array([-1.265, -0.632, 0.0, 0.632, 1.265]) - assert np.allclose(standardized, expected, atol=1e-3) + field = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + standardized = standardize(field, 3.0, math.sqrt(2.5)) + expected = torch.tensor([-1.265, -0.632, 0.0, 0.632, 1.265]) + assert torch.allclose(standardized, expected, atol=1e-3) # Example 2: Auto-compute mean/std standardized_auto = standardize(field) - assert np.allclose(np.mean(standardized_auto), 0.0) - assert np.allclose(np.std(standardized_auto, ddof=0), 1.0) + assert torch.allclose(torch.mean(standardized_auto), torch.tensor(0.0)) + assert torch.allclose(torch.std(standardized_auto, correction=1), torch.tensor(1.0)) def test_unstandardize(): """Test unstandardize function with docstring example.""" - standardized = np.array([-1.265, -0.632, 0.0, 0.632, 1.265]) - original = unstandardize(standardized, 3.0, np.sqrt(2.5)) - expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) - assert np.allclose(original, expected, atol=1e-3) + standardized = torch.tensor([-1.265, -0.632, 0.0, 0.632, 1.265]) + original = unstandardize(standardized, 3.0, math.sqrt(2.5)) + expected = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + assert torch.allclose(original, expected, atol=1e-3) -def test_calculate_normal_positional_encoding(): +@pytest.mark.parametrize("relative", [True, False]) +def test_calculate_normal_positional_encoding(relative): """Test calculate_normal_positional_encoding function with docstring examples.""" # Example 1: Basic coordinates - coords = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + coords = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) cell_size = [0.1, 0.1, 0.1] - encoding = calculate_normal_positional_encoding(coords, cell_dimensions=cell_size) - assert encoding.shape == (2, 12) # Example 2: Relative positioning - coords_b = np.array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) + if relative: + coords_b = torch.tensor([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) + else: + coords_b = None + encoding_rel = calculate_normal_positional_encoding(coords, coords_b, cell_size) assert encoding_rel.shape == (2, 12) @@ -113,59 +119,59 @@ def test_calculate_normal_positional_encoding(): def test_nd_interpolator(): """Test nd_interpolator function with docstring example.""" # Simple 2D interpolation example - coords = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) - field_vals = np.array([[1.0], [2.0], [3.0], [4.0]]) - grid_points = np.array([[0.5, 0.5]]) - result = nd_interpolator([coords], field_vals, grid_points) + coords = torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) + field_vals = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) + grid_points = torch.tensor([[0.5, 0.5]]) + result = nd_interpolator(coords, field_vals, grid_points) assert result.shape[0] == 1 # One grid point def test_pad(): """Test pad function with docstring examples.""" # Example 1: Padding needed - arr = np.array([[1.0, 2.0], [3.0, 4.0]]) + arr = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) padded = pad(arr, 4, -1.0) assert padded.shape == (4, 2) - assert np.array_equal(padded[:2], arr) - assert bool(np.all(padded[2:] == -1.0)) + assert torch.allclose(padded[:2], arr) + assert bool(torch.all(padded[2:] == -1.0)) # Example 2: No padding needed same = pad(arr, 2) - assert np.array_equal(same, arr) + assert torch.allclose(same, arr) def test_pad_inp(): """Test pad_inp function with docstring example.""" - arr = np.array([[[1.0, 2.0]], [[3.0, 4.0]]]) + arr = torch.tensor([[[1.0, 2.0]], [[3.0, 4.0]]]) padded = pad_inp(arr, 4, 0.0) assert padded.shape == (4, 1, 2) - assert np.array_equal(padded[:2], arr) - assert bool(np.all(padded[2:] == 0.0)) + assert torch.allclose(padded[:2], arr) + assert bool(torch.all(padded[2:] == 0.0)) def test_shuffle_array(): """Test shuffle_array function with docstring example.""" - np.random.seed(42) # For reproducible results - data = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + torch.manual_seed(42) # For reproducible results + data = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) subset, indices = shuffle_array(data, 2) assert subset.shape == (2, 2) assert indices.shape == (2,) - assert len(np.unique(indices)) == 2 # No duplicates + assert len(torch.unique(indices)) == 2 # No duplicates def test_shuffle_array_without_sampling(): """Test shuffle_array_without_sampling function with docstring example.""" - np.random.seed(42) # For reproducible results - data = np.array([[1], [2], [3], [4]]) + torch.manual_seed(42) # For reproducible results + data = torch.tensor([[1], [2], [3], [4]]) shuffled, indices = shuffle_array_without_sampling(data) assert shuffled.shape == (4, 1) assert indices.shape == (4,) - assert set(indices) == set(range(4)) # All original indices present + assert set(indices.tolist()) == set(range(4)) # All original indices present def test_calculate_pos_encoding(): """Test calculate_pos_encoding function with docstring example.""" - positions = np.array([0.0, 1.0, 2.0]) + positions = torch.tensor([0.0, 1.0, 2.0]) encodings = calculate_pos_encoding(positions, d=4) assert len(encodings) == 4 assert all(enc.shape == (3,) for enc in encodings) @@ -182,30 +188,30 @@ def test_combine_dict(): def test_create_grid(): """Test create_grid function with docstring example.""" - min_bounds = np.array([0.0, 0.0, 0.0]) - max_bounds = np.array([1.0, 1.0, 1.0]) - grid_res = np.array([2, 2, 2]) + min_bounds = torch.tensor([0.0, 0.0, 0.0]) + max_bounds = torch.tensor([1.0, 1.0, 1.0]) + grid_res = torch.tensor([2, 2, 2]) grid = create_grid(max_bounds, min_bounds, grid_res) assert grid.shape == (2, 2, 2, 3) - assert np.allclose(grid[0, 0, 0], [0.0, 0.0, 0.0]) - assert np.allclose(grid[1, 1, 1], [1.0, 1.0, 1.0]) + assert torch.allclose(grid[0, 0, 0], torch.tensor([0.0, 0.0, 0.0])) + assert torch.allclose(grid[1, 1, 1], torch.tensor([1.0, 1.0, 1.0])) def test_mean_std_sampling(): """Test mean_std_sampling function with docstring example.""" # Create test data with outliers - field = np.array([[1.0], [2.0], [3.0], [10.0]]) # 10.0 is outlier - field_mean = np.array([2.0]) - field_std = np.array([1.0]) + field = torch.tensor([[1.0], [2.0], [3.0], [10.0]]) # 10.0 is outlier + field_mean = torch.tensor([2.0]) + field_std = torch.tensor([1.0]) outliers = mean_std_sampling(field, field_mean, field_std, 2.0) assert 3 in outliers # Index 3 (value 10.0) should be detected as outlier def test_area_weighted_shuffle_array(): """Test area_weighted_shuffle_array function with docstring example.""" - np.random.seed(42) # For reproducible results - mesh_data = np.array([[1.0], [2.0], [3.0], [4.0]]) - cell_areas = np.array([0.1, 0.1, 0.1, 10.0]) # Last point has much larger area + torch.manual_seed(42) # For reproducible results + mesh_data = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) + cell_areas = torch.tensor([0.1, 0.1, 0.1, 10.0]) # Last point has much larger area subset, indices = area_weighted_shuffle_array(mesh_data, 2, cell_areas) assert subset.shape == (2, 1) assert indices.shape == (2,)