diff --git a/tests/example.ckpt b/tests/example.ckpt index 80f914696..886cd39cf 100644 Binary files a/tests/example.ckpt and b/tests/example.ckpt differ diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 96d93d852..b4c5d0a3f 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) import torch -from torch.testing import assert_allclose +from torch.testing import assert_close import pytest from os.path import dirname, join from torchmdnet.calculators import External @@ -11,6 +11,12 @@ from utils import create_example_batch +# Set relative and absolute tolerance values for float32 precision +# The original test used assert_allclose, which is now deprecated. +# assert_close is used instead, with default tolerances of 1e-5 (rtol) and 1.3e-6 (atol) for torch.float32. +# Here, we manually set rtol and atol to match the original test's tolerances. +rtol = 1e-4 +atol = 1e-5 @pytest.mark.parametrize("box", [None, torch.eye(3)]) @pytest.mark.parametrize("use_cuda_graphs", [True, False]) @@ -39,24 +45,24 @@ def test_compare_forward(box, use_cuda_graphs): "precision": 32, } device = "cpu" if not use_cuda_graphs else "cuda" - model = create_model(args).to(device=device) + c_model = load_model(checkpoint).to(device=device) + g_model = load_model(checkpoint, check_errors=not use_cuda_graphs, static_shapes=use_cuda_graphs).to(device=device) z, pos, _ = create_example_batch(multiple_batches=False) z = z.to(device) pos = pos.to(device) - calc = External(checkpoint, z.unsqueeze(0), use_cuda_graph=False, device=device) + calc = External(c_model, z.unsqueeze(0), use_cuda_graph=False, device=device) calc_graph = External( - checkpoint, z.unsqueeze(0), use_cuda_graph=use_cuda_graphs, device=device + g_model, z.unsqueeze(0), use_cuda_graph=use_cuda_graphs, device=device ) - calc.model = model - calc_graph.model = model + if box is not None: box = (box * 2 * args["cutoff_upper"]).unsqueeze(0) + for _ in range(10): e_calc, f_calc = calc.calculate(pos, box) e_pred, f_pred = calc_graph.calculate(pos, box) - assert_allclose(e_calc, e_pred) - assert_allclose(f_calc, f_pred) - + assert_close(e_calc, e_pred, rtol=rtol, atol=atol) + assert_close(f_calc, f_pred, rtol=rtol, atol=atol) def test_compare_forward_multiple(): checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt") @@ -72,5 +78,5 @@ def test_compare_forward_multiple(): torch.cat([torch.zeros(len(z1)), torch.ones(len(z2))]).long(), ) - assert_allclose(e_calc, e_pred) - assert_allclose(f_calc, f_pred.view(-1, len(z1), 3)) + assert_close(e_calc, e_pred, rtol=rtol, atol=atol) + assert_close(f_calc, f_pred.view(-1, len(z1), 3), rtol=rtol, atol=atol) diff --git a/tests/test_examples.py b/tests/test_examples.py index 8cd5155e6..c772776a2 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -27,4 +27,4 @@ def test_example_yamls(fname): z, pos, batch = create_example_batch() model(z, pos, batch) - model(z, pos, batch, q=None, s=None) + model(z, pos, batch, q=None, s=None, extra_args=None) diff --git a/tests/test_model.py b/tests/test_model.py index f606559ef..cde034a08 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -19,16 +19,21 @@ @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("use_batch", [True, False]) @mark.parametrize("explicit_q_s", [True, False]) +@mark.parametrize("explicit_extra_args", [True, False]) @mark.parametrize("precision", [32, 64]) -def test_forward(model_name, use_batch, explicit_q_s, precision): +def test_forward(model_name, use_batch, explicit_q_s, explicit_extra_args, precision): z, pos, batch = create_example_batch() pos = pos.to(dtype=dtype_mapping[precision]) model = create_model( load_example_args(model_name, prior_model=None, precision=precision) ) batch = batch if use_batch else None - if explicit_q_s: + if explicit_q_s and explicit_extra_args: + model(z, pos, batch=batch, q=None, s=None, extra_args=None) + elif explicit_q_s: model(z, pos, batch=batch, q=None, s=None) + elif explicit_extra_args: + model(z, pos, batch=batch, extra_args=None) else: model(z, pos, batch=batch) diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index c647b7d2a..8bbaac2d6 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -49,6 +49,10 @@ def __init__(self, filename, dataset_preload_limit=1024, **kwargs): ("pos", "pos", torch.float32), ("z", "types", torch.long), ] + if "charge" in group: + self.fields.append(("q", "charge", torch.float32)) + if "spin" in group: + self.fields.append(("s", "spin", torch.float32)) if "energy" in group: self.fields.append(("y", "energy", torch.float32)) if "forces" in group: diff --git a/torchmdnet/datasets/memdataset.py b/torchmdnet/datasets/memdataset.py index c84d65590..f25f5a825 100644 --- a/torchmdnet/datasets/memdataset.py +++ b/torchmdnet/datasets/memdataset.py @@ -269,7 +269,7 @@ def get(self, idx): if "q" in self.properties: props["q"] = pt.tensor(self.mmaps["q"][idx], dtype=pt.long) if "pq" in self.properties: - props["pq"] = pt.tensor(self.mmaps["pq"][atoms]) + props["partial_charges"] = pt.tensor(self.mmaps["pq"][atoms]) if "dp" in self.properties: props["dp"] = pt.tensor(self.mmaps["dp"][idx]) # if "mol_idx" in self.properties: diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index b7c8398cc..29ef7a232 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -460,7 +460,8 @@ def forward( pos.requires_grad_(True) # run the potentially wrapped representation model x, v, z, pos, batch = self.representation_model( - z, pos, batch, box=box, q=q, s=s + z, pos, batch, box=box, q=q, s=s, extra_args=extra_args + ) # apply the output network x = self.output_model.pre_reduce(x, v, z, pos, batch) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index 9aec148ca..7bae31d24 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) import torch -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict from torch import Tensor, nn from torchmdnet.models.utils import ( CosineCutoff, @@ -61,6 +61,32 @@ def tensor_norm(tensor): """Computes Frobenius norm.""" return (tensor**2).sum((-2, -1)) +def process_additional_labels(z: torch.Tensor, q: Optional[torch.Tensor], s: Optional[torch.Tensor], extra_args: Optional[Dict[str, torch.Tensor]], batch: torch.Tensor) -> torch.Tensor: + """ + Process additional labels for the model. This function assigns atom-wise properties based on the provided + molecule-wise properties or extra arguments. + Total charge q and spin s are molecule-wise properties. We transform it into an atom-wise property, with all atoms + belonging to the same molecule being assigned the same charge q or spin s. + + Args: + batch (Tensor): Batch tensor indicating the molecule each atom belongs to. + z (Tensor): Atomic numbers tensor. + q (Optional[Tensor]): Total charge tensor for each molecule. + s (Optional[Tensor]): Spin tensor for each molecule. + extra_args (Optional[Dict[str, Tensor]]): Dictionary containing additional properties. + + Returns: + Tensor: Atom-wise property tensor already scaled by 0.1. + """ + if q is not None: + t = q[batch] + elif s is not None: + t = s[batch] + elif extra_args is not None and 'partial_charges' in extra_args: + t = extra_args['partial_charges'] + else: + t = torch.zeros_like(z, device=z.device, dtype=z.dtype) + return t class TensorNet(nn.Module): r"""TensorNet's architecture. From @@ -226,6 +252,7 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: # Obtain graph, with distances and relative position vectors edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) @@ -234,16 +261,13 @@ def forward( edge_vec is not None ), "Distance module did not return directional information" # Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to a ghost atom - # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q - if q is None: - q = torch.zeros_like(z, device=z.device, dtype=z.dtype) - else: - q = q[batch] + + t = process_additional_labels(z, q, s, extra_args, batch) zp = z if self.static_shapes: mask = (edge_index[0] < 0).unsqueeze(0).expand_as(edge_index) zp = torch.cat((z, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) - q = torch.cat((q, torch.zeros(1, device=q.device, dtype=q.dtype)), dim=0) + t = torch.cat((t, torch.zeros(1, device=z.device, dtype=z.dtype)), dim=0) # I trick the model into thinking that the masked edges pertain to the extra atom # WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs edge_index = edge_index.masked_fill(mask, z.shape[0]) @@ -258,7 +282,7 @@ def forward( edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1) X = self.tensor_embedding(zp, edge_index, edge_weight, edge_vec, edge_attr) for layer in self.layers: - X = layer(X, edge_index, edge_weight, edge_attr, q) + X = layer(X, edge_index, edge_weight, edge_attr, t) I, A, S = decompose_tensor(X) x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1) x = self.out_norm(x) @@ -454,7 +478,7 @@ def forward( edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor, - q: Tensor, + t: Tensor, ) -> Tensor: C = self.cutoff(edge_weight) for linear_scalar in self.linears_scalar: @@ -481,7 +505,7 @@ def forward( if self.equivariance_invariance_group == "O(3)": A = torch.matmul(msg, Y) B = torch.matmul(Y, msg) - I, A, S = decompose_tensor((1 + 0.1 * q[..., None, None, None]) * (A + B)) + I, A, S = decompose_tensor((1 + 0.1 * t[..., None, None, None]) * (A + B)) if self.equivariance_invariance_group == "SO(3)": B = torch.matmul(Y, msg) I, A, S = decompose_tensor(2 * B) @@ -491,5 +515,5 @@ def forward( A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) dX = I + A + S - X = X + dX + (1 + 0.1 * q[..., None, None, None]) * torch.matrix_power(dX, 2) + X = X + dX + (1 + 0.1 * t[..., None, None, None]) * torch.matrix_power(dX, 2) return X diff --git a/torchmdnet/models/torchmd_et.py b/torchmdnet/models/torchmd_et.py index 5ff168d54..4b84b16b8 100644 --- a/torchmdnet/models/torchmd_et.py +++ b/torchmdnet/models/torchmd_et.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -196,6 +196,7 @@ def forward( box: Optional[Tensor] = None, q: Optional[Tensor] = None, s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: x = self.embedding(z) diff --git a/torchmdnet/models/torchmd_gn.py b/torchmdnet/models/torchmd_gn.py index 31d68ae03..97939c67e 100644 --- a/torchmdnet/models/torchmd_gn.py +++ b/torchmdnet/models/torchmd_gn.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -196,8 +196,9 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - s: Optional[Tensor] = None, q: Optional[Tensor] = None, + s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) diff --git a/torchmdnet/models/torchmd_t.py b/torchmdnet/models/torchmd_t.py index c11efc080..1360ae432 100644 --- a/torchmdnet/models/torchmd_t.py +++ b/torchmdnet/models/torchmd_t.py @@ -2,7 +2,7 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict import torch from torch import Tensor, nn from torchmdnet.models.utils import ( @@ -190,8 +190,9 @@ def forward( pos: Tensor, batch: Tensor, box: Optional[Tensor] = None, - s: Optional[Tensor] = None, q: Optional[Tensor] = None, + s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: x = self.embedding(z) diff --git a/torchmdnet/models/wrappers.py b/torchmdnet/models/wrappers.py index 444805e06..99c838527 100644 --- a/torchmdnet/models/wrappers.py +++ b/torchmdnet/models/wrappers.py @@ -3,7 +3,7 @@ # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) from abc import abstractmethod, ABCMeta -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict from torch import nn, Tensor @@ -45,8 +45,9 @@ def forward( batch: Tensor, q: Optional[Tensor] = None, s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]: - x, v, z, pos, batch = self.model(z, pos, batch=batch, q=q, s=s) + x, v, z, pos, batch = self.model(z, pos, batch=batch, q=q, s=s, extra_args=extra_args) n_samples = len(batch.unique()) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 8042df514..f8453af1b 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -59,6 +59,14 @@ def __init__(self, loss_fn, extra_args=None): def __call__(self, x, batch): return self.loss_fn(x, batch, **self.extra_args) +def process_extra_args(extra_args, use_partial_charges): + ''' Process extra arguments to remove those that are not needed by the model, before passing them to the forward function.''' + for a in ("y", "neg_dy", "z", "pos", "batch", "box", "q", "s"): + if a in extra_args: + del extra_args[a] + if not use_partial_charges and 'partial_charges' in extra_args: + del extra_args['partial_charges'] + return extra_args class LNNP(LightningModule): """ @@ -77,6 +85,13 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): hparams["charge"] = False if "spin" not in hparams: hparams["spin"] = False + if "partial_charges" not in hparams: + hparams["partial_charges"] = False + # Ensure only one of charge, partial_charges, and spin can be True, otherwise raise a ValueError + if sum([hparams["charge"], hparams["partial_charges"], hparams["spin"]]) > 1: + raise ValueError( + "Only one of 'charge', 'partial_charges', and 'spin' can be True." + ) if "train_loss" not in hparams: hparams["train_loss"] = "mse_loss" if "train_loss_arg" not in hparams: @@ -184,9 +199,7 @@ def predict_step(self, batch, batch_idx): with torch.set_grad_enabled(self.hparams.derivative): extra_args = batch.to_dict() - for a in ("y", "neg_dy", "z", "pos", "batch", "box", "q", "s"): - if a in extra_args: - del extra_args[a] + extra_args = process_extra_args(extra_args, self.hparams.partial_charges) return self( batch.z, batch.pos, @@ -253,9 +266,7 @@ def step(self, batch, loss_fn_list, stage): batch = self.data_transform(batch) with torch.set_grad_enabled(stage == "train" or self.hparams.derivative): extra_args = batch.to_dict() - for a in ("y", "neg_dy", "z", "pos", "batch", "box", "q", "s"): - if a in extra_args: - del extra_args[a] + extra_args = process_extra_args(extra_args, self.hparams.partial_charges) # TODO: the model doesn't necessarily need to return a derivative once # Union typing works under TorchScript (https://github.com/pytorch/pytorch/pull/53180) y, neg_dy = self( diff --git a/torchmdnet/optimize.py b/torchmdnet/optimize.py index 0c7f56513..a59705e50 100644 --- a/torchmdnet/optimize.py +++ b/torchmdnet/optimize.py @@ -2,7 +2,8 @@ # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict +from torch import Tensor import torch as pt from NNPOps.CFConv import CFConv from NNPOps.CFConvNeighbors import CFConvNeighbors @@ -58,6 +59,7 @@ def forward( box: Optional[pt.Tensor] = None, q: Optional[pt.Tensor] = None, s: Optional[pt.Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, ) -> Tuple[pt.Tensor, Optional[pt.Tensor], pt.Tensor, pt.Tensor, pt.Tensor]: assert pt.all(batch == 0) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 96b3a8197..f56c3fb5a 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -85,6 +85,7 @@ def get_argparse(): # architectural args parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge. Set this to True if your dataset contains charges and you want them passed down to the model.') parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state. Set this to True if your dataset contains spin states and you want them passed down to the model.') + parser.add_argument('--partial-charges', type=bool, default=False, help='Model needs partial charges. Set this to True if your dataset contains partial charges and you want them passed down to the model.') parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension') parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model') parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model')