Skip to content

Commit 1ee46f9

Browse files
committed
Test that each model can be converted to TorchScript.
1 parent 8f6dcae commit 1ee46f9

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

tests/test_models.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import torch
23

34
from emle.models import *
45

@@ -29,18 +30,26 @@ def test_emle(alpha_mode):
2930
"""
3031
Check that we can instantiate the default EMLE model.
3132
"""
33+
# Instantiate the default EMLE model.
3234
model = EMLE(alpha_mode=alpha_mode)
3335
assert model is not None
3436

37+
# Make sure the model can be converted to TorchScript.
38+
model = torch.jit.script(model)
39+
3540

3641
@pytest.mark.parametrize("alpha_mode", ["species", "reference"])
3742
def test_ani2x(alpha_mode):
3843
"""
3944
Check that we can instantiate the default ANI2xEMLE model.
4045
"""
46+
# Instantiate the ANI2xEMLE model.
4147
model = ANI2xEMLE(alpha_mode=alpha_mode)
4248
assert model is not None
4349

50+
# Make sure the model can be converted to TorchScript.
51+
model = torch.jit.script(model)
52+
4453
from torchani.models import ANI2x
4554

4655
# Try using an existing ANI2x model.
@@ -49,19 +58,24 @@ def test_ani2x(alpha_mode):
4958
# Create a new ANI2xEMLE model with the existing ANI2x model.
5059
model = ANI2xEMLE(alpha_mode=alpha_mode, ani2x_model=ani2x)
5160

61+
# Make sure the model can be converted to TorchScript.
62+
model = torch.jit.script(model)
63+
5264

5365
@pytest.mark.skipif(not has_nnpops, reason="NNPOps not installed")
5466
@pytest.mark.parametrize("alpha_mode", ["species", "reference"])
5567
def test_ani2x_nnpops(alpha_mode):
5668
"""
5769
Check that we can instantiate the default ANI2xEMLE model with NNPOps.
5870
"""
59-
import torch
60-
71+
# Instantiate the ANI2xEMLE model using NNPOps.
6172
atomic_numbers = torch.tensor([1, 6, 7, 8])
6273
model = ANI2xEMLE(alpha_mode=alpha_mode, atomic_numbers=atomic_numbers)
6374
assert model is not None
6475

76+
# Make sure the model can be converted to TorchScript.
77+
model = torch.jit.script(model)
78+
6579

6680
@pytest.mark.skipif(not has_mace, reason="mace-torch not installed")
6781
@pytest.mark.skipif(not has_e3nn, reason="e3nn not installed")
@@ -73,5 +87,9 @@ def test_mace(alpha_mode, mace_model):
7387
"""
7488
Check that we can instantiate the default MACE model.
7589
"""
90+
# Instantiate the MACEEMLE model.
7691
model = MACEEMLE(alpha_mode=alpha_mode)
7792
assert model is not None
93+
94+
# Make sure the model can be converted to TorchScript.
95+
model = torch.jit.script(model)

0 commit comments

Comments
 (0)