|
1 | 1 | #pragma once
|
2 | 2 | #include <array>
|
| 3 | +#include <numpy/arrayobject.h> |
| 4 | + |
3 | 5 | #include "../pyref.hpp"
|
| 6 | +#include "../numpy_proxy.hpp" |
| 7 | +#include "../py_converter.hpp" |
| 8 | +#include "./vector.hpp" |
4 | 9 |
|
5 | 10 | namespace cpp2py {
|
6 | 11 |
|
7 |
| - template <typename T, size_t R> struct py_converter<std::array<T, R>> { |
8 |
| - // -------------------------------------- |
| 12 | + template <typename T, size_t R> numpy_proxy make_numpy_proxy_from_heap_array(std::array<T, R> *arr_heap) { |
9 | 13 |
|
10 |
| - static PyObject *c2py(std::array<T, R> const &v) { |
11 |
| - PyObject *list = PyList_New(0); |
12 |
| - for (auto const &x : v) { |
13 |
| - pyref y = py_converter<T>::c2py(x); |
14 |
| - if (y.is_null() or (PyList_Append(list, y) == -1)) { |
15 |
| - Py_DECREF(list); |
16 |
| - return NULL; |
17 |
| - } // error |
18 |
| - } |
19 |
| - return list; |
| 14 | + auto delete_pycapsule = [](PyObject *capsule) { |
| 15 | + auto *ptr = static_cast<std::array<T, R> *>(PyCapsule_GetPointer(capsule, "guard")); |
| 16 | + delete ptr; |
| 17 | + }; |
| 18 | + PyObject *capsule = PyCapsule_New(arr_heap, "guard", delete_pycapsule); |
| 19 | + |
| 20 | + return numpy_proxy{1, // rank |
| 21 | + npy_type<std::decay_t<T>>, |
| 22 | + (void *)arr_heap->data(), |
| 23 | + std::is_const_v<T>, |
| 24 | + std::vector<long>{long(R)}, // extents |
| 25 | + std::vector<long>{sizeof(T)}, // strides |
| 26 | + capsule}; |
| 27 | + }; |
| 28 | + |
| 29 | + template <typename T, size_t R> numpy_proxy make_numpy_proxy_from_array(std::array<T, R> const &arr) { |
| 30 | + |
| 31 | + if constexpr (has_npy_type<T>) { |
| 32 | + auto *arr_heap = new std::array<T, R>{arr}; |
| 33 | + return make_numpy_proxy_from_heap_array(arr_heap); |
| 34 | + } else { |
| 35 | + auto *arr_heap = new std::array<pyref, R>{}; |
| 36 | + std::transform(begin(arr), end(arr), begin(*arr_heap), [](T const &x) { return py_converter<std::decay_t<T>>::c2py(x); }); |
| 37 | + return make_numpy_proxy_from_heap_array(arr_heap); |
20 | 38 | }
|
| 39 | + } |
| 40 | + |
| 41 | + // Make a new array from numpy view |
| 42 | + template <typename T, size_t R> std::array<T, R> make_array_from_numpy_proxy(numpy_proxy const &p) { |
| 43 | + EXPECTS(p.extents.size() == 1); |
| 44 | + EXPECTS(p.extents[0] == R); |
| 45 | + |
| 46 | + std::array<T, R> arr; |
| 47 | + |
| 48 | + if (p.element_type == npy_type<pyref>) { |
| 49 | + auto *data = static_cast<pyref *>(p.data); |
| 50 | + std::transform(data, data + R, begin(arr), [](PyObject *o) { return py_converter<std::decay_t<T>>::py2c(o); }); |
| 51 | + } else { |
| 52 | + EXPECTS(p.strides == std::vector<long>{sizeof(T)}); |
| 53 | + T *data = static_cast<T *>(p.data); |
| 54 | + std::copy(data, data + R, begin(arr)); |
| 55 | + } |
| 56 | + |
| 57 | + return arr; |
| 58 | + } |
| 59 | + |
| 60 | + // -------------------------------------- |
| 61 | + |
| 62 | + template <typename T, size_t R> struct py_converter<std::array<T, R>> { |
| 63 | + |
| 64 | + static PyObject *c2py(std::array<T, R> const &a) { return make_numpy_proxy_from_array(a).to_python(); } |
21 | 65 |
|
22 | 66 | // --------------------------------------
|
23 | 67 |
|
24 | 68 | static bool is_convertible(PyObject *ob, bool raise_exception) {
|
25 |
| - if (!PySequence_Check(ob)) goto _false; |
26 |
| - { |
27 |
| - pyref seq = PySequence_Fast(ob, "expected a sequence"); |
28 |
| - int len = PySequence_Size(ob); |
29 |
| - if (len != R) { |
30 |
| - if (raise_exception) { |
31 |
| - auto s = std::string{"Convertion to std::array<T, R> failed : the length of the sequence ( = "} + std::to_string(len) |
32 |
| - + " does not match R = " + std::to_string(R); |
33 |
| - PyErr_SetString(PyExc_TypeError, s.c_str()); |
34 |
| - } |
| 69 | + _import_array(); |
| 70 | + |
| 71 | + // Special case: 1-d ndarray of builtin type |
| 72 | + if (PyArray_Check(ob)) { |
| 73 | + PyArrayObject *arr = (PyArrayObject *)(ob); |
| 74 | +#ifdef PYTHON_NUMPY_VERSION_LT_17 |
| 75 | + int rank = arr->nd; |
| 76 | +#else |
| 77 | + int rank = PyArray_NDIM(arr); |
| 78 | +#endif |
| 79 | + if (PyArray_TYPE(arr) == npy_type<T> and rank == 1) return true; |
| 80 | + } |
| 81 | + |
| 82 | + if (!PySequence_Check(ob)) { |
| 83 | + if (raise_exception) { PyErr_SetString(PyExc_TypeError, ("Cannot convert "s + to_string(ob) + " to std::array as it is not a sequence"s).c_str()); } |
| 84 | + return false; |
| 85 | + } |
| 86 | + |
| 87 | + pyref seq = PySequence_Fast(ob, "expected a sequence"); |
| 88 | + int len = PySequence_Size(ob); |
| 89 | + if (len != R) { |
| 90 | + if (raise_exception) { |
| 91 | + auto s = std::string{"Convertion to std::array<T, R> failed : the length of the sequence ( = "} + std::to_string(len) |
| 92 | + + " does not match R = " + std::to_string(R); |
| 93 | + PyErr_SetString(PyExc_TypeError, s.c_str()); |
| 94 | + } |
| 95 | + return false; |
| 96 | + } |
| 97 | + for (int i = 0; i < len; i++) { |
| 98 | + if (!py_converter<std::decay_t<T>>::is_convertible(PySequence_Fast_GET_ITEM((PyObject *)seq, i), raise_exception)) { |
| 99 | + if (PyErr_Occurred()) PyErr_Print(); |
35 | 100 | return false;
|
36 | 101 | }
|
37 |
| - for (int i = 0; i < len; i++) |
38 |
| - if (!py_converter<T>::is_convertible(PySequence_Fast_GET_ITEM((PyObject *)seq, i), raise_exception)) goto _false; // borrowed ref |
39 |
| - |
40 |
| - return true; |
41 | 102 | }
|
42 |
| - _false: |
43 |
| - if (raise_exception) { PyErr_SetString(PyExc_TypeError, ("Cannot convert "s + to_string(ob) + " to std::array"s).c_str()); } |
44 |
| - return false; |
| 103 | + return true; |
45 | 104 | }
|
46 | 105 |
|
47 | 106 | // --------------------------------------
|
48 | 107 |
|
49 | 108 | static std::array<T, R> py2c(PyObject *ob) {
|
50 |
| - pyref seq = PySequence_Fast(ob, "expected a sequence"); |
| 109 | + _import_array(); |
| 110 | + |
| 111 | + // Special case: 1-d ndarray of builtin type |
| 112 | + if (PyArray_Check(ob)) { |
| 113 | + PyArrayObject *arr = (PyArrayObject *)(ob); |
| 114 | +#ifdef PYTHON_NUMPY_VERSION_LT_17 |
| 115 | + int rank = arr->nd; |
| 116 | +#else |
| 117 | + int rank = PyArray_NDIM(arr); |
| 118 | +#endif |
| 119 | + if (rank == 1) return make_array_from_numpy_proxy<T, R>(make_numpy_proxy(ob)); |
| 120 | + } |
| 121 | + |
| 122 | + ASSERT(PySequence_Check(ob)); |
51 | 123 | std::array<T, R> res;
|
52 |
| - for (int i = 0; i < R; i++) res[i] = py_converter<T>::py2c(PySequence_Fast_GET_ITEM((PyObject *)seq, i)); // borrowed ref |
| 124 | + pyref seq = PySequence_Fast(ob, "expected a sequence"); |
| 125 | + for (int i = 0; i < R; i++) res[i] = py_converter<std::decay_t<T>>::py2c(PySequence_Fast_GET_ITEM((PyObject *)seq, i)); // borrowed ref |
53 | 126 | return res;
|
54 | 127 | }
|
55 | 128 | };
|
|
0 commit comments