Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 84 additions & 7 deletions quaddtype/numpy_quaddtype/src/casts.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#define PY_ARRAY_UNIQUE_SYMBOL QuadPrecType_ARRAY_API
#define PY_UFUNC_UNIQUE_SYMBOL QuadPrecType_UFUNC_API
#define NPY_NO_DEPRECATED_API NPY_2_0_API_VERSION
#define NPY_TARGET_VERSION NPY_2_0_API_VERSION
#define NPY_NO_DEPRECATED_API NPY_2_4_API_VERSION
#define NPY_TARGET_VERSION NPY_2_4_API_VERSION
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit trickier unless you don't care about supporting NumPy <2.4. But if you don't, make sure that the target version is defined in (and matches up) with the file that does the actual NumPy import/loading. (Because that will check runtime compatibility.)

You can lie (I really don't mind), but you cannot use the flag at runtime unless you also check PyArray_RUNTIME_VERSION at runtime and not at compile time.
As is, if you compile with NumPy 2.4 (which you should!), the casts will be broken on 2.3.

(could also just define the flag here manually. We could include the PyArray_RUNTIME_CHECK into the numpy headers as well if you prefer -- to basically simplify exactly this type of code that wants to use it if available. The compat headers have a lot of examples for such code.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure that the target version is defined in (and matches up) with the file that does the actual NumPy import/loading

Good point, that is quaddtype_main.c.

but you cannot use the flag at runtime unless you also check PyArray_RUNTIME_VERSION at runtime

Did you mean PyArray_GetNDArrayCFeatureVersion() >= NPY_2_4_API_VERSION?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean PyArray_RUNTIME_VERSION >= NPY_2_4_API_VERSION.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does that work at runtime? It is not exported as a constant from _multiarray_umath.so

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's injected as a constant into your library, next to the API table. Dunno if that was the right choice, but there we are :).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyArray_RUNTIME_VERSION is runtime! Not a pre-processor macro.

#define NO_IMPORT_ARRAY
#define NO_IMPORT_UFUNC

Expand Down Expand Up @@ -157,7 +157,7 @@ void_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
npy_intp *view_offset)
{
PyErr_SetString(PyExc_TypeError,
PyErr_SetString(PyExc_TypeError,
"Void to QuadPrecision cast is not implemented");
return (NPY_CASTING)-1;
}
Expand Down Expand Up @@ -401,7 +401,7 @@ to_quad<long double>(long double x, QuadBackendType backend)
}

template <typename T>
static NPY_CASTING
static int
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I think this is OK? I didn't see any new warnings, but there are thousands of build warnings when using sleef and I might have missed it.

numpy_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *dtypes[2],
PyArray_Descr *given_descrs[2], PyArray_Descr *loop_descrs[2],
npy_intp *view_offset)
Expand All @@ -419,7 +419,11 @@ numpy_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta
}

loop_descrs[0] = PyArray_GetDefaultDescr(dtypes[0]);
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
return NPY_SAFE_CASTING | NPY_SAME_VALUE_CASTING_FLAG;
#else
return NPY_SAFE_CASTING;
#endif
}

template <typename T>
Expand Down Expand Up @@ -666,6 +670,28 @@ from_quad<long double>(quad_value x, QuadBackendType backend)
}
}

template <typename T>
static inline int
from_quad_checked(quad_value x, QuadBackendType backend, typename NpyType<T>::TYPE *ret) {
*ret = from_quad<typename NpyType<T>::TYPE>(x, backend);
quad_value check = to_quad<typename NpyType<T>::TYPE>(*ret, backend);
if (backend == BACKEND_SLEEF) {
if (check.sleef_value == x.sleef_value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just compare sleef values here? Should this use equality with NaN = NaN?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just compare sleef values here?

I was following the pattern from elsewhere. Is it guaranteed that the sleef backend is used?

Should this use equality with NaN = NaN?

Good point!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the SLEEF branch, I was thinking of Sleef_icmpeqq1(a, b) | (Sleef_iunordq1(a, a) & Sleef_iunordq1(b, b))

return 0;
}
}
else {
if (check.longdouble_value == x.longdouble_value) {
return 0;
}
}
NPY_ALLOW_C_API_DEF;
NPY_ALLOW_C_API;
PyErr_SetString(PyExc_ValueError, "could not cast 'same_value' to QuadType");
NPY_DISABLE_C_API;
return -1;
}

template <typename T>
static NPY_CASTING
quad_to_numpy_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *dtypes[2],
Expand All @@ -685,6 +711,9 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
npy_intp const dimensions[], npy_intp const strides[],
void *NPY_UNUSED(auxdata))
{
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
int same_value_casting = ((context->flags & NPY_SAME_VALUE_CONTEXT_FLAG) == NPY_SAME_VALUE_CONTEXT_FLAG);
#endif
npy_intp N = dimensions[0];
char *in_ptr = data[0];
char *out_ptr = data[1];
Expand All @@ -694,6 +723,24 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const

size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);

