Skip to content

Commit

Permalink
Test that each model can be converted to TorchScript.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Aug 14, 2024
1 parent 8f6dcae commit 1ee46f9
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import torch

from emle.models import *

Expand Down Expand Up @@ -29,18 +30,26 @@ def test_emle(alpha_mode):
"""
Check that we can instantiate the default EMLE model.
"""
# Instantiate the default EMLE model.
model = EMLE(alpha_mode=alpha_mode)
assert model is not None

# Make sure the model can be converted to TorchScript.
model = torch.jit.script(model)


@pytest.mark.parametrize("alpha_mode", ["species", "reference"])
def test_ani2x(alpha_mode):
"""
Check that we can instantiate the default ANI2xEMLE model.
"""
# Instantiate the ANI2xEMLE model.
model = ANI2xEMLE(alpha_mode=alpha_mode)
assert model is not None

# Make sure the model can be converted to TorchScript.
model = torch.jit.script(model)

from torchani.models import ANI2x

# Try using an existing ANI2x model.
Expand All @@ -49,19 +58,24 @@ def test_ani2x(alpha_mode):
# Create a new ANI2xEMLE model with the existing ANI2x model.
model = ANI2xEMLE(alpha_mode=alpha_mode, ani2x_model=ani2x)

# Make sure the model can be converted to TorchScript.
model = torch.jit.script(model)


@pytest.mark.skipif(not has_nnpops, reason="NNPOps not installed")
@pytest.mark.parametrize("alpha_mode", ["species", "reference"])
def test_ani2x_nnpops(alpha_mode):
"""
Check that we can instantiate the default ANI2xEMLE model with NNPOps.
"""
import torch

# Instantiate the ANI2xEMLE model using NNPOps.
atomic_numbers = torch.tensor([1, 6, 7, 8])
model = ANI2xEMLE(alpha_mode=alpha_mode, atomic_numbers=atomic_numbers)
assert model is not None

# Make sure the model can be converted to TorchScript.
model = torch.jit.script(model)


@pytest.mark.skipif(not has_mace, reason="mace-torch not installed")
@pytest.mark.skipif(not has_e3nn, reason="e3nn not installed")
Expand All @@ -73,5 +87,9 @@ def test_mace(alpha_mode, mace_model):
"""
Check that we can instantiate the default MACE model.
"""
# Instantiate the MACEEMLE model.
model = MACEEMLE(alpha_mode=alpha_mode)
assert model is not None

# Make sure the model can be converted to TorchScript.
model = torch.jit.script(model)

0 comments on commit 1ee46f9

Please sign in to comment.