From 36ad0bd78aa9f872ef993738bc5ce6e07e12f75c Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Wed, 19 Nov 2025 17:53:45 -0800 Subject: [PATCH 1/7] implement EddyFormer --- examples/cfd/isotropic_eddyformer/README.md | 95 +++++++ examples/cfd/isotropic_eddyformer/config.yaml | 23 ++ .../isotropic_eddyformer/download_dataset.sh | 1 + .../cfd/isotropic_eddyformer/requirements.txt | 2 + .../train_ef_isotropic.py | 110 +++++++++ physicsnemo/models/eddyformer/__init__.py | 5 + physicsnemo/models/eddyformer/_basis.py | 112 +++++++++ physicsnemo/models/eddyformer/_datatype.py | 233 ++++++++++++++++++ physicsnemo/models/eddyformer/eddyformer.py | 181 ++++++++++++++ physicsnemo/models/eddyformer/sem_attn.py | 74 ++++++ physicsnemo/models/eddyformer/sem_conv.py | 150 +++++++++++ pyproject.toml | 1 + 12 files changed, 987 insertions(+) create mode 100644 examples/cfd/isotropic_eddyformer/README.md create mode 100644 examples/cfd/isotropic_eddyformer/config.yaml create mode 100644 examples/cfd/isotropic_eddyformer/download_dataset.sh create mode 100644 examples/cfd/isotropic_eddyformer/requirements.txt create mode 100644 examples/cfd/isotropic_eddyformer/train_ef_isotropic.py create mode 100644 physicsnemo/models/eddyformer/__init__.py create mode 100644 physicsnemo/models/eddyformer/_basis.py create mode 100644 physicsnemo/models/eddyformer/_datatype.py create mode 100644 physicsnemo/models/eddyformer/eddyformer.py create mode 100644 physicsnemo/models/eddyformer/sem_attn.py create mode 100644 physicsnemo/models/eddyformer/sem_conv.py diff --git a/examples/cfd/isotropic_eddyformer/README.md b/examples/cfd/isotropic_eddyformer/README.md new file mode 100644 index 0000000000..1e22dd9485 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/README.md @@ -0,0 +1,95 @@ +# EddyFormer for 3D Isotropic Turbulence + +This example demonstrates how to use the EddyFormer model for simulating +a three-dimensional isotropic turbulence. This example runs on a single GPU. + +## Problem Overview + +This example focuses on **three-dimensional homogeneous isotropic turbulence (HIT)** sustained by large-scale forcing. The flow is governed by the incompressible Navier–Stokes equations with an external forcing term: + +\[ +\frac{\partial \mathbf{u}}{\partial t} + \mathbf{u} \cdot \nabla \mathbf{u} += \nu \nabla^2 \mathbf{u} + \mathbf{f}(\mathbf{x}) +\] + +where: + +- **\(\mathbf{u}(\mathbf{x}, t)\)** — velocity field in a 3D periodic domain +- **\(\nu = 0.01\)** — kinematic viscosity +- **\(\mathbf{f}(\mathbf{x})\)** — isotropic forcing applied at the largest scales + +### Forcing Mechanism + +To maintain statistically steady turbulence, a **constant-power forcing** is applied to the lowest Fourier modes (\(|\mathbf{k}| \le 1\)). The forcing injects a prescribed amount of energy \(P_{\text{in}} = 1.0\) into the system: + +\[ +\mathbf{f}(\mathbf{x}) = +\frac{P_{\text{in}}}{E_1} +\sum_{\substack{|\mathbf{k}| \le 1 \\ \mathbf{k} \neq 0}} +\hat{\mathbf{u}}_{\mathbf{k}} e^{i \mathbf{k} \cdot \mathbf{x}} +\] + +where: + +\[ +E_1 = \frac{1}{2} +\sum_{|\mathbf{k}| \le 1} +\hat{\mathbf{u}}_{\mathbf{k}} \cdot \hat{\mathbf{u}}_{\mathbf{k}}^{*} +\] + +is the kinetic energy contained in the forced low-wavenumber modes. + +Under this forcing, the flow reaches a **statistically steady state** with a Taylor-scale Reynolds number of: + +**\(\mathrm{Re}_\lambda \approx 94\)** + +### Task Description + +The objective of this example is to **predict the future velocity field** of the turbulent flow. Given \(\mathbf{u}(\mathbf{x}, t)\), the task is: + +> **Predict the velocity field \(\mathbf{u}(\mathbf{x}, t + \Delta t)\) with \(\Delta t = 0.5\).** + +This requires modeling nonlinear, chaotic, multi-scale turbulent dynamics, including: + +- energy injection at large scales +- nonlinear transfer across the inertial range +- dissipation at the smallest scales + +### Dataset Summary + +- **DNS resolution:** \(384^3\) (used to generate the dataset) +- **Stored dataset resolution:** \(96^3\) +- **Kolmogorov scale resolution:** ~0.5 η +- **Forcing:** applied to modes with \(|\mathbf{k}| \le 1\) +- **Viscosity:** \(\nu = 0.01\) +- **Input power:** \(P_{\text{in}} = 1.0\) +- **Flow regime:** statistically steady HIT at \(\mathrm{Re}_\lambda \approx 94\) + +## Prerequisites + +Install the required dependencies by running below: + +```bash +pip install -r requirements.txt +``` + +## Download the Dataset + +The dataset is publicly available at [Huggingface](https://huggingface.co/datasets/ydu11/re94). +To download the dataset, run (you might need to install the Huggingface CLI): + +```bash +bash download_dataset.sh +``` + +## Getting Started + +To train the model, run + +```bash +python train_ef_isotropic.py +``` + +## References + +- [EddyFormer: EddyFormer: Accelerated Neural Simulations of Three-Dimensional Turbulence at Scale](https://arxiv.org/abs/2510.24173) diff --git a/examples/cfd/isotropic_eddyformer/config.yaml b/examples/cfd/isotropic_eddyformer/config.yaml new file mode 100644 index 0000000000..e7018f54d0 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/config.yaml @@ -0,0 +1,23 @@ +model: + idim: 3 + odim: 3 + hdim: 32 + num_layers: 4 + layer_config: + basis: legendre + mesh: [8, 8, 8] + mode: [10, 10, 10] + mode_les: [5, 5, 5] + kernel_size: [2, 2, 2] + kernel_size_les: [2, 2, 2] + ffn_dim: 128 + activation: GELU + num_heads: 4 + heads_dim: 32 + +training: + dataset: data/ns3d-re94 + t: 0.5 + batch_size: 4 + num_epochs: 100 + learning_rate: 1e-3 diff --git a/examples/cfd/isotropic_eddyformer/download_dataset.sh b/examples/cfd/isotropic_eddyformer/download_dataset.sh new file mode 100644 index 0000000000..7b50328c92 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/download_dataset.sh @@ -0,0 +1 @@ +hf download --repo-type dataset ydu11/re94 --local-dir ${1:-data/ns3d-re94} \ No newline at end of file diff --git a/examples/cfd/isotropic_eddyformer/requirements.txt b/examples/cfd/isotropic_eddyformer/requirements.txt new file mode 100644 index 0000000000..001dc23f09 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/requirements.txt @@ -0,0 +1,2 @@ +hydra-core>=1.2.0 +termcolor>=2.1.1 diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py new file mode 100644 index 0000000000..6546d20ca7 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -0,0 +1,110 @@ +import hydra +from typing import Tuple +from torch import Tensor +from omegaconf import DictConfig + +import os +import numpy as np + +import torch +from torch.nn import MSELoss +from torch.optim import Adam +from torch.utils.data import Dataset, DataLoader + +from physicsnemo.models.eddyformer import EddyFormer, EddyFormerConfig +from physicsnemo.distributed import DistributedManager +from physicsnemo.utils import StaticCaptureTraining +from physicsnemo.launch.logging import PythonLogger, LaunchLogger + + +class Re94(Dataset): + + root: str + t: float + + n: int = 50 + dt: float = 0.1 + + def __init__(self, root: str, split: str, *, t: float = 0.5) -> None: + """ + """ + super().__init__() + self.root = root + self.t = t + + self.file = [] + for fname in sorted(os.listdir(root)): + if fname.startswith(split): + self.file.append(fname) + + @property + def stride(self) -> int: + k = int(self.t / self.dt) + assert self.dt * k == self.t + return k + + @property + def samples_per_file(self) -> int: + return self.n - self.stride + 1 + + def __len__(self) -> int: + return len(self.file) * self.samples_per_file + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: + file_idx, time_idx = divmod(idx, self.samples_per_file) + + data = np.load(f"{self.root}/{self.file[file_idx]}", allow_pickle=True).item() + return torch.from_numpy(data["u"][time_idx]), torch.from_numpy(data["u"][time_idx + self.stride]) + +@hydra.main(version_base="1.3", config_path=".", config_name="config.yaml") +def isotropic_trainer(cfg: DictConfig) -> None: + """ + """ + DistributedManager.initialize() # Only call this once in the entire script! + dist = DistributedManager() # call if required elsewhere + + # initialize monitoring + log = PythonLogger(name="re94_ef") + log.file_logging() + LaunchLogger.initialize() # PhysicsNeMo launch logger + + # define model, loss, optimiser, scheduler, data loader + model = EddyFormer( + idim=cfg.model.idim, + odim=cfg.model.odim, + hdim=cfg.model.hdim, + num_layers=cfg.model.num_layers, + cfg=EddyFormerConfig(**cfg.model.layer_config), + ).to(dist.device) + loss_fun = MSELoss(reduction="mean") + optimizer = Adam(model.parameters(), lr=cfg.training.learning_rate) + dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) + + # define forward passes for training and inference + @StaticCaptureTraining( + model=model, optim=optimizer, logger=log, use_amp=False, use_graphs=False + ) + def training_step(input, target): + pred = torch.vmap(model)(input) + loss = loss_fun(pred, target) + return loss + + for epoch in range(cfg.training.num_epochs): + + dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) + + for input, target in dataloader: + + input = input.to(dist.device) + target = target.to(dist.device) + with torch.autograd.set_detect_anomaly(True): + loss = training_step(input, target) + + with LaunchLogger("train", epoch=epoch) as logger: + logger.log_minibatch({"Training loss": loss.item()}) + + log.success("Training completed") + + +if __name__ == "__main__": + isotropic_trainer() diff --git a/physicsnemo/models/eddyformer/__init__.py b/physicsnemo/models/eddyformer/__init__.py new file mode 100644 index 0000000000..db0569fda6 --- /dev/null +++ b/physicsnemo/models/eddyformer/__init__.py @@ -0,0 +1,5 @@ +from ._basis import Legendre +from ._datatype import SEM +from .eddyformer import EddyFormer, EddyFormerLayer + +EddyFormerConfig = EddyFormerLayer.Config diff --git a/physicsnemo/models/eddyformer/_basis.py b/physicsnemo/models/eddyformer/_basis.py new file mode 100644 index 0000000000..e3906ca529 --- /dev/null +++ b/physicsnemo/models/eddyformer/_basis.py @@ -0,0 +1,112 @@ +from typing import Protocol +from torch import Tensor + +import torch +import torch.nn as nn + +import numpy as np +import functools + +class Basis(Protocol): + + grid: Tensor + quad: Tensor + + m: int + f: Tensor + + def fn(self, xs: Tensor) -> Tensor: + """ + Evaluate basis functions at given points. + """ + + def at(self, coef: Tensor, xs: Tensor) -> Tensor: + """ + Evaluate basis expansion at given points. + """ + return torch.tensordot(self.fn(xs), coef, dims=1) + + def modal(self, vals: Tensor) -> Tensor: + """ + Convert nodal values to modal coefficients. + """ + + def nodal(self, coef: Tensor) -> Tensor: + """ + Convert modal coefficients to nodal values. + """ + +class Element(Basis): + + def __init__(self, base: Basis): + """ + """ + +# ---------------------------------------------------------------------------- # +# LEGENDRE # +# ---------------------------------------------------------------------------- # + +from numpy.polynomial import legendre + +@functools.cache +class Legendre(nn.Module, Basis): + + """ + Shifted Legendre polynomials: + - `(1 - x^2) Pn''(x) - 2 x Pn(x) + n (n + 1) Pn(x) = 0` + - `Pn^~(x) = Pn(2 x - 1)` + """ + + def extra_repr(self) -> str: + return f"m={self.m}" + + def __init__(self, m: int, endpoint: bool = False): + """ + """ + super().__init__() + self.m = m + + if endpoint: m -= 1 + c = (0, ) * m + (1, ) + dc = legendre.legder(c) + + x = legendre.legroots(dc if endpoint else c) + y = legendre.legval(x, c if endpoint else dc) + + if endpoint: + x = np.concatenate([[-1], x, [1]]) + y = np.concatenate([[1], y, [1]]) + + w = 1 / y ** 2 + if endpoint: w /= m * (m + 1) + else: w /= 1 - x ** 2 + + self.register_buffer("grid", torch.tensor((1 + x) / 2, dtype=torch.float)) + self.register_buffer("quad", torch.tensor(w, dtype=torch.float)) + + self.register_buffer("f", self.fn(self.grid)) + + def fn(self, xs: Tensor) -> Tensor: + """ + """ + P = torch.ones_like(xs), 2 * xs - 1 + + for i in range(2, self.m): + a, b = (i * 2 - 1) / i, (i - 1) / i + P += a * P[-1] * P[1] - b * P[-2], + + return torch.stack(P, dim=-1) + +# --------------------------------- TRANSFORM -------------------------------- # + + def modal(self, vals: Tensor) -> Tensor: + """ + """ + norm = 2 * torch.arange(self.m, device=vals.device) + 1 + coef = self.f * norm * self.quad[:, None] + return torch.tensordot(coef.T, vals, dims=1) + + def nodal(self, coef: Tensor) -> Tensor: + """ + """ + return self.at(coef, self.grid) diff --git a/physicsnemo/models/eddyformer/_datatype.py b/physicsnemo/models/eddyformer/_datatype.py new file mode 100644 index 0000000000..ea1e5514bf --- /dev/null +++ b/physicsnemo/models/eddyformer/_datatype.py @@ -0,0 +1,233 @@ +from typing import Tuple +from torch import Tensor + +import torch +import torch.nn.functional as F + +from dataclasses import dataclass, replace +from functools import cached_property + +from ._basis import Basis, Legendre + +def interp1d(value: Tensor, xs: Tensor, method: str) -> Tensor: + """ + Interpolate from 1D regular grid to a target points. + + Args: + value: Values on a uniform grid along the first axis. + xs: Resolution or an array normalized by the domain size. + method: Interpolation method. One of "fft", "linear", or + f"lag{n}" for n-point Lagrangian interpolation. + """ + if method == "fft": + coef = torch.fft.rfft(value, dim=0, norm="forward") + + k = 2 * torch.pi * torch.arange(len(coef)) + f = torch.exp(1j * k * xs[..., None]); f[..., 1:-1] *= 2 + return torch.tensordot(f.real, coef.real, dims=1) \ + - torch.tensordot(f.imag, coef.imag, dims=1) + + if method.startswith("lag"): + n_points = int(method[3:]) + + assert n_points % 2 == 0 + r = n_points // 2 - 1 + + n = len(value) + + i = (xs * (N := n - 1)).int() + i = torch.clip(i, r, n - n_points + r) + + # 1. pad the input grid + + v_pad = value, value[:r+2] + + if r > 0: v_pad = (value[-r:], ) + v_pad + value = torch.concatenate(v_pad, dim=0) + + # 2. construct polynomials + + out = 0 + + for j in range(n_points): + lag = value[i + j] + + for k in range(n_points): + if j == k: continue + fac = xs - (i + k - r) / N + while fac.ndim < lag.ndim: + fac = fac.unsqueeze(-1) + lag *= fac * N / (j - k) + + out += lag + return out + + raise ValueError(f"invalid interpolation {method=}") + +# ---------------------------------------------------------------------------- # +# SPECTRAL ELEMENT # +# ---------------------------------------------------------------------------- # + +@dataclass +class SEM: + + """ + Spectral element expansion. The sub-domain partition is + given by the `mesh` attribute. The spectral coefficients + of each element is stored in the first channel dimension, + whose size must equal to the number of elements. + """ + + T_: str + + # Mesh + + size: Tensor + mesh: Tuple[int] + + # Data + + mode_: Tuple[int] = None + nodal: Tensor = None + + @property + def ndim(self) -> int: + return len(self.mesh) + + @property + def mode(self) -> Tuple[int]: + if self.mode_: return self.mode_ + return self.nodal.shape[:self.ndim] + + @property + def use_elem(self) -> bool: + return self.T_.endswith("elem") + + @staticmethod + def basis(T_: str) -> Basis: + if T_.startswith("leg"): return Legendre + raise ValueError(f"invalid basis {T_=}") + + @cached_property + def T(self) -> Tuple[Basis]: + """ + Basis on each dimension. + """ + T = self.basis(self.T_) + return tuple(map(T, self.mode)) + + def to(self, mode: Tuple[int]) -> "SEM": + """ + Resample to another mode. + + Args: + mode: Number of modes. + """ + out = SEM(self.T_, self.size, self.mesh, mode) + + value = self.nodal + for n in range(self.ndim): + coef = self.T[n].modal(value) + if (pad:=out.mode[n] - self.mode[n]) <= 0: coef = coef[:mode[n]] + else: coef = torch.concat([coef, torch.zeros(pad, *coef.shape[1:], device=coef.device)], dim=0) + value = out.T[n].nodal(coef).movedim(0, self.ndim - 1) + + return out.new(value) + + def at(self, *xs: Tensor, uniform: bool = False) -> Tensor: + """ + Evaluate on rectilinear grids. + + Args: + xs: Coordinate of each dimension. + uniform: Whether `xs` are uniformly spaced. + """ + value = self.nodal + for n in range(self.ndim): + x = xs[n] / self.size[n] + coef = self.T[n].modal(value) + + # indices of each global coordinate `x` + idx = torch.floor(x * float(self.mesh[n])).int() + idx = torch.minimum(idx, torch.tensor(self.mesh[n] - 1)) + + # global coordinate to local coordinate + ys = x * float(self.mesh[n]) - torch.arange(self.mesh[n], device=x.device)[idx] + + if not uniform: + + # coefficients where each `x` belongs + coef = coef.movedim(self.ndim, 0)[idx] + + # evaluate at each coordonate and move the output axis to the last dimension + # after `ndim` iterations, the axes are automatically rolled to the correct order + value = torch.vmap(self.T[n].at, out_dims=self.ndim - 1)(coef, ys) + + else: + + # coordinates within each element + ys = ys.reshape(self.mesh[n], -1) + + # batched evaluation of all coordinates + value = torch.vmap(self.T[n].at, (self.ndim, 0))(coef, ys) + value = torch.movedim(value.flatten(end_dim=1), 0, self.ndim - 1) + + return value + +# ---------------------------------- COORDS ---------------------------------- # + + @cached_property + def grid(self) -> Tensor: + axes = [self.T[n].grid.to(self.size.device) for n in range(self.ndim)] + return torch.stack(torch.meshgrid(*axes, indexing="ij"), dim=-1) + + @cached_property + def coords(self) -> Tensor: + local = self.grid + for _ in range(self.ndim): + local = local.unsqueeze(self.ndim) + return self.origins + local * self.lengths + + @cached_property + def origins(self) -> Tensor: + left = [torch.arange(m, device=self.size.device) / m for m in self.mesh] + return torch.stack(torch.meshgrid(*left, indexing="ij"), dim=-1) * self.size + + @cached_property + def lengths(self) -> Tensor: + ns = torch.tensor(self.mesh, device=self.size.device) + return self.size / ns.float() + +# --------------------------------- DATA TYPE -------------------------------- # + + def new(self, nodal: Tensor) -> "SEM": + assert nodal.shape[:self.ndim] == self.mode + return replace(self, mode_=None, nodal=nodal) + + def eval(self, resolution: Tuple[int]) -> Tensor: + xs = [torch.linspace(0, s, n, device=self.size.device) for n, s in zip(resolution, self.size)] + return self.at(*xs, uniform=False) + + def from_grid(self, value: Tensor, method: str) -> "SEM": + """ + Interpolate grid values to a target datatype. + + Args: + out: Target datatype. + method: Interpolation method along each axis. + See `interp1d::method` for details. + """ + xs = self.coords / self.size + for n in range(self.ndim): + + # interpolate at each collocation points. `idx` is the + # index of the elements along the `n`'th dimension. + idx = [slice(None) if i == n else 0 for i in range(self.ndim)] + value = interp1d(value, xs[tuple(idx * 2 + [n])], method) + + # roll the output. The interpolated values have shape `(mode, mesh)`, + # which are moved to the middle (`ndim - 1`) and the end (`ndim + n`) of + # the dimensions. After `ndim` iterations, all axes are ordered correctly. + value = torch.moveaxis(value, (0, 1), (self.ndim - 1, self.ndim + n)) + + return self.new(value) diff --git a/physicsnemo/models/eddyformer/eddyformer.py b/physicsnemo/models/eddyformer/eddyformer.py new file mode 100644 index 0000000000..569eb95857 --- /dev/null +++ b/physicsnemo/models/eddyformer/eddyformer.py @@ -0,0 +1,181 @@ +from typing import Tuple, Union +from torch import Tensor + +import torch +import torch.nn as nn + +from dataclasses import dataclass +from functools import partial + +from ..module import Module +from ..meta import ModelMetaData +from ..layers.mlp_layers import Mlp + +from ._datatype import SEM +from .sem_conv import SEMConv +from .sem_attn import SEMAttn + +# Layer + +class EddyFormerLayer(nn.Module): + + @dataclass + class Config: + + basis: str + mesh: Tuple[int] + mode: Tuple[int] + + # SGS STREAM + kernel_size: Tuple[int] + + ffn_dim: int + activation: str + + # LES STREAM + mode_les: Tuple[int] + kernel_size_les: Tuple[int] + + num_heads: int + heads_dim: int + + @property + def ffn(self) -> partial[Mlp]: + return partial(Mlp, + hidden_features=self.ffn_dim, + act_layer=getattr(nn, self.activation), + ) + + @property + def attn(self) -> partial[SEMAttn]: + return partial(SEMAttn, + mode=self.mode_les, + num_heads=self.num_heads, + heads_dim=self.heads_dim, + ) + + def conv(self, stream: str) -> partial[SEMConv]: + return partial(SEMConv, + kernel_mode=(mode:=self.mode if stream == "sgs" else self.mode_les), + kernel_size=self.kernel_size if stream == "sgs" else self.kernel_size_les, + T=tuple(map(SEM.basis(self.basis), mode)), + ) + + def __init__(self, hdim: int, cfg: Config, *, layer_scale: float = 1e-7): + """ + EddyFormer layer. + """ + super().__init__() + + self.mode = cfg.mode + self.mode_les = cfg.mode_les + + self.eps = nn.Parameter(torch.ones(hdim) * layer_scale) + self.ffn_les, self.ffn_sgs = cfg.ffn(hdim), cfg.ffn(hdim) + + self.sem_conv_sgs = cfg.conv("sgs")(hdim, hdim) + self.sem_conv_les = cfg.conv("les")(hdim, hdim) + self.sem_attn = cfg.attn(hdim, hdim, conv=cfg.conv("les")) + + def __call__(self, les: SEM, sgs: SEM) -> Tuple[SEM, SEM]: + """ + """ + les.nodal = les.nodal + self.sem_attn(les).nodal + les.nodal = les.nodal + self.ffn_les(self.sem_conv_les(les).nodal) + + sgs.nodal = sgs.nodal + self.eps * les.to(self.mode).nodal + sgs.nodal = sgs.nodal + self.ffn_sgs(self.sem_conv_sgs(sgs).nodal) + + return les, sgs + +# Model + +@dataclass +class MetaData(ModelMetaData): + name: str = "EddyFormer" + # Optimization + jit: bool = True + cuda_graphs: bool = True + amp: bool = False + # Inference + onnx_cpu: bool = False + onnx_gpu: bool = False + onnx_runtime: bool = False + # Physics informed + var_dim: int = 1 + func_torch: bool = False + auto_grad: bool = False + +class EddyFormer(Module): + + cfg: EddyFormerLayer.Config + + lift_les: nn.Linear + lift_sgs: nn.Linear + + layers: nn.ModuleList + + proj_les: Mlp + proj_sgs: Mlp + + scale: nn.Parameter + + def __init__(self, + idim: int, + odim: int, + hdim: int, + num_layers: int, + cfg: EddyFormerLayer.Config): + """ + EddyFormer model. + """ + super().__init__(meta=MetaData()) + + self.cfg = cfg + self.ndim = len(cfg.mesh) + + self.lift_les = nn.Linear(idim + self.ndim, hdim) + self.lift_sgs = nn.Linear(idim + self.ndim, hdim) + + self.layers = nn.ModuleList() + for _ in range(num_layers): + layer = EddyFormerLayer(hdim, cfg) + self.layers.append(layer) + + self.proj_les = cfg.ffn(hdim, out_features=odim) + self.proj_sgs = cfg.ffn(hdim, out_features=odim) + + self.scale = nn.Parameter(torch.zeros(odim)) + + def __call__(self, input: Union[SEM, Tensor], return_sem: bool = False) -> Union[SEM, Tensor]: + """ + """ + if isinstance(input, Tensor): + size = 2 * torch.pi * torch.ones(self.ndim, device=input.device) + ϕ = SEM(self.cfg.basis, size, self.cfg.mesh, self.cfg.mode) \ + .from_grid(input, "lag8") # default interpolation method + else: + ϕ = input + + x = ϕ.grid.to(ϕ.nodal) + for n, mesh in enumerate(ϕ.mesh): + x = x.unsqueeze(dim:=self.ndim + n) + x = torch.repeat_interleave(x, mesh, dim) + x = torch.concatenate(torch.broadcast_tensors(ϕ.nodal, x), dim=-1) + + sgs = ϕ.new(x) + les = sgs.to(self.cfg.mode_les) + + sgs.nodal = self.lift_sgs(sgs.nodal) + les.nodal = self.lift_les(les.nodal) + + for layer in self.layers: + les, sgs = layer(les, sgs) + + sgs.nodal = self.proj_sgs(sgs.nodal) + les.nodal = self.proj_les(les.nodal) + + out = ϕ.new(les.to(ϕ.mode).nodal + sgs.nodal) + if not return_sem: out = out.eval(input.shape[:-1]) + + return out diff --git a/physicsnemo/models/eddyformer/sem_attn.py b/physicsnemo/models/eddyformer/sem_attn.py new file mode 100644 index 0000000000..a9b5cc3674 --- /dev/null +++ b/physicsnemo/models/eddyformer/sem_attn.py @@ -0,0 +1,74 @@ +from typing import Tuple +from torch import Tensor + +import torch +import torch.nn as nn + +from functools import partial + +from ._datatype import SEM +from .sem_conv import SEMConv + +class SEMAttn(nn.Module): + + proj: nn.ModuleDict + bias: nn.ParameterDict + norm: nn.ModuleDict + + out: nn.Linear + + def __init__(self, + idim: int, + odim: int, + mode: Tuple[int], + num_heads: int, + heads_dim: int, + *, + conv: partial[SEMConv], + bias_init = torch.zeros): + """ + """ + super().__init__() + + self.proj = nn.ModuleDict() + self.bias = nn.ParameterDict() + self.norm = nn.ModuleDict() + + for name in "QKV": + self.proj[name] = conv(idim, (num_heads, heads_dim)) + + for n in range(len(mode)): + self.bias[f"{name}{n}"] = nn.Parameter(bias_init((num_heads, heads_dim))) + self.norm[f"{name}{n}"] = nn.LayerNorm(heads_dim) + + self.out = nn.Linear(num_heads * heads_dim * len(mode), odim) + + def project(self, ϕ: SEM, name: str) -> Tensor: + """ + Project the features to attention space. + """ + xs = [] + + for n in range(ϕ.ndim): + x = self.proj[name].factor(ϕ, n).nodal + + if name in ["Q", "K"]: + x = x + self.bias[f"{name}{n}"] + + f, g = torch.split(self.norm[f"{name}{n}"](x), x.shape[-1] // 2, dim=-1) + k = ϕ.coords[..., None, [n]] * torch.arange(f.shape[-1], device=x.device) + + f, g = torch.cos(k) * f - torch.sin(k) * g, torch.sin(k) * f + torch.cos(k) * g + x = torch.concatenate([torch.cos(k) + f, torch.sin(k) + g], dim=-1) + + xs.append(x.reshape(ϕ.mode + (-1, ) + x.shape[-2:])) + return torch.concatenate(xs, dim=-1) + + def __call__(self, ϕ: SEM) -> SEM: + """ + Self-attention on SEM features. + """ + q, k, v = (self.project(ϕ, name) for name in "QKV") + + attn = nn.functional.scaled_dot_product_attention(q, k, v) + return ϕ.new(self.out(attn.reshape(*ϕ.mode, *ϕ.mesh, -1))) diff --git a/physicsnemo/models/eddyformer/sem_conv.py b/physicsnemo/models/eddyformer/sem_conv.py new file mode 100644 index 0000000000..2dc4c9a42f --- /dev/null +++ b/physicsnemo/models/eddyformer/sem_conv.py @@ -0,0 +1,150 @@ +from typing import Tuple, Union +from torch import Tensor + +import torch +import torch.nn as nn + +import numpy as np +from functools import partial, cache +from scipy import integrate + +from ._basis import Basis +from ._datatype import SEM + +class SEMConv(nn.Module): + + odim: Tuple[int] + kernel: nn.ParameterList + + def __init__(self, + idim: int, + odim: Union[int, Tuple[int]], + T: Tuple[Basis], + kernel_mode: Tuple[int], + kernel_size: Tuple[int], + kernel_init_std: float = 1e-7): + """ + """ + super().__init__() + self.T = nn.ModuleList(T) + + if isinstance(odim, int): + self.odim = (odim, ) + else: + self.odim = odim + odim = np.prod(odim) + + self.kernel = nn.ParameterList() + for n, (m, s) in enumerate(zip(kernel_mode, kernel_size)): + self.kernel.append(nn.Parameter(coef:=torch.empty(s * m, idim, odim))) + + torch.nn.init.normal_(coef, std=kernel_init_std) + self.register_buffer(f"ws_{n}", weight(T[n], s)) + + def factor(self, ϕ: SEM, dim: int) -> SEM: + """ + Factorized SEM convolution. + + Args: + ϕ: Input SEM feature field. + dim: Dimension to convolve over. + """ + coef, ws = self.kernel[dim], getattr(self, f"ws_{dim}") + out = sem_conv(ϕ.nodal, coef, ws, T=ϕ.T[dim], ndim=ϕ.ndim, dim=dim) + return ϕ.new(out.reshape(out.shape[:-1] + self.odim)) + + def __call__(self, ϕ: SEM) -> SEM: + return ϕ.new(sum(self.factor(ϕ, n).nodal for n in range(ϕ.ndim))) + +# ---------------------------------------------------------------------------- # +# CONVOLUTION # +# ---------------------------------------------------------------------------- # + +def kernel(coef: Tensor, xs: Tensor) -> Tensor: + """ + Evaluate the Fourier kernel. + + Args: + coef: Fourier coefficients. + xs: Query coordinates. + """ + r, i_ = torch.split(coef, (m:=(n:=len(coef)) // 2 + 1, n - m)) + i = torch.zeros_like(r); i[1:n-m+1] = torch.flip(i_, dims=[0]) + + k = 2 * torch.pi * torch.arange(m, device=xs.device) + f = torch.exp(1j * k * xs[..., None]); f[..., 1:-1] *= 2 + + return torch.tensordot(f.real, r, 1) \ + - torch.tensordot(f.imag, i, 1) + +@cache +def weight(T: Basis, s: int, use_mp: bool = True) -> Tensor: + """ + """ + print(f"Pre-computing weights for `{T=}` and `{s=}`...") + + eps = torch.finfo(torch.float).eps + ab = T.grid[..., None] + torch.tensor([-s/2, s/2]) + + map_ = map + if use_mp: + from concurrent.futures import ThreadPoolExecutor + map_ = (pool := ThreadPoolExecutor()).map + + def quad(T: Basis, m: int, a: float, b: float) -> Tensor: + f = lambda x: T.fn(torch.tensor(x))[m] + y, e = integrate.quad(f, a, b) + return y + + ws = [] + for i in range(-s//2, s//2 + 1): + ws.append(w:=[]) + + from tqdm import tqdm + for a, b in tqdm(ab, f"{i=}"): + a = torch.clip(a - i, -eps, 1 + eps) + b = torch.clip(b - i, -eps, 1 + eps) + + q = torch.tensor(list(map_(partial(quad, T, a=a, b=b), range(T.m)))) + w.append(torch.linalg.solve(T.fn(T.grid).T, q).tolist()) + + if use_mp: pool.shutdown() + return torch.tensor(ws) + +def sem_conv(nodal: Tensor, coef: Tensor, ws: Tensor, *, T: Basis, ndim: int, dim: int): + """ + Args: + w: An (s + 1, m, n) array where s is the window size, m is the number + of quadrature nodes, and n is the number of output nodes. + """ + n = ndim + dim # mesh dim + + ns = "".join(map(chr, range(120, 120 + ndim))) + ms = ns.replace(i:=ns[dim], o:=ns[dim].upper()) + + pad_r = nodal.index_select(n, torch.arange(0, r:=len(ws)//2, device=nodal.device)) + pad_l = nodal.index_select(n, torch.arange(nodal.shape[n]-r, nodal.shape[n], device=nodal.device)) + + # pad_r = torch.slice_copy(nodal, n, 0, r:=len(ws)//2) + # pad_l = torch.slice_copy(nodal, n, -r, end=None) + + f = torch.concatenate([pad_l, nodal, pad_r], dim=n) + out = [] + # out = torch.zeros(*nodal.shape[:-1], coef.shape[-1], device=nodal.device) + + for k, w in enumerate(ws): + + x = T.grid + k - r + xy = T.grid[:, None] - x + + fx = torch.narrow(f, n, k, nodal.shape[n]) + gxy = kernel(coef, xy / (len(ws) - 1)) + + eqn = f"{ns}...i, {o}{i}io, {o}{i} -> {ms}...o" + # print(f"{eqn}: {tuple(fx.shape)}, {tuple(gxy.shape)}, {tuple(w.shape)}") + + # print(out.shape, torch.einsum(eqn, fx, gxy, w).shape) + # out += torch.einsum(eqn, fx, gxy, w) + out.append(torch.einsum(eqn, fx, gxy, w)) + + return sum(out) diff --git a/pyproject.toml b/pyproject.toml index 5415416427..af58333f31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ all = [ "ruamel.yaml>=0.17.22", "scikit-learn>=1.0.2", "scikit-image>=0.24.0", + "scipy>=1.15.0", "warp-lang>=1.0", "vtk>=9.2.6", "pyvista>=0.40.1", From 9cb780a57d488ee247761d96669eabdf1a81bbef Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Thu, 20 Nov 2025 21:15:50 -0800 Subject: [PATCH 2/7] fix format issue --- examples/cfd/isotropic_eddyformer/README.md | 2 +- .../isotropic_eddyformer/download_dataset.sh | 2 +- .../isotropic_eddyformer/train_ef_isotropic.py | 18 +++++++++++------- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/cfd/isotropic_eddyformer/README.md b/examples/cfd/isotropic_eddyformer/README.md index 1e22dd9485..027b668367 100644 --- a/examples/cfd/isotropic_eddyformer/README.md +++ b/examples/cfd/isotropic_eddyformer/README.md @@ -92,4 +92,4 @@ python train_ef_isotropic.py ## References -- [EddyFormer: EddyFormer: Accelerated Neural Simulations of Three-Dimensional Turbulence at Scale](https://arxiv.org/abs/2510.24173) +- [EddyFormer: Accelerated Neural Simulations of Three-Dimensional Turbulence at Scale](https://arxiv.org/abs/2510.24173) diff --git a/examples/cfd/isotropic_eddyformer/download_dataset.sh b/examples/cfd/isotropic_eddyformer/download_dataset.sh index 7b50328c92..52da8b034d 100644 --- a/examples/cfd/isotropic_eddyformer/download_dataset.sh +++ b/examples/cfd/isotropic_eddyformer/download_dataset.sh @@ -1 +1 @@ -hf download --repo-type dataset ydu11/re94 --local-dir ${1:-data/ns3d-re94} \ No newline at end of file +hf download --repo-type dataset ydu11/re94 --local-dir ${1:-data/ns3d-re94} diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py index 6546d20ca7..a433e60bc1 100644 --- a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -68,7 +68,7 @@ def isotropic_trainer(cfg: DictConfig) -> None: log.file_logging() LaunchLogger.initialize() # PhysicsNeMo launch logger - # define model, loss, optimiser, scheduler, data loader + # define model, loss, optimizer model = EddyFormer( idim=cfg.model.idim, odim=cfg.model.odim, @@ -78,11 +78,18 @@ def isotropic_trainer(cfg: DictConfig) -> None: ).to(dist.device) loss_fun = MSELoss(reduction="mean") optimizer = Adam(model.parameters(), lr=cfg.training.learning_rate) + + # define dataset and dataloader dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) + dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) - # define forward passes for training and inference + # define forward passes for training @StaticCaptureTraining( - model=model, optim=optimizer, logger=log, use_amp=False, use_graphs=False + model=model, + optim=optimizer, + logger=log, + use_amp=False, + use_graphs=False ) def training_step(input, target): pred = torch.vmap(model)(input) @@ -91,14 +98,11 @@ def training_step(input, target): for epoch in range(cfg.training.num_epochs): - dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) - for input, target in dataloader: input = input.to(dist.device) target = target.to(dist.device) - with torch.autograd.set_detect_anomaly(True): - loss = training_step(input, target) + loss = training_step(input, target) with LaunchLogger("train", epoch=epoch) as logger: logger.log_minibatch({"Training loss": loss.item()}) From a4d7a6505353cbcb7cb3cf8fdc79c1241ed93605 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Thu, 20 Nov 2025 21:18:03 -0800 Subject: [PATCH 3/7] verify rope dimension --- physicsnemo/models/eddyformer/sem_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/physicsnemo/models/eddyformer/sem_attn.py b/physicsnemo/models/eddyformer/sem_attn.py index a9b5cc3674..2294d0c47b 100644 --- a/physicsnemo/models/eddyformer/sem_attn.py +++ b/physicsnemo/models/eddyformer/sem_attn.py @@ -54,6 +54,7 @@ def project(self, ϕ: SEM, name: str) -> Tensor: if name in ["Q", "K"]: x = x + self.bias[f"{name}{n}"] + assert x.shape[-1] % 2 == 0 f, g = torch.split(self.norm[f"{name}{n}"](x), x.shape[-1] // 2, dim=-1) k = ϕ.coords[..., None, [n]] * torch.arange(f.shape[-1], device=x.device) From 6bf561773c6cbd598a719cb58c460a22f64fdc87 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Sat, 22 Nov 2025 22:12:02 -0800 Subject: [PATCH 4/7] fix device and docstring --- physicsnemo/models/eddyformer/_datatype.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/physicsnemo/models/eddyformer/_datatype.py b/physicsnemo/models/eddyformer/_datatype.py index ea1e5514bf..e2a248f7d4 100644 --- a/physicsnemo/models/eddyformer/_datatype.py +++ b/physicsnemo/models/eddyformer/_datatype.py @@ -2,7 +2,6 @@ from torch import Tensor import torch -import torch.nn.functional as F from dataclasses import dataclass, replace from functools import cached_property @@ -22,7 +21,7 @@ def interp1d(value: Tensor, xs: Tensor, method: str) -> Tensor: if method == "fft": coef = torch.fft.rfft(value, dim=0, norm="forward") - k = 2 * torch.pi * torch.arange(len(coef)) + k = 2 * torch.pi * torch.arange(len(coef), device=xs.device) f = torch.exp(1j * k * xs[..., None]); f[..., 1:-1] *= 2 return torch.tensordot(f.real, coef.real, dims=1) \ - torch.tensordot(f.imag, coef.imag, dims=1) @@ -149,7 +148,7 @@ def at(self, *xs: Tensor, uniform: bool = False) -> Tensor: # indices of each global coordinate `x` idx = torch.floor(x * float(self.mesh[n])).int() - idx = torch.minimum(idx, torch.tensor(self.mesh[n] - 1)) + idx = torch.minimum(idx, torch.tensor(self.mesh[n] - 1, device=idx.device)) # global coordinate to local coordinate ys = x * float(self.mesh[n]) - torch.arange(self.mesh[n], device=x.device)[idx] @@ -159,7 +158,7 @@ def at(self, *xs: Tensor, uniform: bool = False) -> Tensor: # coefficients where each `x` belongs coef = coef.movedim(self.ndim, 0)[idx] - # evaluate at each coordonate and move the output axis to the last dimension + # evaluate at each coordinate and move the output axis to the last dimension # after `ndim` iterations, the axes are automatically rolled to the correct order value = torch.vmap(self.T[n].at, out_dims=self.ndim - 1)(coef, ys) @@ -213,7 +212,7 @@ def from_grid(self, value: Tensor, method: str) -> "SEM": Interpolate grid values to a target datatype. Args: - out: Target datatype. + value: Input tensor (include boundary points). method: Interpolation method along each axis. See `interp1d::method` for details. """ From 5ca494cbaf93b72acf182735b4ed0c7413135037 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Sat, 22 Nov 2025 22:13:49 -0800 Subject: [PATCH 5/7] fix import and remove comments --- physicsnemo/models/eddyformer/sem_conv.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/physicsnemo/models/eddyformer/sem_conv.py b/physicsnemo/models/eddyformer/sem_conv.py index 2dc4c9a42f..fa51e5006c 100644 --- a/physicsnemo/models/eddyformer/sem_conv.py +++ b/physicsnemo/models/eddyformer/sem_conv.py @@ -3,10 +3,11 @@ import torch import torch.nn as nn - import numpy as np + from functools import partial, cache from scipy import integrate +from tqdm import tqdm from ._basis import Basis from ._datatype import SEM @@ -100,7 +101,6 @@ def quad(T: Basis, m: int, a: float, b: float) -> Tensor: for i in range(-s//2, s//2 + 1): ws.append(w:=[]) - from tqdm import tqdm for a, b in tqdm(ab, f"{i=}"): a = torch.clip(a - i, -eps, 1 + eps) b = torch.clip(b - i, -eps, 1 + eps) @@ -125,12 +125,8 @@ def sem_conv(nodal: Tensor, coef: Tensor, ws: Tensor, *, T: Basis, ndim: int, di pad_r = nodal.index_select(n, torch.arange(0, r:=len(ws)//2, device=nodal.device)) pad_l = nodal.index_select(n, torch.arange(nodal.shape[n]-r, nodal.shape[n], device=nodal.device)) - # pad_r = torch.slice_copy(nodal, n, 0, r:=len(ws)//2) - # pad_l = torch.slice_copy(nodal, n, -r, end=None) - f = torch.concatenate([pad_l, nodal, pad_r], dim=n) out = [] - # out = torch.zeros(*nodal.shape[:-1], coef.shape[-1], device=nodal.device) for k, w in enumerate(ws): @@ -141,10 +137,6 @@ def sem_conv(nodal: Tensor, coef: Tensor, ws: Tensor, *, T: Basis, ndim: int, di gxy = kernel(coef, xy / (len(ws) - 1)) eqn = f"{ns}...i, {o}{i}io, {o}{i} -> {ms}...o" - # print(f"{eqn}: {tuple(fx.shape)}, {tuple(gxy.shape)}, {tuple(w.shape)}") - - # print(out.shape, torch.einsum(eqn, fx, gxy, w).shape) - # out += torch.einsum(eqn, fx, gxy, w) out.append(torch.einsum(eqn, fx, gxy, w)) return sum(out) From b839de8f708321a2db05a54376dd33f1cc17385b Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Mon, 24 Nov 2025 12:30:22 -0800 Subject: [PATCH 6/7] use ddp; change to rel l2 loss; add checkpointing --- examples/cfd/isotropic_eddyformer/config.yaml | 5 ++- .../train_ef_isotropic.py | 45 ++++++++++++++----- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/examples/cfd/isotropic_eddyformer/config.yaml b/examples/cfd/isotropic_eddyformer/config.yaml index e7018f54d0..8c0198c4d1 100644 --- a/examples/cfd/isotropic_eddyformer/config.yaml +++ b/examples/cfd/isotropic_eddyformer/config.yaml @@ -3,6 +3,7 @@ model: odim: 3 hdim: 32 num_layers: 4 + use_scale: true layer_config: basis: legendre mesh: [8, 8, 8] @@ -17,7 +18,9 @@ model: training: dataset: data/ns3d-re94 + result_dir: outputs/ef-re94 t: 0.5 batch_size: 4 - num_epochs: 100 + num_epochs: 1 learning_rate: 1e-3 + ckpt_every: 1000 diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py index a433e60bc1..d246b41883 100644 --- a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -7,13 +7,14 @@ import numpy as np import torch -from torch.nn import MSELoss from torch.optim import Adam from torch.utils.data import Dataset, DataLoader +from torch.nn.parallel import DistributedDataParallel from physicsnemo.models.eddyformer import EddyFormer, EddyFormerConfig from physicsnemo.distributed import DistributedManager from physicsnemo.utils import StaticCaptureTraining +from physicsnemo.launch.utils import save_checkpoint from physicsnemo.launch.logging import PythonLogger, LaunchLogger @@ -65,25 +66,43 @@ def isotropic_trainer(cfg: DictConfig) -> None: # initialize monitoring log = PythonLogger(name="re94_ef") - log.file_logging() + log.file_logging(f"{cfg.training.result_dir}/log.txt") LaunchLogger.initialize() # PhysicsNeMo launch logger - # define model, loss, optimizer + # define model and optimizer model = EddyFormer( idim=cfg.model.idim, odim=cfg.model.odim, hdim=cfg.model.hdim, num_layers=cfg.model.num_layers, + use_scale=cfg.model.use_scale, cfg=EddyFormerConfig(**cfg.model.layer_config), ).to(dist.device) - loss_fun = MSELoss(reduction="mean") + + if dist.distributed: + ddps = torch.cuda.Stream() + with torch.cuda.stream(ddps): + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + ) + torch.cuda.current_stream().wait_stream(ddps) + log.success("Initialized DDP training") + optimizer = Adam(model.parameters(), lr=cfg.training.learning_rate) # define dataset and dataloader dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) - # define forward passes for training + # define relative l2 error as the loss function + def loss_fun(pred: Tensor, target: Tensor) -> Tensor: + return torch.linalg.norm(pred - target) / torch.linalg.norm(target) + + # define training step @StaticCaptureTraining( model=model, optim=optimizer, @@ -91,14 +110,16 @@ def isotropic_trainer(cfg: DictConfig) -> None: use_amp=False, use_graphs=False ) - def training_step(input, target): + def training_step(input: Tensor, target: Tensor) -> Tensor: pred = torch.vmap(model)(input) - loss = loss_fun(pred, target) - return loss + loss = torch.vmap(loss_fun)(pred, target) + return torch.mean(loss) - for epoch in range(cfg.training.num_epochs): + it = 0 + log.info("Training started") - for input, target in dataloader: + for epoch in range(cfg.training.num_epochs): + for it, (input, target) in enumerate(dataloader, it): input = input.to(dist.device) target = target.to(dist.device) @@ -107,7 +128,11 @@ def training_step(input, target): with LaunchLogger("train", epoch=epoch) as logger: logger.log_minibatch({"Training loss": loss.item()}) + if it and it % cfg.training.ckpt_every == 0 and dist.rank == 0: + save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer, epoch=it) + log.success("Training completed") + save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer) if __name__ == "__main__": From ff2947c02f0d441b6e7582c8304cf1d9eb7f49e9 Mon Sep 17 00:00:00 2001 From: Yiheng Du Date: Mon, 24 Nov 2025 12:33:51 -0800 Subject: [PATCH 7/7] switch to physicsnemo.Module; add use_scale; separate EddyFormerConfig class --- physicsnemo/models/eddyformer/__init__.py | 4 +- physicsnemo/models/eddyformer/eddyformer.py | 111 ++++++++++++-------- physicsnemo/models/eddyformer/sem_attn.py | 10 +- physicsnemo/models/eddyformer/sem_conv.py | 3 +- 4 files changed, 76 insertions(+), 52 deletions(-) diff --git a/physicsnemo/models/eddyformer/__init__.py b/physicsnemo/models/eddyformer/__init__.py index db0569fda6..63144343ba 100644 --- a/physicsnemo/models/eddyformer/__init__.py +++ b/physicsnemo/models/eddyformer/__init__.py @@ -1,5 +1,3 @@ from ._basis import Legendre from ._datatype import SEM -from .eddyformer import EddyFormer, EddyFormerLayer - -EddyFormerConfig = EddyFormerLayer.Config +from .eddyformer import EddyFormer, EddyFormerConfig diff --git a/physicsnemo/models/eddyformer/eddyformer.py b/physicsnemo/models/eddyformer/eddyformer.py index 569eb95857..ec65c3a5ff 100644 --- a/physicsnemo/models/eddyformer/eddyformer.py +++ b/physicsnemo/models/eddyformer/eddyformer.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Tuple, Union, Optional from torch import Tensor import torch @@ -15,53 +15,72 @@ from .sem_conv import SEMConv from .sem_attn import SEMAttn -# Layer - -class EddyFormerLayer(nn.Module): +class EddyFormerConfig(Module): - @dataclass - class Config: + basis: str + mesh: Tuple[int] + mode: Tuple[int] - basis: str - mesh: Tuple[int] - mode: Tuple[int] + # SGS STREAM + kernel_size: Tuple[int] - # SGS STREAM - kernel_size: Tuple[int] + ffn_dim: int + activation: str - ffn_dim: int - activation: str + # LES STREAM + mode_les: Tuple[int] + kernel_size_les: Tuple[int] - # LES STREAM - mode_les: Tuple[int] - kernel_size_les: Tuple[int] + num_heads: int + heads_dim: int - num_heads: int - heads_dim: int + def __init__(self, basis: str, mesh: Tuple[int], mode: Tuple[int], + kernel_size: Tuple[int], ffn_dim: int, activation: str, + mode_les: Tuple[int], kernel_size_les: Tuple[int], num_heads: int, heads_dim: int): + """ + """ + super().__init__() - @property - def ffn(self) -> partial[Mlp]: - return partial(Mlp, - hidden_features=self.ffn_dim, - act_layer=getattr(nn, self.activation), - ) + self.basis = basis + self.mesh = mesh + self.mode = mode + + self.kernel_size = kernel_size + self.ffn_dim = ffn_dim + self.activation = activation + + self.mode_les = mode_les + self.kernel_size_les = kernel_size_les + self.num_heads = num_heads + self.heads_dim = heads_dim + + @property + def ffn(self) -> partial[Mlp]: + return partial(Mlp, + hidden_features=self.ffn_dim, + act_layer=getattr(nn, self.activation), + ) + + @property + def attn(self) -> partial[SEMAttn]: + return partial(SEMAttn, + mode=self.mode_les, + num_heads=self.num_heads, + heads_dim=self.heads_dim, + ) + + def conv(self, stream: str) -> partial[SEMConv]: + return partial(SEMConv, + kernel_mode=(mode:=self.mode if stream == "sgs" else self.mode_les), + kernel_size=self.kernel_size if stream == "sgs" else self.kernel_size_les, + T=tuple(map(SEM.basis(self.basis), mode)), + ) - @property - def attn(self) -> partial[SEMAttn]: - return partial(SEMAttn, - mode=self.mode_les, - num_heads=self.num_heads, - heads_dim=self.heads_dim, - ) +# Layer - def conv(self, stream: str) -> partial[SEMConv]: - return partial(SEMConv, - kernel_mode=(mode:=self.mode if stream == "sgs" else self.mode_les), - kernel_size=self.kernel_size if stream == "sgs" else self.kernel_size_les, - T=tuple(map(SEM.basis(self.basis), mode)), - ) +class EddyFormerLayer(Module): - def __init__(self, hdim: int, cfg: Config, *, layer_scale: float = 1e-7): + def __init__(self, hdim: int, cfg: EddyFormerConfig, *, layer_scale: float = 1e-7): """ EddyFormer layer. """ @@ -108,7 +127,7 @@ class MetaData(ModelMetaData): class EddyFormer(Module): - cfg: EddyFormerLayer.Config + cfg: EddyFormerConfig lift_les: nn.Linear lift_sgs: nn.Linear @@ -118,14 +137,16 @@ class EddyFormer(Module): proj_les: Mlp proj_sgs: Mlp - scale: nn.Parameter + scale: Optional[nn.Parameter] def __init__(self, idim: int, odim: int, hdim: int, num_layers: int, - cfg: EddyFormerLayer.Config): + *, + use_scale: bool = True, + cfg: EddyFormerConfig): """ EddyFormer model. """ @@ -145,7 +166,7 @@ def __init__(self, self.proj_les = cfg.ffn(hdim, out_features=odim) self.proj_sgs = cfg.ffn(hdim, out_features=odim) - self.scale = nn.Parameter(torch.zeros(odim)) + self.scale = nn.Parameter(torch.zeros(odim)) if use_scale else None def __call__(self, input: Union[SEM, Tensor], return_sem: bool = False) -> Union[SEM, Tensor]: """ @@ -175,7 +196,9 @@ def __call__(self, input: Union[SEM, Tensor], return_sem: bool = False) -> Union sgs.nodal = self.proj_sgs(sgs.nodal) les.nodal = self.proj_les(les.nodal) - out = ϕ.new(les.to(ϕ.mode).nodal + sgs.nodal) - if not return_sem: out = out.eval(input.shape[:-1]) + scale = self.scale if self.scale is not None else 1. + out = ϕ.new(les.to(ϕ.mode).nodal + scale * sgs.nodal) + if not return_sem: + out = out.eval(input.shape[:-1]) return out diff --git a/physicsnemo/models/eddyformer/sem_attn.py b/physicsnemo/models/eddyformer/sem_attn.py index 2294d0c47b..261a6ea8c5 100644 --- a/physicsnemo/models/eddyformer/sem_attn.py +++ b/physicsnemo/models/eddyformer/sem_attn.py @@ -6,10 +6,11 @@ from functools import partial +from ..module import Module from ._datatype import SEM from .sem_conv import SEMConv -class SEMAttn(nn.Module): +class SEMAttn(Module): proj: nn.ModuleDict bias: nn.ParameterDict @@ -37,9 +38,10 @@ def __init__(self, for name in "QKV": self.proj[name] = conv(idim, (num_heads, heads_dim)) - for n in range(len(mode)): - self.bias[f"{name}{n}"] = nn.Parameter(bias_init((num_heads, heads_dim))) - self.norm[f"{name}{n}"] = nn.LayerNorm(heads_dim) + if name in ["Q", "K"]: + for n in range(len(mode)): + self.bias[f"{name}{n}"] = nn.Parameter(bias_init((num_heads, heads_dim))) + self.norm[f"{name}{n}"] = nn.LayerNorm(heads_dim) self.out = nn.Linear(num_heads * heads_dim * len(mode), odim) diff --git a/physicsnemo/models/eddyformer/sem_conv.py b/physicsnemo/models/eddyformer/sem_conv.py index fa51e5006c..6d8b55d021 100644 --- a/physicsnemo/models/eddyformer/sem_conv.py +++ b/physicsnemo/models/eddyformer/sem_conv.py @@ -9,10 +9,11 @@ from scipy import integrate from tqdm import tqdm +from ..module import Module from ._basis import Basis from ._datatype import SEM -class SEMConv(nn.Module): +class SEMConv(Module): odim: Tuple[int] kernel: nn.ParameterList