-
-
Notifications
You must be signed in to change notification settings - Fork 13
implement same_value casting for numpy <-> quadtype #161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
0d234bd
5cd17c8
36b7a64
445d544
b35756a
e7143dd
40c1ba9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
#define PY_ARRAY_UNIQUE_SYMBOL QuadPrecType_ARRAY_API | ||
#define PY_UFUNC_UNIQUE_SYMBOL QuadPrecType_UFUNC_API | ||
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION | ||
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION | ||
#define NPY_NO_DEPRECATED_API NPY_2_4_API_VERSION | ||
#define NPY_TARGET_VERSION NPY_2_4_API_VERSION | ||
#define NO_IMPORT_ARRAY | ||
#define NO_IMPORT_UFUNC | ||
|
||
|
@@ -157,7 +157,7 @@ void_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta * | |
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2], | ||
npy_intp *view_offset) | ||
{ | ||
PyErr_SetString(PyExc_TypeError, | ||
PyErr_SetString(PyExc_TypeError, | ||
"Void to QuadPrecision cast is not implemented"); | ||
return (NPY_CASTING)-1; | ||
} | ||
|
@@ -401,7 +401,7 @@ to_quad<long double>(long double x, QuadBackendType backend) | |
} | ||
|
||
template <typename T> | ||
static NPY_CASTING | ||
static int | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... I think this is OK? I didn't see any new warnings, but there are thousands of build warnings when using sleef and I might have missed it. |
||
numpy_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *dtypes[2], | ||
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2], | ||
npy_intp *view_offset) | ||
|
@@ -419,7 +419,11 @@ numpy_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta | |
} | ||
|
||
loop_descrs[0] = PyArray_GetDefaultDescr(dtypes[0]); | ||
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION | ||
return NPY_SAFE_CASTING | NPY_SAME_VALUE_CASTING_FLAG; | ||
#else | ||
return NPY_SAFE_CASTING; | ||
#endif | ||
} | ||
|
||
template <typename T> | ||
|
@@ -666,6 +670,28 @@ from_quad<long double>(quad_value x, QuadBackendType backend) | |
} | ||
} | ||
|
||
template <typename T> | ||
static inline int | ||
from_quad_checked(quad_value x, QuadBackendType backend, typename NpyType<T>::TYPE *ret) { | ||
*ret = from_quad<typename NpyType<T>::TYPE>(x, backend); | ||
quad_value check = to_quad<typename NpyType<T>::TYPE>(*ret, backend); | ||
if (backend == BACKEND_SLEEF) { | ||
if (check.sleef_value == x.sleef_value) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just compare sleef values here? Should this use equality with NaN = NaN? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I was following the pattern from elsewhere. Is it guaranteed that the sleef backend is used?
Good point! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the SLEEF branch, I was thinking of |
||
return 0; | ||
} | ||
} | ||
else { | ||
if (check.longdouble_value == x.longdouble_value) { | ||
return 0; | ||
} | ||
} | ||
NPY_ALLOW_C_API_DEF; | ||
NPY_ALLOW_C_API; | ||
PyErr_SetString(PyExc_ValueError, "could not cast 'same_value' to QuadType"); | ||
NPY_DISABLE_C_API; | ||
return -1; | ||
} | ||
|
||
template <typename T> | ||
static NPY_CASTING | ||
quad_to_numpy_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *dtypes[2], | ||
|
@@ -685,6 +711,9 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const | |
npy_intp const dimensions[], npy_intp const strides[], | ||
void *NPY_UNUSED(auxdata)) | ||
{ | ||
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION | ||
int same_value_casting = ((context->flags & NPY_SAME_VALUE_CONTEXT_FLAG) == NPY_SAME_VALUE_CONTEXT_FLAG); | ||
#endif | ||
npy_intp N = dimensions[0]; | ||
char *in_ptr = data[0]; | ||
char *out_ptr = data[1]; | ||
|
@@ -694,6 +723,24 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const | |
|
||
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double); | ||
|
||
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION | ||
if (same_value_casting) { | ||
while (N--) { | ||
quad_value in_val; | ||
memcpy(&in_val, in_ptr, elem_size); | ||
typename NpyType<T>::TYPE out_val; | ||
if (from_quad_checked<T>(in_val, backend, &out_val) < 0) { | ||
return -1; | ||
} | ||
memcpy(out_ptr, &out_val, sizeof(typename NpyType<T>::TYPE)); | ||
|
||
in_ptr += strides[0]; | ||
out_ptr += strides[1]; | ||
} | ||
} else { | ||
#else | ||
{ | ||
#endif | ||
while (N--) { | ||
quad_value in_val; | ||
memcpy(&in_val, in_ptr, elem_size); | ||
|
@@ -703,7 +750,7 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const | |
|
||
in_ptr += strides[0]; | ||
out_ptr += strides[1]; | ||
} | ||
}} | ||
return 0; | ||
} | ||
|
||
|
@@ -716,10 +763,36 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d | |
npy_intp N = dimensions[0]; | ||
char *in_ptr = data[0]; | ||
char *out_ptr = data[1]; | ||
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION | ||
int same_value_casting = ((context->flags & NPY_SAME_VALUE_CONTEXT_FLAG) == NPY_SAME_VALUE_CONTEXT_FLAG); | ||
#endif | ||
|
||
QuadPrecDTypeObject *quad_descr = (QuadPrecDTypeObject *)context->descriptors[0]; | ||
QuadBackendType backend = quad_descr->backend; | ||
|
||
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION | ||
if (same_value_casting) { | ||
while (N--) { | ||
quad_value in_val; | ||
if (backend == BACKEND_SLEEF) { | ||
in_val.sleef_value = *(Sleef_quad *)in_ptr; | ||
} | ||
else { | ||
in_val.longdouble_value = *(long double *)in_ptr; | ||
} | ||
|
||
typename NpyType<T>::TYPE out_val; | ||
if (from_quad_checked<T>(in_val, backend, &out_val) < 0) { | ||
return -1; | ||
} | ||
*(typename NpyType<T>::TYPE *)(out_ptr) = out_val; | ||
|
||
in_ptr += strides[0]; | ||
out_ptr += strides[1]; | ||
}} else { | ||
#else | ||
{ | ||
#endif | ||
while (N--) { | ||
quad_value in_val; | ||
if (backend == BACKEND_SLEEF) { | ||
|
@@ -734,7 +807,7 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d | |
|
||
in_ptr += strides[0]; | ||
out_ptr += strides[1]; | ||
} | ||
}} | ||
return 0; | ||
} | ||
|
||
|
@@ -771,7 +844,11 @@ add_cast_from(PyArray_DTypeMeta *to) | |
.name = "cast_QuadPrec_to_NumPy", | ||
.nin = 1, | ||
.nout = 1, | ||
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION | ||
.casting = NPY_SAME_VALUE_CASTING, | ||
#else | ||
.casting = NPY_UNSAFE_CASTING, | ||
#endif | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not clear how this is used, doesn't the resolver actually state which casting is supported? |
||
.flags = NPY_METH_SUPPORTS_UNALIGNED, | ||
.dtypes = dtypes, | ||
.slots = slots, | ||
|
@@ -904,4 +981,4 @@ free_casts(void) | |
} | ||
} | ||
spec_count = 0; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
#include <algorithm> | ||
|
||
#ifndef DISABLE_QUADBLAS | ||
#include "../subprojects/qblas/include/quadblas/quadblas.hpp" | ||
#include "quadblas/quadblas.hpp" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I needed this to build with |
||
#endif // DISABLE_QUADBLAS | ||
|
||
extern "C" { | ||
|
@@ -230,4 +230,4 @@ _quadblas_get_num_threads(void) | |
|
||
#endif // DISABLE_QUADBLAS | ||
|
||
} // extern "C" | ||
} // extern "C" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,6 +76,16 @@ def test_unsupported_astype(dtype): | |
np.array(QuadPrecision(1)).astype(dtype, casting="unsafe") | ||
|
||
|
||
def test_same_value_cast(): | ||
# This will fail if compiled with NPY_TARGET_VERSION NPY<2_4_API_VERSION | ||
|
||
a = np.arange(30, dtype=np.float32) | ||
# upcasting can never fail | ||
b = a.astype(QuadPrecision, casting='same_value') | ||
c = b.astype(np.float32, casting='same_value') | ||
assert np.all(c == a) | ||
with pytest.raises(ValueError, match="could not cast 'same_value'"): | ||
(b + 1e22).astype(np.float32, casting='same_value') | ||
|
||
def test_basic_equality(): | ||
assert QuadPrecision("12") == QuadPrecision( | ||
"12.0") == QuadPrecision("12.00") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit trickier unless you don't care about supporting NumPy <2.4. But if you don't, make sure that the target version is defined in (and matches up) with the file that does the actual NumPy import/loading. (Because that will check runtime compatibility.)
You can lie (I really don't mind), but you cannot use the flag at runtime unless you also check
PyArray_RUNTIME_VERSION
at runtime and not at compile time.As is, if you compile with NumPy 2.4 (which you should!), the casts will be broken on 2.3.
(could also just define the flag here manually. We could include the
PyArray_RUNTIME_CHECK
into the numpy headers as well if you prefer -- to basically simplify exactly this type of code that wants to use it if available. The compat headers have a lot of examples for such code.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, that is
quaddtype_main.c
.Did you mean
PyArray_GetNDArrayCFeatureVersion() >= NPY_2_4_API_VERSION
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean
PyArray_RUNTIME_VERSION >= NPY_2_4_API_VERSION
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does that work at runtime? It is not exported as a constant from
_multiarray_umath.so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's injected as a constant into your library, next to the API table. Dunno if that was the right choice, but there we are :).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyArray_RUNTIME_VERSION
is runtime! Not a pre-processor macro.