|
1 | 1 | #pragma once
|
2 | 2 | #include <array>
|
| 3 | +#include <numpy/arrayobject.h> |
| 4 | + |
| 5 | +#include "../numpy_proxy.hpp" |
| 6 | +#include "../py_converter.hpp" |
| 7 | +#include "./vector.hpp" |
3 | 8 |
|
4 | 9 | namespace cpp2py {
|
5 | 10 |
|
6 |
| - template <typename T, size_t R> struct py_converter<std::array<T, R>> { |
7 |
| - // -------------------------------------- |
| 11 | + template <typename T, size_t R> numpy_proxy make_numpy_proxy_from_heap_array(std::array<T, R> *arr_heap) { |
8 | 12 |
|
9 |
| - static PyObject *c2py(std::array<T, R> const &v) { |
10 |
| - PyObject *list = PyList_New(0); |
11 |
| - for (auto const &x : v) { |
12 |
| - pyref y = py_converter<T>::c2py(x); |
13 |
| - if (y.is_null() or (PyList_Append(list, y) == -1)) { |
14 |
| - Py_DECREF(list); |
15 |
| - return NULL; |
16 |
| - } // error |
17 |
| - } |
18 |
| - return list; |
| 13 | + auto delete_pycapsule = [](PyObject *capsule) { |
| 14 | + auto *ptr = static_cast<std::array<T, R> *>(PyCapsule_GetPointer(capsule, "guard")); |
| 15 | + delete ptr; |
| 16 | + }; |
| 17 | + PyObject *capsule = PyCapsule_New(arr_heap, "guard", delete_pycapsule); |
| 18 | + |
| 19 | + return numpy_proxy{1, // rank |
| 20 | + npy_type<std::decay_t<T>>, |
| 21 | + (void *)arr_heap->data(), |
| 22 | + std::is_const_v<T>, |
| 23 | + std::vector<long>{long(R)}, // extents |
| 24 | + std::vector<long>{sizeof(T)}, // strides |
| 25 | + capsule}; |
| 26 | + }; |
| 27 | + |
| 28 | + template <typename T, size_t R> numpy_proxy make_numpy_proxy_from_array(std::array<T, R> const &arr) { |
| 29 | + |
| 30 | + if constexpr (has_npy_type<T>) { |
| 31 | + auto *arr_heap = new std::array<T, R>{arr}; |
| 32 | + return make_numpy_proxy_from_heap_array(arr_heap); |
| 33 | + } else { |
| 34 | + auto *arr_heap = new std::array<pyref, R>{}; |
| 35 | + std::transform(begin(arr), end(arr), begin(*arr_heap), [](T const &x) { return py_converter<std::decay_t<T>>::c2py(x); }); |
| 36 | + return make_numpy_proxy_from_heap_array(arr_heap); |
19 | 37 | }
|
| 38 | + } |
| 39 | + |
| 40 | + // Make a new array from numpy view |
| 41 | + template <typename T, size_t R> std::array<T, R> make_array_from_numpy_proxy(numpy_proxy const &p) { |
| 42 | + EXPECTS(p.extents.size() == 1); |
| 43 | + EXPECTS(p.extents[0] == R); |
| 44 | + |
| 45 | + std::array<T, R> arr; |
| 46 | + |
| 47 | + if (p.element_type == npy_type<pyref>) { |
| 48 | + auto *data = static_cast<pyref *>(p.data); |
| 49 | + std::transform(data, data + R, begin(arr), [](PyObject *o) { return py_converter<std::decay_t<T>>::py2c(o); }); |
| 50 | + } else { |
| 51 | + EXPECTS(p.strides == std::vector<long>{sizeof(T)}); |
| 52 | + T *data = static_cast<T *>(p.data); |
| 53 | + std::copy(data, data + R, begin(arr)); |
| 54 | + } |
| 55 | + |
| 56 | + return arr; |
| 57 | + } |
| 58 | + |
| 59 | + // -------------------------------------- |
| 60 | + |
| 61 | + template <typename T, size_t R> struct py_converter<std::array<T, R>> { |
| 62 | + |
| 63 | + static PyObject *c2py(std::array<T, R> const &a) { return make_numpy_proxy_from_array(a).to_python(); } |
20 | 64 |
|
21 | 65 | // --------------------------------------
|
22 | 66 |
|
23 | 67 | static bool is_convertible(PyObject *ob, bool raise_exception) {
|
24 |
| - if (!PySequence_Check(ob)) goto _false; |
25 |
| - { |
26 |
| - pyref seq = PySequence_Fast(ob, "expected a sequence"); |
27 |
| - int len = PySequence_Size(ob); |
28 |
| - if (len != R) { |
29 |
| - if (raise_exception) { |
30 |
| - auto s = std::string{"Convertion to std::array<T, R> failed : the length of the sequence ( = "} + std::to_string(len) |
31 |
| - + " does not match R = " + std::to_string(R); |
32 |
| - PyErr_SetString(PyExc_TypeError, s.c_str()); |
33 |
| - } |
| 68 | + _import_array(); |
| 69 | + |
| 70 | + // Special case: 1-d ndarray of builtin type |
| 71 | + if (PyArray_Check(ob)) { |
| 72 | + PyArrayObject *arr = (PyArrayObject *)(ob); |
| 73 | +#ifdef PYTHON_NUMPY_VERSION_LT_17 |
| 74 | + int rank = arr->nd; |
| 75 | +#else |
| 76 | + int rank = PyArray_NDIM(arr); |
| 77 | +#endif |
| 78 | + if (PyArray_TYPE(arr) == npy_type<T> and rank == 1) return true; |
| 79 | + } |
| 80 | + |
| 81 | + if (!PySequence_Check(ob)) { |
| 82 | + if (raise_exception) { PyErr_SetString(PyExc_TypeError, "Cannot convert a non-sequence to std::array"); } |
| 83 | + return false; |
| 84 | + } |
| 85 | + |
| 86 | + pyref seq = PySequence_Fast(ob, "expected a sequence"); |
| 87 | + int len = PySequence_Size(ob); |
| 88 | + if (len != R) { |
| 89 | + if (raise_exception) { |
| 90 | + auto s = std::string{"Convertion to std::array<T, R> failed : the length of the sequence ( = "} + std::to_string(len) |
| 91 | + + " does not match R = " + std::to_string(R); |
| 92 | + PyErr_SetString(PyExc_TypeError, s.c_str()); |
| 93 | + } |
| 94 | + return false; |
| 95 | + } |
| 96 | + for (int i = 0; i < len; i++) { |
| 97 | + if (!py_converter<std::decay_t<T>>::is_convertible(PySequence_Fast_GET_ITEM((PyObject *)seq, i), raise_exception)) { |
| 98 | + if (PyErr_Occurred()) PyErr_Print(); |
34 | 99 | return false;
|
35 | 100 | }
|
36 |
| - for (int i = 0; i < len; i++) |
37 |
| - if (!py_converter<T>::is_convertible(PySequence_Fast_GET_ITEM((PyObject *)seq, i), raise_exception)) goto _false; // borrowed ref |
38 |
| - |
39 |
| - return true; |
40 | 101 | }
|
41 |
| - _false: |
42 |
| - if (raise_exception) { PyErr_SetString(PyExc_TypeError, "Cannot convert to std::array"); } |
43 |
| - return false; |
| 102 | + return true; |
44 | 103 | }
|
45 | 104 |
|
46 | 105 | // --------------------------------------
|
47 | 106 |
|
48 | 107 | static std::array<T, R> py2c(PyObject *ob) {
|
49 |
| - pyref seq = PySequence_Fast(ob, "expected a sequence"); |
| 108 | + _import_array(); |
| 109 | + |
| 110 | + // Special case: 1-d ndarray of builtin type |
| 111 | + if (PyArray_Check(ob)) { |
| 112 | + PyArrayObject *arr = (PyArrayObject *)(ob); |
| 113 | +#ifdef PYTHON_NUMPY_VERSION_LT_17 |
| 114 | + int rank = arr->nd; |
| 115 | +#else |
| 116 | + int rank = PyArray_NDIM(arr); |
| 117 | +#endif |
| 118 | + if (rank == 1) return make_array_from_numpy_proxy<T, R>(make_numpy_proxy(ob)); |
| 119 | + } |
| 120 | + |
| 121 | + ASSERT(PySequence_Check(ob)); |
50 | 122 | std::array<T, R> res;
|
51 |
| - for (int i = 0; i < R; i++) res[i] = py_converter<T>::py2c(PySequence_Fast_GET_ITEM((PyObject *)seq, i)); // borrowed ref |
| 123 | + pyref seq = PySequence_Fast(ob, "expected a sequence"); |
| 124 | + for (int i = 0; i < R; i++) res[i] = py_converter<std::decay_t<T>>::py2c(PySequence_Fast_GET_ITEM((PyObject *)seq, i)); // borrowed ref |
52 | 125 | return res;
|
53 | 126 | }
|
54 | 127 | };
|
|
0 commit comments