|
1 | 1 | #pragma once
|
2 |
| -//#include <vector> |
3 |
| -//#include <numeric> |
| 2 | +#include <vector> |
| 3 | +#include <numpy/arrayobject.h> |
| 4 | + |
| 5 | +#include "../numpy_proxy.hpp" |
4 | 6 |
|
5 | 7 | namespace cpp2py {
|
6 | 8 |
|
| 9 | + template <typename T> |
| 10 | + static void delete_pycapsule(PyObject *capsule) { |
| 11 | + auto *ptr = static_cast<std::unique_ptr<T[]> *>(PyCapsule_GetPointer(capsule, "guard")); |
| 12 | + delete ptr; |
| 13 | + } |
| 14 | + |
| 15 | + // Convert vector to numpy_proxy, WARNING: Deep Copy |
| 16 | + template <typename T> numpy_proxy make_numpy_proxy_from_vector(std::vector<T> const &v) { |
| 17 | + |
| 18 | + auto * data_ptr = new std::unique_ptr<T[]>{new T[v.size()]}; |
| 19 | + std::copy(begin(v), end(v), data_ptr->get()); |
| 20 | + auto capsule = PyCapsule_New(data_ptr, "guard", &delete_pycapsule<T>); |
| 21 | + |
| 22 | + return {1, // rank |
| 23 | + npy_type<std::remove_const_t<T>>, |
| 24 | + (void *)data_ptr->get(), |
| 25 | + std::is_const_v<T>, |
| 26 | + v_t{static_cast<long>(v.size())}, // extents |
| 27 | + v_t{sizeof(T)}, // strides |
| 28 | + capsule}; |
| 29 | + } |
| 30 | + |
| 31 | + // Make a new vector from numpy view |
| 32 | + template <typename T> |
| 33 | + std::vector<T> make_vector_from_numpy_proxy(numpy_proxy const &p) { |
| 34 | + //EXPECTS(p.extents.size() == 1); |
| 35 | + //EXPECTS(p.strides == v_t{sizeof(T)}); |
| 36 | + |
| 37 | + T * data = static_cast<T *>(p.data); |
| 38 | + long size = p.extents[0]; |
| 39 | + |
| 40 | + std::vector<T> v(size); |
| 41 | + std::copy(data, data + size, begin(v)); |
| 42 | + return v; |
| 43 | + } |
| 44 | + |
| 45 | + // -------------------------------------- |
| 46 | + |
7 | 47 | template <typename T> struct py_converter<std::vector<T>> {
|
8 |
| - |
9 |
| - // -------------------------------------- |
10 |
| - |
11 |
| - static PyObject *c2py(std::vector<T> const &v) { |
12 |
| - PyObject *list = PyList_New(0); |
13 |
| - for (auto const &x : v) { |
14 |
| - pyref y = py_converter<T>::c2py(x); |
15 |
| - if (y.is_null() or (PyList_Append(list, y) == -1)) { |
16 |
| - Py_DECREF(list); |
17 |
| - return NULL; |
18 |
| - } // error |
| 48 | + |
| 49 | + static PyObject *c2py(std::vector<T> const &v) { |
| 50 | + |
| 51 | + if constexpr (has_npy_type<T>) { |
| 52 | + return make_numpy_proxy_from_vector(v).to_python(); |
| 53 | + } |
| 54 | + else{ // Convert to Python List |
| 55 | + PyObject *list = PyList_New(0); |
| 56 | + for (auto const &x : v) { |
| 57 | + pyref y = py_converter<T>::c2py(x); |
| 58 | + if (y.is_null() or (PyList_Append(list, y) == -1)) { |
| 59 | + Py_DECREF(list); |
| 60 | + return NULL; |
| 61 | + } // error |
| 62 | + } |
| 63 | + return list; |
19 | 64 | }
|
20 |
| - return list; |
21 | 65 | }
|
22 | 66 |
|
23 |
| - // -------------------------------------- |
| 67 | + // -------------------------------------- |
24 | 68 |
|
25 |
| - static bool is_convertible(PyObject *ob, bool raise_exception) { |
26 |
| - if (!PySequence_Check(ob)) goto _false; |
27 |
| - { |
| 69 | + static bool is_convertible(PyObject *ob, bool raise_exception) { |
| 70 | + if (PySequence_Check(ob)) { |
28 | 71 | pyref seq = PySequence_Fast(ob, "expected a sequence");
|
29 | 72 | int len = PySequence_Size(ob);
|
30 | 73 | for (int i = 0; i < len; i++)
|
31 | 74 | if (!py_converter<T>::is_convertible(PySequence_Fast_GET_ITEM((PyObject *)seq, i), raise_exception)) goto _false; //borrowed ref
|
32 | 75 | return true;
|
| 76 | + } else if (PyArray_Check(ob)) { |
| 77 | + PyArrayObject *arr = (PyArrayObject *)(ob); |
| 78 | + if (PyArray_TYPE(arr) != npy_type<T>) goto _false; |
| 79 | +#ifdef PYTHON_NUMPY_VERSION_LT_17 |
| 80 | + int rank = arr->nd; |
| 81 | +#else |
| 82 | + int rank = PyArray_NDIM(arr); |
| 83 | +#endif |
| 84 | + if (rank != 1) goto _false; |
| 85 | + return true; |
33 | 86 | }
|
34 | 87 | _false:
|
35 | 88 | if (raise_exception) { PyErr_SetString(PyExc_TypeError, "Cannot convert to std::vector"); }
|
36 | 89 | return false;
|
37 | 90 | }
|
38 | 91 |
|
39 |
| - // -------------------------------------- |
40 |
| - |
41 |
| - static std::vector<T> py2c(PyObject *ob) { |
42 |
| - pyref seq = PySequence_Fast(ob, "expected a sequence"); |
43 |
| - std::vector<T> res; |
44 |
| - int len = PySequence_Size(ob); |
45 |
| - for (int i = 0; i < len; i++) res.push_back(py_converter<T>::py2c(PySequence_Fast_GET_ITEM((PyObject *)seq, i))); //borrowed ref |
46 |
| - return res; |
| 92 | + // -------------------------------------- |
| 93 | + |
| 94 | + static std::vector<T> py2c(PyObject *ob) { |
| 95 | + if (PySequence_Check(ob)) { |
| 96 | + std::vector<T> res; |
| 97 | + pyref seq = PySequence_Fast(ob, "expected a sequence"); |
| 98 | + int len = PySequence_Size(ob); |
| 99 | + for (int i = 0; i < len; i++) res.push_back(py_converter<T>::py2c(PySequence_Fast_GET_ITEM((PyObject *)seq, i))); //borrowed ref |
| 100 | + return res; |
| 101 | + } |
| 102 | + //ASSERT(PyArray_Check(ob)); |
| 103 | + return make_vector_from_numpy_proxy<T>(make_numpy_proxy(ob)); |
47 | 104 | }
|
48 | 105 | };
|
49 | 106 |
|
|
0 commit comments