Skip to content

to_onnx return ONNXProgram #20811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ markers = [
]
filterwarnings = [
"error::FutureWarning",
"ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated
]
xfail_strict = true
junit_duration_report = "call"
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
torch >=2.1.0, <2.8.0
fsspec[http] >=2022.5.0, <2025.6.0
packaging >=20.0, <=25.0
typing-extensions >=4.4.0, <4.14.0
typing-extensions >=4.10.0, <4.14.0
lightning-utilities >=0.10.0, <0.15.0
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2025.6.0
torchmetrics >=0.7.0, <1.8.0
packaging >=20.0, <=25.0
typing-extensions >=4.4.0, <4.14.0
typing-extensions >=4.10.0, <4.14.0
lightning-utilities >=0.10.0, <0.15.0
1 change: 1 addition & 0 deletions requirements/pytorch/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ scikit-learn >0.22.1, <1.7.0
numpy >=1.17.2, <1.27.0
onnx >=1.12.0, <1.19.0
onnxruntime >=1.12.0, <1.21.0
onnxscript >= 0.2.2, <0.3.0
psutil <7.0.1 # for `DeviceStatsMonitor`
pandas >1.0, <2.3.0 # needed in benchmarks
fastapi # for `ServableModuleValidator` # not setting version as re-defined in App
Expand Down
20 changes: 17 additions & 3 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.onnx import ONNXProgram

_ONNX_AVAILABLE = RequirementCache("onnx")
_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript")

warning_cache = WarningCache()
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -1360,12 +1362,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None:
)

@torch.no_grad()
def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
def to_onnx(
self,
file_path: Union[str, Path, BytesIO, None] = None,
input_sample: Optional[Any] = None,
**kwargs: Any,
) -> Optional["ONNXProgram"]:
"""Saves the model in ONNX format.

Args:
file_path: The path of the file the onnx model should be saved to.
file_path: The path of the file the onnx model should be saved to. Default: None (no file saved).
input_sample: An input for tracing. Default: None (Use self.example_input_array)

**kwargs: Will be passed to torch.onnx.export function.

Example::
Expand All @@ -1386,6 +1394,11 @@ def forward(self, x):
if not _ONNX_AVAILABLE:
raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.")

if kwargs.get("dynamo", False) and not _ONNXSCRIPT_AVAILABLE:
raise ModuleNotFoundError(
f"`{type(self).__name__}.to_onnx(dynamo=True)` requires `onnxscript` to be installed."
)

mode = self.training

if input_sample is None:
Expand All @@ -1402,8 +1415,9 @@ def forward(self, x):
file_path = str(file_path) if isinstance(file_path, Path) else file_path
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
# BytesIO does work, too.
torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
self.train(mode)
return ret

@torch.no_grad()
def to_torchscript(
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE

_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")
Expand All @@ -42,6 +42,7 @@ def _runif_reasons(
psutil: bool = False,
sklearn: bool = False,
onnx: bool = False,
onnxscript: bool = False,
) -> tuple[list[str], dict[str, bool]]:
"""Construct reasons for pytest skipif.

Expand All @@ -64,6 +65,7 @@ def _runif_reasons(
psutil: Require that psutil is installed.
sklearn: Require that scikit-learn is installed.
onnx: Require that onnx is installed.
onnxscript: Require that onnxscript is installed.

"""

Expand Down Expand Up @@ -96,4 +98,7 @@ def _runif_reasons(
if onnx and not _ONNX_AVAILABLE:
reasons.append("onnx")

if onnxscript and not _ONNXSCRIPT_AVAILABLE:
reasons.append("onnxscript")

return reasons, kwargs
32 changes: 30 additions & 2 deletions tests/tests_pytorch/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,16 @@ def test_error_if_no_input(tmp_path):
model.to_onnx(file_path)


@pytest.mark.parametrize(
"dynamo",
[
None,
pytest.param(False, marks=RunIf(min_torch="2.7.0", dynamo=True, onnxscript=True)),
pytest.param(True, marks=RunIf(min_torch="2.7.0", dynamo=True, onnxscript=True)),
],
)
@RunIf(onnx=True)
def test_if_inference_output_is_valid(tmp_path):
def test_if_inference_output_is_valid(tmp_path, dynamo):
"""Test that the output inferred from ONNX model is same as from PyTorch."""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
Expand All @@ -153,7 +161,12 @@ def test_if_inference_output_is_valid(tmp_path):
torch_out = model(model.example_input_array)

file_path = os.path.join(tmp_path, "model.onnx")
model.to_onnx(file_path, model.example_input_array, export_params=True)
kwargs = {
"export_params": True,
}
if dynamo is not None:
kwargs["dynamo"] = dynamo
model.to_onnx(file_path, model.example_input_array, **kwargs)

ort_kwargs = {"providers": "CPUExecutionProvider"} if compare_version("onnxruntime", operator.ge, "1.16.0") else {}
ort_session = onnxruntime.InferenceSession(file_path, **ort_kwargs)
Expand All @@ -167,3 +180,18 @@ def to_numpy(tensor):

# compare ONNX Runtime and PyTorch results
assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)


@RunIf(onnx=True, min_torch="2.7.0", dynamo=True, onnxscript=True)
def test_model_return_type():
model = BoringModel()
model.example_input_array = torch.randn((1, 32))
model.eval()

onnx_pg = model.to_onnx(dynamo=True)
assert isinstance(onnx_pg, torch.onnx.ONNXProgram)

model_ret = model(model.example_input_array)
inf_ret = onnx_pg(model.example_input_array)

assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05)
Loading