Skip to content

Commit cdb44f7

Browse files
MaanasAroraseberg
andauthored
BUG: Remove unnecessary copying and casting from out array in choose (numpy#28206)
This PR does not copy the existing out array in np.choose, instead creating the new output array from scratch and copying it back manually. --------- Co-authored-by: Sebastian Berg <[email protected]>
1 parent 7d81f1e commit cdb44f7

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

numpy/_core/src/multiarray/item_selection.c

+33-14
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,7 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
10281028
}
10291029
dtype = PyArray_DESCR(mps[0]);
10301030

1031+
int copy_existing_out = 0;
10311032
/* Set-up return array */
10321033
if (out == NULL) {
10331034
Py_INCREF(dtype);
@@ -1039,10 +1040,6 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
10391040
(PyObject *)ap);
10401041
}
10411042
else {
1042-
int flags = NPY_ARRAY_CARRAY |
1043-
NPY_ARRAY_WRITEBACKIFCOPY |
1044-
NPY_ARRAY_FORCECAST;
1045-
10461043
if ((PyArray_NDIM(out) != multi->nd)
10471044
|| !PyArray_CompareLists(PyArray_DIMS(out),
10481045
multi->dimensions,
@@ -1052,9 +1049,13 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
10521049
goto fail;
10531050
}
10541051

1052+
if (PyArray_FailUnlessWriteable(out, "output array") < 0) {
1053+
goto fail;
1054+
}
1055+
10551056
for (i = 0; i < n; i++) {
10561057
if (arrays_overlap(out, mps[i])) {
1057-
flags |= NPY_ARRAY_ENSURECOPY;
1058+
copy_existing_out = 1;
10581059
}
10591060
}
10601061

@@ -1064,10 +1065,25 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
10641065
* so the input array is not changed
10651066
* before the error is called
10661067
*/
1067-
flags |= NPY_ARRAY_ENSURECOPY;
1068+
copy_existing_out = 1;
1069+
}
1070+
1071+
if (!PyArray_EquivTypes(dtype, PyArray_DESCR(out))) {
1072+
copy_existing_out = 1;
1073+
}
1074+
1075+
if (copy_existing_out) {
1076+
Py_INCREF(dtype);
1077+
obj = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
1078+
dtype,
1079+
multi->nd,
1080+
multi->dimensions,
1081+
NULL, NULL, 0,
1082+
(PyObject *)out);
1083+
}
1084+
else {
1085+
obj = (PyArrayObject *)Py_NewRef(out);
10681086
}
1069-
Py_INCREF(dtype);
1070-
obj = (PyArrayObject *)PyArray_FromArray(out, dtype, flags);
10711087
}
10721088

10731089
if (obj == NULL) {
@@ -1080,12 +1096,13 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
10801096
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
10811097
if (PyDataType_REFCHK(dtype)) {
10821098
int is_aligned = IsUintAligned(obj);
1099+
PyArray_Descr *obj_dtype = PyArray_DESCR(obj);
10831100
PyArray_GetDTypeTransferFunction(
10841101
is_aligned,
10851102
dtype->elsize,
1086-
dtype->elsize,
1103+
obj_dtype->elsize,
10871104
dtype,
1088-
dtype, 0, &cast_info,
1105+
obj_dtype, 0, &cast_info,
10891106
&transfer_flags);
10901107
}
10911108

@@ -1142,11 +1159,13 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
11421159
}
11431160
Py_DECREF(ap);
11441161
PyDataMem_FREE(mps);
1145-
if (out != NULL && out != obj) {
1146-
Py_INCREF(out);
1147-
PyArray_ResolveWritebackIfCopy(obj);
1162+
if (copy_existing_out) {
1163+
int res = PyArray_CopyInto(out, obj);
11481164
Py_DECREF(obj);
1149-
obj = out;
1165+
if (res < 0) {
1166+
return NULL;
1167+
}
1168+
return Py_NewRef(out);
11501169
}
11511170
return (PyObject *)obj;
11521171

numpy/_core/tests/test_multiarray.py

+16
Original file line numberDiff line numberDiff line change
@@ -1980,6 +1980,12 @@ def test_choose(self):
19801980
y = np.choose([0, 0, 0], [x[:3], x[:3], x[:3]], out=x[1:4], mode='wrap')
19811981
assert_equal(y, np.array([0, 1, 2]))
19821982

1983+
# gh_28206 check fail when out not writeable
1984+
x = np.arange(3)
1985+
out = np.zeros(3)
1986+
out.setflags(write=False)
1987+
assert_raises(ValueError, np.choose, [0, 1, 2], [x, x, x], out=out)
1988+
19831989
def test_prod(self):
19841990
ba = [1, 2, 10, 11, 6, 5, 4]
19851991
ba2 = [[1, 2, 3, 4], [5, 6, 7, 9], [10, 3, 4, 5]]
@@ -10287,6 +10293,16 @@ def test_gh_24459():
1028710293
np.choose(a, [3, -1])
1028810294

1028910295

10296+
def test_gh_28206():
10297+
a = np.arange(3)
10298+
b = np.ones((3, 3), dtype=np.int64)
10299+
out = np.array([np.nan, np.nan, np.nan])
10300+
10301+
with warnings.catch_warnings():
10302+
warnings.simplefilter("error", RuntimeWarning)
10303+
np.choose(a, b, out=out)
10304+
10305+
1029010306
@pytest.mark.parametrize("N", np.arange(2, 512))
1029110307
@pytest.mark.parametrize("dtype", [np.int16, np.uint16,
1029210308
np.int32, np.uint32, np.int64, np.uint64])

0 commit comments

Comments
 (0)