1
1
import pytest
2
+ import torch
2
3
3
4
from emle .models import *
4
5
@@ -29,18 +30,26 @@ def test_emle(alpha_mode):
29
30
"""
30
31
Check that we can instantiate the default EMLE model.
31
32
"""
33
+ # Instantiate the default EMLE model.
32
34
model = EMLE (alpha_mode = alpha_mode )
33
35
assert model is not None
34
36
37
+ # Make sure the model can be converted to TorchScript.
38
+ model = torch .jit .script (model )
39
+
35
40
36
41
@pytest .mark .parametrize ("alpha_mode" , ["species" , "reference" ])
37
42
def test_ani2x (alpha_mode ):
38
43
"""
39
44
Check that we can instantiate the default ANI2xEMLE model.
40
45
"""
46
+ # Instantiate the ANI2xEMLE model.
41
47
model = ANI2xEMLE (alpha_mode = alpha_mode )
42
48
assert model is not None
43
49
50
+ # Make sure the model can be converted to TorchScript.
51
+ model = torch .jit .script (model )
52
+
44
53
from torchani .models import ANI2x
45
54
46
55
# Try using an existing ANI2x model.
@@ -49,19 +58,24 @@ def test_ani2x(alpha_mode):
49
58
# Create a new ANI2xEMLE model with the existing ANI2x model.
50
59
model = ANI2xEMLE (alpha_mode = alpha_mode , ani2x_model = ani2x )
51
60
61
+ # Make sure the model can be converted to TorchScript.
62
+ model = torch .jit .script (model )
63
+
52
64
53
65
@pytest .mark .skipif (not has_nnpops , reason = "NNPOps not installed" )
54
66
@pytest .mark .parametrize ("alpha_mode" , ["species" , "reference" ])
55
67
def test_ani2x_nnpops (alpha_mode ):
56
68
"""
57
69
Check that we can instantiate the default ANI2xEMLE model with NNPOps.
58
70
"""
59
- import torch
60
-
71
+ # Instantiate the ANI2xEMLE model using NNPOps.
61
72
atomic_numbers = torch .tensor ([1 , 6 , 7 , 8 ])
62
73
model = ANI2xEMLE (alpha_mode = alpha_mode , atomic_numbers = atomic_numbers )
63
74
assert model is not None
64
75
76
+ # Make sure the model can be converted to TorchScript.
77
+ model = torch .jit .script (model )
78
+
65
79
66
80
@pytest .mark .skipif (not has_mace , reason = "mace-torch not installed" )
67
81
@pytest .mark .skipif (not has_e3nn , reason = "e3nn not installed" )
@@ -73,5 +87,9 @@ def test_mace(alpha_mode, mace_model):
73
87
"""
74
88
Check that we can instantiate the default MACE model.
75
89
"""
90
+ # Instantiate the MACEEMLE model.
76
91
model = MACEEMLE (alpha_mode = alpha_mode )
77
92
assert model is not None
93
+
94
+ # Make sure the model can be converted to TorchScript.
95
+ model = torch .jit .script (model )
0 commit comments