Skip to content

Commit 238b695

Browse files
authored
[stubgen] Added a new --exclude-values flag (#1185)
The flag forces all values to be rendered as ..., which is usually what you want in a .pyi file. The motivating use-case in JAX was version attributes, which are currently rendered as foo_version: int = 42 so every version bump will be unnecessarily reflected in the .pyis.
1 parent 0435861 commit 238b695

File tree

6 files changed

+38
-9
lines changed

6 files changed

+38
-9
lines changed

cmake/nanobind-config.cmake

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ endfunction()
590590
# ---------------------------------------------------------------------------
591591

592592
function (nanobind_add_stub name)
593-
cmake_parse_arguments(PARSE_ARGV 1 ARG "VERBOSE;INCLUDE_PRIVATE;EXCLUDE_DOCSTRINGS;INSTALL_TIME;RECURSIVE;EXCLUDE_FROM_ALL" "MODULE;COMPONENT;PATTERN_FILE;OUTPUT_PATH" "PYTHON_PATH;DEPENDS;MARKER_FILE;OUTPUT")
593+
cmake_parse_arguments(PARSE_ARGV 1 ARG "VERBOSE;INCLUDE_PRIVATE;EXCLUDE_DOCSTRINGS;EXCLUDE_VALUES;INSTALL_TIME;RECURSIVE;EXCLUDE_FROM_ALL" "MODULE;COMPONENT;PATTERN_FILE;OUTPUT_PATH" "PYTHON_PATH;DEPENDS;MARKER_FILE;OUTPUT")
594594

595595
if (EXISTS ${NB_DIR}/src/stubgen.py)
596596
set(NB_STUBGEN "${NB_DIR}/src/stubgen.py")
@@ -614,6 +614,10 @@ function (nanobind_add_stub name)
614614
list(APPEND NB_STUBGEN_ARGS -D)
615615
endif()
616616

617+
if (ARG_EXCLUDE_VALUES)
618+
list(APPEND NB_STUBGEN_ARGS --exclude-values)
619+
endif()
620+
617621
if (ARG_RECURSIVE)
618622
list(APPEND NB_STUBGEN_ARGS -r)
619623
endif()

docs/typing.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ The program has the following command line options:
540540
.. code-block:: text
541541
542542
usage: python -m nanobind.stubgen [-h] [-o FILE] [-O PATH] [-i PATH] [-m MODULE]
543-
[-r] [-M FILE] [-P] [-D] [-q]
543+
[-r] [-M FILE] [-P] [-D] [--exclude-values] [-q]
544544
545545
Generate stubs for nanobind-based extensions.
546546
@@ -559,6 +559,7 @@ The program has the following command line options:
559559
-P, --include-private include private members (with single leading or
560560
trailing underscore)
561561
-D, --exclude-docstrings exclude docstrings from the generated stub
562+
--exclude-values force the use of ... for values
562563
-q, --quiet do not generate any output in the absence of failures
563564
564565

src/stubgen.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,12 +1016,14 @@ def expr_str(self, e: Any, abbrev: bool = True) -> Optional[str]:
10161016
"""
10171017
tp = type(e)
10181018
if issubclass(tp, (bool, int, type(None), type(builtins.Ellipsis))):
1019-
return repr(e)
1019+
s = repr(e)
1020+
if len(s) < self.max_expr_length or not abbrev:
1021+
return s
10201022
elif issubclass(tp, float):
10211023
s = repr(e)
10221024
if "inf" in s or "nan" in s:
1023-
return f"float('{s}')"
1024-
else:
1025+
s = f"float('{s}')"
1026+
if len(s) < self.max_expr_length or not abbrev:
10251027
return s
10261028
elif issubclass(tp, type) or typing.get_origin(e):
10271029
return self.type_str(e)
@@ -1041,13 +1043,17 @@ def expr_str(self, e: Any, abbrev: bool = True) -> Optional[str]:
10411043
tv = self.import_object("typing", "TypeVar")
10421044
s = f'{tv}("{e.__name__}"'
10431045
for v in getattr(e, "__constraints__", ()):
1044-
v = self.expr_str(v)
1046+
v = self.type_str(v)
10451047
assert v
10461048
s += ", " + v
1047-
for k in ["contravariant", "covariant", "bound", "infer_variance"]:
1049+
if v := getattr(e, "__bound__", None):
1050+
v = self.type_str(v)
1051+
assert v
1052+
s += ", bound=" + v
1053+
for k in ["contravariant", "covariant", "infer_variance"]:
10481054
v = getattr(e, f"__{k}__", None)
10491055
if v:
1050-
v = self.expr_str(v)
1056+
v = self.expr_str(v, abbrev=False)
10511057
if v is None:
10521058
return None
10531059
s += f", {k}=" + v
@@ -1335,6 +1341,14 @@ def parse_options(args: List[str]) -> argparse.Namespace:
13351341
help="exclude docstrings from the generated stub",
13361342
)
13371343

1344+
parser.add_argument(
1345+
"--exclude-values",
1346+
dest="exclude_values",
1347+
default=False,
1348+
action="store_true",
1349+
help="force the use of ... for values",
1350+
)
1351+
13381352
parser.add_argument(
13391353
"-q",
13401354
"--quiet",
@@ -1479,6 +1493,7 @@ def main(args: Optional[List[str]] = None) -> None:
14791493
recursive=opt.recursive,
14801494
include_docstrings=opt.include_docstrings,
14811495
include_private=opt.include_private,
1496+
max_expr_length=0 if opt.exclude_values else 50,
14821497
patterns=patterns,
14831498
output_file=file
14841499
)

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ foreach (NAME functions classes ndarray jax tensorflow stl enum typing make_iter
104104
set(EXTRA
105105
MARKER_FILE py.typed
106106
PATTERN_FILE "${CMAKE_CURRENT_SOURCE_DIR}/pattern_file.nb"
107+
EXCLUDE_VALUES
107108
)
108109
set(EXTRA_DEPENDS "${OUT_DIR}/py_stub_test.py")
109110
else()

tests/test_typing.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ NB_MODULE(test_typing_ext, m) {
114114
m.def("list_front", [](nb::list l) { return l[0]; },
115115
nb::sig("def list_front[T](arg: list[T], /) -> T"));
116116

117+
// Type variables with constraints and a bound.
118+
m.attr("T2") = nb::type_var("T2", "bound"_a = nb::type<Foo>());
119+
m.attr("T3") = nb::type_var("T3", *nb::make_tuple(nb::type<Foo>(), nb::type<Wrapper>()));
120+
117121
// Some statements that will be modified by the pattern file
118122
m.def("remove_me", []{});
119123
m.def("tweak_me", [](nb::object o) { return o; }, "prior docstring\nremains preserved");

tests/test_typing_ext.pyi.ref

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class CustomSignature(Iterable[int]):
4444
def value(self, value: Optional[int], /) -> None:
4545
"""docstring for setter"""
4646

47-
pytree: dict = {'a' : ('b', [123])}
47+
pytree: dict = ...
4848

4949
T = TypeVar("T", contravariant=True)
5050

@@ -63,6 +63,10 @@ class WrapperTypeParam[T]:
6363

6464
def list_front[T](arg: list[T], /) -> T: ...
6565

66+
T2 = TypeVar("T2", bound=Foo)
67+
68+
T3 = TypeVar("T3", Foo, Wrapper)
69+
6670
def tweak_me(arg: int):
6771
"""
6872
prior docstring

0 commit comments

Comments
 (0)