Skip to content

Commit 80ec453

Browse files
committed
Convert std::vector and std::array of non-basic types to NPY arrays
1 parent ca17d77 commit 80ec453

File tree

2 files changed

+146
-71
lines changed

2 files changed

+146
-71
lines changed

c++/cpp2py/converters/std_array.hpp

+104-31
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,128 @@
11
#pragma once
22
#include <array>
3+
#include <numpy/arrayobject.h>
4+
35
#include "../pyref.hpp"
6+
#include "../numpy_proxy.hpp"
7+
#include "../py_converter.hpp"
8+
#include "./vector.hpp"
49

510
namespace cpp2py {
611

7-
template <typename T, size_t R> struct py_converter<std::array<T, R>> {
8-
// --------------------------------------
12+
template <typename T, size_t R> numpy_proxy make_numpy_proxy_from_heap_array(std::array<T, R> *arr_heap) {
913

10-
static PyObject *c2py(std::array<T, R> const &v) {
11-
PyObject *list = PyList_New(0);
12-
for (auto const &x : v) {
13-
pyref y = py_converter<T>::c2py(x);
14-
if (y.is_null() or (PyList_Append(list, y) == -1)) {
15-
Py_DECREF(list);
16-
return NULL;
17-
} // error
18-
}
19-
return list;
14+
auto delete_pycapsule = [](PyObject *capsule) {
15+
auto *ptr = static_cast<std::array<T, R> *>(PyCapsule_GetPointer(capsule, "guard"));
16+
delete ptr;
17+
};
18+
PyObject *capsule = PyCapsule_New(arr_heap, "guard", delete_pycapsule);
19+
20+
return numpy_proxy{1, // rank
21+
npy_type<std::decay_t<T>>,
22+
(void *)arr_heap->data(),
23+
std::is_const_v<T>,
24+
std::vector<long>{long(R)}, // extents
25+
std::vector<long>{sizeof(T)}, // strides
26+
capsule};
27+
};
28+
29+
template <typename T, size_t R> numpy_proxy make_numpy_proxy_from_array(std::array<T, R> const &arr) {
30+
31+
if constexpr (has_npy_type<T>) {
32+
auto *arr_heap = new std::array<T, R>{arr};
33+
return make_numpy_proxy_from_heap_array(arr_heap);
34+
} else {
35+
auto *arr_heap = new std::array<pyref, R>{};
36+
std::transform(begin(arr), end(arr), begin(*arr_heap), [](T const &x) { return py_converter<std::decay_t<T>>::c2py(x); });
37+
return make_numpy_proxy_from_heap_array(arr_heap);
2038
}
39+
}
40+
41+
// Make a new array from numpy view
42+
template <typename T, size_t R> std::array<T, R> make_array_from_numpy_proxy(numpy_proxy const &p) {
43+
EXPECTS(p.extents.size() == 1);
44+
EXPECTS(p.extents[0] == R);
45+
46+
std::array<T, R> arr;
47+
48+
if (p.element_type == npy_type<pyref>) {
49+
auto *data = static_cast<pyref *>(p.data);
50+
std::transform(data, data + R, begin(arr), [](PyObject *o) { return py_converter<std::decay_t<T>>::py2c(o); });
51+
} else {
52+
EXPECTS(p.strides == std::vector<long>{sizeof(T)});
53+
T *data = static_cast<T *>(p.data);
54+
std::copy(data, data + R, begin(arr));
55+
}
56+
57+
return arr;
58+
}
59+
60+
// --------------------------------------
61+
62+
template <typename T, size_t R> struct py_converter<std::array<T, R>> {
63+
64+
static PyObject *c2py(std::array<T, R> const &a) { return make_numpy_proxy_from_array(a).to_python(); }
2165

2266
// --------------------------------------
2367

2468
static bool is_convertible(PyObject *ob, bool raise_exception) {
25-
if (!PySequence_Check(ob)) goto _false;
26-
{
27-
pyref seq = PySequence_Fast(ob, "expected a sequence");
28-
int len = PySequence_Size(ob);
29-
if (len != R) {
30-
if (raise_exception) {
31-
auto s = std::string{"Convertion to std::array<T, R> failed : the length of the sequence ( = "} + std::to_string(len)
32-
+ " does not match R = " + std::to_string(R);
33-
PyErr_SetString(PyExc_TypeError, s.c_str());
34-
}
69+
_import_array();
70+
71+
// Special case: 1-d ndarray of builtin type
72+
if (PyArray_Check(ob)) {
73+
PyArrayObject *arr = (PyArrayObject *)(ob);
74+
#ifdef PYTHON_NUMPY_VERSION_LT_17
75+
int rank = arr->nd;
76+
#else
77+
int rank = PyArray_NDIM(arr);
78+
#endif
79+
if (PyArray_TYPE(arr) == npy_type<T> and rank == 1) return true;
80+
}
81+
82+
if (!PySequence_Check(ob)) {
83+
if (raise_exception) { PyErr_SetString(PyExc_TypeError, ("Cannot convert "s + to_string(ob) + " to std::array as it is not a sequence"s).c_str()); }
84+
return false;
85+
}
86+
87+
pyref seq = PySequence_Fast(ob, "expected a sequence");
88+
int len = PySequence_Size(ob);
89+
if (len != R) {
90+
if (raise_exception) {
91+
auto s = std::string{"Convertion to std::array<T, R> failed : the length of the sequence ( = "} + std::to_string(len)
92+
+ " does not match R = " + std::to_string(R);
93+
PyErr_SetString(PyExc_TypeError, s.c_str());
94+
}
95+
return false;
96+
}
97+
for (int i = 0; i < len; i++) {
98+
if (!py_converter<std::decay_t<T>>::is_convertible(PySequence_Fast_GET_ITEM((PyObject *)seq, i), raise_exception)) {
99+
if (PyErr_Occurred()) PyErr_Print();
35100
return false;
36101
}
37-
for (int i = 0; i < len; i++)
38-
if (!py_converter<T>::is_convertible(PySequence_Fast_GET_ITEM((PyObject *)seq, i), raise_exception)) goto _false; // borrowed ref
39-
40-
return true;
41102
}
42-
_false:
43-
if (raise_exception) { PyErr_SetString(PyExc_TypeError, ("Cannot convert "s + to_string(ob) + " to std::array"s).c_str()); }
44-
return false;
103+
return true;
45104
}
46105

47106
// --------------------------------------
48107

49108
static std::array<T, R> py2c(PyObject *ob) {
50-
pyref seq = PySequence_Fast(ob, "expected a sequence");
109+
_import_array();
110+
111+
// Special case: 1-d ndarray of builtin type
112+
if (PyArray_Check(ob)) {
113+
PyArrayObject *arr = (PyArrayObject *)(ob);
114+
#ifdef PYTHON_NUMPY_VERSION_LT_17
115+
int rank = arr->nd;
116+
#else
117+
int rank = PyArray_NDIM(arr);
118+
#endif
119+
if (rank == 1) return make_array_from_numpy_proxy<T, R>(make_numpy_proxy(ob));
120+
}
121+
122+
ASSERT(PySequence_Check(ob));
51123
std::array<T, R> res;
52-
for (int i = 0; i < R; i++) res[i] = py_converter<T>::py2c(PySequence_Fast_GET_ITEM((PyObject *)seq, i)); // borrowed ref
124+
pyref seq = PySequence_Fast(ob, "expected a sequence");
125+
for (int i = 0; i < R; i++) res[i] = py_converter<std::decay_t<T>>::py2c(PySequence_Fast_GET_ITEM((PyObject *)seq, i)); // borrowed ref
53126
return res;
54127
}
55128
};

c++/cpp2py/converters/vector.hpp

+42-40
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "../pyref.hpp"
88
#include "../macros.hpp"
99
#include "../numpy_proxy.hpp"
10+
#include "../py_converter.hpp"
1011

1112
namespace cpp2py {
1213

@@ -15,35 +16,54 @@ namespace cpp2py {
1516
static_assert(is_instantiation_of_v<std::vector, std::decay_t<V>>, "Logic error");
1617
using value_type = typename std::remove_reference_t<V>::value_type;
1718

18-
auto *vec_heap = new std::vector<value_type>{std::forward<V>(v)};
19-
auto delete_pycapsule = [](PyObject *capsule) {
20-
auto *ptr = static_cast<std::vector<value_type> *>(PyCapsule_GetPointer(capsule, "guard"));
21-
delete ptr;
22-
};
23-
PyObject *capsule = PyCapsule_New(vec_heap, "guard", delete_pycapsule);
24-
25-
return {1, // rank
26-
npy_type<value_type>,
27-
(void *)vec_heap->data(),
28-
std::is_const_v<value_type>,
29-
std::vector<long>{long(vec_heap->size())}, // extents
30-
std::vector<long>{sizeof(value_type)}, // strides
31-
capsule};
19+
if constexpr (has_npy_type<value_type>) {
20+
auto *vec_heap = new std::vector<value_type>{std::forward<V>(v)};
21+
auto delete_pycapsule = [](PyObject *capsule) {
22+
auto *ptr = static_cast<std::vector<value_type> *>(PyCapsule_GetPointer(capsule, "guard"));
23+
delete ptr;
24+
};
25+
PyObject *capsule = PyCapsule_New(vec_heap, "guard", delete_pycapsule);
26+
27+
return {1, // rank
28+
npy_type<value_type>,
29+
(void *)vec_heap->data(),
30+
std::is_const_v<value_type>,
31+
std::vector<long>{long(vec_heap->size())}, // extents
32+
std::vector<long>{sizeof(value_type)}, // strides
33+
capsule};
34+
} else {
35+
std::vector<pyref> vobj(v.size());
36+
std::transform(begin(v), end(v), begin(vobj), [](auto &&x) {
37+
if constexpr (std::is_reference_v<V>) {
38+
return convert_to_python(x);
39+
} else { // vector passed as rvalue
40+
return convert_to_python(std::move(x));
41+
}
42+
});
43+
return make_numpy_proxy_from_vector(std::move(vobj));
44+
}
3245
}
3346

3447
// Make a new vector from numpy view
3548
template <typename T> std::vector<T> make_vector_from_numpy_proxy(numpy_proxy const &p) {
3649
EXPECTS(p.extents.size() == 1);
37-
EXPECTS(p.strides[0] % sizeof(T) == 0);
3850

3951
long size = p.extents[0];
40-
long step = p.strides[0] / sizeof(T);
4152

4253
std::vector<T> v(size);
4354

44-
T *data = static_cast<T *>(p.data);
45-
for(long i = 0; i < size; ++i)
46-
v[i] = *(data + i * step);
55+
if (p.element_type == npy_type<pyref>) {
56+
long step = p.strides[0] / sizeof(pyref);
57+
auto **data = static_cast<PyObject **>(p.data);
58+
for(long i = 0; i < size; ++i)
59+
v[i] = py_converter<std::decay_t<T>>::py2c(data[i * step]);
60+
} else {
61+
EXPECTS(p.strides[0] % sizeof(T) == 0);
62+
long step = p.strides[0] / sizeof(T);
63+
T *data = static_cast<T *>(p.data);
64+
for(long i = 0; i < size; ++i)
65+
v[i] = data[i * step];
66+
}
4767

4868
return v;
4969
}
@@ -54,26 +74,7 @@ namespace cpp2py {
5474

5575
template <typename V> static PyObject *c2py(V &&v) {
5676
static_assert(is_instantiation_of_v<std::vector, std::decay_t<V>>, "Logic error");
57-
using value_type = typename std::remove_reference_t<V>::value_type;
58-
59-
if constexpr (has_npy_type<value_type>) {
60-
return make_numpy_proxy_from_vector(std::forward<V>(v)).to_python();
61-
} else { // Convert to Python List
62-
PyObject *list = PyList_New(0);
63-
for (auto &x : v) {
64-
pyref y;
65-
if constexpr(std::is_reference_v<V>){
66-
y = py_converter<value_type>::c2py(x);
67-
} else { // Vector passed as rvalue
68-
y = py_converter<value_type>::c2py(std::move(x));
69-
}
70-
if (y.is_null() or (PyList_Append(list, y) == -1)) {
71-
Py_DECREF(list);
72-
return NULL;
73-
} // error
74-
}
75-
return list;
76-
}
77+
return make_numpy_proxy_from_vector(std::forward<V>(v)).to_python();
7778
}
7879

7980
// --------------------------------------
@@ -100,7 +101,8 @@ namespace cpp2py {
100101
pyref seq = PySequence_Fast(ob, "expected a sequence");
101102
int len = PySequence_Size(ob);
102103
for (int i = 0; i < len; i++) {
103-
if (!py_converter<T>::is_convertible(PySequence_Fast_GET_ITEM((PyObject *)seq, i), raise_exception)) { // borrowed ref
104+
if (!py_converter<std::decay_t<T>>::is_convertible(PySequence_Fast_GET_ITEM((PyObject *)seq, i), raise_exception)) { // borrowed ref
105+
if (PyErr_Occurred()) PyErr_Print();
104106
return false;
105107
}
106108
}

0 commit comments

Comments
 (0)