Skip to content

Feature/gotennet#41

Open
MarkKon wants to merge 4 commits intomainfrom
feature/gotennet
Open

Feature/gotennet#41
MarkKon wants to merge 4 commits intomainfrom
feature/gotennet

Conversation

@MarkKon
Copy link
Copy Markdown
Contributor

@MarkKon MarkKon commented Nov 14, 2025

Summary

Adding GotenNet model and reaction wrappers from (https://github.com/sarpaykent/GotenNet)

Type of change

  • Feature
  • Bug fix
  • Documentation
  • Refactor / Maintenance
  • CI / Build
  • Tests
  • Other: ___

Changes

  • Added core GotenNet model
  • Added reaction wrappers for reaction_3d_graph representation
  • Added core model and experiment confs

Breaking changes

  • None
  • Yes (describe impact and migration notes):

How I tested this

  • Unit tests pass
  • Integration / e2e checks
  • Docs build locally
  • Manual verification (brief steps):

Checklist

  • Relevant labels added (e.g., feature/bug/documentation/maintenance/tests/ci)
  • Branch name follows policy (feature/bugfix/hotfix/docs/chore/ci/tests/perf/build)
  • Changelog consideration:
    • Not applicable (internal-only/no user impact)
    • I added/updated CHANGELOG.md, or this will be captured in release notes
  • If new dependency or config change, rationale and docs are included

Reviewer notes (optional)

Anything that would help reviewers focus (areas of risk, follow-ups, docs to check).

@MarkKon MarkKon requested a review from a team as a code owner November 14, 2025 13:45
Copilot AI review requested due to automatic review settings November 14, 2025 13:45
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds the GotenNet model, a Graph Attention Transformer Network for atomic systems, along with reaction wrappers and configuration files. The implementation is ported from an external repository (https://github.com/sarpaykent/GotenNet).

Key changes:

  • Core GotenNet architecture with GATA (Graph Attention Transformer Architecture) and EQFF (Equivariant Feed-Forward) layers
  • Reaction prediction wrappers (GotenNetChemTorchR, GotennetReaction) for processing reactant and transition state structures
  • Configuration files for model instantiation and experiments

Reviewed Changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
src/chemtorch/components/model/gotennet.py Complete GotenNet model implementation with attention mechanisms, message passing layers, and reaction wrappers
pyproject.toml Added dependencies for e3nn (equivariant neural networks) and einops
conf/model/gotennet.yaml Configuration for GotenNetChemTorchR model
conf/model/gotenreaction.yaml Configuration for GotennetReaction model
conf/experiment/gotennet.yaml Experiment setup for GotenNet
conf/experiment/gotenreaction.yaml Experiment setup for reaction prediction

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +25 to +40
def get_split_sizes_from_lmax(lmax):
"""
Return split sizes for torch.split based on lmax.

Calculates the dimensions of spherical harmonic components for each
angular momentum value from 1 to lmax.

Args:
lmax: Maximum angular momentum value.

Returns:
List[int]: List of split sizes for torch.split.
"""
return [2 * l + 1 for l in range(1, lmax + 1)]


Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function get_split_sizes_from_lmax is defined twice in this file (lines 25-38 and lines 1050-1064). The second definition at line 1050 has an additional start parameter with default value 1. This duplication is confusing and can lead to maintenance issues. Consider keeping only the more flexible version (with the start parameter) and updating all call sites accordingly, or rename one of them to reflect different purposes.

Suggested change
def get_split_sizes_from_lmax(lmax):
"""
Return split sizes for torch.split based on lmax.
Calculates the dimensions of spherical harmonic components for each
angular momentum value from 1 to lmax.
Args:
lmax: Maximum angular momentum value.
Returns:
List[int]: List of split sizes for torch.split.
"""
return [2 * l + 1 for l in range(1, lmax + 1)]
# Removed duplicate definition of get_split_sizes_from_lmax. See later in file for the version with the 'start' parameter.

Copilot uses AI. Check for mistakes.
Comment on lines +118 to +122
print(f"Exponent p={p} has to be >= 2.")
print("Exiting code.")

exit()

Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using print() statements and exit() for error handling is inappropriate in a library. Replace with proper exception raising (e.g., raise ValueError(f\"Exponent p={p} has to be >= 2.\")) so calling code can handle the error appropriately. The comment on line 117 also acknowledges this should be changed to use a logger.

Suggested change
print(f"Exponent p={p} has to be >= 2.")
print("Exiting code.")
exit()
raise ValueError(f"Exponent p={p} has to be >= 2.")

Copilot uses AI. Check for mistakes.
raise ValueError(f"Unknown initialization {init_str}")


# train.py -m label=mu,alpha,homo,lumo,r2,zpve,U0,U,H,G,Cv name='${label_str}_int6_glo-ort_3090' hydra.sweeper.n_jobs=1 model.representation.n_interactions=6 model.representation.weight_init=glo_orthogonal
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented-out training command appears to be leftover development code and should be removed. It doesn't provide value in the production codebase and clutters the code.

Suggested change
# train.py -m label=mu,alpha,homo,lumo,r2,zpve,U0,U,H,G,Cv name='${label_str}_int6_glo-ort_3090' hydra.sweeper.n_jobs=1 model.representation.n_interactions=6 model.representation.weight_init=glo_orthogonal

Copilot uses AI. Check for mistakes.
Comment on lines +504 to +506
if inspect.isclass(activation):
self.activation = activation()
self.activation = activation
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The activation assignment on line 506 unconditionally overwrites the instantiated activation from line 505. This means if activation is a class, line 505 instantiates it correctly, but line 506 immediately overwrites self.activation with the class itself (not an instance). Line 506 should be inside an else block, or the logic should be: self.activation = activation() if inspect.isclass(activation) else activation.

Suggested change
if inspect.isclass(activation):
self.activation = activation()
self.activation = activation
self.activation = activation() if inspect.isclass(activation) else activation

Copilot uses AI. Check for mistakes.
# keep as is
edge_vec = pos[edge_index[0]] - pos[edge_index[1]]
else:
edge_vec = pos[edge_index[0]] - pos[edge_index[1]]
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both branches of the conditional compute edge_vec identically, making the direction parameter ineffective. The else branch should likely compute the reverse direction: edge_vec = pos[edge_index[1]] - pos[edge_index[0]] to match the 'target_to_source' direction.

Suggested change
edge_vec = pos[edge_index[0]] - pos[edge_index[1]]
edge_vec = pos[edge_index[1]] - pos[edge_index[0]]

Copilot uses AI. Check for mistakes.
self.reset_parameters()

@classmethod
def load_from_checkpoint(cls, checkpoint_path: str, device="cpu") -> None:
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation is None, but the method actually returns a GotenNet instance on line 1962 (return gotennet). The return type should be changed to -> 'GotenNet' or -> Self (if using Python 3.11+).

Suggested change
def load_from_checkpoint(cls, checkpoint_path: str, device="cpu") -> None:
def load_from_checkpoint(cls, checkpoint_path: str, device="cpu") -> 'GotenNet':

Copilot uses AI. Check for mistakes.
h.unsqueeze_(1)
t_ij = t_ij_init
for _i, (gata, eqff) in enumerate(
zip(self.gata_list, self.eqff_list, strict=False)
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The strict=False parameter is explicitly set, but both gata_list and eqff_list are constructed with the same length (range(self.n_interactions) on lines 1900 and 1914). Using strict=True would be safer as it ensures the lists remain synchronized and would catch any future bugs where they become mismatched.

Suggested change
zip(self.gata_list, self.eqff_list, strict=False)
zip(self.gata_list, self.eqff_list, strict=True)

Copilot uses AI. Check for mistakes.
Comment on lines +1938 to +1948
assert "hyper_parameters" in checkpoint, (
"Checkpoint must contain 'hyper_parameters' key."
)
hyper_parameters = checkpoint["hyper_parameters"]
assert "representation" in hyper_parameters, (
"Hyperparameters must contain 'representation' key."
)
representation_config = hyper_parameters["representation"]
_ = representation_config.pop("_target_", None)

assert "state_dict" in checkpoint, "Checkpoint must contain 'state_dict' key."
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using assert statements for input validation in production code is not recommended, as assertions can be disabled with Python's -O flag. Replace with explicit if checks that raise appropriate exceptions (e.g., KeyError or ValueError) to ensure validation always occurs.

Suggested change
assert "hyper_parameters" in checkpoint, (
"Checkpoint must contain 'hyper_parameters' key."
)
hyper_parameters = checkpoint["hyper_parameters"]
assert "representation" in hyper_parameters, (
"Hyperparameters must contain 'representation' key."
)
representation_config = hyper_parameters["representation"]
_ = representation_config.pop("_target_", None)
assert "state_dict" in checkpoint, "Checkpoint must contain 'state_dict' key."
if "hyper_parameters" not in checkpoint:
raise KeyError("Checkpoint must contain 'hyper_parameters' key.")
hyper_parameters = checkpoint["hyper_parameters"]
if "representation" not in hyper_parameters:
raise KeyError("Hyperparameters must contain 'representation' key.")
representation_config = hyper_parameters["representation"]
_ = representation_config.pop("_target_", None)
if "state_dict" not in checkpoint:
raise KeyError("Checkpoint must contain 'state_dict' key.")

Copilot uses AI. Check for mistakes.
Comment on lines +2042 to +2059
def forward(self, z, pos, batch) -> Tuple[Tensor, Tensor]:
"""
Compute atomic representations/embeddings.

Args:
inputs: Dictionary of input tensors containing atomic_numbers, pos, batch,
edge_index, r_ij, and dir_ij. Shape information:
- atomic_numbers: [num_nodes]
- pos: [num_nodes, 3]
- batch: [num_nodes]
- edge_index: [2, num_edges]

Returns:
Tuple containing:
- Atomic representation [num_nodes, hidden_dims]
- High-degree steerable features [num_nodes, (L_max ** 2) - 1, hidden_dims]
"""
edge_index, edge_diff, edge_vec = self.distance(pos, batch)
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method requires 4 positional arguments, whereas overridden GotenNet.forward requires 5.

Suggested change
def forward(self, z, pos, batch) -> Tuple[Tensor, Tensor]:
"""
Compute atomic representations/embeddings.
Args:
inputs: Dictionary of input tensors containing atomic_numbers, pos, batch,
edge_index, r_ij, and dir_ij. Shape information:
- atomic_numbers: [num_nodes]
- pos: [num_nodes, 3]
- batch: [num_nodes]
- edge_index: [2, num_edges]
Returns:
Tuple containing:
- Atomic representation [num_nodes, hidden_dims]
- High-degree steerable features [num_nodes, (L_max ** 2) - 1, hidden_dims]
"""
edge_index, edge_diff, edge_vec = self.distance(pos, batch)
def forward(self, z, pos, batch, edge_index=None, edge_diff=None, edge_vec=None) -> Tuple[Tensor, Tensor]:
"""
Compute atomic representations/embeddings.
Args:
z: Atomic numbers or input features [num_nodes]
pos: Positions [num_nodes, 3]
batch: Batch indices [num_nodes]
edge_index: [2, num_edges] (optional)
edge_diff: [num_edges, 1] (optional)
edge_vec: [num_edges, 3] (optional)
Returns:
Tuple containing:
- Atomic representation [num_nodes, hidden_dims]
- High-degree steerable features [num_nodes, (L_max ** 2) - 1, hidden_dims]
"""
if edge_index is None or edge_diff is None or edge_vec is None:
edge_index, edge_diff, edge_vec = self.distance(pos, batch)

Copilot uses AI. Check for mistakes.
import math
import inspect
from functools import partial
from typing import Callable, List, Mapping, Optional, Tuple, Union
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Mapping' is not used.

Suggested change
from typing import Callable, List, Mapping, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

Copilot uses AI. Check for mistakes.
@tonyzamyatin
Copy link
Copy Markdown
Collaborator

Hi, please add an integration test (see "Reproducibilty Tests" in the advanced guide and "Testing" in the developer guide in the ChemTorch documentation).

I will add functionality to automatically enumerate all config components next week so we don't need to manually in the future.

In the meantime, please reuse one of the existing experiment configs and override the model by adding model=gotenet to the overrides of your test set in test_registry.yaml.

@tonyzamyatin
Copy link
Copy Markdown
Collaborator

I will review the PR on Monday :)

@tonyzamyatin tonyzamyatin self-assigned this Dec 22, 2025
@tonyzamyatin tonyzamyatin added enhancement New feature or request new component Adding a new component to the ChemTorch library labels Dec 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request new component Adding a new component to the ChemTorch library

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants