Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MAINT] Run array API conformity with 2024.12 spec #2021

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ jobs:
cd /home/runner/work/array-api-tests
${CONDA_PREFIX}/bin/python -c "import dpctl; dpctl.lsplatform()"
export ARRAY_API_TESTS_MODULE=dpctl.tensor
export ARRAY_API_TESTS_VERSION=2024.12
${CONDA_PREFIX}/bin/python -m pytest --json-report --json-report-file=$FILE --disable-deadline --skips-file ${GITHUB_WORKSPACE}/.github/workflows/array-api-skips.txt array_api_tests/ || true
- name: Set Github environment variables
shell: bash -l {0}
Expand Down
75 changes: 61 additions & 14 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
import builtins
import operator
from numbers import Integral

import numpy as np

Expand Down Expand Up @@ -819,15 +820,26 @@ def _take_multi_index(ary, inds, p, mode=0):
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)
any_usmarray = False
for ind in inds:
if not isinstance(ind, dpt.usm_ndarray):
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
if isinstance(ind, dpt.usm_ndarray):
any_usmarray = True
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) "
"type"
)
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
elif not isinstance(ind, Integral):
raise TypeError(
"all elements of `ind` expected to be usm_ndarrays "
"or integers"
)
if not any_usmarray:
raise TypeError(
"at least one element of `ind` expected to be a usm_ndarray"
)
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is None:
Expand All @@ -838,6 +850,18 @@ def _take_multi_index(ary, inds, p, mode=0):
"be associated with the same queue."
)
if len(inds) > 1:
inds = tuple(
map(
lambda ind: (
ind
if isinstance(ind, dpt.usm_ndarray)
else dpt.asarray(
ind, usm_type=res_usm_type, sycl_queue=exec_q
)
),
inds,
)
)
ind_dt = dpt.result_type(*inds)
# ind arrays have been checked to be of integer dtype
if ind_dt.kind not in "ui":
Expand Down Expand Up @@ -968,15 +992,26 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
]
if not isinstance(inds, (list, tuple)):
inds = (inds,)
any_usmarray = False
for ind in inds:
if not isinstance(ind, dpt.usm_ndarray):
raise TypeError("all elements of `ind` expected to be usm_ndarrays")
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) type"
if isinstance(ind, dpt.usm_ndarray):
any_usmarray = True
if ind.dtype.kind not in "ui":
raise IndexError(
"arrays used as indices must be of integer (or boolean) "
"type"
)
queues_.append(ind.sycl_queue)
usm_types_.append(ind.usm_type)
elif not isinstance(ind, Integral):
raise TypeError(
"all elements of `ind` expected to be usm_ndarrays "
"or integers"
)
if not any_usmarray:
raise TypeError(
"at least one element of `ind` expected to be a usm_ndarray"
)
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is not None:
Expand All @@ -994,6 +1029,18 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
"be associated with the same queue."
)
if len(inds) > 1:
inds = tuple(
map(
lambda ind: (
ind
if isinstance(ind, dpt.usm_ndarray)
else dpt.asarray(
ind, usm_type=vals_usm_type, sycl_queue=exec_q
)
),
inds,
)
)
ind_dt = dpt.result_type(*inds)
# ind arrays have been checked to be of integer dtype
if ind_dt.kind not in "ui":
Expand Down
61 changes: 41 additions & 20 deletions dpctl/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ cdef Py_ssize_t _slice_len(
cdef bint _is_integral(object x) except *:
"""Gives True if x is an integral slice spec"""
if isinstance(x, usm_ndarray):
if x.ndim > 0:
return False
if x.dtype.kind not in "ui":
return False
return True
return False
if isinstance(x, bool):
return False
if isinstance(x, int):
Expand Down Expand Up @@ -179,10 +175,12 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
if array_streak_started:
array_streak_interrupted = True
elif _is_integral(i):
explicit_index += 1
axes_referenced += 1
if array_streak_started:
array_streak_interrupted = True
if array_streak_started and not array_streak_interrupted:
# integers converted to arrays in this case
array_count += 1
else:
explicit_index += 1
elif isinstance(i, usm_ndarray):
if not seen_arrays_yet:
seen_arrays_yet = True
Expand All @@ -196,7 +194,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
dt_k = i.dtype.kind
if dt_k == "b" and i.ndim > 0:
axes_referenced += i.ndim
elif dt_k in "ui" and i.ndim > 0:
elif dt_k in "ui":
axes_referenced += 1
else:
raise IndexError(
Expand Down Expand Up @@ -229,6 +227,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
advanced_start_pos_set = False
new_offset = offset
is_empty = False
array_streak = False
for i in range(len(ind)):
ind_i = ind[i]
if (ind_i is Ellipsis):
Expand All @@ -239,9 +238,13 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
is_empty = True
new_offset = offset
k = k_new
if array_streak:
array_streak = False
elif ind_i is None:
new_shape.append(1)
new_strides.append(0)
if array_streak:
array_streak = False
elif isinstance(ind_i, slice):
k_new = k + 1
sl_start, sl_stop, sl_step = ind_i.indices(shape[k])
Expand All @@ -255,26 +258,44 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
is_empty = True
new_offset = offset
k = k_new
if array_streak:
array_streak = False
elif _is_boolean(ind_i):
new_shape.append(1 if ind_i else 0)
new_strides.append(0)
if array_streak:
array_streak = False
elif _is_integral(ind_i):
ind_i = ind_i.__index__()
if 0 <= ind_i < shape[k]:
if array_streak:
# integer will be converted to an array, still raise if OOB
if not (0 <= ind_i < shape[k] or -shape[k] <= ind_i < 0):
raise IndexError(
("Index {0} is out of range for "
"axes {1} with size {2}").format(ind_i, k, shape[k]))
new_advanced_ind.append(ind_i)
k_new = k + 1
if not is_empty:
new_offset = new_offset + ind_i * strides[k]
k = k_new
elif -shape[k] <= ind_i < 0:
k_new = k + 1
if not is_empty:
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
new_shape.extend(shape[k:k_new])
new_strides.extend(strides[k:k_new])
k = k_new
else:
raise IndexError(
("Index {0} is out of range for "
"axes {1} with size {2}").format(ind_i, k, shape[k]))
if 0 <= ind_i < shape[k]:
k_new = k + 1
if not is_empty:
new_offset = new_offset + ind_i * strides[k]
k = k_new
elif -shape[k] <= ind_i < 0:
k_new = k + 1
if not is_empty:
new_offset = new_offset + (shape[k] + ind_i) * strides[k]
k = k_new
else:
raise IndexError(
("Index {0} is out of range for "
"axes {1} with size {2}").format(ind_i, k, shape[k]))
elif isinstance(ind_i, usm_ndarray):
if not array_streak:
array_streak = True
if not advanced_start_pos_set:
new_advanced_start_pos = len(new_shape)
advanced_start_pos_set = True
Expand Down
45 changes: 23 additions & 22 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue)
ev = self_queue.submit_barrier()
stream.submit_barrier(dependent_events=[ev])


cdef class usm_ndarray:
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
offset=0, order="C", buffer_ctor_kwargs=dict(), \
Expand Down Expand Up @@ -962,28 +961,30 @@ cdef class usm_ndarray:
return res

from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
key_ = adv_ind[0]
adv_ind_end_p = key_.ndim + adv_ind_start_p
if adv_ind_end_p > res.ndim:
raise IndexError("too many indices for the array")
key_shape = key_.shape
arr_shape = res.shape[adv_ind_start_p:adv_ind_end_p]
for i in range(key_.ndim):
if matching:
if not key_shape[i] == arr_shape[i] and key_shape[i] > 0:
matching = 0
if not matching:
raise IndexError("boolean index did not match indexed array in dimensions")
res = _extract_impl(res, key_, axis=adv_ind_start_p)
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

if any(ind.dtype == dpt_bool for ind in adv_ind):
# if len(adv_ind == 1), the (only) element is always an array
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
key_ = adv_ind[0]
adv_ind_end_p = key_.ndim + adv_ind_start_p
if adv_ind_end_p > res.ndim:
raise IndexError("too many indices for the array")
key_shape = key_.shape
arr_shape = res.shape[adv_ind_start_p:adv_ind_end_p]
for i in range(key_.ndim):
if matching:
if not key_shape[i] == arr_shape[i] and key_shape[i] > 0:
matching = 0
if not matching:
raise IndexError("boolean index did not match indexed array in dimensions")
res = _extract_impl(res, key_, axis=adv_ind_start_p)
res.flags_ = _copy_writable(res.flags_, self.flags_)
return res

if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
adv_ind_int = list()
for ind in adv_ind:
if ind.dtype == dpt_bool:
adv_ind_int.extend(_nonzero_impl(ind))
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
adv_ind_int.extend(_nonzero_impl(ind))
else:
adv_ind_int.append(ind)
res = _take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
Expand Down Expand Up @@ -1433,10 +1434,10 @@ cdef class usm_ndarray:
_place_impl(Xv, adv_ind[0], rhs, axis=adv_ind_start_p)
return

if any(ind.dtype == dpt_bool for ind in adv_ind):
if any((isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool) for ind in adv_ind):
adv_ind_int = list()
for ind in adv_ind:
if ind.dtype == dpt_bool:
if isinstance(ind, usm_ndarray) and ind.dtype == dpt_bool:
adv_ind_int.extend(_nonzero_impl(ind))
else:
adv_ind_int.append(ind)
Expand Down
10 changes: 8 additions & 2 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,14 @@ def test_advanced_slice5():
q = get_queue_or_skip()
ii = dpt.asarray([1, 2], sycl_queue=q)
x = _make_3d("i4", q)
with pytest.raises(IndexError):
x[ii, 0, ii]
y = x[ii, 0, ii]
assert isinstance(y, dpt.usm_ndarray)
# 0 broadcast to [0, 0] per array API
assert y.shape == ii.shape
assert _all_equal(
(x[ii[i], 0, ii[i]] for i in range(ii.shape[0])),
(y[i] for i in range(ii.shape[0])),
)


def test_advanced_slice6():
Expand Down
Loading