Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR adds the GotenNet model, a Graph Attention Transformer Network for atomic systems, along with reaction wrappers and configuration files. The implementation is ported from an external repository (https://github.com/sarpaykent/GotenNet).
Key changes:
- Core GotenNet architecture with GATA (Graph Attention Transformer Architecture) and EQFF (Equivariant Feed-Forward) layers
- Reaction prediction wrappers (GotenNetChemTorchR, GotennetReaction) for processing reactant and transition state structures
- Configuration files for model instantiation and experiments
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
| src/chemtorch/components/model/gotennet.py | Complete GotenNet model implementation with attention mechanisms, message passing layers, and reaction wrappers |
| pyproject.toml | Added dependencies for e3nn (equivariant neural networks) and einops |
| conf/model/gotennet.yaml | Configuration for GotenNetChemTorchR model |
| conf/model/gotenreaction.yaml | Configuration for GotennetReaction model |
| conf/experiment/gotennet.yaml | Experiment setup for GotenNet |
| conf/experiment/gotenreaction.yaml | Experiment setup for reaction prediction |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get_split_sizes_from_lmax(lmax): | ||
| """ | ||
| Return split sizes for torch.split based on lmax. | ||
|
|
||
| Calculates the dimensions of spherical harmonic components for each | ||
| angular momentum value from 1 to lmax. | ||
|
|
||
| Args: | ||
| lmax: Maximum angular momentum value. | ||
|
|
||
| Returns: | ||
| List[int]: List of split sizes for torch.split. | ||
| """ | ||
| return [2 * l + 1 for l in range(1, lmax + 1)] | ||
|
|
||
|
|
There was a problem hiding this comment.
The function get_split_sizes_from_lmax is defined twice in this file (lines 25-38 and lines 1050-1064). The second definition at line 1050 has an additional start parameter with default value 1. This duplication is confusing and can lead to maintenance issues. Consider keeping only the more flexible version (with the start parameter) and updating all call sites accordingly, or rename one of them to reflect different purposes.
| def get_split_sizes_from_lmax(lmax): | |
| """ | |
| Return split sizes for torch.split based on lmax. | |
| Calculates the dimensions of spherical harmonic components for each | |
| angular momentum value from 1 to lmax. | |
| Args: | |
| lmax: Maximum angular momentum value. | |
| Returns: | |
| List[int]: List of split sizes for torch.split. | |
| """ | |
| return [2 * l + 1 for l in range(1, lmax + 1)] | |
| # Removed duplicate definition of get_split_sizes_from_lmax. See later in file for the version with the 'start' parameter. |
| print(f"Exponent p={p} has to be >= 2.") | ||
| print("Exiting code.") | ||
|
|
||
| exit() | ||
|
|
There was a problem hiding this comment.
Using print() statements and exit() for error handling is inappropriate in a library. Replace with proper exception raising (e.g., raise ValueError(f\"Exponent p={p} has to be >= 2.\")) so calling code can handle the error appropriately. The comment on line 117 also acknowledges this should be changed to use a logger.
| print(f"Exponent p={p} has to be >= 2.") | |
| print("Exiting code.") | |
| exit() | |
| raise ValueError(f"Exponent p={p} has to be >= 2.") |
| raise ValueError(f"Unknown initialization {init_str}") | ||
|
|
||
|
|
||
| # train.py -m label=mu,alpha,homo,lumo,r2,zpve,U0,U,H,G,Cv name='${label_str}_int6_glo-ort_3090' hydra.sweeper.n_jobs=1 model.representation.n_interactions=6 model.representation.weight_init=glo_orthogonal |
There was a problem hiding this comment.
This commented-out training command appears to be leftover development code and should be removed. It doesn't provide value in the production codebase and clutters the code.
| # train.py -m label=mu,alpha,homo,lumo,r2,zpve,U0,U,H,G,Cv name='${label_str}_int6_glo-ort_3090' hydra.sweeper.n_jobs=1 model.representation.n_interactions=6 model.representation.weight_init=glo_orthogonal |
| if inspect.isclass(activation): | ||
| self.activation = activation() | ||
| self.activation = activation |
There was a problem hiding this comment.
The activation assignment on line 506 unconditionally overwrites the instantiated activation from line 505. This means if activation is a class, line 505 instantiates it correctly, but line 506 immediately overwrites self.activation with the class itself (not an instance). Line 506 should be inside an else block, or the logic should be: self.activation = activation() if inspect.isclass(activation) else activation.
| if inspect.isclass(activation): | |
| self.activation = activation() | |
| self.activation = activation | |
| self.activation = activation() if inspect.isclass(activation) else activation |
| # keep as is | ||
| edge_vec = pos[edge_index[0]] - pos[edge_index[1]] | ||
| else: | ||
| edge_vec = pos[edge_index[0]] - pos[edge_index[1]] |
There was a problem hiding this comment.
Both branches of the conditional compute edge_vec identically, making the direction parameter ineffective. The else branch should likely compute the reverse direction: edge_vec = pos[edge_index[1]] - pos[edge_index[0]] to match the 'target_to_source' direction.
| edge_vec = pos[edge_index[0]] - pos[edge_index[1]] | |
| edge_vec = pos[edge_index[1]] - pos[edge_index[0]] |
| self.reset_parameters() | ||
|
|
||
| @classmethod | ||
| def load_from_checkpoint(cls, checkpoint_path: str, device="cpu") -> None: |
There was a problem hiding this comment.
The return type annotation is None, but the method actually returns a GotenNet instance on line 1962 (return gotennet). The return type should be changed to -> 'GotenNet' or -> Self (if using Python 3.11+).
| def load_from_checkpoint(cls, checkpoint_path: str, device="cpu") -> None: | |
| def load_from_checkpoint(cls, checkpoint_path: str, device="cpu") -> 'GotenNet': |
| h.unsqueeze_(1) | ||
| t_ij = t_ij_init | ||
| for _i, (gata, eqff) in enumerate( | ||
| zip(self.gata_list, self.eqff_list, strict=False) |
There was a problem hiding this comment.
The strict=False parameter is explicitly set, but both gata_list and eqff_list are constructed with the same length (range(self.n_interactions) on lines 1900 and 1914). Using strict=True would be safer as it ensures the lists remain synchronized and would catch any future bugs where they become mismatched.
| zip(self.gata_list, self.eqff_list, strict=False) | |
| zip(self.gata_list, self.eqff_list, strict=True) |
| assert "hyper_parameters" in checkpoint, ( | ||
| "Checkpoint must contain 'hyper_parameters' key." | ||
| ) | ||
| hyper_parameters = checkpoint["hyper_parameters"] | ||
| assert "representation" in hyper_parameters, ( | ||
| "Hyperparameters must contain 'representation' key." | ||
| ) | ||
| representation_config = hyper_parameters["representation"] | ||
| _ = representation_config.pop("_target_", None) | ||
|
|
||
| assert "state_dict" in checkpoint, "Checkpoint must contain 'state_dict' key." |
There was a problem hiding this comment.
Using assert statements for input validation in production code is not recommended, as assertions can be disabled with Python's -O flag. Replace with explicit if checks that raise appropriate exceptions (e.g., KeyError or ValueError) to ensure validation always occurs.
| assert "hyper_parameters" in checkpoint, ( | |
| "Checkpoint must contain 'hyper_parameters' key." | |
| ) | |
| hyper_parameters = checkpoint["hyper_parameters"] | |
| assert "representation" in hyper_parameters, ( | |
| "Hyperparameters must contain 'representation' key." | |
| ) | |
| representation_config = hyper_parameters["representation"] | |
| _ = representation_config.pop("_target_", None) | |
| assert "state_dict" in checkpoint, "Checkpoint must contain 'state_dict' key." | |
| if "hyper_parameters" not in checkpoint: | |
| raise KeyError("Checkpoint must contain 'hyper_parameters' key.") | |
| hyper_parameters = checkpoint["hyper_parameters"] | |
| if "representation" not in hyper_parameters: | |
| raise KeyError("Hyperparameters must contain 'representation' key.") | |
| representation_config = hyper_parameters["representation"] | |
| _ = representation_config.pop("_target_", None) | |
| if "state_dict" not in checkpoint: | |
| raise KeyError("Checkpoint must contain 'state_dict' key.") |
| def forward(self, z, pos, batch) -> Tuple[Tensor, Tensor]: | ||
| """ | ||
| Compute atomic representations/embeddings. | ||
|
|
||
| Args: | ||
| inputs: Dictionary of input tensors containing atomic_numbers, pos, batch, | ||
| edge_index, r_ij, and dir_ij. Shape information: | ||
| - atomic_numbers: [num_nodes] | ||
| - pos: [num_nodes, 3] | ||
| - batch: [num_nodes] | ||
| - edge_index: [2, num_edges] | ||
|
|
||
| Returns: | ||
| Tuple containing: | ||
| - Atomic representation [num_nodes, hidden_dims] | ||
| - High-degree steerable features [num_nodes, (L_max ** 2) - 1, hidden_dims] | ||
| """ | ||
| edge_index, edge_diff, edge_vec = self.distance(pos, batch) |
There was a problem hiding this comment.
This method requires 4 positional arguments, whereas overridden GotenNet.forward requires 5.
| def forward(self, z, pos, batch) -> Tuple[Tensor, Tensor]: | |
| """ | |
| Compute atomic representations/embeddings. | |
| Args: | |
| inputs: Dictionary of input tensors containing atomic_numbers, pos, batch, | |
| edge_index, r_ij, and dir_ij. Shape information: | |
| - atomic_numbers: [num_nodes] | |
| - pos: [num_nodes, 3] | |
| - batch: [num_nodes] | |
| - edge_index: [2, num_edges] | |
| Returns: | |
| Tuple containing: | |
| - Atomic representation [num_nodes, hidden_dims] | |
| - High-degree steerable features [num_nodes, (L_max ** 2) - 1, hidden_dims] | |
| """ | |
| edge_index, edge_diff, edge_vec = self.distance(pos, batch) | |
| def forward(self, z, pos, batch, edge_index=None, edge_diff=None, edge_vec=None) -> Tuple[Tensor, Tensor]: | |
| """ | |
| Compute atomic representations/embeddings. | |
| Args: | |
| z: Atomic numbers or input features [num_nodes] | |
| pos: Positions [num_nodes, 3] | |
| batch: Batch indices [num_nodes] | |
| edge_index: [2, num_edges] (optional) | |
| edge_diff: [num_edges, 1] (optional) | |
| edge_vec: [num_edges, 3] (optional) | |
| Returns: | |
| Tuple containing: | |
| - Atomic representation [num_nodes, hidden_dims] | |
| - High-degree steerable features [num_nodes, (L_max ** 2) - 1, hidden_dims] | |
| """ | |
| if edge_index is None or edge_diff is None or edge_vec is None: | |
| edge_index, edge_diff, edge_vec = self.distance(pos, batch) |
| import math | ||
| import inspect | ||
| from functools import partial | ||
| from typing import Callable, List, Mapping, Optional, Tuple, Union |
There was a problem hiding this comment.
Import of 'Mapping' is not used.
| from typing import Callable, List, Mapping, Optional, Tuple, Union | |
| from typing import Callable, List, Optional, Tuple, Union |
|
Hi, please add an integration test (see "Reproducibilty Tests" in the advanced guide and "Testing" in the developer guide in the ChemTorch documentation). I will add functionality to automatically enumerate all config components next week so we don't need to manually in the future. In the meantime, please reuse one of the existing experiment configs and override the model by adding |
|
I will review the PR on Monday :) |
Summary
Adding GotenNet model and reaction wrappers from (https://github.com/sarpaykent/GotenNet)
Type of change
Changes
Breaking changes
How I tested this
Checklist
Reviewer notes (optional)
Anything that would help reviewers focus (areas of risk, follow-ups, docs to check).