Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions quaddtype/numpy_quaddtype/src/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,40 @@ quad_sqrt(const Sleef_quad *op)
return Sleef_sqrtq1_u05(*op);
}

static inline Sleef_quad
quad_cbrt(const Sleef_quad *op)
{
// SLEEF doesn't provide cbrt, so we implement it using pow
// cbrt(x) = x^(1/3)
// For negative values: cbrt(-x) = -cbrt(x)

// Handle special cases
if (Sleef_iunordq1(*op, *op)) {
return *op; // NaN
}
if (Sleef_icmpeqq1(*op, QUAD_ZERO)) {
return *op; // ±0
}
// Check if op is ±inf: isinf(x) = abs(x) == inf
if (Sleef_icmpeqq1(Sleef_fabsq1(*op), QUAD_POS_INF)) {
return *op; // ±inf
}

// Compute 1/3 as a quad precision constant
Sleef_quad three = Sleef_cast_from_int64q1(3);
Sleef_quad one_third = Sleef_divq1_u05(QUAD_ONE, three);

// Handle negative values: cbrt(-x) = -cbrt(x)
if (Sleef_icmpltq1(*op, QUAD_ZERO)) {
Sleef_quad abs_val = Sleef_fabsq1(*op);
Sleef_quad result = Sleef_powq1_u10(abs_val, one_third);
return Sleef_negq1(result);
}

// Positive values
return Sleef_powq1_u10(*op, one_third);
}

static inline Sleef_quad
quad_square(const Sleef_quad *op)
{
Expand Down Expand Up @@ -260,6 +294,12 @@ ld_sqrt(const long double *op)
return sqrtl(*op);
}

static inline long double
ld_cbrt(const long double *op)
{
return cbrtl(*op);
}

static inline long double
ld_square(const long double *op)
{
Expand Down
3 changes: 3 additions & 0 deletions quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ init_quad_unary_ops(PyObject *numpy)
if (create_quad_unary_ufunc<quad_sqrt, ld_sqrt>(numpy, "sqrt") < 0) {
return -1;
}
if (create_quad_unary_ufunc<quad_cbrt, ld_cbrt>(numpy, "cbrt") < 0) {
return -1;
}
if (create_quad_unary_ufunc<quad_square, ld_square>(numpy, "square") < 0) {
return -1;
}
Expand Down
2 changes: 1 addition & 1 deletion quaddtype/release_tracker.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
| log1p | ✅ | ✅ |
| sqrt | ✅ | ✅ |
| square | ✅ | ✅ |
| cbrt | | |
| cbrt | | ✅ |
| reciprocal | ✅ | ✅ |
| gcd | | |
| lcm | | |
Expand Down
75 changes: 75 additions & 0 deletions quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,81 @@ def test_rint_near_halfway():
assert np.rint(QuadPrecision("7.5")) == 8


@pytest.mark.parametrize("val", [
# Perfect cubes
"1.0", "8.0", "27.0", "64.0", "125.0", "1000.0",
# Negative perfect cubes
"-1.0", "-8.0", "-27.0", "-64.0", "-125.0", "-1000.0",
# Small positive values
"0.001", "0.008", "0.027", "1e-9", "1e-15", "1e-100",
# Small negative values
"-0.001", "-0.008", "-0.027", "-1e-9", "-1e-15", "-1e-100",
# Large positive values
"1e10", "1e15", "1e100", "1e300",
# Large negative values
"-1e10", "-1e15", "-1e100", "-1e300",
# Fractional values
"0.5", "2.5", "3.5", "10.5", "100.5",
"-0.5", "-2.5", "-3.5", "-10.5", "-100.5",
# Edge cases
"0.0", "-0.0",
# Special values
"inf", "-inf", "nan", "-nan"
])
def test_cbrt(val):
"""Comprehensive test for cube root function"""
quad_val = QuadPrecision(val)
float_val = float(val)

quad_result = np.cbrt(quad_val)
float_result = np.cbrt(float_val)

# Handle NaN cases
if np.isnan(float_result):
assert np.isnan(
float(quad_result)), f"Expected NaN for cbrt({val}), got {float(quad_result)}"
return

# Handle infinity cases
if np.isinf(float_result):
assert np.isinf(
float(quad_result)), f"Expected inf for cbrt({val}), got {float(quad_result)}"
assert np.sign(float_result) == np.sign(
float(quad_result)), f"Infinity sign mismatch for cbrt({val})"
return

# For finite results, check value and sign
# Use relative tolerance for cbrt
if float_result != 0.0:
rtol = 1e-14 if abs(float_result) < 1e100 else 1e-10
np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=1e-15,
err_msg=f"Value mismatch for cbrt({val})")
else:
# For zero results
assert float(quad_result) == 0.0, f"Expected 0 for cbrt({val}), got {float(quad_result)}"
assert np.signbit(float_result) == np.signbit(
quad_result), f"Zero sign mismatch for cbrt({val})"


def test_cbrt_accuracy():
"""Test that cbrt gives accurate results for perfect cubes"""
# Test perfect cubes
for i in [1, 2, 3, 4, 5, 10, 100]:
val = QuadPrecision(i ** 3)
result = np.cbrt(val)
expected = QuadPrecision(i)
np.testing.assert_allclose(float(result), float(expected), rtol=1e-14, atol=1e-15,
err_msg=f"cbrt({i}^3) should equal {i}")

# Test negative perfect cubes
for i in [1, 2, 3, 4, 5, 10, 100]:
val = QuadPrecision(-(i ** 3))
result = np.cbrt(val)
expected = QuadPrecision(-i)
np.testing.assert_allclose(float(result), float(expected), rtol=1e-14, atol=1e-15,
err_msg=f"cbrt(-{i}^3) should equal -{i}")


@pytest.mark.parametrize("op", ["exp", "exp2"])
@pytest.mark.parametrize("val", [
# Basic cases
Expand Down
Loading