#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
if (same_value_casting) {
while (N--) {
quad_value in_val;
memcpy(&in_val, in_ptr, elem_size);
typename NpyType<T>::TYPE out_val;
if (from_quad_checked<T>(in_val, backend, &out_val) < 0) {
return -1;
}
memcpy(out_ptr, &out_val, sizeof(typename NpyType<T>::TYPE));

in_ptr += strides[0];
out_ptr += strides[1];
}
} else {
#else
{
#endif
while (N--) {
quad_value in_val;
memcpy(&in_val, in_ptr, elem_size);
Expand All @@ -703,7 +750,7 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const

in_ptr += strides[0];
out_ptr += strides[1];
}
}}
return 0;
}

Expand All @@ -716,10 +763,36 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
npy_intp N = dimensions[0];
char *in_ptr = data[0];
char *out_ptr = data[1];
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
int same_value_casting = ((context->flags & NPY_SAME_VALUE_CONTEXT_FLAG) == NPY_SAME_VALUE_CONTEXT_FLAG);
#endif

QuadPrecDTypeObject *quad_descr = (QuadPrecDTypeObject *)context->descriptors[0];
QuadBackendType backend = quad_descr->backend;

#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
if (same_value_casting) {
while (N--) {
quad_value in_val;
if (backend == BACKEND_SLEEF) {
in_val.sleef_value = *(Sleef_quad *)in_ptr;
}
else {
in_val.longdouble_value = *(long double *)in_ptr;
}

typename NpyType<T>::TYPE out_val;
if (from_quad_checked<T>(in_val, backend, &out_val) < 0) {
return -1;
}
*(typename NpyType<T>::TYPE *)(out_ptr) = out_val;

in_ptr += strides[0];
out_ptr += strides[1];
}} else {
#else
{
#endif
while (N--) {
quad_value in_val;
if (backend == BACKEND_SLEEF) {
Expand All @@ -734,7 +807,7 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d

in_ptr += strides[0];
out_ptr += strides[1];
}
}}
return 0;
}

Expand Down Expand Up @@ -771,7 +844,11 @@ add_cast_from(PyArray_DTypeMeta *to)
.name = "cast_QuadPrec_to_NumPy",
.nin = 1,
.nout = 1,
#if NPY_FEATURE_VERSION > NPY_2_3_API_VERSION
.casting = NPY_SAME_VALUE_CASTING,
#else
.casting = NPY_UNSAFE_CASTING,
#endif
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not clear how this is used, doesn't the resolver actually state which casting is supported?

.flags = NPY_METH_SUPPORTS_UNALIGNED,
.dtypes = dtypes,
.slots = slots,
Expand Down Expand Up @@ -904,4 +981,4 @@ free_casts(void)
}
}
spec_count = 0;
}
}
4 changes: 2 additions & 2 deletions quaddtype/numpy_quaddtype/src/quadblas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <algorithm>

#ifndef DISABLE_QUADBLAS
#include "../subprojects/qblas/include/quadblas/quadblas.hpp"
#include "quadblas/quadblas.hpp"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed this to build with -e, i.e. pip install -Csetup-args="-Dbuildtype=debug" -e . --no-build-isolation 2>&1 | tee /tmp/build.txt. Otherwise the build failed.

#endif // DISABLE_QUADBLAS

extern "C" {
Expand Down Expand Up @@ -230,4 +230,4 @@ _quadblas_get_num_threads(void)

#endif // DISABLE_QUADBLAS

} // extern "C"
} // extern "C"
10 changes: 10 additions & 0 deletions quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def test_unsupported_astype(dtype):
np.array(QuadPrecision(1)).astype(dtype, casting="unsafe")


def test_same_value_cast():
# This will fail if compiled with NPY_TARGET_VERSION NPY<2_4_API_VERSION
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this text check the numpy version and be ignored for <2.4?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs both installed numpy.__version__ >=2.23 and that numpy_quadtype is built with the 2.24+ C-API

a = np.arange(30, dtype=np.float32)
# upcasting can never fail
b = a.astype(QuadPrecision, casting='same_value')
c = b.astype(np.float32, casting='same_value')
assert np.all(c == a)
with pytest.raises(ValueError, match="could not cast 'same_value'"):
(b + 1e22).astype(np.float32, casting='same_value')

def test_basic_equality():
assert QuadPrecision("12") == QuadPrecision(
"12.0") == QuadPrecision("12.00")
Expand Down
Loading