diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index da8bce2d..8de8b324 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -23,6 +23,10 @@ Supported model outputs :py:meth:`ase.Atoms.get_forces`, …); - arbitrary outputs can be computed for any :py:class:`ase.Atoms` using :py:meth:`MetatomicCalculator.run_model`; +- for non-equivariant architectures like + `PET `_, + rotatonally-averaged energies, forces, and stresses can be computed using + :py:class:`metatomic.torch.ase_calculator.SymmetrizedCalculator`. How to install the code ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/torch/reference/ase.rst b/docs/src/torch/reference/ase.rst index f217a3b9..ddb1f49a 100644 --- a/docs/src/torch/reference/ase.rst +++ b/docs/src/torch/reference/ase.rst @@ -17,3 +17,7 @@ not just the energy, through the .. autoclass:: metatomic.torch.ase_calculator.MetatomicCalculator :show-inheritance: :members: + +.. autoclass:: metatomic.torch.ase_calculator.SymmetrizedCalculator + :show-inheritance: + :members: diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 346e8929..cd222323 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import metatensor.torch import numpy as np @@ -33,7 +33,6 @@ all_properties as ALL_ASE_PROPERTIES, ) - FilePath = Union[str, bytes, pathlib.PurePath] LOGGER = logging.getLogger(__name__) @@ -575,6 +574,8 @@ def compute_energy( self, atoms: Union[ase.Atoms, List[ase.Atoms]], compute_forces_and_stresses: bool = False, + *, + compute_energies: bool = False, ) -> Dict[str, Union[Union[float, np.ndarray], List[Union[float, np.ndarray]]]]: """ Compute the energy of the given ``atoms``. @@ -601,8 +602,14 @@ def compute_energy( atoms_list = atoms was_single = False + properties = ["energy"] + energy_per_atom = False + if compute_energies: + energy_per_atom = True + properties.append("energies") + outputs = self._ase_properties_to_metatensor_outputs( - properties=["energy"], + properties=properties, calculate_forces=compute_forces_and_stresses, calculate_stress=compute_forces_and_stresses, calculate_stresses=False, @@ -649,9 +656,26 @@ def compute_energy( ) energies = predictions[self._energy_key] - results_as_numpy_arrays = { - "energy": energies.block().values.detach().cpu().numpy().flatten().tolist() - } + if energy_per_atom: + results_as_numpy_arrays = { + "energies": energies.block().values.squeeze(-1).detach().cpu().numpy(), + "energy": metatensor.torch.sum_over_samples(energies, ["atom"]) + .block() + .values.detach() + .cpu() + .numpy() + .flatten() + .tolist(), + } + split_sizes = [len(system) for system in systems] + split_indices = np.cumsum(split_sizes[:-1]) + results_as_numpy_arrays["energies"] = np.split( + results_as_numpy_arrays["energies"], split_indices, axis=0 + ) + else: + results_as_numpy_arrays = { + "energy": energies.block().values.squeeze(-1).detach().cpu().numpy(), + } if compute_forces_and_stresses: if self.parameters["non_conservative"]: results_as_numpy_arrays["forces"] = ( @@ -664,8 +688,9 @@ def compute_energy( ) # all the forces are concatenated in a single array, so we need to # split them into the original systems - split_sizes = [len(system) for system in systems] - split_indices = np.cumsum(split_sizes[:-1]) + if not energy_per_atom: + split_sizes = [len(system) for system in systems] + split_indices = np.cumsum(split_sizes[:-1]) results_as_numpy_arrays["forces"] = np.split( results_as_numpy_arrays["forces"], split_indices, axis=0 ) @@ -860,3 +885,467 @@ def _full_3x3_to_voigt_6_stress(stress): (stress[0, 1] + stress[1, 0]) / 2.0, ] ) + + +class SymmetrizedCalculator(ase.calculators.calculator.Calculator): + r""" + Take a MetatomicCalculator and average its predictions to make it (approximately) + equivariant. Only predictions for energy, forces and stress are supported. + + The default is to average over a quadrature of the orthogonal group O(3) composed + this way: + + - Lebedev quadrature of the unit sphere (S^2) + - Equispaced sampling of the unit circle (S^1) + - Both proper and improper rotations are taken into account by including the + inversion operation (if ``include_inversion=True``) + + :param base_calculator: the MetatomicCalculator to be symmetrized + :param l_max: the maximum spherical harmonic degree that the model is expected to + be able to represent. This is used to choose the quadrature order. If ``0``, + no rotational averaging will be performed (it can be useful to average only over + the space group, see ``apply_group_symmetry``). + :param batch_size: number of rotated systems to evaluate at once. If ``None``, all + systems will be evaluated at once (this can lead to high memory usage). + :param include_inversion: if ``True``, the inversion operation will be included in + the averaging. This is required to average over the full orthogonal group O(3). + :param apply_space_group_symmetry: if ``True``, the results will be averaged over + discrete space group of rotations for the input system. The group operations are + computed with `spglib `, and the average is + performed after the O(3) averaging (if any). This has no effect for non-periodic + systems. + :param store_rotational_std: if ``True``, the results will contain the standard + deviation over the different rotations for each property (e.g., ``energy_std``). + """ + + implemented_properties = ["energy", "energies", "forces", "stress", "stresses"] + + def __init__( + self, + base_calculator: MetatomicCalculator, + *, + l_max: int = 3, + batch_size: Optional[int] = None, + include_inversion: bool = True, + apply_space_group_symmetry: bool = False, + store_rotational_std: bool = False, + ) -> None: + try: + from scipy.integrate import lebedev_rule # noqa: F401 + except ImportError as e: + raise ImportError( + "scipy is required to use the `SymmetrizedCalculator`, please install " + "it with `pip install scipy` or `conda install scipy`" + ) from e + + super().__init__() + + self.base_calculator = base_calculator + if l_max > 131: + raise ValueError( + f"l_max={l_max} is too large, the maximum supported value is 131" + ) + self.l_max = l_max + self.include_inversion = include_inversion + + if l_max > 0: + lebedev_order, n_inplane_rotations = _choose_quadrature(l_max) + self.quadrature_rotations, self.quadrature_weights = _get_quadrature( + lebedev_order, n_inplane_rotations, include_inversion + ) + else: + # no quadrature + self.quadrature_rotations = np.array([np.eye(3)]) + self.quadrature_weights = np.array([1.0]) + + self.batch_size = ( + batch_size if batch_size is not None else len(self.quadrature_rotations) + ) + + self.store_rotational_std = store_rotational_std + self.apply_space_group_symmetry = apply_space_group_symmetry + + def calculate( + self, atoms: ase.Atoms, properties: List[str], system_changes: List[str] + ) -> None: + """ + Perform the calculation for the given atoms and properties. + + :param atoms: the :py:class:`ase.Atoms` on which to perform the calculation + :param properties: list of properties to compute, among ``energy``, ``forces``, + and ``stress`` + :param system_changes: list of changes to the system since the last call to + ``calculate`` + """ + super().calculate(atoms, properties, system_changes) + self.base_calculator.calculate(atoms, properties, system_changes) + + compute_forces_and_stresses = "forces" in properties or "stress" in properties + compute_energies = "energies" in properties + + if len(self.quadrature_rotations) > 0: + rotated_atoms_list = _rotate_atoms(atoms, self.quadrature_rotations) + batches = [ + rotated_atoms_list[i : i + self.batch_size] + for i in range(0, len(rotated_atoms_list), self.batch_size) + ] + results: Dict[str, np.ndarray] = {} + for batch in batches: + try: + batch_results = self.base_calculator.compute_energy( + batch, + compute_forces_and_stresses, + compute_energies=compute_energies, + ) + for key, value in batch_results.items(): + results.setdefault(key, []) + results[key].extend( + [value] if isinstance(value, float) else value + ) + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + "Out of memory error encountered during rotational averaging. " + "Please reduce the batch size or use lower rotational " + "averaging parameters. This can be done by setting the " + "`batch_size` and `l_max` parameters while initializing the " + "calculator." + ) from e + + self.results.update( + _compute_rotational_average( + results, + self.quadrature_rotations, + self.quadrature_weights, + self.store_rotational_std, + ) + ) + + if self.apply_space_group_symmetry: + # Apply the discrete space group of the system a posteriori + Q_list, P_list = _get_group_operations(atoms) + self.results.update(_average_over_group(self.results, Q_list, P_list)) + + +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = 2 * L_max + 1 + return n, K + + +def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: + """ + Create a list of copies of ``atoms``, rotated by each of the given ``rotations``. + + :param atoms: the :py:class:`ase.Atoms` to be rotated + :param rotations: (N, 3, 3) array of orthogonal matrices + :return: list of N :py:class:`ase.Atoms`, each rotated by the corresponding matrix + """ + rotated_atoms_list = [] + has_cell = atoms.cell is not None and atoms.cell.rank > 0 + for rot in rotations: + new_atoms = atoms.copy() + new_atoms.positions = new_atoms.positions @ rot.T + if has_cell: + new_atoms.set_cell( + new_atoms.cell.array @ rot.T, scale_atoms=False, apply_constraint=False + ) + new_atoms.wrap() + rotated_atoms_list.append(new_atoms) + return rotated_atoms_list + + +def _get_quadrature(lebedev_order: int, n_rotations: int, include_inversion: bool): + """ + Lebedev(S^2) x uniform angle quadrature on SO(3). + If include_inversion=True, extend to O(3) by adding inversion * R. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :param include_inversion: if ``True``, include the inversion operation in the + quadrature + :return: (N, 3, 3) array of orthogonal matrices, and (N,) array of weights + associated to each matrix + """ + from scipy.integrate import lebedev_rule + + # Lebedev nodes (X: (3, M)) + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(z) # (M,) + # beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + + K = int(n_rotations) + gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) + + Rot = _rotations_from_angles(alpha, beta, gamma) + R_so3 = Rot.as_matrix() # (N, 3, 3) + + # SO(3) Haar–probability weights: w_i/(4*pi*K), repeated over gamma + w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) + + if not include_inversion: + return R_so3, w_so3 + + # Extend to O(3) by appending inversion * R + P = -np.eye(3) + R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) + w_o3 = np.concatenate([0.5 * w_so3, 0.5 * w_so3], axis=0) + + return R_o3, w_o3 + + +def _rotations_from_angles(alpha, beta, gamma): + from scipy.spatial.transform import Rotation + + # Build all combinations (alpha_i, beta_i, gamma_j) + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", A) + * Rotation.from_euler("y", B) + * Rotation.from_euler("z", G) + ) + + return Rot + + +def _compute_rotational_average(results, rotations, weights, store_std): + R = rotations + B = R.shape[0] + w = weights + w = w / w.sum() + + def _wreshape(x): + return w.reshape((B,) + (1,) * (x.ndim - 1)) + + def _wmean(x): + return np.sum(_wreshape(x) * x, axis=0) + + def _wstd(x): + mu = _wmean(x) + return np.sqrt(np.sum(_wreshape(x) * (x - mu) ** 2, axis=0)) + + out = {} + + # Energy (B,) + if "energy" in results: + E = np.asarray(results["energy"], dtype=float) # (B,) + out["energy"] = _wmean(E) # () + if store_std: + out["energy_rot_std"] = _wstd(E) # () + + if "energies" in results: + E = np.asarray(results["energies"], dtype=float) # (B,N) + out["energies"] = _wmean(E) # (N,) + if store_std: + out["energies_rot_std"] = _wstd(E) # (N,) + + # Forces (B,N,3) from rotated structures: back-rotate with F' R + if "forces" in results: + F = np.asarray(results["forces"], dtype=float) # (B,N,3) + F_back = F @ R # F' R + out["forces"] = _wmean(F_back) # (N,3) + if store_std: + out["forces_rot_std"] = _wstd(F_back) # (N,3) + + # Stress (B,3,3) from rotated structures: back-rotate with R^T S' R + if "stress" in results: + S = np.asarray(results["stress"], dtype=float) # (B,3,3) + RT = np.swapaxes(R, 1, 2) + S_back = RT @ S @ R # R^T S' R + out["stress"] = _wmean(S_back) # (3,3) + if store_std: + out["stress_rot_std"] = _wstd(S_back) # (3,3) + + if "stresses" in results: + S = np.asarray(results["stresses"], dtype=float) # (B,N,3,3) + RT = np.swapaxes(R, 1, 2) + S_back = RT[:, None, :, :] @ S @ R[:, None, :, :] # R^T S' R + out["stresses"] = _wmean(S_back) # (N,3,3) + if store_std: + out["stresses_rot_std"] = _wstd(S_back) # (N,3,3) + + return out + + +def _get_group_operations( + atoms: ase.Atoms, symprec: float = 1e-6, angle_tolerance: float = -1.0 +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """ + Extract point-group rotations Q_g (Cartesian, 3x3) and the corresponding + atom-index permutations P_g (N x N) induced by the space-group operations. + Returns Q_list, Cartesian rotation matrices of the point group, + and P_list, permutation matrices mapping original indexing -> indexing after (R,t), + + :param atoms: input structure + :param symprec: tolerance for symmetry finding + :param angle_tolerance: tolerance for symmetry finding (in degrees). If less than 0, + a value depending on ``symprec`` will be chosen automatically by spglib. + :return: List of rotation matrices and permutation matrices. + + """ + try: + import spglib + except ImportError as e: + raise ImportError( + "spglib is required to use the SymmetrizedCalculator with " + "`apply_group_symmetry=True`. Please install it with " + "`pip install spglib` or `conda install -c conda-forge spglib`" + ) from e + + # Lattice with column vectors a1,a2,a3 (spglib expects (cell, frac, Z)) + A = atoms.cell.array.T # (3,3) + frac = atoms.get_scaled_positions() # (N,3) in [0,1) + numbers = atoms.numbers + N = len(atoms) + + data = spglib.get_symmetry_dataset( + (atoms.cell.array, frac, numbers), + symprec=symprec, + angle_tolerance=angle_tolerance, + ) + + if data is None: + # No symmetry found + return [], [] + R_frac = data.rotations # (n_ops, 3,3), integer + t_frac = data.translations # (n_ops, 3) + Z = numbers + + # Match fractional coords modulo 1 within a tolerance, respecting chemical species + def _match_index(x_new, frac_ref, Z_ref, Z_i, tol=1e-6): + d = np.abs(frac_ref - x_new) # (N,3) + d = np.minimum(d, 1.0 - d) # periodic distance + # Mask by identical species + mask = Z_ref == Z_i + if not np.any(mask): + raise RuntimeError("No matching species found while building permutation.") + # Choose argmin over max-norm within species + idx = np.where(mask)[0] + j = idx[np.argmin(np.max(d[idx], axis=1))] + + # Sanity check + if np.max(d[j]) > tol: + pass + return j + + Q_list, P_list = [], [] + seen = set() + Ainv = np.linalg.inv(A) + + for Rf, tf in zip(R_frac, t_frac, strict=False): + # Cartesian rotation: Q = A Rf A^{-1} + Q = A @ Rf @ Ainv + # Deduplicate rotations (point group) by rounding + key = tuple(np.round(Q.flatten(), 12)) + if key in seen: + continue + seen.add(key) + + # Build the permutation P from i to j + P = np.zeros((N, N), dtype=int) + new_frac = (frac @ Rf.T + tf) % 1.0 # images after (Rf,tf) + for i in range(N): + j = _match_index(new_frac[i], frac, Z, Z[i]) + P[j, i] = 1 # column i maps to row j + + Q_list.append(Q.astype(float)) + P_list.append(P) + + return Q_list, P_list + + +def _average_over_group( + results: dict, Q_list: List[np.ndarray], P_list: List[np.ndarray] +) -> dict: + """ + Apply the point-group projector in output space. + + :param results: Must contain 'energy' (scalar), and/or 'forces' (N,3), and/or + 'stress' (3,3). These are predictions for the current structure in the reference + frame. + :param Q_list: Rotation matrices of the point group, from + :py:func:`_get_group_operations` + :param P_list: Permutation matrices of the point group, from + :py:func:`_get_group_operations` + :return out: Projected quantities. + """ + m = len(Q_list) + if m == 0: + return results # nothing to do + + out = {} + # Energy: unchanged by the projector (scalar) + if "energy" in results: + out["energy"] = float(results["energy"]) + + # Forces: (N,3) row-vectors; projector: (1/|G|) \sum_g P_g^T F Q_g + if "forces" in results: + F = np.asarray(results["forces"], float) + if F.ndim != 2 or F.shape[1] != 3: + raise ValueError(f"'forces' must be (N,3), got {F.shape}") + acc = np.zeros_like(F) + for Q, P in zip(Q_list, P_list, strict=False): + acc += P.T @ (F @ Q) + out["forces"] = acc / m + + # Stress: (3,3); projector: (1/|G|) \sum_g Q_g^T S Q_g + if "stress" in results: + S = np.asarray(results["stress"], float) + if S.shape != (3, 3): + raise ValueError(f"'stress' must be (3,3), got {S.shape}") + # S = 0.5 * (S + S.T) # symmetrize just in case + acc = np.zeros_like(S) + for Q in Q_list: + acc += Q.T @ S @ Q + S_pg = acc / m + out["stress"] = S_pg + + return out diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index abc5df43..9d200997 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -277,8 +277,11 @@ def test_run_model(tmpdir, model, atoms): assert outputs["non_conservative_stress"].block().values.shape == (2, 3, 3, 1) -@pytest.mark.parametrize("non_conservative", [True, False]) -def test_compute_energy(tmpdir, model, atoms, non_conservative): +@pytest.mark.parametrize( + "non_conservative, compute_energies", + [(True, True), (False, False), (True, False), (False, True)], +) +def test_compute_energy(tmpdir, model, atoms, non_conservative, compute_energies): ref = atoms.copy() ref.calc = ase.calculators.lj.LennardJones( sigma=SIGMA, epsilon=EPSILON, rc=CUTOFF, ro=CUTOFF, smooth=False @@ -292,23 +295,35 @@ def test_compute_energy(tmpdir, model, atoms, non_conservative): non_conservative=non_conservative, ) - energy = calculator.compute_energy(atoms)["energy"] - assert np.allclose(ref.get_potential_energy(), energy) + results = calculator.compute_energy(atoms, compute_energies=compute_energies) + if compute_energies: + energies = results["energies"] + assert np.allclose(ref.get_potential_energies(), energies) + assert np.allclose(ref.get_potential_energy(), results["energy"]) - results = calculator.compute_energy(atoms, compute_forces_and_stresses=True) + results = calculator.compute_energy( + atoms, compute_forces_and_stresses=True, compute_energies=compute_energies + ) assert np.allclose(ref.get_potential_energy(), results["energy"]) if not non_conservative: assert np.allclose(ref.get_forces(), results["forces"]) assert np.allclose( ref.get_stress(), _full_3x3_to_voigt_6_stress(results["stress"]) ) + if compute_energies: + assert np.allclose(ref.get_potential_energies(), results["energies"]) - energies = calculator.compute_energy([atoms, atoms])["energy"] - assert np.allclose(ref.get_potential_energy(), energies[0]) - assert np.allclose(ref.get_potential_energy(), energies[1]) + results = calculator.compute_energy([atoms, atoms]) + assert np.allclose(ref.get_potential_energy(), results["energy"][0]) + assert np.allclose(ref.get_potential_energy(), results["energy"][1]) + if compute_energies: + assert np.allclose(ref.get_potential_energies(), results["energies"][0]) + assert np.allclose(ref.get_potential_energies(), results["energies"][1]) results = calculator.compute_energy( - [atoms, atoms], compute_forces_and_stresses=True + [atoms, atoms], + compute_forces_and_stresses=True, + compute_energies=compute_energies, ) assert np.allclose(ref.get_potential_energy(), results["energy"][0]) assert np.allclose(ref.get_potential_energy(), results["energy"][1]) @@ -321,6 +336,9 @@ def test_compute_energy(tmpdir, model, atoms, non_conservative): assert np.allclose( ref.get_stress(), _full_3x3_to_voigt_6_stress(results["stress"][1]) ) + if compute_energies: + assert np.allclose(ref.get_potential_energies(), results["energies"][0]) + assert np.allclose(ref.get_potential_energies(), results["energies"][1]) atoms_no_pbc = atoms.copy() atoms_no_pbc.pbc = [False, False, False] diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py new file mode 100644 index 00000000..d7723438 --- /dev/null +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -0,0 +1,598 @@ +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pytest +import torch +from ase import Atoms +from ase.build import bulk, molecule +from metatensor.torch import Labels, TensorBlock, TensorMap + +import metatomic.torch as mta +from metatomic.torch import ( + ModelOutput, + NeighborListOptions, + System, +) +from metatomic.torch.ase_calculator import SymmetrizedCalculator, _get_quadrature + + +def _body_axis_from_system(system: System) -> torch.Tensor: + """ + Return the normalized vector connecting the two farthest atoms. + + :param atoms: Atomic configuration. + :return: Normalized 3D vector defining the body axis. + """ + pos = system.positions + if len(pos) < 2: + return torch.tensor([0.0, 0.0, 1.0], dtype=pos.dtype, device=pos.device) + d2 = torch.sum((pos[:, None, :] - pos[None, :, :]) ** 2, axis=-1) + # i, j = torch.unravel_index(torch.argmax(d2), d2.shape) # for newer PyTorch + idx = torch.argmax(d2) + i = idx // d2.shape[1] + j = idx % d2.shape[1] + b = pos[j] - pos[i] + nrm = torch.linalg.norm(b) + return ( + b / nrm + if nrm > 0 + else torch.tensor([0.0, 0.0, 1.0], dtype=pos.dtype, device=pos.device) + ) + + +def _legendre_0_1_2_3(c: float) -> tuple[float, float, float, float]: + """ + Compute Legendre polynomials P0..P3(c). + + :param c: Cosine between the body axis and the lab z-axis. + :return: Tuple (P0, P1, P2, P3). + """ + P0 = 1.0 + P1 = c + P2 = 0.5 * (3 * c**2 - 1.0) + P3 = 0.5 * (5 * c**3 - 3 * c) + return P0, P1, P2, P3 + + +class MockAnisoModel(torch.nn.Module): + """ + Deterministic, rotation-dependent mock for testing SymmetrizedCalculator. + + Components: + - Energy: E_true + a1*P1 + a2*P2 + a3*P3 + - Forces: F_true + (b1*P1 + b2*P2 + b3*P3)*zhat + optional tensor L=2 term + - Stress: p_iso*I + (c2*P2 + c3*P3)*D + + :param a: Coefficients for Legendre P0..P3 in the energy. + :param b: Coefficients for P1..P3 in the forces (spurious vector parts). + :param c: Coefficients for P2,P3 in the stress (spurious deviators). + :param p_iso: Isotropic (true) part of the stress tensor. + :param tensor_forces: If True, add L=2 tensor-coupled force term. + :param tensor_amp: Amplitude of the tensor-coupled force component. + :param dtype: Data type for internal tensors. + :param device: Device for internal tensors. + """ + + def __init__( + self, + a: Tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + b: Tuple[float, float, float] = (0.0, 0.0, 0.0), + c: Tuple[float, float] = (0.0, 0.0), + p_iso: float = 1.0, + tensor_forces: bool = False, + tensor_amp: float = 0.5, + dtype: torch.dtype = torch.float64, + device: Union[str, torch.device] = "cpu", + ) -> None: + super().__init__() + self.a0, self.a1, self.a2, self.a3 = a + self.b1, self.b2, self.b3 = b + self.c2, self.c3 = c + self.p_iso = p_iso + self.tensor_forces = tensor_forces + self.tensor_amp = tensor_amp + self._dtype = dtype + self._device = torch.device(device) + + # Fixed bases + self._zhat = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device) + self._D = torch.diag(torch.tensor([1.0, -1.0, 0.0], dtype=dtype, device=device)) + + @torch.jit.export + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + n_sys = len(systems) + + # Pre-allocate storages + energies: List[torch.Tensor] = [] + stresses: List[torch.Tensor] = [] + forces: List[torch.Tensor] = [] + + for sys in systems: + pos = sys.positions + + # Determine body axis and related scalars + b = _body_axis_from_system(sys).to(dtype=self._dtype, device=self._device) + cval = float(torch.dot(b, self._zhat)) + P0, P1, P2, P3 = _legendre_0_1_2_3(cval) + + # Energy + E_true = torch.sum(pos**2) + E = E_true + self.a0 * P0 + self.a1 * P1 + self.a2 * P2 + self.a3 * P3 + energies.append(E) + + # Forces + F_true = pos.clone() + F_spur = (self.b1 * P1 + self.b2 * P2 + self.b3 * P3) * self._zhat[None, :] + F = F_true + F_spur + if self.tensor_forces: + v = torch.cross(self._zhat, b, dim=0) + s = torch.norm(v) + cth = float(torch.dot(self._zhat, b)) + if s < 1e-15: + R = ( + torch.eye(3, dtype=self._dtype, device=self._device) + if cth > 0 + else -torch.eye(3, dtype=self._dtype, device=self._device) + ) + else: + vx = torch.tensor( + [ + [0.0, -v[2], v[1]], + [v[2], 0.0, -v[0]], + [-v[1], v[0], 0.0], + ], + dtype=self._dtype, + device=self._device, + ) + R = torch.eye(3) + vx + vx @ vx * ((1.0 - cth) / (s**2)) + T = R @ self._D @ R.T + F_tensor = self.tensor_amp * (T @ self._zhat) + F = F + F_tensor[None, :] + forces.append(F) + + # Stress + S = ( + self.p_iso * torch.eye(3, dtype=self._dtype, device=self._device) + + (self.c2 * P2 + self.c3 * P3) * self._D + ) + stresses.append(S) + + result: Dict[str, TensorMap] = {} + key = Labels( + names=["_"], + values=torch.tensor([[0]], dtype=torch.int64, device=self._device), + ) + + # Energy + print(torch.stack(energies, dim=0).shape) + energy_block = TensorBlock( + values=torch.stack(energies, dim=0) + .to(dtype=self._dtype, device=self._device) + .unsqueeze(-1), + samples=Labels( + names=["system"], + values=torch.arange( + n_sys, dtype=torch.int64, device=self._device + ).unsqueeze(1), + ), + components=[], + properties=Labels( + names=["energy"], + values=torch.tensor([[0]], dtype=torch.int64, device=self._device), + ), + ) + + # Forces + print(torch.cat(forces, dim=0).shape) + force_block = TensorBlock( + values=torch.cat(forces, dim=0) + .to(dtype=self._dtype, device=self._device) + .unsqueeze(-1), + samples=Labels( + names=["system", "atom"], + values=torch.cat( + [ + torch.cartesian_prod(torch.tensor([i]), torch.arange(len(sys))) + for i, sys in enumerate(systems) + ] + ).to(dtype=torch.int64, device=self._device), + ), + components=[ + Labels( + "xyz", + torch.arange(3) + .reshape(-1, 1) + .to(dtype=torch.int64, device=self._device), + ) + ], # vector components + properties=Labels( + names=["non_conservative_forces"], + values=torch.tensor([[0]], dtype=torch.int64, device=self._device), + ), + ) + + # Stress + print(torch.stack(stresses, dim=0).shape) + stress_block = TensorBlock( + values=torch.stack(stresses, dim=0) + .to(dtype=self._dtype, device=self._device) + .unsqueeze(-1), + samples=Labels( + names=["system"], + values=torch.arange( + n_sys, dtype=torch.int64, device=self._device + ).unsqueeze(1), + ), + components=[ + Labels( + "xyz_1", + torch.arange(3) + .reshape(-1, 1) + .to(dtype=torch.int64, device=self._device), + ), + Labels( + "xyz_2", + torch.arange(3) + .reshape(-1, 1) + .to(dtype=torch.int64, device=self._device), + ), + ], + properties=Labels( + names=["non_conservative_stress"], + values=torch.tensor([[0]], dtype=torch.int64, device=self._device), + ), + ) + + if "energy" in outputs: + result["energy"] = TensorMap(key, [energy_block]) + + if "non_conservative_forces" in outputs: + result["non_conservative_forces"] = TensorMap(key, [force_block]) + + if "non_conservative_stress" in outputs: + result["non_conservative_stress"] = TensorMap(key, [stress_block]) + + return result + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [] + + +def mock_calculator( + a: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + b: tuple[float, float, float] = (0.0, 0.0, 0.0), + c: tuple[float, float] = (0.0, 0.0), + p_iso: float = 1.0, + tensor_forces: bool = False, + tensor_amp: float = 0.5, +) -> mta.ase_calculator.MetatomicCalculator: + model = MockAnisoModel( + a=a, + b=b, + c=c, + p_iso=p_iso, + tensor_forces=tensor_forces, + tensor_amp=tensor_amp, + ) + model.eval() + + atomistic_model = mta.AtomisticModel( + model, + mta.ModelMetadata("mock_aniso", "Mock anisotropic model for testing"), + mta.ModelCapabilities( + { + "energy": mta.ModelOutput(per_atom=False), + "non_conservative_forces": mta.ModelOutput(per_atom=True), + "non_conservative_stress": mta.ModelOutput(per_atom=False), + }, + list(range(1, 102)), + 100, + "angstrom", + ["cpu"], + "float64", + ), + ) + return mta.ase_calculator.MetatomicCalculator( + atomistic_model, + non_conservative=True, + do_gradients_with_energy=False, + additional_outputs={ + "energy": mta.ModelOutput(per_atom=False), + "non_conservative_forces": mta.ModelOutput(per_atom=True), + "non_conservative_stress": mta.ModelOutput(per_atom=False), + }, + ) + + +@pytest.fixture +def dimer() -> Atoms: + """ + Create a small asymmetric geometry with a well-defined body axis. + + :return: ASE Atoms object with the H2 molecule. + """ + return Atoms("H2", positions=[[0, 0, 0], [0.3, 0.2, 1.0]]) + + +@pytest.fixture +def fcc_bulk() -> Atoms: + """ + Create a small FCC bulk structure. + + :return: ASE Atoms object with FCC Cu. + """ + return bulk("Cu", "fcc", cubic=True) + + +def test_quadrature_normalization(): + """Verify normalization and determinant signs of the quadrature.""" + R, w = _get_quadrature(lebedev_order=11, n_rotations=5, include_inversion=True) + assert np.isclose(np.sum(w), 1.0) + dets = np.linalg.det(R) + assert np.all(np.isin(np.round(dets).astype(int), [-1, 1])) + + +@pytest.mark.parametrize("Lmax, expect_removed", [(0, False), (3, True)]) +def test_energy_L_components_removed( + dimer: Atoms, Lmax: int, expect_removed: bool +) -> None: + """ + Verify that spurious energy components vanish once rotational averaging is applied. + For Lmax>0, all use the same minimal Lebedev rule (order=3). + """ + a = (1.0, 1.0, 1.0, 1.0) + base = mock_calculator(a=a) + calc = SymmetrizedCalculator(base, l_max=Lmax) + dimer.calc = calc + dimer.get_forces() + e = dimer.get_potential_energy() + E_true = float(np.sum(dimer.positions**2)) + if expect_removed: + assert np.isclose(e, E_true + a[0], atol=1e-10) + else: + assert not np.isclose(e, E_true + a[0], atol=1e-10) + + +def test_force_backrotation_exact(dimer: Atoms) -> None: + """ + Check that forces are back-rotated exactly when no spurious terms are present. + + :param dimer: Test atomic structure. + """ + base = mock_calculator(b=(0, 0, 0)) + calc = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc + F = dimer.get_forces() + expected_F = dimer.get_positions() + expected_F -= np.mean(expected_F, axis=0) + assert np.allclose(F, expected_F, atol=1e-12) + + +def test_tensorial_L2_force_cancellation(dimer: Atoms) -> None: + """ + Tensor-coupled (L=2) force components must vanish under O(3) averaging. + + Since the minimal Lebedev order used internally is 3, all quadratures + integrate L=2 components exactly; we only check for correct cancellation. + """ + base = mock_calculator(tensor_forces=True, tensor_amp=1.0) + + for Lmax in [1, 2, 3]: + calc = SymmetrizedCalculator(base, l_max=Lmax) + dimer.calc = calc + F = dimer.get_forces() + expected_F = dimer.get_positions() + expected_F -= np.mean(expected_F, axis=0) + assert np.allclose(F, expected_F, atol=1e-10) + + +def test_stress_isotropization(fcc_bulk: Atoms) -> None: + """ + Check that stress deviatoric parts (L=2,3) vanish under full O(3) averaging. + + :param dimer: Test atomic structure. + """ + base = mock_calculator(c=(2.0, 1.0), p_iso=5.0) + calc = SymmetrizedCalculator(base, l_max=9, include_inversion=True) + fcc_bulk.calc = calc + fcc_bulk.get_forces() + S = fcc_bulk.get_stress(voigt=False) + + fcc_bulk.calc = base + fcc_bulk.get_forces() + + iso = np.trace(S) / 3.0 + assert np.isclose(iso, 5.0, atol=1e-10) + + +def test_cancellation_vs_Lmax(dimer: Atoms) -> None: + """ + Residual anisotropy must vanish once rotational averaging is applied. + All quadratures with Lmax>0 are equivalent (Lebedev order=3). + """ + a = (0.0, 0.0, 1.0, 1.0) + base = mock_calculator(a=a) + E_true = float(np.sum(dimer.positions**2)) + + # No averaging + calc0 = SymmetrizedCalculator(base, l_max=0) + dimer.calc = calc0 + dimer.get_forces() + e0 = dimer.get_potential_energy() + + # Averaged + calc3 = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc3 + dimer.get_forces() + e3 = dimer.get_potential_energy() + + assert not np.isclose(e0, E_true, atol=1e-10) + assert np.isclose(e3, E_true, atol=1e-10) + + +def test_joint_energy_force_consistency(dimer: Atoms) -> None: + """ + Combined test: both energy and forces are consistent and invariant. + + :param dimer: Test atomic structure. + """ + base = mock_calculator(a=(1, 1, 1, 1), b=(0, 0, 0)) + calc = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc + f = dimer.get_forces() + e = dimer.get_potential_energy() + expected_F = dimer.get_positions() + expected_F -= np.mean(expected_F, axis=0) + assert np.isclose(e, np.sum(dimer.positions**2) + 1.0, atol=1e-10) + assert np.allclose(f, expected_F, atol=1e-12) + + +def test_rotate_atoms_preserves_geometry(tmp_path): + """Check that _rotate_atoms applies rotations correctly and preserves distances.""" + from scipy.spatial.transform import Rotation + + from metatomic.torch.ase_calculator import _rotate_atoms + + # Build simple cubic cell with 2 atoms along x + atoms = Atoms("H2", positions=[[0, 0, 0], [1, 0, 0]], cell=np.eye(3)) + R = Rotation.from_euler("z", 90, degrees=True).as_matrix()[None, ...] # 90° about z + + rotated = _rotate_atoms(atoms, R)[0] + # Positions should now align along y + assert np.allclose( + rotated.positions[1] - rotated.positions[0], [0, 1, 0], atol=1e-12 + ) + # Cell rotated + assert np.allclose(rotated.cell[0], [0, 1, 0], atol=1e-12) + # Distances preserved + d0 = atoms.get_distance(0, 1) + d1 = rotated.get_distance(0, 1) + assert np.isclose(d0, d1, atol=1e-12) + + +def test_choose_quadrature_rules(): + """Check that _choose_quadrature selects appropriate rules.""" + from metatomic.torch.ase_calculator import _choose_quadrature + + for L in [0, 5, 17, 50]: + lebedev_order, n_gamma = _choose_quadrature(L) + assert lebedev_order >= L + assert n_gamma == 2 * L + 1 + + +def test_get_quadrature_properties(): + """Check properties of the quadrature returned by _get_quadrature.""" + from metatomic.torch.ase_calculator import _get_quadrature + + R, w = _get_quadrature(lebedev_order=11, n_rotations=5, include_inversion=False) + assert np.isclose(np.sum(w), 1.0) + assert np.allclose([np.dot(r.T, r) for r in R], np.eye(3), atol=1e-12) + assert np.allclose(np.linalg.det(R), 1.0, atol=1e-12) + + R_inv, w_inv = _get_quadrature( + lebedev_order=11, n_rotations=5, include_inversion=True + ) + assert len(R_inv) == 2 * len(R) + dets = np.linalg.det(R_inv) + assert np.all(np.isin(np.sign(dets).astype(int), [-1, 1])) + assert np.isclose(np.sum(w_inv), 1.0) + + +def test_compute_rotational_average_identity(): + """Check that _compute_rotational_average produces correct averages.""" + from metatomic.torch.ase_calculator import _compute_rotational_average + + R = np.repeat(np.eye(3)[None, :, :], 3, axis=0) + w = np.ones(3) / 3 + results = { + "energy": np.array([1.0, 2.0, 3.0]), + "forces": np.array([[[1, 0, 0]], [[0, 1, 0]], [[0, 0, 1]]]), + "stress": np.array([np.eye(3), 2 * np.eye(3), 3 * np.eye(3)]), + } + out = _compute_rotational_average(results, R, w, False) + assert np.isclose(out["energy"], np.mean(results["energy"])) + assert np.allclose(out["forces"], np.mean(results["forces"], axis=0)) + assert np.allclose(out["stress"], np.mean(results["stress"], axis=0)) + + out = _compute_rotational_average(results, R, w, True) + assert "energy_rot_std" in out + assert "forces_rot_std" in out + assert "stress_rot_std" in out + + +def test_average_over_fcc_group(fcc_bulk: Atoms): + """ + Check that averaging over the space group of an FCC crystal + produces an isotropic (scalar) stress tensor. + """ + from metatomic.torch.ase_calculator import ( + _average_over_group, + _get_group_operations, + ) + + # FCC conventional cubic cell (4 atoms) + atoms = fcc_bulk + + energy = 0.0 + forces = np.random.normal(0, 1, (4, 3)) + forces -= np.mean(forces, axis=0) # Ensure zero net force + + # Create an intentionally anisotropic stress + stress = np.array([[10.0, 1.0, 0.0], [1.0, 5.0, 0.0], [0.0, 0.0, 1.0]]) + results = {"energy": energy, "forces": forces, "stress": stress} + + Q_list, P_list = _get_group_operations(atoms) + out = _average_over_group(results, Q_list, P_list) + + # Energy must be unchanged + assert np.isclose(out["energy"], energy) + + # Forces must average to zero by symmetry + F_pg = out["forces"] + assert np.allclose(F_pg, np.zeros_like(F_pg)) + + S_pg = out["stress"] + + # The averaged stress must be isotropic: S_pg = (trace/3)*I + iso = np.trace(S_pg) / 3.0 + assert np.allclose(S_pg, np.eye(3) * iso, atol=1e-8) + + +def test_space_group_average_non_periodic(): + """ + Check that averaging over the space group of a non-periodic system leaves the + results unchanged. + """ + from metatomic.torch.ase_calculator import ( + _average_over_group, + _get_group_operations, + ) + + # Methane molecule (Td symmetry) + atoms = molecule("CH4") + + energy = 0.0 + forces = np.random.normal(0, 1, (4, 3)) + forces -= np.mean(forces, axis=0) # Ensure zero net force + + results = {"energy": energy, "forces": forces} + + Q_list, P_list = _get_group_operations(atoms) + + # Check that the operation lists are empty + assert len(Q_list) == 0 + assert len(P_list) == 0 + + out = _average_over_group(results, Q_list, P_list) + + # Energy must be unchanged + assert np.isclose(out["energy"], energy) + + # Forces must be unchanged + F_pg = out["forces"] + assert np.allclose(F_pg, forces) diff --git a/tox.ini b/tox.ini index c9debaa3..00c0907f 100644 --- a/tox.ini +++ b/tox.ini @@ -150,6 +150,9 @@ deps = # for metatensor-lj-test setuptools-scm cmake + # for symmetrized calculator + scipy + spglib changedir = python/metatomic_torch commands =