Skip to content

to_onnx return ONNXProgram #20811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ markers = [
]
filterwarnings = [
"error::FutureWarning",
"ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated
]
xfail_strict = true
junit_duration_report = "call"
1 change: 1 addition & 0 deletions requirements/pytorch/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ scikit-learn >0.22.1, <1.7.0
numpy >=1.17.2, <1.27.0
onnx >=1.12.0, <1.18.0
onnxruntime >=1.12.0, <1.21.0
onnxscript >= 0.2.2, <0.2.6
psutil <7.0.1 # for `DeviceStatsMonitor`
pandas >1.0, <2.3.0 # needed in benchmarks
fastapi # for `ServableModuleValidator` # not setting version as re-defined in App
Expand Down
20 changes: 17 additions & 3 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.onnx import ONNXProgram

_ONNX_AVAILABLE = RequirementCache("onnx")
_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript")

warning_cache = WarningCache()
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -1360,12 +1362,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None:
)

@torch.no_grad()
def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
def to_onnx(
self,
file_path: Union[str, Path, BytesIO, None] = None,
input_sample: Optional[Any] = None,
**kwargs: Any,
) -> Union["ONNXProgram", None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's technically correct, but since we've only one type, what do you think about Optional["ONNXProgram"] for readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I've changed it in the last commit. Thanks.

"""Saves the model in ONNX format.

Args:
file_path: The path of the file the onnx model should be saved to.
file_path: The path of the file the onnx model should be saved to. Default: None (no file saved).
input_sample: An input for tracing. Default: None (Use self.example_input_array)

**kwargs: Will be passed to torch.onnx.export function.

Example::
Expand All @@ -1386,6 +1394,11 @@ def forward(self, x):
if not _ONNX_AVAILABLE:
raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.")

if kwargs.get("dynamo", False) and not _ONNXSCRIPT_AVAILABLE:
raise ModuleNotFoundError(
f"`{type(self).__name__}.to_onnx(dynamo=True)` requires `onnxscript` to be installed."
)

mode = self.training

if input_sample is None:
Expand All @@ -1402,8 +1415,9 @@ def forward(self, x):
file_path = str(file_path) if isinstance(file_path, Path) else file_path
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
# BytesIO does work, too.
torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
self.train(mode)
return ret

@torch.no_grad()
def to_torchscript(
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE

_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")
Expand All @@ -42,6 +42,7 @@ def _runif_reasons(
psutil: bool = False,
sklearn: bool = False,
onnx: bool = False,
onnxscript: bool = False,
) -> tuple[list[str], dict[str, bool]]:
"""Construct reasons for pytest skipif.

Expand All @@ -64,6 +65,7 @@ def _runif_reasons(
psutil: Require that psutil is installed.
sklearn: Require that scikit-learn is installed.
onnx: Require that onnx is installed.
onnxscript: Require that onnxscript is installed.

"""

Expand Down Expand Up @@ -96,4 +98,7 @@ def _runif_reasons(
if onnx and not _ONNX_AVAILABLE:
reasons.append("onnx")

if onnxscript and not _ONNXSCRIPT_AVAILABLE:
reasons.append("onnxscript")

return reasons, kwargs
10 changes: 10 additions & 0 deletions tests/tests_pytorch/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,13 @@ def to_numpy(tensor):

# compare ONNX Runtime and PyTorch results
assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)


@RunIf(onnx=True, min_torch="2.7.0", dynamo=True, onnxscript=True)
def test_model_return_type():
model = BoringModel()
model.example_input_array = torch.randn((1, 32))
model.eval()

ret = model.to_onnx(dynamo=True)
assert isinstance(ret, torch.onnx.ONNXProgram)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be worth adding a check that verifies the output of the exported ONNX model matches the PyTorch model’s output (within tolerance)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could use the existed test test_if_inference_output_is_valid for this purpose? I added an parametrized dynamo to that unittest in the last commit as well.

Thank you ~.

Loading