@@ -1028,6 +1028,7 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
1028
1028
}
1029
1029
dtype = PyArray_DESCR (mps [0 ]);
1030
1030
1031
+ int copy_existing_out = 0 ;
1031
1032
/* Set-up return array */
1032
1033
if (out == NULL ) {
1033
1034
Py_INCREF (dtype );
@@ -1039,10 +1040,6 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
1039
1040
(PyObject * )ap );
1040
1041
}
1041
1042
else {
1042
- int flags = NPY_ARRAY_CARRAY |
1043
- NPY_ARRAY_WRITEBACKIFCOPY |
1044
- NPY_ARRAY_FORCECAST ;
1045
-
1046
1043
if ((PyArray_NDIM (out ) != multi -> nd )
1047
1044
|| !PyArray_CompareLists (PyArray_DIMS (out ),
1048
1045
multi -> dimensions ,
@@ -1052,9 +1049,13 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
1052
1049
goto fail ;
1053
1050
}
1054
1051
1052
+ if (PyArray_FailUnlessWriteable (out , "output array" ) < 0 ) {
1053
+ goto fail ;
1054
+ }
1055
+
1055
1056
for (i = 0 ; i < n ; i ++ ) {
1056
1057
if (arrays_overlap (out , mps [i ])) {
1057
- flags |= NPY_ARRAY_ENSURECOPY ;
1058
+ copy_existing_out = 1 ;
1058
1059
}
1059
1060
}
1060
1061
@@ -1064,10 +1065,25 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
1064
1065
* so the input array is not changed
1065
1066
* before the error is called
1066
1067
*/
1067
- flags |= NPY_ARRAY_ENSURECOPY ;
1068
+ copy_existing_out = 1 ;
1069
+ }
1070
+
1071
+ if (!PyArray_EquivTypes (dtype , PyArray_DESCR (out ))) {
1072
+ copy_existing_out = 1 ;
1073
+ }
1074
+
1075
+ if (copy_existing_out ) {
1076
+ Py_INCREF (dtype );
1077
+ obj = (PyArrayObject * )PyArray_NewFromDescr (& PyArray_Type ,
1078
+ dtype ,
1079
+ multi -> nd ,
1080
+ multi -> dimensions ,
1081
+ NULL , NULL , 0 ,
1082
+ (PyObject * )out );
1083
+ }
1084
+ else {
1085
+ obj = (PyArrayObject * )Py_NewRef (out );
1068
1086
}
1069
- Py_INCREF (dtype );
1070
- obj = (PyArrayObject * )PyArray_FromArray (out , dtype , flags );
1071
1087
}
1072
1088
1073
1089
if (obj == NULL ) {
@@ -1080,12 +1096,13 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
1080
1096
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0 ;
1081
1097
if (PyDataType_REFCHK (dtype )) {
1082
1098
int is_aligned = IsUintAligned (obj );
1099
+ PyArray_Descr * obj_dtype = PyArray_DESCR (obj );
1083
1100
PyArray_GetDTypeTransferFunction (
1084
1101
is_aligned ,
1085
1102
dtype -> elsize ,
1086
- dtype -> elsize ,
1103
+ obj_dtype -> elsize ,
1087
1104
dtype ,
1088
- dtype , 0 , & cast_info ,
1105
+ obj_dtype , 0 , & cast_info ,
1089
1106
& transfer_flags );
1090
1107
}
1091
1108
@@ -1142,11 +1159,13 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
1142
1159
}
1143
1160
Py_DECREF (ap );
1144
1161
PyDataMem_FREE (mps );
1145
- if (out != NULL && out != obj ) {
1146
- Py_INCREF (out );
1147
- PyArray_ResolveWritebackIfCopy (obj );
1162
+ if (copy_existing_out ) {
1163
+ int res = PyArray_CopyInto (out , obj );
1148
1164
Py_DECREF (obj );
1149
- obj = out ;
1165
+ if (res < 0 ) {
1166
+ return NULL ;
1167
+ }
1168
+ return Py_NewRef (out );
1150
1169
}
1151
1170
return (PyObject * )obj ;
1152
1171
0 commit comments