From bde96143767fbff85fae8055ac76454672e10960 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 11 May 2025 14:50:21 +0800 Subject: [PATCH 01/23] feat: return `ONNXProgram` when exporting with dynamo=True. --- src/lightning/pytorch/core/module.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..3760fa6b47ddd 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -74,8 +74,10 @@ if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh + from torch.onnx import ONNXProgram _ONNX_AVAILABLE = RequirementCache("onnx") +_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript") warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -1360,12 +1362,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None: ) @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None: + def to_onnx( + self, + file_path: Union[str, Path, BytesIO, None] = None, + input_sample: Optional[Any] = None, + **kwargs: Any, + ) -> Union["ONNXProgram", None]: """Saves the model in ONNX format. Args: - file_path: The path of the file the onnx model should be saved to. + file_path: The path of the file the onnx model should be saved to. Default: None (no file saved). input_sample: An input for tracing. Default: None (Use self.example_input_array) + **kwargs: Will be passed to torch.onnx.export function. Example:: @@ -1386,6 +1394,11 @@ def forward(self, x): if not _ONNX_AVAILABLE: raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.") + if kwargs.get("dynamo", False) and not _ONNXSCRIPT_AVAILABLE: + raise ModuleNotFoundError( + f"`{type(self).__name__}.to_onnx(dynamo=True)` requires `onnxscript` to be installed." + ) + mode = self.training if input_sample is None: @@ -1402,8 +1415,9 @@ def forward(self, x): file_path = str(file_path) if isinstance(file_path, Path) else file_path # PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but # BytesIO does work, too. - torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore + ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore self.train(mode) + return ret @torch.no_grad() def to_torchscript( From a966c0b553ac5ca54f4d714be8f6a6bc44bedf23 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 11 May 2025 14:58:17 +0800 Subject: [PATCH 02/23] test: add to_onnx(dynamo=True) unittests. --- requirements/pytorch/test.txt | 1 + src/lightning/pytorch/utilities/testing/_runif.py | 7 ++++++- tests/tests_pytorch/models/test_onnx.py | 14 ++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 412a8f270bf47..bcd73f9d91860 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -11,6 +11,7 @@ scikit-learn >0.22.1, <1.7.0 numpy >=1.17.2, <1.27.0 onnx >=1.12.0, <1.18.0 onnxruntime >=1.12.0, <1.21.0 +onnxscript >= 0.2.2, <0.2.6 psutil <7.0.1 # for `DeviceStatsMonitor` pandas >1.0, <2.3.0 # needed in benchmarks fastapi # for `ServableModuleValidator` # not setting version as re-defined in App diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 9c46913681143..5bb8f984b2749 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -18,7 +18,7 @@ from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE -from lightning.pytorch.core.module import _ONNX_AVAILABLE +from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") @@ -42,6 +42,7 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, + onnxscript: bool = False, ) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. @@ -64,6 +65,7 @@ def _runif_reasons( psutil: Require that psutil is installed. sklearn: Require that scikit-learn is installed. onnx: Require that onnx is installed. + onnxscript: Require that onnxscript is installed. """ @@ -96,4 +98,7 @@ def _runif_reasons( if onnx and not _ONNX_AVAILABLE: reasons.append("onnx") + if onnxscript and not _ONNXSCRIPT_AVAILABLE: + reasons.append("onnxscript") + return reasons, kwargs diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 81fd5631a3400..37993fe3d739e 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -13,6 +13,7 @@ # limitations under the License. import operator import os +import warnings from io import BytesIO from pathlib import Path from unittest.mock import patch @@ -167,3 +168,16 @@ def to_numpy(tensor): # compare ONNX Runtime and PyTorch results assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) + + +@RunIf(onnx=True, min_torch="2.7.0", dynamo=True, onnxscript=True) +def test_model_return_type(): + model = BoringModel() + model.example_input_array = torch.randn((1, 32)) + model.eval() + + # Temporarily suppress FutureWarning from onnxscript internal function. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + ret = model.to_onnx(dynamo=True) + assert isinstance(ret, torch.onnx.ONNXProgram) From e7342e38d992ce064c940d629a032a75e60f3d81 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 11 May 2025 15:33:25 +0800 Subject: [PATCH 03/23] fix: add ignore filter in pyproject.toml --- pyproject.toml | 1 + tests/tests_pytorch/models/test_onnx.py | 6 +----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 48439bee75332..d3f6b10673187 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -180,6 +180,7 @@ markers = [ ] filterwarnings = [ "error::FutureWarning", + "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated ] xfail_strict = true junit_duration_report = "call" diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 37993fe3d739e..19490aad36852 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -13,7 +13,6 @@ # limitations under the License. import operator import os -import warnings from io import BytesIO from pathlib import Path from unittest.mock import patch @@ -176,8 +175,5 @@ def test_model_return_type(): model.example_input_array = torch.randn((1, 32)) model.eval() - # Temporarily suppress FutureWarning from onnxscript internal function. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - ret = model.to_onnx(dynamo=True) + ret = model.to_onnx(dynamo=True) assert isinstance(ret, torch.onnx.ONNXProgram) From 3ee3ea92b976e49f7a38d7e70f9255969f99f328 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 21 May 2025 20:15:56 +0800 Subject: [PATCH 04/23] fix: change the return type annotation of `to_onnx`. --- src/lightning/pytorch/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 3760fa6b47ddd..900105a2e296d 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1367,7 +1367,7 @@ def to_onnx( file_path: Union[str, Path, BytesIO, None] = None, input_sample: Optional[Any] = None, **kwargs: Any, - ) -> Union["ONNXProgram", None]: + ) -> Optional["ONNXProgram"]: """Saves the model in ONNX format. Args: From bc812155ff884df41918d2e87e0b7708b6e623c6 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 21 May 2025 20:23:13 +0800 Subject: [PATCH 05/23] test: add parametrized `dynamo` to test `test_if_inference_output_is_valid`. --- tests/tests_pytorch/models/test_onnx.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 19490aad36852..54c2107aae906 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -139,8 +139,15 @@ def test_error_if_no_input(tmp_path): model.to_onnx(file_path) +@pytest.mark.parametrize( + "dynamo", + [ + False, + pytest.param(True, marks=RunIf(min_torch="2.7.0", dynamo=True, onnxscript=True)), + ], +) @RunIf(onnx=True) -def test_if_inference_output_is_valid(tmp_path): +def test_if_inference_output_is_valid(tmp_path, dynamo): """Test that the output inferred from ONNX model is same as from PyTorch.""" model = BoringModel() model.example_input_array = torch.randn(5, 32) @@ -153,7 +160,7 @@ def test_if_inference_output_is_valid(tmp_path): torch_out = model(model.example_input_array) file_path = os.path.join(tmp_path, "model.onnx") - model.to_onnx(file_path, model.example_input_array, export_params=True) + model.to_onnx(file_path, model.example_input_array, export_params=True, dynamo=dynamo) ort_kwargs = {"providers": "CPUExecutionProvider"} if compare_version("onnxruntime", operator.ge, "1.16.0") else {} ort_session = onnxruntime.InferenceSession(file_path, **ort_kwargs) From 236f1a071cbc6cdad19fe988038f139d0bb41fda Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 21 May 2025 22:01:11 +0800 Subject: [PATCH 06/23] test: add difference check in `test_model_return_type`. --- tests/tests_pytorch/models/test_onnx.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 54c2107aae906..4b9fa758a1bb4 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -182,5 +182,10 @@ def test_model_return_type(): model.example_input_array = torch.randn((1, 32)) model.eval() - ret = model.to_onnx(dynamo=True) - assert isinstance(ret, torch.onnx.ONNXProgram) + onnx_pg = model.to_onnx(dynamo=True) + assert isinstance(onnx_pg, torch.onnx.ONNXProgram) + + model_ret = model(model.example_input_array) + inf_ret = onnx_pg(model.example_input_array) + + assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05) From 019125d51de768d953a470164925d3fed065e2b4 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 22 May 2025 02:23:15 +0800 Subject: [PATCH 07/23] fix: fix unittest. --- tests/tests_pytorch/models/test_onnx.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 4b9fa758a1bb4..50ada13b12a8c 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -142,7 +142,8 @@ def test_error_if_no_input(tmp_path): @pytest.mark.parametrize( "dynamo", [ - False, + None, + pytest.param(False, marks=RunIf(min_torch="2.7.0", dynamo=True, onnxscript=True)), pytest.param(True, marks=RunIf(min_torch="2.7.0", dynamo=True, onnxscript=True)), ], ) @@ -160,7 +161,12 @@ def test_if_inference_output_is_valid(tmp_path, dynamo): torch_out = model(model.example_input_array) file_path = os.path.join(tmp_path, "model.onnx") - model.to_onnx(file_path, model.example_input_array, export_params=True, dynamo=dynamo) + kwargs = { + "export_params": True, + } + if dynamo is not None: + kwargs["dynamo"] = dynamo + model.to_onnx(file_path, model.example_input_array, **kwargs) ort_kwargs = {"providers": "CPUExecutionProvider"} if compare_version("onnxruntime", operator.ge, "1.16.0") else {} ort_session = onnxruntime.InferenceSession(file_path, **ort_kwargs) From 791d7773dcbd58da93274bbe0a4fd6851ed0f34c Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Mon, 2 Jun 2025 22:47:04 +0800 Subject: [PATCH 08/23] deps: bump typing_extension for onnxscript. --- requirements/pytorch/base.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 1b1f743a618b9..9021f4c828a68 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -7,5 +7,5 @@ PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2025.6.0 torchmetrics >=0.7.0, <1.8.0 packaging >=20.0, <=25.0 -typing-extensions >=4.4.0, <4.14.0 +typing-extensions >=4.10.0, <4.14.0 lightning-utilities >=0.10.0, <0.15.0 From acdf3c123cab7fa8ff24a11ac5ddf95c478c279a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Mon, 2 Jun 2025 23:01:17 +0800 Subject: [PATCH 09/23] deps: bump typing_extension for onnxscript. --- requirements/fabric/base.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 335742103d078..85072649babf1 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -4,5 +4,5 @@ torch >=2.1.0, <2.8.0 fsspec[http] >=2022.5.0, <2025.6.0 packaging >=20.0, <=25.0 -typing-extensions >=4.4.0, <4.14.0 +typing-extensions >=4.10.0, <4.14.0 lightning-utilities >=0.10.0, <0.15.0 From e046d272e939beed1df9187c529085cf33b68e4b Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Tue, 3 Jun 2025 20:33:05 +0800 Subject: [PATCH 10/23] deps: bump onnxscript upper bound. --- requirements/pytorch/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 160c294701976..d38dcbcf74dba 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -11,7 +11,7 @@ scikit-learn >0.22.1, <1.7.0 numpy >=1.17.2, <1.27.0 onnx >=1.12.0, <1.19.0 onnxruntime >=1.12.0, <1.21.0 -onnxscript >= 0.2.2, <0.2.6 +onnxscript >= 0.2.2, <0.3.0 psutil <7.0.1 # for `DeviceStatsMonitor` pandas >1.0, <2.3.0 # needed in benchmarks fastapi # for `ServableModuleValidator` # not setting version as re-defined in App From a0a7d1f29aa37bec6c4800d358a4a8d258af7276 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 6 Jun 2025 05:28:19 +0800 Subject: [PATCH 11/23] test: add test `test_model_onnx_export_missing_onnxscript`. --- tests/tests_pytorch/models/test_onnx.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 50ada13b12a8c..dfe5dfd9cc009 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -13,6 +13,7 @@ # limitations under the License. import operator import os +import re from io import BytesIO from pathlib import Path from unittest.mock import patch @@ -25,6 +26,7 @@ import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer +from lightning.pytorch.core.module import _ONNXSCRIPT_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf from tests_pytorch.utilities.test_model_summary import UnorderedModel @@ -182,6 +184,20 @@ def to_numpy(tensor): assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) +@RunIf(min_torch="2.7.0", dynamo=True) +@pytest.mark.skipif(_ONNXSCRIPT_AVAILABLE, reason="Run this test only if onnxscript is not available.") +def test_model_onnx_export_missing_onnxscript(): + """Test that an error is raised if onnxscript is not available.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.raises( + ModuleNotFoundError, + match=re.escape(f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` to be installed."), + ): + model.to_onnx(dynamo=True) + + @RunIf(onnx=True, min_torch="2.7.0", dynamo=True, onnxscript=True) def test_model_return_type(): model = BoringModel() From aca9fd1c138531f33baefcd5b750677270d5a5f1 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 7 Jun 2025 09:50:37 +0800 Subject: [PATCH 12/23] revert typing-extension bump. --- requirements/fabric/base.txt | 2 +- requirements/pytorch/base.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 85072649babf1..335742103d078 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -4,5 +4,5 @@ torch >=2.1.0, <2.8.0 fsspec[http] >=2022.5.0, <2025.6.0 packaging >=20.0, <=25.0 -typing-extensions >=4.10.0, <4.14.0 +typing-extensions >=4.4.0, <4.14.0 lightning-utilities >=0.10.0, <0.15.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 9021f4c828a68..1b1f743a618b9 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -7,5 +7,5 @@ PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2025.6.0 torchmetrics >=0.7.0, <1.8.0 packaging >=20.0, <=25.0 -typing-extensions >=4.10.0, <4.14.0 +typing-extensions >=4.4.0, <4.14.0 lightning-utilities >=0.10.0, <0.15.0 From 1396f3531c38a2d73376396ee1c75b03eb9b41fe Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 7 Jun 2025 09:52:31 +0800 Subject: [PATCH 13/23] lower the min_torch version in unittest. --- tests/tests_pytorch/models/test_onnx.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index dfe5dfd9cc009..60e557607eb65 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -145,8 +145,8 @@ def test_error_if_no_input(tmp_path): "dynamo", [ None, - pytest.param(False, marks=RunIf(min_torch="2.7.0", dynamo=True, onnxscript=True)), - pytest.param(True, marks=RunIf(min_torch="2.7.0", dynamo=True, onnxscript=True)), + pytest.param(False, marks=RunIf(min_torch="2.6.0", dynamo=True, onnxscript=True)), + pytest.param(True, marks=RunIf(min_torch="2.6.0", dynamo=True, onnxscript=True)), ], ) @RunIf(onnx=True) @@ -184,7 +184,7 @@ def to_numpy(tensor): assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) -@RunIf(min_torch="2.7.0", dynamo=True) +@RunIf(min_torch="2.6.0", dynamo=True) @pytest.mark.skipif(_ONNXSCRIPT_AVAILABLE, reason="Run this test only if onnxscript is not available.") def test_model_onnx_export_missing_onnxscript(): """Test that an error is raised if onnxscript is not available.""" @@ -198,7 +198,7 @@ def test_model_onnx_export_missing_onnxscript(): model.to_onnx(dynamo=True) -@RunIf(onnx=True, min_torch="2.7.0", dynamo=True, onnxscript=True) +@RunIf(onnx=True, min_torch="2.6.0", dynamo=True, onnxscript=True) def test_model_return_type(): model = BoringModel() model.example_input_array = torch.randn((1, 32)) From 8f050ea51e2d96e9142176d2759f76bae1797523 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Mon, 16 Jun 2025 23:02:16 +0800 Subject: [PATCH 14/23] feat: enable ONNXProgram export on torch 2.5.0 --- src/lightning/fabric/utilities/imports.py | 1 + src/lightning/pytorch/core/module.py | 6 +++-- tests/tests_pytorch/models/test_onnx.py | 32 ++++++++++++++++++----- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index a618371d7f2b4..1962e336b3eb9 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -34,6 +34,7 @@ _TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") +_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 4aa57817286a2..71f153387ff6f 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -47,6 +47,7 @@ from lightning.fabric.utilities.apply_func import convert_to_tensors from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH from lightning.fabric.wrappers import _FabricOptimizer from lightning.pytorch.callbacks.callback import Callback @@ -1394,9 +1395,10 @@ def forward(self, x): if not _ONNX_AVAILABLE: raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.") - if kwargs.get("dynamo", False) and not _ONNXSCRIPT_AVAILABLE: + if kwargs.get("dynamo", False) and not (_ONNXSCRIPT_AVAILABLE and _TORCH_GREATER_EQUAL_2_5): raise ModuleNotFoundError( - f"`{type(self).__name__}.to_onnx(dynamo=True)` requires `onnxscript` to be installed." + f"`{type(self).__name__}.to_onnx(dynamo=True)` " + "requires `onnxscript` and `torch>=2.5.0` to be installed." ) mode = self.training diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 60e557607eb65..57e0db014c2cf 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -145,8 +145,8 @@ def test_error_if_no_input(tmp_path): "dynamo", [ None, - pytest.param(False, marks=RunIf(min_torch="2.6.0", dynamo=True, onnxscript=True)), - pytest.param(True, marks=RunIf(min_torch="2.6.0", dynamo=True, onnxscript=True)), + pytest.param(False, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)), + pytest.param(True, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)), ], ) @RunIf(onnx=True) @@ -184,7 +184,7 @@ def to_numpy(tensor): assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) -@RunIf(min_torch="2.6.0", dynamo=True) +@RunIf(min_torch="2.5.0", dynamo=True) @pytest.mark.skipif(_ONNXSCRIPT_AVAILABLE, reason="Run this test only if onnxscript is not available.") def test_model_onnx_export_missing_onnxscript(): """Test that an error is raised if onnxscript is not available.""" @@ -193,21 +193,41 @@ def test_model_onnx_export_missing_onnxscript(): with pytest.raises( ModuleNotFoundError, - match=re.escape(f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` to be installed."), + match=re.escape( + f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.", + ), ): model.to_onnx(dynamo=True) -@RunIf(onnx=True, min_torch="2.6.0", dynamo=True, onnxscript=True) +@RunIf(onnx=True, min_torch="2.5.0", dynamo=True, onnxscript=True) def test_model_return_type(): model = BoringModel() model.example_input_array = torch.randn((1, 32)) model.eval() onnx_pg = model.to_onnx(dynamo=True) - assert isinstance(onnx_pg, torch.onnx.ONNXProgram) + + onnx_cls = torch.onnx.ONNXProgram if torch.__version__ >= "2.6.0" else torch.onnx._internal.exporter.ONNXProgram + + assert isinstance(onnx_pg, onnx_cls) model_ret = model(model.example_input_array) inf_ret = onnx_pg(model.example_input_array) assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05) + + +@RunIf(max_torch="2.5.0") +def test_model_onnx_export_wrong_torch_version(): + """Test that an error is raised if onnxscript is not available.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.raises( + ModuleNotFoundError, + match=re.escape( + f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.", + ), + ): + model.to_onnx(dynamo=True) From ce3e6b740c9a19f31c393240babfbcaf544d8523 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Mon, 16 Jun 2025 18:43:38 +0200 Subject: [PATCH 15/23] extensions --- requirements/fabric/base.txt | 2 +- requirements/pytorch/base.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 335742103d078..6ff5be6805cd3 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -4,5 +4,5 @@ torch >=2.1.0, <2.8.0 fsspec[http] >=2022.5.0, <2025.6.0 packaging >=20.0, <=25.0 -typing-extensions >=4.4.0, <4.14.0 +typing-extensions >4.4.0, <4.14.0 lightning-utilities >=0.10.0, <0.15.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 1b1f743a618b9..f155e2987c3d6 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -7,5 +7,5 @@ PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2025.6.0 torchmetrics >=0.7.0, <1.8.0 packaging >=20.0, <=25.0 -typing-extensions >=4.4.0, <4.14.0 +typing-extensions >4.4.0, <4.14.0 lightning-utilities >=0.10.0, <0.15.0 From a470fe85391718ad46e2e5c3b478db58815f6382 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Wed, 18 Jun 2025 13:45:27 +0200 Subject: [PATCH 16/23] ds --- src/lightning/fabric/utilities/testing/_runif.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 6f5d933f9dae3..671ebc43bca4f 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -112,9 +112,7 @@ def _runif_reasons( reasons.append("Standalone execution") kwargs["standalone"] = True - if deepspeed and not ( - _DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils") - ): + if deepspeed and not (_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4): reasons.append("Deepspeed") if dynamo: From 9e4a4947bdabeb7e8847d88941fc3055643d2b7f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Jun 2025 11:45:51 +0000 Subject: [PATCH 17/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/utilities/testing/_runif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 671ebc43bca4f..a2c605f411f0a 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -17,7 +17,7 @@ from typing import Optional import torch -from lightning_utilities.core.imports import RequirementCache, compare_version +from lightning_utilities.core.imports import compare_version from packaging.version import Version from lightning.fabric.accelerators import XLAAccelerator From 67af423bd75d2da72c0e874e71d22ead1591191a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 18 Jun 2025 22:07:52 +0800 Subject: [PATCH 18/23] dep: test fixing pydantic version. --- requirements/pytorch/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 6c29eb67fe36a..40b14ac7cfcb1 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,3 +18,4 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger` +pydantic >= 1.10.22 From 0e4cb807c79b223465a2cd78d400fcbb0c22c77a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 18 Jun 2025 22:51:28 +0800 Subject: [PATCH 19/23] Revert "dep: test fixing pydantic version." This reverts commit 67af423bd75d2da72c0e874e71d22ead1591191a. --- requirements/pytorch/test.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 40b14ac7cfcb1..6c29eb67fe36a 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,4 +18,3 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger` -pydantic >= 1.10.22 From b26072d7030641dc0f16eb4b88227fafee30c0c5 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 19 Jun 2025 00:22:24 +0800 Subject: [PATCH 20/23] dep: add serve deps. --- .actions/assistant.py | 1 + requirements/pytorch/serve.txt | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 requirements/pytorch/serve.txt diff --git a/.actions/assistant.py b/.actions/assistant.py index 7b2d49423d622..7ed6e93d2702e 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -33,6 +33,7 @@ "requirements/pytorch/extra.txt", "requirements/pytorch/strategies.txt", "requirements/pytorch/examples.txt", + "requirements/pytorch/serve.txt", ), "fabric": ( "requirements/fabric/base.txt", diff --git a/requirements/pytorch/serve.txt b/requirements/pytorch/serve.txt new file mode 100644 index 0000000000000..5c96b03502b89 --- /dev/null +++ b/requirements/pytorch/serve.txt @@ -0,0 +1,2 @@ +fastapi >= 0.98.0 +pydantic >= 1.10.22 From d1b859760cd10713ebe1a95e55dd4c125f594517 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 19 Jun 2025 00:50:35 +0800 Subject: [PATCH 21/23] ci: test. --- .github/workflows/ci-tests-fabric.yml | 2 +- .github/workflows/ci-tests-pytorch.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index 4e4d2c9eed3cb..1e6b9f67bce11 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -138,7 +138,7 @@ jobs: - name: Install package & dependencies timeout-minutes: 20 run: | - pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \ + pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies,${EXTRA_PREFIX}serve]" -U --prefer-binary \ --extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" pip list - name: Dump handy wheels diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index d295d5475942a..a99740701a738 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -136,7 +136,7 @@ jobs: - name: Install package & dependencies timeout-minutes: 20 run: | - pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \ + pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies,${EXTRA_PREFIX}serve]" -U --prefer-binary \ -r requirements/_integrations/accelerators.txt \ --extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" pip list From aa951fd6cce9e017955f66510deae67c34e97dbf Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 26 Jun 2025 21:59:16 +0800 Subject: [PATCH 22/23] update onnxscript upperbound. --- requirements/pytorch/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 83534244bea61..d3cf86a48d79d 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -11,7 +11,7 @@ scikit-learn >0.22.1, <1.7.0 numpy >=1.17.2, <1.27.0 onnx >=1.12.0, <1.19.0 onnxruntime >=1.12.0, <1.21.0 -onnxscript >= 0.2.2, <0.3.0 +onnxscript >= 0.2.2, <0.4.0 psutil <7.0.1 # for `DeviceStatsMonitor` pandas >2.0, <2.4.0 # needed in benchmarks fastapi # for `ServableModuleValidator` # not setting version as re-defined in App From 9491953f825d66cd6a12cfbf0d70e2716e4807dd Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 28 Jun 2025 11:48:43 +0800 Subject: [PATCH 23/23] align with ce3e6b7 --- requirements/fabric/base.txt | 2 +- requirements/pytorch/base.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 6a75c42428915..264ab6969a965 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -4,5 +4,5 @@ torch >=2.1.0, <2.8.0 fsspec[http] >=2022.5.0, <2025.6.0 packaging >=20.0, <=25.0 -typing-extensions >=4.5.0, <4.15.0 +typing-extensions >4.5.0, <4.15.0 lightning-utilities >=0.10.0, <0.15.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index c649a2b39ab90..3afd1aef2519c 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -7,5 +7,5 @@ PyYAML >5.4, <6.1.0 fsspec[http] >=2022.5.0, <2025.6.0 torchmetrics >0.7.0, <1.8.0 packaging >=20.0, <=25.0 -typing-extensions >=4.5.0, <4.15.0 +typing-extensions >4.5.0, <4.15.0 lightning-utilities >=0.10.0, <0.15.0