diff --git a/examples/cfd/isotropic_eddyformer/README.md b/examples/cfd/isotropic_eddyformer/README.md new file mode 100644 index 0000000000..027b668367 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/README.md @@ -0,0 +1,95 @@ +# EddyFormer for 3D Isotropic Turbulence + +This example demonstrates how to use the EddyFormer model for simulating +a three-dimensional isotropic turbulence. This example runs on a single GPU. + +## Problem Overview + +This example focuses on **three-dimensional homogeneous isotropic turbulence (HIT)** sustained by large-scale forcing. The flow is governed by the incompressible Navier–Stokes equations with an external forcing term: + +\[ +\frac{\partial \mathbf{u}}{\partial t} + \mathbf{u} \cdot \nabla \mathbf{u} += \nu \nabla^2 \mathbf{u} + \mathbf{f}(\mathbf{x}) +\] + +where: + +- **\(\mathbf{u}(\mathbf{x}, t)\)** — velocity field in a 3D periodic domain +- **\(\nu = 0.01\)** — kinematic viscosity +- **\(\mathbf{f}(\mathbf{x})\)** — isotropic forcing applied at the largest scales + +### Forcing Mechanism + +To maintain statistically steady turbulence, a **constant-power forcing** is applied to the lowest Fourier modes (\(|\mathbf{k}| \le 1\)). The forcing injects a prescribed amount of energy \(P_{\text{in}} = 1.0\) into the system: + +\[ +\mathbf{f}(\mathbf{x}) = +\frac{P_{\text{in}}}{E_1} +\sum_{\substack{|\mathbf{k}| \le 1 \\ \mathbf{k} \neq 0}} +\hat{\mathbf{u}}_{\mathbf{k}} e^{i \mathbf{k} \cdot \mathbf{x}} +\] + +where: + +\[ +E_1 = \frac{1}{2} +\sum_{|\mathbf{k}| \le 1} +\hat{\mathbf{u}}_{\mathbf{k}} \cdot \hat{\mathbf{u}}_{\mathbf{k}}^{*} +\] + +is the kinetic energy contained in the forced low-wavenumber modes. + +Under this forcing, the flow reaches a **statistically steady state** with a Taylor-scale Reynolds number of: + +**\(\mathrm{Re}_\lambda \approx 94\)** + +### Task Description + +The objective of this example is to **predict the future velocity field** of the turbulent flow. Given \(\mathbf{u}(\mathbf{x}, t)\), the task is: + +> **Predict the velocity field \(\mathbf{u}(\mathbf{x}, t + \Delta t)\) with \(\Delta t = 0.5\).** + +This requires modeling nonlinear, chaotic, multi-scale turbulent dynamics, including: + +- energy injection at large scales +- nonlinear transfer across the inertial range +- dissipation at the smallest scales + +### Dataset Summary + +- **DNS resolution:** \(384^3\) (used to generate the dataset) +- **Stored dataset resolution:** \(96^3\) +- **Kolmogorov scale resolution:** ~0.5 η +- **Forcing:** applied to modes with \(|\mathbf{k}| \le 1\) +- **Viscosity:** \(\nu = 0.01\) +- **Input power:** \(P_{\text{in}} = 1.0\) +- **Flow regime:** statistically steady HIT at \(\mathrm{Re}_\lambda \approx 94\) + +## Prerequisites + +Install the required dependencies by running below: + +```bash +pip install -r requirements.txt +``` + +## Download the Dataset + +The dataset is publicly available at [Huggingface](https://huggingface.co/datasets/ydu11/re94). +To download the dataset, run (you might need to install the Huggingface CLI): + +```bash +bash download_dataset.sh +``` + +## Getting Started + +To train the model, run + +```bash +python train_ef_isotropic.py +``` + +## References + +- [EddyFormer: Accelerated Neural Simulations of Three-Dimensional Turbulence at Scale](https://arxiv.org/abs/2510.24173) diff --git a/examples/cfd/isotropic_eddyformer/config.yaml b/examples/cfd/isotropic_eddyformer/config.yaml new file mode 100644 index 0000000000..8c0198c4d1 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/config.yaml @@ -0,0 +1,26 @@ +model: + idim: 3 + odim: 3 + hdim: 32 + num_layers: 4 + use_scale: true + layer_config: + basis: legendre + mesh: [8, 8, 8] + mode: [10, 10, 10] + mode_les: [5, 5, 5] + kernel_size: [2, 2, 2] + kernel_size_les: [2, 2, 2] + ffn_dim: 128 + activation: GELU + num_heads: 4 + heads_dim: 32 + +training: + dataset: data/ns3d-re94 + result_dir: outputs/ef-re94 + t: 0.5 + batch_size: 4 + num_epochs: 1 + learning_rate: 1e-3 + ckpt_every: 1000 diff --git a/examples/cfd/isotropic_eddyformer/download_dataset.sh b/examples/cfd/isotropic_eddyformer/download_dataset.sh new file mode 100644 index 0000000000..52da8b034d --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/download_dataset.sh @@ -0,0 +1 @@ +hf download --repo-type dataset ydu11/re94 --local-dir ${1:-data/ns3d-re94} diff --git a/examples/cfd/isotropic_eddyformer/requirements.txt b/examples/cfd/isotropic_eddyformer/requirements.txt new file mode 100644 index 0000000000..001dc23f09 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/requirements.txt @@ -0,0 +1,2 @@ +hydra-core>=1.2.0 +termcolor>=2.1.1 diff --git a/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py new file mode 100644 index 0000000000..d246b41883 --- /dev/null +++ b/examples/cfd/isotropic_eddyformer/train_ef_isotropic.py @@ -0,0 +1,139 @@ +import hydra +from typing import Tuple +from torch import Tensor +from omegaconf import DictConfig + +import os +import numpy as np + +import torch +from torch.optim import Adam +from torch.utils.data import Dataset, DataLoader +from torch.nn.parallel import DistributedDataParallel + +from physicsnemo.models.eddyformer import EddyFormer, EddyFormerConfig +from physicsnemo.distributed import DistributedManager +from physicsnemo.utils import StaticCaptureTraining +from physicsnemo.launch.utils import save_checkpoint +from physicsnemo.launch.logging import PythonLogger, LaunchLogger + + +class Re94(Dataset): + + root: str + t: float + + n: int = 50 + dt: float = 0.1 + + def __init__(self, root: str, split: str, *, t: float = 0.5) -> None: + """ + """ + super().__init__() + self.root = root + self.t = t + + self.file = [] + for fname in sorted(os.listdir(root)): + if fname.startswith(split): + self.file.append(fname) + + @property + def stride(self) -> int: + k = int(self.t / self.dt) + assert self.dt * k == self.t + return k + + @property + def samples_per_file(self) -> int: + return self.n - self.stride + 1 + + def __len__(self) -> int: + return len(self.file) * self.samples_per_file + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: + file_idx, time_idx = divmod(idx, self.samples_per_file) + + data = np.load(f"{self.root}/{self.file[file_idx]}", allow_pickle=True).item() + return torch.from_numpy(data["u"][time_idx]), torch.from_numpy(data["u"][time_idx + self.stride]) + +@hydra.main(version_base="1.3", config_path=".", config_name="config.yaml") +def isotropic_trainer(cfg: DictConfig) -> None: + """ + """ + DistributedManager.initialize() # Only call this once in the entire script! + dist = DistributedManager() # call if required elsewhere + + # initialize monitoring + log = PythonLogger(name="re94_ef") + log.file_logging(f"{cfg.training.result_dir}/log.txt") + LaunchLogger.initialize() # PhysicsNeMo launch logger + + # define model and optimizer + model = EddyFormer( + idim=cfg.model.idim, + odim=cfg.model.odim, + hdim=cfg.model.hdim, + num_layers=cfg.model.num_layers, + use_scale=cfg.model.use_scale, + cfg=EddyFormerConfig(**cfg.model.layer_config), + ).to(dist.device) + + if dist.distributed: + ddps = torch.cuda.Stream() + with torch.cuda.stream(ddps): + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + output_device=dist.device, + broadcast_buffers=dist.broadcast_buffers, + find_unused_parameters=dist.find_unused_parameters, + ) + torch.cuda.current_stream().wait_stream(ddps) + log.success("Initialized DDP training") + + optimizer = Adam(model.parameters(), lr=cfg.training.learning_rate) + + # define dataset and dataloader + dataset = Re94(root=cfg.training.dataset, split="train", t=cfg.training.t) + dataloader = DataLoader(dataset, cfg.training.batch_size, shuffle=True) + + # define relative l2 error as the loss function + def loss_fun(pred: Tensor, target: Tensor) -> Tensor: + return torch.linalg.norm(pred - target) / torch.linalg.norm(target) + + # define training step + @StaticCaptureTraining( + model=model, + optim=optimizer, + logger=log, + use_amp=False, + use_graphs=False + ) + def training_step(input: Tensor, target: Tensor) -> Tensor: + pred = torch.vmap(model)(input) + loss = torch.vmap(loss_fun)(pred, target) + return torch.mean(loss) + + it = 0 + log.info("Training started") + + for epoch in range(cfg.training.num_epochs): + for it, (input, target) in enumerate(dataloader, it): + + input = input.to(dist.device) + target = target.to(dist.device) + loss = training_step(input, target) + + with LaunchLogger("train", epoch=epoch) as logger: + logger.log_minibatch({"Training loss": loss.item()}) + + if it and it % cfg.training.ckpt_every == 0 and dist.rank == 0: + save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer, epoch=it) + + log.success("Training completed") + save_checkpoint(f"{cfg.training.result_dir}/ckpt.pt", model, optimizer) + + +if __name__ == "__main__": + isotropic_trainer() diff --git a/physicsnemo/models/eddyformer/__init__.py b/physicsnemo/models/eddyformer/__init__.py new file mode 100644 index 0000000000..63144343ba --- /dev/null +++ b/physicsnemo/models/eddyformer/__init__.py @@ -0,0 +1,3 @@ +from ._basis import Legendre +from ._datatype import SEM +from .eddyformer import EddyFormer, EddyFormerConfig diff --git a/physicsnemo/models/eddyformer/_basis.py b/physicsnemo/models/eddyformer/_basis.py new file mode 100644 index 0000000000..e3906ca529 --- /dev/null +++ b/physicsnemo/models/eddyformer/_basis.py @@ -0,0 +1,112 @@ +from typing import Protocol +from torch import Tensor + +import torch +import torch.nn as nn + +import numpy as np +import functools + +class Basis(Protocol): + + grid: Tensor + quad: Tensor + + m: int + f: Tensor + + def fn(self, xs: Tensor) -> Tensor: + """ + Evaluate basis functions at given points. + """ + + def at(self, coef: Tensor, xs: Tensor) -> Tensor: + """ + Evaluate basis expansion at given points. + """ + return torch.tensordot(self.fn(xs), coef, dims=1) + + def modal(self, vals: Tensor) -> Tensor: + """ + Convert nodal values to modal coefficients. + """ + + def nodal(self, coef: Tensor) -> Tensor: + """ + Convert modal coefficients to nodal values. + """ + +class Element(Basis): + + def __init__(self, base: Basis): + """ + """ + +# ---------------------------------------------------------------------------- # +# LEGENDRE # +# ---------------------------------------------------------------------------- # + +from numpy.polynomial import legendre + +@functools.cache +class Legendre(nn.Module, Basis): + + """ + Shifted Legendre polynomials: + - `(1 - x^2) Pn''(x) - 2 x Pn(x) + n (n + 1) Pn(x) = 0` + - `Pn^~(x) = Pn(2 x - 1)` + """ + + def extra_repr(self) -> str: + return f"m={self.m}" + + def __init__(self, m: int, endpoint: bool = False): + """ + """ + super().__init__() + self.m = m + + if endpoint: m -= 1 + c = (0, ) * m + (1, ) + dc = legendre.legder(c) + + x = legendre.legroots(dc if endpoint else c) + y = legendre.legval(x, c if endpoint else dc) + + if endpoint: + x = np.concatenate([[-1], x, [1]]) + y = np.concatenate([[1], y, [1]]) + + w = 1 / y ** 2 + if endpoint: w /= m * (m + 1) + else: w /= 1 - x ** 2 + + self.register_buffer("grid", torch.tensor((1 + x) / 2, dtype=torch.float)) + self.register_buffer("quad", torch.tensor(w, dtype=torch.float)) + + self.register_buffer("f", self.fn(self.grid)) + + def fn(self, xs: Tensor) -> Tensor: + """ + """ + P = torch.ones_like(xs), 2 * xs - 1 + + for i in range(2, self.m): + a, b = (i * 2 - 1) / i, (i - 1) / i + P += a * P[-1] * P[1] - b * P[-2], + + return torch.stack(P, dim=-1) + +# --------------------------------- TRANSFORM -------------------------------- # + + def modal(self, vals: Tensor) -> Tensor: + """ + """ + norm = 2 * torch.arange(self.m, device=vals.device) + 1 + coef = self.f * norm * self.quad[:, None] + return torch.tensordot(coef.T, vals, dims=1) + + def nodal(self, coef: Tensor) -> Tensor: + """ + """ + return self.at(coef, self.grid) diff --git a/physicsnemo/models/eddyformer/_datatype.py b/physicsnemo/models/eddyformer/_datatype.py new file mode 100644 index 0000000000..e2a248f7d4 --- /dev/null +++ b/physicsnemo/models/eddyformer/_datatype.py @@ -0,0 +1,232 @@ +from typing import Tuple +from torch import Tensor + +import torch + +from dataclasses import dataclass, replace +from functools import cached_property + +from ._basis import Basis, Legendre + +def interp1d(value: Tensor, xs: Tensor, method: str) -> Tensor: + """ + Interpolate from 1D regular grid to a target points. + + Args: + value: Values on a uniform grid along the first axis. + xs: Resolution or an array normalized by the domain size. + method: Interpolation method. One of "fft", "linear", or + f"lag{n}" for n-point Lagrangian interpolation. + """ + if method == "fft": + coef = torch.fft.rfft(value, dim=0, norm="forward") + + k = 2 * torch.pi * torch.arange(len(coef), device=xs.device) + f = torch.exp(1j * k * xs[..., None]); f[..., 1:-1] *= 2 + return torch.tensordot(f.real, coef.real, dims=1) \ + - torch.tensordot(f.imag, coef.imag, dims=1) + + if method.startswith("lag"): + n_points = int(method[3:]) + + assert n_points % 2 == 0 + r = n_points // 2 - 1 + + n = len(value) + + i = (xs * (N := n - 1)).int() + i = torch.clip(i, r, n - n_points + r) + + # 1. pad the input grid + + v_pad = value, value[:r+2] + + if r > 0: v_pad = (value[-r:], ) + v_pad + value = torch.concatenate(v_pad, dim=0) + + # 2. construct polynomials + + out = 0 + + for j in range(n_points): + lag = value[i + j] + + for k in range(n_points): + if j == k: continue + fac = xs - (i + k - r) / N + while fac.ndim < lag.ndim: + fac = fac.unsqueeze(-1) + lag *= fac * N / (j - k) + + out += lag + return out + + raise ValueError(f"invalid interpolation {method=}") + +# ---------------------------------------------------------------------------- # +# SPECTRAL ELEMENT # +# ---------------------------------------------------------------------------- # + +@dataclass +class SEM: + + """ + Spectral element expansion. The sub-domain partition is + given by the `mesh` attribute. The spectral coefficients + of each element is stored in the first channel dimension, + whose size must equal to the number of elements. + """ + + T_: str + + # Mesh + + size: Tensor + mesh: Tuple[int] + + # Data + + mode_: Tuple[int] = None + nodal: Tensor = None + + @property + def ndim(self) -> int: + return len(self.mesh) + + @property + def mode(self) -> Tuple[int]: + if self.mode_: return self.mode_ + return self.nodal.shape[:self.ndim] + + @property + def use_elem(self) -> bool: + return self.T_.endswith("elem") + + @staticmethod + def basis(T_: str) -> Basis: + if T_.startswith("leg"): return Legendre + raise ValueError(f"invalid basis {T_=}") + + @cached_property + def T(self) -> Tuple[Basis]: + """ + Basis on each dimension. + """ + T = self.basis(self.T_) + return tuple(map(T, self.mode)) + + def to(self, mode: Tuple[int]) -> "SEM": + """ + Resample to another mode. + + Args: + mode: Number of modes. + """ + out = SEM(self.T_, self.size, self.mesh, mode) + + value = self.nodal + for n in range(self.ndim): + coef = self.T[n].modal(value) + if (pad:=out.mode[n] - self.mode[n]) <= 0: coef = coef[:mode[n]] + else: coef = torch.concat([coef, torch.zeros(pad, *coef.shape[1:], device=coef.device)], dim=0) + value = out.T[n].nodal(coef).movedim(0, self.ndim - 1) + + return out.new(value) + + def at(self, *xs: Tensor, uniform: bool = False) -> Tensor: + """ + Evaluate on rectilinear grids. + + Args: + xs: Coordinate of each dimension. + uniform: Whether `xs` are uniformly spaced. + """ + value = self.nodal + for n in range(self.ndim): + x = xs[n] / self.size[n] + coef = self.T[n].modal(value) + + # indices of each global coordinate `x` + idx = torch.floor(x * float(self.mesh[n])).int() + idx = torch.minimum(idx, torch.tensor(self.mesh[n] - 1, device=idx.device)) + + # global coordinate to local coordinate + ys = x * float(self.mesh[n]) - torch.arange(self.mesh[n], device=x.device)[idx] + + if not uniform: + + # coefficients where each `x` belongs + coef = coef.movedim(self.ndim, 0)[idx] + + # evaluate at each coordinate and move the output axis to the last dimension + # after `ndim` iterations, the axes are automatically rolled to the correct order + value = torch.vmap(self.T[n].at, out_dims=self.ndim - 1)(coef, ys) + + else: + + # coordinates within each element + ys = ys.reshape(self.mesh[n], -1) + + # batched evaluation of all coordinates + value = torch.vmap(self.T[n].at, (self.ndim, 0))(coef, ys) + value = torch.movedim(value.flatten(end_dim=1), 0, self.ndim - 1) + + return value + +# ---------------------------------- COORDS ---------------------------------- # + + @cached_property + def grid(self) -> Tensor: + axes = [self.T[n].grid.to(self.size.device) for n in range(self.ndim)] + return torch.stack(torch.meshgrid(*axes, indexing="ij"), dim=-1) + + @cached_property + def coords(self) -> Tensor: + local = self.grid + for _ in range(self.ndim): + local = local.unsqueeze(self.ndim) + return self.origins + local * self.lengths + + @cached_property + def origins(self) -> Tensor: + left = [torch.arange(m, device=self.size.device) / m for m in self.mesh] + return torch.stack(torch.meshgrid(*left, indexing="ij"), dim=-1) * self.size + + @cached_property + def lengths(self) -> Tensor: + ns = torch.tensor(self.mesh, device=self.size.device) + return self.size / ns.float() + +# --------------------------------- DATA TYPE -------------------------------- # + + def new(self, nodal: Tensor) -> "SEM": + assert nodal.shape[:self.ndim] == self.mode + return replace(self, mode_=None, nodal=nodal) + + def eval(self, resolution: Tuple[int]) -> Tensor: + xs = [torch.linspace(0, s, n, device=self.size.device) for n, s in zip(resolution, self.size)] + return self.at(*xs, uniform=False) + + def from_grid(self, value: Tensor, method: str) -> "SEM": + """ + Interpolate grid values to a target datatype. + + Args: + value: Input tensor (include boundary points). + method: Interpolation method along each axis. + See `interp1d::method` for details. + """ + xs = self.coords / self.size + for n in range(self.ndim): + + # interpolate at each collocation points. `idx` is the + # index of the elements along the `n`'th dimension. + idx = [slice(None) if i == n else 0 for i in range(self.ndim)] + value = interp1d(value, xs[tuple(idx * 2 + [n])], method) + + # roll the output. The interpolated values have shape `(mode, mesh)`, + # which are moved to the middle (`ndim - 1`) and the end (`ndim + n`) of + # the dimensions. After `ndim` iterations, all axes are ordered correctly. + value = torch.moveaxis(value, (0, 1), (self.ndim - 1, self.ndim + n)) + + return self.new(value) diff --git a/physicsnemo/models/eddyformer/eddyformer.py b/physicsnemo/models/eddyformer/eddyformer.py new file mode 100644 index 0000000000..ec65c3a5ff --- /dev/null +++ b/physicsnemo/models/eddyformer/eddyformer.py @@ -0,0 +1,204 @@ +from typing import Tuple, Union, Optional +from torch import Tensor + +import torch +import torch.nn as nn + +from dataclasses import dataclass +from functools import partial + +from ..module import Module +from ..meta import ModelMetaData +from ..layers.mlp_layers import Mlp + +from ._datatype import SEM +from .sem_conv import SEMConv +from .sem_attn import SEMAttn + +class EddyFormerConfig(Module): + + basis: str + mesh: Tuple[int] + mode: Tuple[int] + + # SGS STREAM + kernel_size: Tuple[int] + + ffn_dim: int + activation: str + + # LES STREAM + mode_les: Tuple[int] + kernel_size_les: Tuple[int] + + num_heads: int + heads_dim: int + + def __init__(self, basis: str, mesh: Tuple[int], mode: Tuple[int], + kernel_size: Tuple[int], ffn_dim: int, activation: str, + mode_les: Tuple[int], kernel_size_les: Tuple[int], num_heads: int, heads_dim: int): + """ + """ + super().__init__() + + self.basis = basis + self.mesh = mesh + self.mode = mode + + self.kernel_size = kernel_size + self.ffn_dim = ffn_dim + self.activation = activation + + self.mode_les = mode_les + self.kernel_size_les = kernel_size_les + self.num_heads = num_heads + self.heads_dim = heads_dim + + @property + def ffn(self) -> partial[Mlp]: + return partial(Mlp, + hidden_features=self.ffn_dim, + act_layer=getattr(nn, self.activation), + ) + + @property + def attn(self) -> partial[SEMAttn]: + return partial(SEMAttn, + mode=self.mode_les, + num_heads=self.num_heads, + heads_dim=self.heads_dim, + ) + + def conv(self, stream: str) -> partial[SEMConv]: + return partial(SEMConv, + kernel_mode=(mode:=self.mode if stream == "sgs" else self.mode_les), + kernel_size=self.kernel_size if stream == "sgs" else self.kernel_size_les, + T=tuple(map(SEM.basis(self.basis), mode)), + ) + +# Layer + +class EddyFormerLayer(Module): + + def __init__(self, hdim: int, cfg: EddyFormerConfig, *, layer_scale: float = 1e-7): + """ + EddyFormer layer. + """ + super().__init__() + + self.mode = cfg.mode + self.mode_les = cfg.mode_les + + self.eps = nn.Parameter(torch.ones(hdim) * layer_scale) + self.ffn_les, self.ffn_sgs = cfg.ffn(hdim), cfg.ffn(hdim) + + self.sem_conv_sgs = cfg.conv("sgs")(hdim, hdim) + self.sem_conv_les = cfg.conv("les")(hdim, hdim) + self.sem_attn = cfg.attn(hdim, hdim, conv=cfg.conv("les")) + + def __call__(self, les: SEM, sgs: SEM) -> Tuple[SEM, SEM]: + """ + """ + les.nodal = les.nodal + self.sem_attn(les).nodal + les.nodal = les.nodal + self.ffn_les(self.sem_conv_les(les).nodal) + + sgs.nodal = sgs.nodal + self.eps * les.to(self.mode).nodal + sgs.nodal = sgs.nodal + self.ffn_sgs(self.sem_conv_sgs(sgs).nodal) + + return les, sgs + +# Model + +@dataclass +class MetaData(ModelMetaData): + name: str = "EddyFormer" + # Optimization + jit: bool = True + cuda_graphs: bool = True + amp: bool = False + # Inference + onnx_cpu: bool = False + onnx_gpu: bool = False + onnx_runtime: bool = False + # Physics informed + var_dim: int = 1 + func_torch: bool = False + auto_grad: bool = False + +class EddyFormer(Module): + + cfg: EddyFormerConfig + + lift_les: nn.Linear + lift_sgs: nn.Linear + + layers: nn.ModuleList + + proj_les: Mlp + proj_sgs: Mlp + + scale: Optional[nn.Parameter] + + def __init__(self, + idim: int, + odim: int, + hdim: int, + num_layers: int, + *, + use_scale: bool = True, + cfg: EddyFormerConfig): + """ + EddyFormer model. + """ + super().__init__(meta=MetaData()) + + self.cfg = cfg + self.ndim = len(cfg.mesh) + + self.lift_les = nn.Linear(idim + self.ndim, hdim) + self.lift_sgs = nn.Linear(idim + self.ndim, hdim) + + self.layers = nn.ModuleList() + for _ in range(num_layers): + layer = EddyFormerLayer(hdim, cfg) + self.layers.append(layer) + + self.proj_les = cfg.ffn(hdim, out_features=odim) + self.proj_sgs = cfg.ffn(hdim, out_features=odim) + + self.scale = nn.Parameter(torch.zeros(odim)) if use_scale else None + + def __call__(self, input: Union[SEM, Tensor], return_sem: bool = False) -> Union[SEM, Tensor]: + """ + """ + if isinstance(input, Tensor): + size = 2 * torch.pi * torch.ones(self.ndim, device=input.device) + ϕ = SEM(self.cfg.basis, size, self.cfg.mesh, self.cfg.mode) \ + .from_grid(input, "lag8") # default interpolation method + else: + ϕ = input + + x = ϕ.grid.to(ϕ.nodal) + for n, mesh in enumerate(ϕ.mesh): + x = x.unsqueeze(dim:=self.ndim + n) + x = torch.repeat_interleave(x, mesh, dim) + x = torch.concatenate(torch.broadcast_tensors(ϕ.nodal, x), dim=-1) + + sgs = ϕ.new(x) + les = sgs.to(self.cfg.mode_les) + + sgs.nodal = self.lift_sgs(sgs.nodal) + les.nodal = self.lift_les(les.nodal) + + for layer in self.layers: + les, sgs = layer(les, sgs) + + sgs.nodal = self.proj_sgs(sgs.nodal) + les.nodal = self.proj_les(les.nodal) + + scale = self.scale if self.scale is not None else 1. + out = ϕ.new(les.to(ϕ.mode).nodal + scale * sgs.nodal) + + if not return_sem: + out = out.eval(input.shape[:-1]) + return out diff --git a/physicsnemo/models/eddyformer/sem_attn.py b/physicsnemo/models/eddyformer/sem_attn.py new file mode 100644 index 0000000000..261a6ea8c5 --- /dev/null +++ b/physicsnemo/models/eddyformer/sem_attn.py @@ -0,0 +1,77 @@ +from typing import Tuple +from torch import Tensor + +import torch +import torch.nn as nn + +from functools import partial + +from ..module import Module +from ._datatype import SEM +from .sem_conv import SEMConv + +class SEMAttn(Module): + + proj: nn.ModuleDict + bias: nn.ParameterDict + norm: nn.ModuleDict + + out: nn.Linear + + def __init__(self, + idim: int, + odim: int, + mode: Tuple[int], + num_heads: int, + heads_dim: int, + *, + conv: partial[SEMConv], + bias_init = torch.zeros): + """ + """ + super().__init__() + + self.proj = nn.ModuleDict() + self.bias = nn.ParameterDict() + self.norm = nn.ModuleDict() + + for name in "QKV": + self.proj[name] = conv(idim, (num_heads, heads_dim)) + + if name in ["Q", "K"]: + for n in range(len(mode)): + self.bias[f"{name}{n}"] = nn.Parameter(bias_init((num_heads, heads_dim))) + self.norm[f"{name}{n}"] = nn.LayerNorm(heads_dim) + + self.out = nn.Linear(num_heads * heads_dim * len(mode), odim) + + def project(self, ϕ: SEM, name: str) -> Tensor: + """ + Project the features to attention space. + """ + xs = [] + + for n in range(ϕ.ndim): + x = self.proj[name].factor(ϕ, n).nodal + + if name in ["Q", "K"]: + x = x + self.bias[f"{name}{n}"] + assert x.shape[-1] % 2 == 0 + + f, g = torch.split(self.norm[f"{name}{n}"](x), x.shape[-1] // 2, dim=-1) + k = ϕ.coords[..., None, [n]] * torch.arange(f.shape[-1], device=x.device) + + f, g = torch.cos(k) * f - torch.sin(k) * g, torch.sin(k) * f + torch.cos(k) * g + x = torch.concatenate([torch.cos(k) + f, torch.sin(k) + g], dim=-1) + + xs.append(x.reshape(ϕ.mode + (-1, ) + x.shape[-2:])) + return torch.concatenate(xs, dim=-1) + + def __call__(self, ϕ: SEM) -> SEM: + """ + Self-attention on SEM features. + """ + q, k, v = (self.project(ϕ, name) for name in "QKV") + + attn = nn.functional.scaled_dot_product_attention(q, k, v) + return ϕ.new(self.out(attn.reshape(*ϕ.mode, *ϕ.mesh, -1))) diff --git a/physicsnemo/models/eddyformer/sem_conv.py b/physicsnemo/models/eddyformer/sem_conv.py new file mode 100644 index 0000000000..6d8b55d021 --- /dev/null +++ b/physicsnemo/models/eddyformer/sem_conv.py @@ -0,0 +1,143 @@ +from typing import Tuple, Union +from torch import Tensor + +import torch +import torch.nn as nn +import numpy as np + +from functools import partial, cache +from scipy import integrate +from tqdm import tqdm + +from ..module import Module +from ._basis import Basis +from ._datatype import SEM + +class SEMConv(Module): + + odim: Tuple[int] + kernel: nn.ParameterList + + def __init__(self, + idim: int, + odim: Union[int, Tuple[int]], + T: Tuple[Basis], + kernel_mode: Tuple[int], + kernel_size: Tuple[int], + kernel_init_std: float = 1e-7): + """ + """ + super().__init__() + self.T = nn.ModuleList(T) + + if isinstance(odim, int): + self.odim = (odim, ) + else: + self.odim = odim + odim = np.prod(odim) + + self.kernel = nn.ParameterList() + for n, (m, s) in enumerate(zip(kernel_mode, kernel_size)): + self.kernel.append(nn.Parameter(coef:=torch.empty(s * m, idim, odim))) + + torch.nn.init.normal_(coef, std=kernel_init_std) + self.register_buffer(f"ws_{n}", weight(T[n], s)) + + def factor(self, ϕ: SEM, dim: int) -> SEM: + """ + Factorized SEM convolution. + + Args: + ϕ: Input SEM feature field. + dim: Dimension to convolve over. + """ + coef, ws = self.kernel[dim], getattr(self, f"ws_{dim}") + out = sem_conv(ϕ.nodal, coef, ws, T=ϕ.T[dim], ndim=ϕ.ndim, dim=dim) + return ϕ.new(out.reshape(out.shape[:-1] + self.odim)) + + def __call__(self, ϕ: SEM) -> SEM: + return ϕ.new(sum(self.factor(ϕ, n).nodal for n in range(ϕ.ndim))) + +# ---------------------------------------------------------------------------- # +# CONVOLUTION # +# ---------------------------------------------------------------------------- # + +def kernel(coef: Tensor, xs: Tensor) -> Tensor: + """ + Evaluate the Fourier kernel. + + Args: + coef: Fourier coefficients. + xs: Query coordinates. + """ + r, i_ = torch.split(coef, (m:=(n:=len(coef)) // 2 + 1, n - m)) + i = torch.zeros_like(r); i[1:n-m+1] = torch.flip(i_, dims=[0]) + + k = 2 * torch.pi * torch.arange(m, device=xs.device) + f = torch.exp(1j * k * xs[..., None]); f[..., 1:-1] *= 2 + + return torch.tensordot(f.real, r, 1) \ + - torch.tensordot(f.imag, i, 1) + +@cache +def weight(T: Basis, s: int, use_mp: bool = True) -> Tensor: + """ + """ + print(f"Pre-computing weights for `{T=}` and `{s=}`...") + + eps = torch.finfo(torch.float).eps + ab = T.grid[..., None] + torch.tensor([-s/2, s/2]) + + map_ = map + if use_mp: + from concurrent.futures import ThreadPoolExecutor + map_ = (pool := ThreadPoolExecutor()).map + + def quad(T: Basis, m: int, a: float, b: float) -> Tensor: + f = lambda x: T.fn(torch.tensor(x))[m] + y, e = integrate.quad(f, a, b) + return y + + ws = [] + for i in range(-s//2, s//2 + 1): + ws.append(w:=[]) + + for a, b in tqdm(ab, f"{i=}"): + a = torch.clip(a - i, -eps, 1 + eps) + b = torch.clip(b - i, -eps, 1 + eps) + + q = torch.tensor(list(map_(partial(quad, T, a=a, b=b), range(T.m)))) + w.append(torch.linalg.solve(T.fn(T.grid).T, q).tolist()) + + if use_mp: pool.shutdown() + return torch.tensor(ws) + +def sem_conv(nodal: Tensor, coef: Tensor, ws: Tensor, *, T: Basis, ndim: int, dim: int): + """ + Args: + w: An (s + 1, m, n) array where s is the window size, m is the number + of quadrature nodes, and n is the number of output nodes. + """ + n = ndim + dim # mesh dim + + ns = "".join(map(chr, range(120, 120 + ndim))) + ms = ns.replace(i:=ns[dim], o:=ns[dim].upper()) + + pad_r = nodal.index_select(n, torch.arange(0, r:=len(ws)//2, device=nodal.device)) + pad_l = nodal.index_select(n, torch.arange(nodal.shape[n]-r, nodal.shape[n], device=nodal.device)) + + f = torch.concatenate([pad_l, nodal, pad_r], dim=n) + out = [] + + for k, w in enumerate(ws): + + x = T.grid + k - r + xy = T.grid[:, None] - x + + fx = torch.narrow(f, n, k, nodal.shape[n]) + gxy = kernel(coef, xy / (len(ws) - 1)) + + eqn = f"{ns}...i, {o}{i}io, {o}{i} -> {ms}...o" + out.append(torch.einsum(eqn, fx, gxy, w)) + + return sum(out) diff --git a/pyproject.toml b/pyproject.toml index 96848109c9..41fc28189f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ all = [ "ruamel.yaml>=0.17.22", "scikit-learn>=1.0.2", "scikit-image>=0.24.0", + "scipy>=1.15.0", "warp-lang>=1.0", "vtk>=9.2.6", "pyvista>=0.40.1",