@@ -14,7 +14,43 @@ use pyo3::{
1414
1515use crate :: npyffi:: * ;
1616
17- pub ( crate ) const MOD_NAME : & str = "numpy._core.multiarray" ;
17+ pub ( crate ) fn numpy_core_name ( py : Python < ' _ > ) -> PyResult < & ' static str > {
18+ static MOD_NAME : GILOnceCell < & ' static str > = GILOnceCell :: new ( ) ;
19+
20+ MOD_NAME
21+ . get_or_try_init ( py, || {
22+ // numpy 2 renamed to numpy._core
23+
24+ // strategy mirrored from https://github.com/pybind/pybind11/blob/af67e87393b0f867ccffc2702885eea12de063fc/include/pybind11/numpy.h#L175-L195
25+
26+ let numpy = PyModule :: import_bound ( py, "numpy" ) ?;
27+ let version_string = numpy. getattr ( "__version__" ) ?;
28+
29+ let numpy_lib = PyModule :: import_bound ( py, "numpy.lib" ) ?;
30+ let numpy_version = numpy_lib
31+ . getattr ( "NumpyVersion" ) ?
32+ . call1 ( ( version_string, ) ) ?;
33+ let major_version: u8 = numpy_version. getattr ( "major" ) ?. extract ( ) ?;
34+
35+ Ok ( if major_version >= 2 {
36+ "numpy._core"
37+ } else {
38+ "numpy.core"
39+ } )
40+ } )
41+ . copied ( )
42+ }
43+
44+ pub ( crate ) fn mod_name ( py : Python < ' _ > ) -> PyResult < & ' static str > {
45+ static MOD_NAME : GILOnceCell < String > = GILOnceCell :: new ( ) ;
46+ MOD_NAME
47+ . get_or_try_init ( py, || {
48+ let numpy_core = numpy_core_name ( py) ?;
49+ Ok ( format ! ( "{}.multiarray" , numpy_core) )
50+ } )
51+ . map ( String :: as_str)
52+ }
53+
1854const CAPSULE_NAME : & str = "_ARRAY_API" ;
1955
2056/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
@@ -49,7 +85,7 @@ impl PyArrayAPI {
4985 unsafe fn get < ' py > ( & self , py : Python < ' py > , offset : isize ) -> * const * const c_void {
5086 let api = self
5187 . 0
52- . get_or_try_init ( py, || get_numpy_api ( py, MOD_NAME , CAPSULE_NAME ) )
88+ . get_or_try_init ( py, || get_numpy_api ( py, mod_name ( py ) ? , CAPSULE_NAME ) )
5389 . expect ( "Failed to access NumPy array API capsule" ) ;
5490
5591 api. offset ( offset)
0 commit comments