Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 51 additions & 47 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
if importlib.util.find_spec("torchvision"):
import torchvision.models as models

trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep")



@pytest.mark.unit
@pytest.mark.critical
def test_base_full_compile(ir):
def test_base_full_compile(ir, tmp_path):
"""
This tests export serde functionality on a base model
which is fully TRT convertible
Expand Down Expand Up @@ -56,9 +56,9 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
# Check Pyt and TRT exported program outputs
cos_sim = cosine_similarity(model(input), trt_module(input)[0])
assertions.assertTrue(
Expand All @@ -76,7 +76,7 @@ def forward(self, x):

@pytest.mark.unit
@pytest.mark.critical
def test_base_full_compile_multiple_outputs(ir):
def test_base_full_compile_multiple_outputs(ir, tmp_path):
"""
This tests export serde functionality on a base model
with multiple outputs which is fully TRT convertible
Expand Down Expand Up @@ -111,9 +111,9 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
# Check Pyt and TRT exported program outputs
outputs_pyt = model(input)
outputs_trt = trt_module(input)
Expand All @@ -136,7 +136,7 @@ def forward(self, x):

@pytest.mark.unit
@pytest.mark.critical
def test_no_compile(ir):
def test_no_compile(ir, tmp_path):
"""
This tests export serde functionality on a model
which won't convert to TRT because of min_block_size=5 constraint
Expand Down Expand Up @@ -170,9 +170,9 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
# Check Pyt and TRT exported program outputs
outputs_pyt = model(input)
outputs_trt = trt_module(input)
Expand All @@ -195,7 +195,7 @@ def forward(self, x):

@pytest.mark.unit
@pytest.mark.critical
def test_hybrid_relu_fallback(ir):
def test_hybrid_relu_fallback(ir, tmp_path):
"""
This tests export save and load functionality on a hybrid
model with Pytorch and TRT segments. Relu (unweighted) layer is forced to
Expand Down Expand Up @@ -232,9 +232,9 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)
for idx in range(len(outputs_pyt)):
Expand All @@ -258,7 +258,7 @@ def forward(self, x):
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
def test_resnet18(ir):
def test_resnet18(ir, tmp_path):
"""
This tests export save and load functionality on Resnet18 model
"""
Expand All @@ -279,9 +279,9 @@ def test_resnet18(ir):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
Expand All @@ -303,7 +303,7 @@ def test_resnet18(ir):
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
def test_resnet18_cpu_offload(ir):
def test_resnet18_cpu_offload(ir, tmp_path):
"""
This tests export save and load functionality on Resnet18 model
"""
Expand Down Expand Up @@ -331,9 +331,9 @@ def test_resnet18_cpu_offload(ir):
msg="Model should be offloaded to CPU",
)
model.cuda()
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
Expand All @@ -355,7 +355,7 @@ def test_resnet18_cpu_offload(ir):
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
def test_resnet18_dynamic(ir):
def test_resnet18_dynamic(ir, tmp_path):
"""
This tests export save and load functionality on Resnet18 model
"""
Expand All @@ -380,9 +380,10 @@ def test_resnet18_dynamic(ir):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, retrace=False)
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
# TODO: Enable this serialization issues are fixed
# deser_trt_module = torchtrt.load(trt_ep_path).module()
# deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
Expand All @@ -395,7 +396,7 @@ def test_resnet18_dynamic(ir):
@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
def test_resnet18_torch_exec_ops_serde(ir):
def test_resnet18_torch_exec_ops_serde(ir, tmp_path):
"""
This tests export save and load functionality on Resnet18 model
"""
Expand All @@ -413,8 +414,9 @@ def test_resnet18_torch_exec_ops_serde(ir):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, retrace=False)
deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
outputs_pyt = deser_trt_module(input)
outputs_trt = trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
Expand All @@ -426,7 +428,7 @@ def test_resnet18_torch_exec_ops_serde(ir):

@pytest.mark.unit
@pytest.mark.critical
def test_hybrid_conv_fallback(ir):
def test_hybrid_conv_fallback(ir, tmp_path):
"""
This tests export save and load functionality on a hybrid
model where a conv (a weighted layer) has been forced to fallback to Pytorch.
Expand Down Expand Up @@ -463,9 +465,9 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)

Expand All @@ -487,7 +489,7 @@ def forward(self, x):

@pytest.mark.unit
@pytest.mark.critical
def test_hybrid_conv_fallback_cpu_offload(ir):
def test_hybrid_conv_fallback_cpu_offload(ir, tmp_path):
"""
This tests export save and load functionality on a hybrid
model where a conv (a weighted layer) has been forced to fallback to Pytorch.
Expand Down Expand Up @@ -525,9 +527,9 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
model.cuda()
torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)

Expand All @@ -549,7 +551,7 @@ def forward(self, x):

@pytest.mark.unit
@pytest.mark.critical
def test_arange_export(ir):
def test_arange_export(ir, tmp_path):
"""
This tests export save and load functionality on a arange static graph
Here the arange output is a static constant (which is registered as input to the graph)
Expand Down Expand Up @@ -584,9 +586,9 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

torchtrt.save(trt_module, trt_ep_path, retrace=False)

deser_trt_module = torchtrt.load(trt_ep_path).module()
trt_ep_path = tmp_path / "trt.ep"
torchtrt.save(trt_module, str(trt_ep_path), retrace=False)
deser_trt_module = torchtrt.load(str(trt_ep_path)).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)

Expand All @@ -607,7 +609,7 @@ def forward(self, x):


@pytest.mark.unit
def test_save_load_ts(ir):
def test_save_load_ts(ir, tmp_path):
"""
This tests save/load API on Torchscript format (model still compiled using dynamo workflow)
"""
Expand Down Expand Up @@ -640,10 +642,12 @@ def forward(self, x):
msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule",
)
outputs_trt = trt_gm(input)
# Save it as torchscript representation
torchtrt.save(trt_gm, "./trt.ts", output_format="torchscript", inputs=[input])

trt_ts_module = torchtrt.load("./trt.ts")
# Save it as torchscript representation in the test's tmp_path
trt_ts_path = tmp_path / "trt.ts"
torchtrt.save(trt_gm, str(trt_ts_path), output_format="torchscript", inputs=[input])

trt_ts_module = torchtrt.load(str(trt_ts_path))
outputs_trt_deser = trt_ts_module(input)

cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)
Expand Down