Skip to content

Commit 809c6da

Browse files
authored
Merge pull request #183 from SwayamInSync/expm1
2 parents 6417cd8 + 22ad6aa commit 809c6da

File tree

4 files changed

+85
-1
lines changed

4 files changed

+85
-1
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ quad_exp2(const Sleef_quad *op)
124124
return Sleef_exp2q1_u10(*op);
125125
}
126126

127+
static inline Sleef_quad
128+
quad_expm1(const Sleef_quad *op)
129+
{
130+
return Sleef_expm1q1_u10(*op);
131+
}
132+
127133
static inline Sleef_quad
128134
quad_sin(const Sleef_quad *op)
129135
{
@@ -308,6 +314,12 @@ ld_exp2(const long double *op)
308314
return exp2l(*op);
309315
}
310316

317+
static inline long double
318+
ld_expm1(const long double *op)
319+
{
320+
return expm1l(*op);
321+
}
322+
311323
static inline long double
312324
ld_sin(const long double *op)
313325
{

quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ init_quad_unary_ops(PyObject *numpy)
203203
if (create_quad_unary_ufunc<quad_exp2, ld_exp2>(numpy, "exp2") < 0) {
204204
return -1;
205205
}
206+
if (create_quad_unary_ufunc<quad_expm1, ld_expm1>(numpy, "expm1") < 0) {
207+
return -1;
208+
}
206209
if (create_quad_unary_ufunc<quad_sin, ld_sin>(numpy, "sin") < 0) {
207210
return -1;
208211
}

quaddtype/release_tracker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
| log |||
3535
| log2 |||
3636
| log10 |||
37-
| expm1 | | |
37+
| expm1 | | |
3838
| log1p |||
3939
| sqrt |||
4040
| square |||

quaddtype/tests/test_quaddtype.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,75 @@ def test_log1p(val):
429429
quad_result), f"Zero sign mismatch for {op}({val})"
430430

431431

432+
@pytest.mark.parametrize("val", [
433+
# Cases close to 0 (where expm1 is most accurate and important)
434+
"0.0", "-0.0",
435+
"1e-10", "-1e-10", "1e-15", "-1e-15", "1e-20", "-1e-20",
436+
"1e-100", "-1e-100", "1e-300", "-1e-300",
437+
# Small values
438+
"0.001", "-0.001", "0.01", "-0.01", "0.1", "-0.1",
439+
# Moderate values
440+
"0.5", "-0.5", "1.0", "-1.0", "2.0", "-2.0",
441+
# Larger values
442+
"5.0", "-5.0", "10.0", "-10.0", "20.0", "-20.0",
443+
# Values that test exp behavior
444+
"50.0", "-50.0", "100.0", "-100.0",
445+
# Large positive values (exp(x) grows rapidly)
446+
"200.0", "500.0", "700.0",
447+
# Large negative values (should approach -1)
448+
"-200.0", "-500.0", "-700.0", "-1000.0",
449+
# Special values
450+
"inf", # Should give inf
451+
"-inf", # Should give -1
452+
"nan", "-nan"
453+
])
454+
def test_expm1(val):
455+
"""Comprehensive test for expm1 function: exp(x) - 1
456+
457+
This function provides greater precision than exp(x) - 1 for small values of x.
458+
"""
459+
quad_val = QuadPrecision(val)
460+
float_val = float(val)
461+
462+
quad_result = np.expm1(quad_val)
463+
float_result = np.expm1(float_val)
464+
465+
# Handle NaN cases
466+
if np.isnan(float_result):
467+
assert np.isnan(
468+
float(quad_result)), f"Expected NaN for expm1({val}), got {float(quad_result)}"
469+
return
470+
471+
# Handle infinity cases
472+
if np.isinf(float_result):
473+
assert np.isinf(
474+
float(quad_result)), f"Expected inf for expm1({val}), got {float(quad_result)}"
475+
assert np.sign(float_result) == np.sign(
476+
float(quad_result)), f"Infinity sign mismatch for expm1({val})"
477+
return
478+
479+
# For finite results
480+
# expm1 is designed for high accuracy near 0, so use tight tolerances for small inputs
481+
if abs(float(val)) < 1e-10:
482+
rtol = 1e-15
483+
atol = 1e-20
484+
elif abs(float_result) < 1:
485+
rtol = 1e-14
486+
atol = 1e-15
487+
else:
488+
# For larger results, use relative tolerance
489+
rtol = 1e-14
490+
atol = 1e-15
491+
492+
np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=atol,
493+
err_msg=f"Value mismatch for expm1({val})")
494+
495+
# Check sign for zero results
496+
if float_result == 0.0:
497+
assert np.signbit(float_result) == np.signbit(
498+
quad_result), f"Zero sign mismatch for expm1({val})"
499+
500+
432501
@pytest.mark.parametrize("x", [
433502
# Regular values
434503
"0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5",

0 commit comments

Comments
 (0)