Skip to content

Commit dd350fe

Browse files
authored
Per-module precomputed constants, improve enum casting perf. (#1184)
Creating runtime constants, particularly strings, has a significant cost on Python. This commit adds a per-module stash of precomputed constants, allowing us to avoid this cost and improve the performance of casting enums from Python in the process.
1 parent d442daa commit dd350fe

File tree

9 files changed

+216
-17
lines changed

9 files changed

+216
-17
lines changed

include/nanobind/nb_defs.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,22 @@
2828
# define NB_INLINE __forceinline
2929
# define NB_NOINLINE __declspec(noinline)
3030
# define NB_INLINE_LAMBDA
31+
# define NB_NOUNROLL
3132
#else
3233
# define NB_EXPORT __attribute__ ((visibility("default")))
3334
# define NB_IMPORT NB_EXPORT
3435
# define NB_INLINE inline __attribute__((always_inline))
3536
# define NB_NOINLINE __attribute__((noinline))
3637
# if defined(__clang__)
3738
# define NB_INLINE_LAMBDA __attribute__((always_inline))
39+
# define NB_NOUNROLL _Pragma("nounroll")
3840
# else
3941
# define NB_INLINE_LAMBDA
42+
# if defined(__GNUC__)
43+
# define NB_NOUNROLL _Pragma("GCC unroll 0")
44+
# else
45+
# define NB_NOUNROLL
46+
# endif
4047
# endif
4148
#endif
4249

@@ -202,12 +209,14 @@
202209
X(const X &) = delete; \
203210
X &operator=(const X &) = delete;
204211

212+
#define NB_MOD_STATE_SIZE 80
213+
205214
// Helper macros to ensure macro arguments are expanded before token pasting/stringification
206215
#define NB_MODULE_IMPL(name, variable) NB_MODULE_IMPL2(name, variable)
207216
#define NB_MODULE_IMPL2(name, variable) \
208217
static void nanobind_##name##_exec_impl(nanobind::module_); \
209218
static int nanobind_##name##_exec(PyObject *m) { \
210-
nanobind::detail::init(NB_DOMAIN_STR); \
219+
nanobind::detail::nb_module_exec(NB_DOMAIN_STR, m); \
211220
try { \
212221
nanobind_##name##_exec_impl( \
213222
nanobind::borrow<nanobind::module_>(m)); \
@@ -227,8 +236,9 @@
227236
NB_MODULE_SLOTS_2 \
228237
}; \
229238
static struct PyModuleDef nanobind_##name##_module = { \
230-
PyModuleDef_HEAD_INIT, #name, nullptr, 0, nullptr, \
231-
nanobind_##name##_slots, nullptr, nullptr, nullptr \
239+
PyModuleDef_HEAD_INIT, #name, nullptr, NB_MOD_STATE_SIZE, nullptr, \
240+
nanobind_##name##_slots, nanobind::detail::nb_module_traverse, \
241+
nanobind::detail::nb_module_clear, nanobind::detail::nb_module_free \
232242
}; \
233243
extern "C" [[maybe_unused]] NB_EXPORT PyObject *PyInit_##name(void); \
234244
extern "C" PyObject *PyInit_##name(void) { \

include/nanobind/nb_lib.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,17 @@
99

1010
NAMESPACE_BEGIN(NB_NAMESPACE)
1111

12+
NAMESPACE_BEGIN(dlpack)
13+
14+
// The version of DLPack that is supported by libnanobind
15+
static constexpr uint32_t major_version = 0;
16+
static constexpr uint32_t minor_version = 0;
17+
1218
// Forward declarations for types in ndarray.h (1)
13-
namespace dlpack { struct dltensor; struct dtype; }
19+
struct dltensor;
20+
struct dtype;
21+
22+
NAMESPACE_END(dlpack)
1423

1524
NAMESPACE_BEGIN(detail)
1625

@@ -107,7 +116,10 @@ NB_CORE void raise_next_overload_if_null(void *p);
107116

108117
// ========================================================================
109118

110-
NB_CORE void init(const char *domain);
119+
NB_CORE void nb_module_exec(const char *domain, PyObject *m);
120+
NB_CORE int nb_module_traverse(PyObject *m, visitproc visit, void *arg);
121+
NB_CORE int nb_module_clear(PyObject *m);
122+
NB_CORE void nb_module_free(void *m);
111123

