From fd1c21bd3ea2bd4ac3914dfde958774ba498f514 Mon Sep 17 00:00:00 2001 From: leimao Date: Wed, 26 Nov 2025 20:15:58 -0800 Subject: [PATCH] Allow Model Export Test Parallelism --- tests/py/dynamo/models/test_export_serde.py | 98 +++++++++++---------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index c5b007e34b..e4f1e63a99 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -17,12 +17,12 @@ if importlib.util.find_spec("torchvision"): import torchvision.models as models -trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep") + @pytest.mark.unit @pytest.mark.critical -def test_base_full_compile(ir): +def test_base_full_compile(ir, tmp_path): """ This tests export serde functionality on a base model which is fully TRT convertible @@ -56,9 +56,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) - - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() # Check Pyt and TRT exported program outputs cos_sim = cosine_similarity(model(input), trt_module(input)[0]) assertions.assertTrue( @@ -76,7 +76,7 @@ def forward(self, x): @pytest.mark.unit @pytest.mark.critical -def test_base_full_compile_multiple_outputs(ir): +def test_base_full_compile_multiple_outputs(ir, tmp_path): """ This tests export serde functionality on a base model with multiple outputs which is fully TRT convertible @@ -111,9 +111,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) - - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() # Check Pyt and TRT exported program outputs outputs_pyt = model(input) outputs_trt = trt_module(input) @@ -136,7 +136,7 @@ def forward(self, x): @pytest.mark.unit @pytest.mark.critical -def test_no_compile(ir): +def test_no_compile(ir, tmp_path): """ This tests export serde functionality on a model which won't convert to TRT because of min_block_size=5 constraint @@ -170,9 +170,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) - - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() # Check Pyt and TRT exported program outputs outputs_pyt = model(input) outputs_trt = trt_module(input) @@ -195,7 +195,7 @@ def forward(self, x): @pytest.mark.unit @pytest.mark.critical -def test_hybrid_relu_fallback(ir): +def test_hybrid_relu_fallback(ir, tmp_path): """ This tests export save and load functionality on a hybrid model with Pytorch and TRT segments. Relu (unweighted) layer is forced to @@ -232,9 +232,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) - - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() outputs_pyt = model(input) outputs_trt = trt_module(input) for idx in range(len(outputs_pyt)): @@ -258,7 +258,7 @@ def forward(self, x): not importlib.util.find_spec("torchvision"), "torchvision is not installed", ) -def test_resnet18(ir): +def test_resnet18(ir, tmp_path): """ This tests export save and load functionality on Resnet18 model """ @@ -279,9 +279,9 @@ def test_resnet18(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) - - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() outputs_pyt = model(input) outputs_trt = trt_module(input) cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) @@ -303,7 +303,7 @@ def test_resnet18(ir): not importlib.util.find_spec("torchvision"), "torchvision is not installed", ) -def test_resnet18_cpu_offload(ir): +def test_resnet18_cpu_offload(ir, tmp_path): """ This tests export save and load functionality on Resnet18 model """ @@ -331,9 +331,9 @@ def test_resnet18_cpu_offload(ir): msg="Model should be offloaded to CPU", ) model.cuda() - torchtrt.save(trt_module, trt_ep_path, retrace=False) - - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() outputs_pyt = model(input) outputs_trt = trt_module(input) cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) @@ -355,7 +355,7 @@ def test_resnet18_cpu_offload(ir): not importlib.util.find_spec("torchvision"), "torchvision is not installed", ) -def test_resnet18_dynamic(ir): +def test_resnet18_dynamic(ir, tmp_path): """ This tests export save and load functionality on Resnet18 model """ @@ -380,9 +380,10 @@ def test_resnet18_dynamic(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) # TODO: Enable this serialization issues are fixed - # deser_trt_module = torchtrt.load(trt_ep_path).module() + # deser_trt_module = torchtrt.load(str(trt_ep_path)).module() outputs_pyt = model(input) outputs_trt = trt_module(input) cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) @@ -395,7 +396,7 @@ def test_resnet18_dynamic(ir): @unittest.skipIf( not importlib.util.find_spec("torchvision"), "torchvision not installed" ) -def test_resnet18_torch_exec_ops_serde(ir): +def test_resnet18_torch_exec_ops_serde(ir, tmp_path): """ This tests export save and load functionality on Resnet18 model """ @@ -413,8 +414,9 @@ def test_resnet18_torch_exec_ops_serde(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() outputs_pyt = deser_trt_module(input) outputs_trt = trt_module(input) cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) @@ -426,7 +428,7 @@ def test_resnet18_torch_exec_ops_serde(ir): @pytest.mark.unit @pytest.mark.critical -def test_hybrid_conv_fallback(ir): +def test_hybrid_conv_fallback(ir, tmp_path): """ This tests export save and load functionality on a hybrid model where a conv (a weighted layer) has been forced to fallback to Pytorch. @@ -463,9 +465,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) - - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() outputs_pyt = model(input) outputs_trt = trt_module(input) @@ -487,7 +489,7 @@ def forward(self, x): @pytest.mark.unit @pytest.mark.critical -def test_hybrid_conv_fallback_cpu_offload(ir): +def test_hybrid_conv_fallback_cpu_offload(ir, tmp_path): """ This tests export save and load functionality on a hybrid model where a conv (a weighted layer) has been forced to fallback to Pytorch. @@ -525,9 +527,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) model.cuda() - torchtrt.save(trt_module, trt_ep_path, retrace=False) - - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() outputs_pyt = model(input) outputs_trt = trt_module(input) @@ -549,7 +551,7 @@ def forward(self, x): @pytest.mark.unit @pytest.mark.critical -def test_arange_export(ir): +def test_arange_export(ir, tmp_path): """ This tests export save and load functionality on a arange static graph Here the arange output is a static constant (which is registered as input to the graph) @@ -584,9 +586,9 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) - - deser_trt_module = torchtrt.load(trt_ep_path).module() + trt_ep_path = tmp_path / "trt.ep" + torchtrt.save(trt_module, str(trt_ep_path), retrace=False) + deser_trt_module = torchtrt.load(str(trt_ep_path)).module() outputs_pyt = model(input) outputs_trt = trt_module(input) @@ -607,7 +609,7 @@ def forward(self, x): @pytest.mark.unit -def test_save_load_ts(ir): +def test_save_load_ts(ir, tmp_path): """ This tests save/load API on Torchscript format (model still compiled using dynamo workflow) """ @@ -640,10 +642,12 @@ def forward(self, x): msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule", ) outputs_trt = trt_gm(input) - # Save it as torchscript representation - torchtrt.save(trt_gm, "./trt.ts", output_format="torchscript", inputs=[input]) - trt_ts_module = torchtrt.load("./trt.ts") + # Save it as torchscript representation in the test's tmp_path + trt_ts_path = tmp_path / "trt.ts" + torchtrt.save(trt_gm, str(trt_ts_path), output_format="torchscript", inputs=[input]) + + trt_ts_module = torchtrt.load(str(trt_ts_path)) outputs_trt_deser = trt_ts_module(input) cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)