Skip to content

Commit af57451

Browse files
hpkfftwjakob
authored andcommitted
ndarray import from buffer protocol requires integer stride. (#489)
1 parent 937a1df commit af57451

File tree

5 files changed

+42
-5
lines changed

5 files changed

+42
-5
lines changed

docs/api_extra.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -634,21 +634,21 @@ section <ndarrays>`.
634634
.. cpp:function:: size_t itemsize() const
635635

636636
Return the size of a single array element in bytes. The returned value
637-
is rounded to the next full byte in case of bit-level representations
637+
is rounded up to the next full byte in case of bit-level representations
638638
(query :cpp:member:`dtype::bits` for bit-level granularity).
639639

640640
.. cpp:function:: size_t nbytes() const
641641

642642
Return the size of the entire array bytes. The returned value is rounded
643-
to the next full byte in case of bit-level representations.
643+
up to the next full byte in case of bit-level representations.
644644

645645
.. cpp:function:: size_t shape(size_t i) const
646646

647647
Return the size of dimension `i`.
648648

649649
.. cpp:function:: int64_t stride(size_t i) const
650650

651-
Return the stride of dimension `i`.
651+
Return the stride (in number of elements) of dimension `i`.
652652

653653
.. cpp:function:: const int64_t* shape_ptr() const
654654

src/nb_ndarray.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,14 @@ static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool ro) {
262262

263263
scoped_pymalloc<int64_t> strides((size_t) view->ndim);
264264
scoped_pymalloc<int64_t> shape((size_t) view->ndim);
265+
const int64_t itemsize = static_cast<int64_t>(view->itemsize);
265266
for (size_t i = 0; i < (size_t) view->ndim; ++i) {
266-
strides[i] = (int64_t) (view->strides[i] / view->itemsize);
267+
int64_t stride = view->strides[i] / itemsize;
268+
if (stride * itemsize != view->strides[i]) {
269+
PyBuffer_Release(view.get());
270+
return nullptr;
271+
}
272+
strides[i] = stride;
267273
shape[i] = (int64_t) view->shape[i];
268274
}
269275

tests/test_ndarray.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ NB_MODULE(test_ndarray_ext, m) {
4444
return t.nbytes();
4545
}, "array"_a.noconvert());
4646

47+
m.def("get_stride", [](const nb::ndarray<> &t, size_t i) {
48+
return t.stride(i);
49+
}, "array"_a.noconvert(), "i"_a);
50+
4751
m.def("check_shape_ptr", [](const nb::ndarray<> &t) {
4852
std::vector<int64_t> shape(t.ndim());
4953
std::copy(t.shape_ptr(), t.shape_ptr() + t.ndim(), shape.begin());

tests/test_ndarray.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def test28_reference_internal():
558558
assert msg in str(excinfo.value)
559559

560560
@needs_numpy
561-
def test29_force_contig_pytorch():
561+
def test29_force_contig_numpy():
562562
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
563563
b = t.make_contig(a)
564564
assert b is a
@@ -656,3 +656,28 @@ def __dlpack__(self):
656656

657657
arr = DLPackWrapper(np.zeros((1)))
658658
assert t.check(arr)
659+
660+
@needs_numpy
661+
def test37_noninteger_stride():
662+
a = np.array([[1, 2, 3, 4, 0, 0], [5, 6, 7, 8, 0, 0]], dtype=np.float32)
663+
s = a[:, 0:4] # slice
664+
t.pass_float32(s)
665+
assert t.get_stride(s, 0) == 6;
666+
assert t.get_stride(s, 1) == 1;
667+
v = s.view(np.complex64)
668+
t.pass_complex64(v)
669+
assert t.get_stride(v, 0) == 3;
670+
assert t.get_stride(v, 1) == 1;
671+
672+
a = np.array([[1, 2, 3, 4, 0], [5, 6, 7, 8, 0]], dtype=np.float32)
673+
s = a[:, 0:4] # slice
674+
t.pass_float32(s)
675+
assert t.get_stride(s, 0) == 5;
676+
assert t.get_stride(s, 1) == 1;
677+
v = s.view(np.complex64)
678+
with pytest.raises(TypeError) as excinfo:
679+
t.pass_complex64(v)
680+
assert 'incompatible function arguments' in str(excinfo.value)
681+
with pytest.raises(TypeError) as excinfo:
682+
t.get_stride(v, 0);
683+
assert 'incompatible function arguments' in str(excinfo.value)

tests/test_ndarray_ext.pyi.ref

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def get_shape(array: Annotated[ArrayLike, dict(writable=False)]) -> list: ...
6767

6868
def get_size(array: ArrayLike) -> int: ...
6969

70+
def get_stride(array: ArrayLike, i: int) -> int: ...
71+
7072
def implicit(array: Annotated[ArrayLike, dict(dtype='float32', order='C', shape=(2, 2))]) -> int: ...
7173

7274
@overload

0 commit comments

Comments
 (0)