112124
// ========================================================================
113125

src/nb_enum.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ bool enum_from_python(const std::type_info *tp, PyObject *o, int64_t *out, uint8
190190
return false;
191191

192192
if ((t->flags & (uint32_t) enum_flags::is_flag) != 0 && Py_TYPE(o) == t->type_py) {
193-
PyObject *value_o = PyObject_GetAttrString(o, "value");
193+
PyObject *value_o =
194+
PyObject_GetAttr(o, static_pyobjects[pyobj_name::value_str]);
194195
if (value_o == nullptr) {
195196
PyErr_Clear();
196197
return false;

src/nb_internals.cpp

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,71 @@ void default_exception_translator(const std::exception_ptr &p, void *) {
164164
nb_internals *internals = nullptr;
165165
PyTypeObject *nb_meta_cache = nullptr;
166166

167+
168+
static const char* interned_c_strs[pyobj_name::string_count] {
169+
"value",
170+
"copy",
171+
"from_dlpack",
172+
"__dlpack__",
173+
"max_version",
174+
"dl_device",
175+
};
176+
177+
PyObject **static_pyobjects = nullptr;
178+
179+
static bool init_pyobjects(PyObject* m) {
180+
PyObject** pyobjects = (PyObject**) PyModule_GetState(m);
181+
if (!pyobjects)
182+
return false;
183+
184+
NB_NOUNROLL
185+
for (int i = 0; i < pyobj_name::string_count; ++i)
186+
pyobjects[i] = PyUnicode_InternFromString(interned_c_strs[i]);
187+
188+
pyobjects[pyobj_name::copy_tpl] =
189+
PyTuple_Pack(1, pyobjects[pyobj_name::copy_str]);
190+
pyobjects[pyobj_name::max_version_tpl] =
191+
PyTuple_Pack(1, pyobjects[pyobj_name::max_version_str]);
192+
193+
PyObject* one = PyLong_FromLong(1);
194+
PyObject* zero = PyLong_FromLong(0);
195+
pyobjects[pyobj_name::dl_cpu_tpl] = PyTuple_Pack(2, one, zero);
196+
Py_DECREF(zero);
197+
Py_DECREF(one);
198+
199+
PyObject* major = PyLong_FromLong(dlpack::major_version);
200+
PyObject* minor = PyLong_FromLong(dlpack::minor_version);
201+
pyobjects[pyobj_name::dl_version_tpl] = PyTuple_Pack(2, major, minor);
202+
Py_DECREF(minor);
203+
Py_DECREF(major);
204+
205+
static_pyobjects = pyobjects;
206+
207+
return true;
208+
}
209+
210+
NB_NOINLINE int nb_module_traverse(PyObject *m, visitproc visit, void *arg) {
211+
PyObject** pyobjects = (PyObject**) PyModule_GetState(m);
212+
NB_NOUNROLL
213+
for (int i = 0; i < pyobj_name::total_count; ++i)
214+
Py_VISIT(pyobjects[i]);
215+
return 0;
216+
}
217+
218+
NB_NOINLINE int nb_module_clear(PyObject *m) {
219+
PyObject** pyobjects = (PyObject**) PyModule_GetState(m);
220+
NB_NOUNROLL
221+
for (int i = 0; i < pyobj_name::total_count; ++i)
222+
Py_CLEAR(pyobjects[i]);
223+
return 0;
224+
}
225+
226+
void nb_module_free(void *m) {
227+
// Allow nanobind_##name##_exec to omit calling nb_module_clear on error.
228+
(void) nb_module_clear((PyObject *) m);
229+
}
230+
231+
167232
static bool is_alive_value = false;
168233
static bool *is_alive_ptr = &is_alive_value;
169234
bool is_alive() noexcept { return *is_alive_ptr; }
@@ -317,29 +382,34 @@ static void internals_cleanup() {
317382
#endif
318383
}
319384

320-
NB_NOINLINE void init(const char *name) {
385+
NB_NOINLINE void nb_module_exec(const char *name, PyObject *m) {
321386
if (internals)
322387
return;
323388

389+
check(init_pyobjects(m), "nanobind::detail::nb_module_exec(): "
390+
"could not initialize module state!");
391+
324392
#if defined(PYPY_VERSION)
325393
PyObject *dict = PyEval_GetBuiltins();
326394
#elif PY_VERSION_HEX < 0x03090000
327395
PyObject *dict = PyInterpreterState_GetDict(_PyInterpreterState_Get());
328396
#else
329397
PyObject *dict = PyInterpreterState_GetDict(PyInterpreterState_Get());
330398
#endif
331-
check(dict, "nanobind::detail::init(): could not access internals dictionary!");
399+
check(dict, "nanobind::detail::nb_module_exec(): "
400+
"could not access internals dictionary!");
332401

333402
PyObject *key = PyUnicode_FromFormat("__nb_internals_%s_%s__",
334403
abi_tag(), name ? name : "");
335-
check(key, "nanobind::detail::init(): could not create dictionary key!");
404+
check(key, "nanobind::detail::nb_module_exec(): "
405+
"could not create dictionary key!");
336406

337407
PyObject *capsule = dict_get_item_ref_or_fail(dict, key);
338408
if (capsule) {
339409
Py_DECREF(key);
340410
internals = (nb_internals *) PyCapsule_GetPointer(capsule, "nb_internals");
341-
check(internals,
342-
"nanobind::detail::internals_fetch(): capsule pointer is NULL!");
411+
check(internals, "nanobind::detail::nb_module_exec(): "
412+
"capsule pointer is NULL!");
343413
nb_meta_cache = internals->nb_meta;
344414
is_alive_ptr = internals->is_alive_ptr;
345415
Py_DECREF(capsule);
@@ -381,7 +451,7 @@ NB_NOINLINE void init(const char *name) {
381451

382452
check(p->nb_module && p->nb_meta && p->nb_type_dict && p->nb_func &&
383453
p->nb_method && p->nb_bound_method,
384-
"nanobind::detail::init(): initialization failed!");
454+
"nanobind::detail::nb_module_exec(): initialization failed!");
385455

386456
#if PY_VERSION_HEX < 0x03090000
387457
p->nb_func->tp_flags |= NB_HAVE_VECTORCALL;
@@ -476,7 +546,7 @@ NB_NOINLINE void init(const char *name) {
476546
capsule = PyCapsule_New(p, "nb_internals", nullptr);
477547
int rv = PyDict_SetItem(dict, key, capsule);
478548
check(!rv && capsule,
479-
"nanobind::detail::init(): capsule creation failed!");
549+
"nanobind::detail::nb_module_exec(): capsule creation failed!");
480550
Py_DECREF(capsule);
481551
Py_DECREF(key);
482552
internals = p;

src/nb_internals.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,30 @@ struct nb_internals {
420420
size_t shard_count = 1;
421421
};
422422

423+
// Names for the PyObject* entries in the per-module state array.
424+
// These names are scoped, but will implicitly convert to int.
425+
struct pyobj_name {
426+
enum : int {
427+
value_str = 0, // string "value"
428+
copy_str, // string "copy"
429+
from_dlpack_str, // string "from_dlpack"
430+
dunder_dlpack_str, // string "__dlpack__"
431+
max_version_str, // string "max_version"
432+
dl_device_str, // string "dl_device"
433+
string_count,
434+
435+
copy_tpl = string_count, // tuple ("copy")
436+
max_version_tpl, // tuple ("max_version")
437+
dl_cpu_tpl, // tuple (1, 0), which corresponds to nb::device::cpu
438+
dl_version_tpl, // tuple (dlpack::major_version, dlpack::minor_version)
439+
total_count
440+
};
441+
};
442+
443+
static_assert(pyobj_name::total_count * sizeof(PyObject*) == NB_MOD_STATE_SIZE);
444+
445+
extern PyObject **static_pyobjects;
446+
423447
/// Convenience macro to potentially access cached functions
424448
#if defined(Py_LIMITED_API)
425449
# define NB_SLOT(type, name) internals->type##_##name

tests/inter_module.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Shared create_shared() {
44
return { 123 };
55
}
66

7-
bool check_shared(const Shared &shared) {
8-
return shared.value == 123;
7+
bool check_shared(const Shared &shared, int expected) {
8+
return shared.value == expected;
9+
}
10+
11+
void increment_shared(Shared &shared) {
12+
++shared.value;
913
}

tests/inter_module.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ struct EXPORT_SHARED Shared {
1111
};
1212

1313
extern EXPORT_SHARED Shared create_shared();
14-
extern EXPORT_SHARED bool check_shared(const Shared &shared);
14+
extern EXPORT_SHARED bool check_shared(const Shared &shared, int expected);
15+
extern EXPORT_SHARED void increment_shared(Shared &shared);

tests/test_inter_module.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,86 @@
44
import pytest
55
from common import xfail_on_pypy_darwin
66

7+
try:
8+
from concurrent import interpreters # Added in Python 3.14
9+
def needs_interpreters(x):
10+
return x
11+
except:
12+
needs_interpreters = pytest.mark.skip(reason="interpreters required")
13+
14+
715
@xfail_on_pypy_darwin
816
def test01_inter_module():
917
s = t1.create_shared()
10-
assert t2.check_shared(s)
18+
assert t2.check_shared(s, 123)
19+
t2.increment_shared(s)
20+
assert t2.check_shared(s, 124)
1121
with pytest.raises(TypeError) as excinfo:
1222
assert t3.check_shared(s)
1323
assert 'incompatible function arguments' in str(excinfo.value)
24+
25+
26+
@xfail_on_pypy_darwin
27+
def test02_reload_module():
28+
s1 = t1.create_shared()
29+
s2 = t1.create_shared()
30+
assert s2 is not s1
31+
assert type(s2) is type(s1)
32+
t2.increment_shared(s2)
33+
import importlib
34+
new_t1 = importlib.reload(t1)
35+
assert new_t1 is t1
36+
s3 = new_t1.create_shared()
37+
assert type(s3) is type(s1)
38+
new_t2 = importlib.reload(t2)
39+
assert new_t2 is t2
40+
s4 = new_t1.create_shared()
41+
assert type(s4) is type(s1)
42+
assert new_t2.check_shared(s2, 124)
43+
44+
45+
@xfail_on_pypy_darwin
46+
def test03_reimport_module():
47+
s1 = t1.create_shared()
48+
s2 = t1.create_shared()
49+
t2.increment_shared(s2)
50+
import sys
51+
del sys.modules['test_inter_module_1_ext']
52+
import test_inter_module_1_ext as new_t1
53+
assert new_t1 is not t1
54+
s3 = new_t1.create_shared()
55+
assert type(s3) is type(s1)
56+
del sys.modules['test_inter_module_2_ext']
57+
with pytest.warns(RuntimeWarning, match="'Shared' was already registered"):
58+
import test_inter_module_2_ext as new_t2
59+
assert new_t2 is not t2
60+
s4 = new_t1.create_shared()
61+
assert type(s4) is type(s1)
62+
assert new_t2.check_shared(s2, 124)
63+
64+
65+
def run():
66+
import sys
67+
if 'tests' not in sys.path[0]:
68+
import os
69+
builddir = sys.path[0]
70+
sys.path.insert(0, os.path.join(builddir, 'tests', 'Release'))
71+
sys.path.insert(0, os.path.join(builddir, 'tests', 'Debug'))
72+
sys.path.insert(0, os.path.join(builddir, 'tests'))
73+
import test_inter_module_1_ext as new_t1
74+
import test_inter_module_2_ext as new_t2
75+
success = True
76+
s = new_t1.create_shared()
77+
success &= new_t2.check_shared(s, 123)
78+
new_t2.increment_shared(s)
79+
success &= new_t2.check_shared(s, 124)
80+
return success
81+
82+
@needs_interpreters
83+
def test04_subinterpreters():
84+
assert run()
85+
interp = interpreters.create()
86+
with pytest.raises(interpreters.ExecutionFailed) as excinfo:
87+
assert interp.call(run)
88+
assert 'does not support loading in subinterpreters' in str(excinfo.value)
89+
interp.close()

tests/test_inter_module_2.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ namespace nb = nanobind;
66
NB_MODULE(test_inter_module_2_ext, m) {
77
nb::class_<Shared>(m, "Shared");
88
m.def("check_shared", &check_shared);
9+
m.def("increment_shared", &increment_shared);
910
}

0 commit comments

Comments
 (0)