diff --git a/.actions/assistant.py b/.actions/assistant.py index 7b2d49423d622..7ed6e93d2702e 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -33,6 +33,7 @@ "requirements/pytorch/extra.txt", "requirements/pytorch/strategies.txt", "requirements/pytorch/examples.txt", + "requirements/pytorch/serve.txt", ), "fabric": ( "requirements/fabric/base.txt", diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index 4e4d2c9eed3cb..1e6b9f67bce11 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -138,7 +138,7 @@ jobs: - name: Install package & dependencies timeout-minutes: 20 run: | - pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \ + pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies,${EXTRA_PREFIX}serve]" -U --prefer-binary \ --extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" pip list - name: Dump handy wheels diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index d295d5475942a..a99740701a738 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -136,7 +136,7 @@ jobs: - name: Install package & dependencies timeout-minutes: 20 run: | - pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \ + pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies,${EXTRA_PREFIX}serve]" -U --prefer-binary \ -r requirements/_integrations/accelerators.txt \ --extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" pip list diff --git a/pyproject.toml b/pyproject.toml index b45f60489c6fe..a63da5f246392 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -180,6 +180,7 @@ markers = [ ] filterwarnings = [ "error::FutureWarning", + "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated ] xfail_strict = true junit_duration_report = "call" diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index cb70b24ae26c4..9c17d44d95e8d 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -4,5 +4,5 @@ torch >=2.1.0, <2.8.0 fsspec[http] >=2022.5.0, <2025.6.0 packaging >=20.0, <=25.0 -typing-extensions >=4.4.0, <4.15.0 +typing-extensions >4.4.0, <4.15.0 lightning-utilities >=0.10.0, <0.15.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index e77ecc8a1baeb..904d889e58f96 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -7,5 +7,5 @@ PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2025.6.0 torchmetrics >=0.7.0, <1.8.0 packaging >=20.0, <=25.0 -typing-extensions >=4.4.0, <4.15.0 +typing-extensions >4.4.0, <4.15.0 lightning-utilities >=0.10.0, <0.15.0 diff --git a/requirements/pytorch/serve.txt b/requirements/pytorch/serve.txt new file mode 100644 index 0000000000000..5c96b03502b89 --- /dev/null +++ b/requirements/pytorch/serve.txt @@ -0,0 +1,2 @@ +fastapi >= 0.98.0 +pydantic >= 1.10.22 diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 865109c87b140..6c29eb67fe36a 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -11,6 +11,7 @@ scikit-learn >0.22.1, <1.7.0 numpy >=1.17.2, <1.27.0 onnx >=1.12.0, <1.19.0 onnxruntime >=1.12.0, <1.21.0 +onnxscript >= 0.2.2, <0.3.0 psutil <7.0.1 # for `DeviceStatsMonitor` pandas >2.0, <2.4.0 # needed in benchmarks fastapi # for `ServableModuleValidator` # not setting version as re-defined in App diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index a618371d7f2b4..1962e336b3eb9 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -34,6 +34,7 @@ _TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") +_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 6f5d933f9dae3..a2c605f411f0a 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -17,7 +17,7 @@ from typing import Optional import torch -from lightning_utilities.core.imports import RequirementCache, compare_version +from lightning_utilities.core.imports import compare_version from packaging.version import Version from lightning.fabric.accelerators import XLAAccelerator @@ -112,9 +112,7 @@ def _runif_reasons( reasons.append("Standalone execution") kwargs["standalone"] = True - if deepspeed and not ( - _DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils") - ): + if deepspeed and not (_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4): reasons.append("Deepspeed") if dynamo: diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 7df0cb7757f81..101cceedfaaf0 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -47,6 +47,7 @@ from lightning.fabric.utilities.apply_func import convert_to_tensors from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH from lightning.fabric.wrappers import _FabricOptimizer from lightning.pytorch.callbacks.callback import Callback @@ -74,8 +75,10 @@ if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh + from torch.onnx import ONNXProgram _ONNX_AVAILABLE = RequirementCache("onnx") +_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript") warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -1386,12 +1389,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None: ) @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None: + def to_onnx( + self, + file_path: Union[str, Path, BytesIO, None] = None, + input_sample: Optional[Any] = None, + **kwargs: Any, + ) -> Optional["ONNXProgram"]: """Saves the model in ONNX format. Args: - file_path: The path of the file the onnx model should be saved to. + file_path: The path of the file the onnx model should be saved to. Default: None (no file saved). input_sample: An input for tracing. Default: None (Use self.example_input_array) + **kwargs: Will be passed to torch.onnx.export function. Example:: @@ -1412,6 +1421,12 @@ def forward(self, x): if not _ONNX_AVAILABLE: raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.") + if kwargs.get("dynamo", False) and not (_ONNXSCRIPT_AVAILABLE and _TORCH_GREATER_EQUAL_2_5): + raise ModuleNotFoundError( + f"`{type(self).__name__}.to_onnx(dynamo=True)` " + "requires `onnxscript` and `torch>=2.5.0` to be installed." + ) + mode = self.training if input_sample is None: @@ -1428,8 +1443,9 @@ def forward(self, x): file_path = str(file_path) if isinstance(file_path, Path) else file_path # PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but # BytesIO does work, too. - torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore + ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore self.train(mode) + return ret @torch.no_grad() def to_torchscript( diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 9c46913681143..5bb8f984b2749 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -18,7 +18,7 @@ from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE -from lightning.pytorch.core.module import _ONNX_AVAILABLE +from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") @@ -42,6 +42,7 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, + onnxscript: bool = False, ) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. @@ -64,6 +65,7 @@ def _runif_reasons( psutil: Require that psutil is installed. sklearn: Require that scikit-learn is installed. onnx: Require that onnx is installed. + onnxscript: Require that onnxscript is installed. """ @@ -96,4 +98,7 @@ def _runif_reasons( if onnx and not _ONNX_AVAILABLE: reasons.append("onnx") + if onnxscript and not _ONNXSCRIPT_AVAILABLE: + reasons.append("onnxscript") + return reasons, kwargs diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 81fd5631a3400..57e0db014c2cf 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -13,6 +13,7 @@ # limitations under the License. import operator import os +import re from io import BytesIO from pathlib import Path from unittest.mock import patch @@ -25,6 +26,7 @@ import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer +from lightning.pytorch.core.module import _ONNXSCRIPT_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf from tests_pytorch.utilities.test_model_summary import UnorderedModel @@ -139,8 +141,16 @@ def test_error_if_no_input(tmp_path): model.to_onnx(file_path) +@pytest.mark.parametrize( + "dynamo", + [ + None, + pytest.param(False, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)), + pytest.param(True, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)), + ], +) @RunIf(onnx=True) -def test_if_inference_output_is_valid(tmp_path): +def test_if_inference_output_is_valid(tmp_path, dynamo): """Test that the output inferred from ONNX model is same as from PyTorch.""" model = BoringModel() model.example_input_array = torch.randn(5, 32) @@ -153,7 +163,12 @@ def test_if_inference_output_is_valid(tmp_path): torch_out = model(model.example_input_array) file_path = os.path.join(tmp_path, "model.onnx") - model.to_onnx(file_path, model.example_input_array, export_params=True) + kwargs = { + "export_params": True, + } + if dynamo is not None: + kwargs["dynamo"] = dynamo + model.to_onnx(file_path, model.example_input_array, **kwargs) ort_kwargs = {"providers": "CPUExecutionProvider"} if compare_version("onnxruntime", operator.ge, "1.16.0") else {} ort_session = onnxruntime.InferenceSession(file_path, **ort_kwargs) @@ -167,3 +182,52 @@ def to_numpy(tensor): # compare ONNX Runtime and PyTorch results assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) + + +@RunIf(min_torch="2.5.0", dynamo=True) +@pytest.mark.skipif(_ONNXSCRIPT_AVAILABLE, reason="Run this test only if onnxscript is not available.") +def test_model_onnx_export_missing_onnxscript(): + """Test that an error is raised if onnxscript is not available.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.raises( + ModuleNotFoundError, + match=re.escape( + f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.", + ), + ): + model.to_onnx(dynamo=True) + + +@RunIf(onnx=True, min_torch="2.5.0", dynamo=True, onnxscript=True) +def test_model_return_type(): + model = BoringModel() + model.example_input_array = torch.randn((1, 32)) + model.eval() + + onnx_pg = model.to_onnx(dynamo=True) + + onnx_cls = torch.onnx.ONNXProgram if torch.__version__ >= "2.6.0" else torch.onnx._internal.exporter.ONNXProgram + + assert isinstance(onnx_pg, onnx_cls) + + model_ret = model(model.example_input_array) + inf_ret = onnx_pg(model.example_input_array) + + assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05) + + +@RunIf(max_torch="2.5.0") +def test_model_onnx_export_wrong_torch_version(): + """Test that an error is raised if onnxscript is not available.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.raises( + ModuleNotFoundError, + match=re.escape( + f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.", + ), + ): + model.to_onnx(dynamo=True)