Skip to content

Commit 0a26cfc

Browse files
committed
logaddexp2 ufunc
1 parent 5ac0265 commit 0a26cfc

File tree

4 files changed

+193
-1
lines changed

4 files changed

+193
-1
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,49 @@ quad_logaddexp(const Sleef_quad *x, const Sleef_quad *y)
626626
return Sleef_addq1_u05(max_val, log1p_term);
627627
}
628628

629+
static inline Sleef_quad
630+
quad_logaddexp2(const Sleef_quad *x, const Sleef_quad *y)
631+
{
632+
// logaddexp2(x, y) = log2(2^x + 2^y)
633+
// Numerically stable implementation: max(x, y) + log2(1 + 2^(-abs(x - y)))
634+
635+
// Handle NaN
636+
if (Sleef_iunordq1(*x, *y)) {
637+
return Sleef_iunordq1(*x, *x) ? *x : *y;
638+
}
639+
640+
// Handle infinities
641+
// If both are -inf, result is -inf
642+
Sleef_quad neg_inf = Sleef_negq1(QUAD_POS_INF);
643+
if (Sleef_icmpeqq1(*x, neg_inf) && Sleef_icmpeqq1(*y, neg_inf)) {
644+
return neg_inf;
645+
}
646+
647+
// If either is +inf, result is +inf
648+
if (Sleef_icmpeqq1(*x, QUAD_POS_INF) || Sleef_icmpeqq1(*y, QUAD_POS_INF)) {
649+
return QUAD_POS_INF;
650+
}
651+
652+
// If one is -inf, result is the other value
653+
if (Sleef_icmpeqq1(*x, neg_inf)) {
654+
return *y;
655+
}
656+
if (Sleef_icmpeqq1(*y, neg_inf)) {
657+
return *x;
658+
}
659+
660+
// log2(2^x + 2^y) = max(x, y) + log2(1 + 2^(-abs(x - y)))
661+
Sleef_quad diff = Sleef_subq1_u05(*x, *y);
662+
Sleef_quad abs_diff = Sleef_fabsq1(diff);
663+
Sleef_quad neg_abs_diff = Sleef_negq1(abs_diff);
664+
Sleef_quad exp2_term = Sleef_exp2q1_u10(neg_abs_diff);
665+
Sleef_quad one_plus_exp2 = Sleef_addq1_u05(QUAD_ONE, exp2_term);
666+
Sleef_quad log2_term = Sleef_log2q1_u10(one_plus_exp2);
667+
668+
Sleef_quad max_val = Sleef_icmpgtq1(*x, *y) ? *x : *y;
669+
return Sleef_addq1_u05(max_val, log2_term);
670+
}
671+
629672
// Binary long double operations
630673
typedef long double (*binary_op_longdouble_def)(const long double *, const long double *);
631674

@@ -759,6 +802,45 @@ ld_logaddexp(const long double *x, const long double *y)
759802
return max_val + log1pl(expl(-abs_diff));
760803
}
761804

805+
static inline long double
806+
ld_logaddexp2(const long double *x, const long double *y)
807+
{
808+
// logaddexp2(x, y) = log2(2^x + 2^y)
809+
// Numerically stable implementation: max(x, y) + log2(1 + 2^(-abs(x - y)))
810+
811+
// Handle NaN
812+
if (isnan(*x) || isnan(*y)) {
813+
return isnan(*x) ? *x : *y;
814+
}
815+
816+
// Handle infinities
817+
// If both are -inf, result is -inf
818+
if (isinf(*x) && *x < 0 && isinf(*y) && *y < 0) {
819+
return -INFINITY;
820+
}
821+
822+
// If either is +inf, result is +inf
823+
if ((isinf(*x) && *x > 0) || (isinf(*y) && *y > 0)) {
824+
return INFINITY;
825+
}
826+
827+
// If one is -inf, result is the other value
828+
if (isinf(*x) && *x < 0) {
829+
return *y;
830+
}
831+
if (isinf(*y) && *y < 0) {
832+
return *x;
833+
}
834+
835+
// Numerically stable computation
836+
// log2(2^x + 2^y) = max(x, y) + log2(1 + 2^(-abs(x - y)))
837+
long double diff = *x - *y;
838+
long double abs_diff = fabsl(diff);
839+
long double max_val = (*x > *y) ? *x : *y;
840+
// log2(1 + z) = log(1 + z) / log(2)
841+
return max_val + log1pl(exp2l(-abs_diff)) / M_LN2;
842+
}
843+
762844
// comparison quad functions
763845
typedef npy_bool (*cmp_quad_def)(const Sleef_quad *, const Sleef_quad *);
764846

quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,5 +243,8 @@ init_quad_binary_ops(PyObject *numpy)
243243
if (create_quad_binary_ufunc<quad_logaddexp, ld_logaddexp>(numpy, "logaddexp") < 0) {
244244
return -1;
245245
}
246+
if (create_quad_binary_ufunc<quad_logaddexp2, ld_logaddexp2>(numpy, "logaddexp2") < 0) {
247+
return -1;
248+
}
246249
return 0;
247250
}

quaddtype/release_tracker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
| matmul |||
1212
| divide |||
1313
| logaddexp |||
14-
| logaddexp2 | | |
14+
| logaddexp2 | | |
1515
| true_divide | | |
1616
| floor_divide | | |
1717
| negative |||

quaddtype/tests/test_quaddtype.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,113 @@ def test_logaddexp_special_properties():
528528
np.testing.assert_allclose(float(result1), float(result2), rtol=1e-14)
529529

530530

531+
@pytest.mark.parametrize("x", [
532+
# Regular values
533+
"0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5",
534+
# Large values (test numerical stability)
535+
"100.0", "1000.0", "-100.0", "-1000.0",
536+
# Small values
537+
"1e-10", "-1e-10", "1e-20", "-1e-20",
538+
# Special values
539+
"inf", "-inf", "nan", "-nan", "-0.0"
540+
])
541+
@pytest.mark.parametrize("y", [
542+
# Regular values
543+
"0.0", "1.0", "2.0", "-1.0", "-2.0", "0.5", "-0.5",
544+
# Large values
545+
"100.0", "1000.0", "-100.0", "-1000.0",
546+
# Small values
547+
"1e-10", "-1e-10", "1e-20", "-1e-20",
548+
# Special values
549+
"inf", "-inf", "nan", "-nan", "-0.0"
550+
])
551+
def test_logaddexp2(x, y):
552+
"""Comprehensive test for logaddexp2 function: log2(2^x + 2^y)"""
553+
quad_x = QuadPrecision(x)
554+
quad_y = QuadPrecision(y)
555+
float_x = float(x)
556+
float_y = float(y)
557+
558+
quad_result = np.logaddexp2(quad_x, quad_y)
559+
float_result = np.logaddexp2(float_x, float_y)
560+
561+
# Handle NaN cases
562+
if np.isnan(float_result):
563+
assert np.isnan(float(quad_result)), \
564+
f"Expected NaN for logaddexp2({x}, {y}), got {float(quad_result)}"
565+
return
566+
567+
# Handle infinity cases
568+
if np.isinf(float_result):
569+
assert np.isinf(float(quad_result)), \
570+
f"Expected inf for logaddexp2({x}, {y}), got {float(quad_result)}"
571+
if not np.isnan(float_result):
572+
assert np.sign(float_result) == np.sign(float(quad_result)), \
573+
f"Infinity sign mismatch for logaddexp2({x}, {y})"
574+
return
575+
576+
# For finite results, check with appropriate tolerance
577+
# logaddexp2 is numerically sensitive, especially for large differences
578+
if abs(float_x - float_y) > 50:
579+
# When values differ greatly, result should be close to max(x, y)
580+
rtol = 1e-10
581+
atol = 1e-10
582+
else:
583+
rtol = 1e-13
584+
atol = 1e-15
585+
586+
np.testing.assert_allclose(
587+
float(quad_result), float_result,
588+
rtol=rtol, atol=atol,
589+
err_msg=f"Value mismatch for logaddexp2({x}, {y})"
590+
)
591+
592+
593+
def test_logaddexp2_special_properties():
594+
"""Test special mathematical properties of logaddexp2"""
595+
# logaddexp2(x, x) = x + 1 (since log2(2^x + 2^x) = log2(2 * 2^x) = log2(2) + log2(2^x) = 1 + x)
596+
x = QuadPrecision("2.0")
597+
result = np.logaddexp2(x, x)
598+
expected = float(x) + 1.0
599+
np.testing.assert_allclose(float(result), expected, rtol=1e-14)
600+
601+
# logaddexp2(x, -inf) = x
602+
x = QuadPrecision("5.0")
603+
result = np.logaddexp2(x, QuadPrecision("-inf"))
604+
np.testing.assert_allclose(float(result), float(x), rtol=1e-14)
605+
606+
# logaddexp2(-inf, x) = x
607+
result = np.logaddexp2(QuadPrecision("-inf"), x)
608+
np.testing.assert_allclose(float(result), float(x), rtol=1e-14)
609+
610+
# logaddexp2(-inf, -inf) = -inf
611+
result = np.logaddexp2(QuadPrecision("-inf"), QuadPrecision("-inf"))
612+
assert np.isinf(float(result)) and float(result) < 0
613+
614+
# logaddexp2(inf, anything) = inf
615+
result = np.logaddexp2(QuadPrecision("inf"), QuadPrecision("100.0"))
616+
assert np.isinf(float(result)) and float(result) > 0
617+
618+
# logaddexp2(anything, inf) = inf
619+
result = np.logaddexp2(QuadPrecision("100.0"), QuadPrecision("inf"))
620+
assert np.isinf(float(result)) and float(result) > 0
621+
622+
# Commutativity: logaddexp2(x, y) = logaddexp2(y, x)
623+
x = QuadPrecision("3.0")
624+
y = QuadPrecision("5.0")
625+
result1 = np.logaddexp2(x, y)
626+
result2 = np.logaddexp2(y, x)
627+
np.testing.assert_allclose(float(result1), float(result2), rtol=1e-14)
628+
629+
# Relationship with logaddexp: logaddexp2(x, y) = logaddexp(x*ln2, y*ln2) / ln2
630+
x = QuadPrecision("2.0")
631+
y = QuadPrecision("3.0")
632+
result_logaddexp2 = np.logaddexp2(x, y)
633+
ln2 = np.log(2.0)
634+
result_logaddexp = np.logaddexp(float(x) * ln2, float(y) * ln2) / ln2
635+
np.testing.assert_allclose(float(result_logaddexp2), result_logaddexp, rtol=1e-13)
636+
637+
531638
def test_inf():
532639
assert QuadPrecision("inf") > QuadPrecision("1e1000")
533640
assert np.signbit(QuadPrecision("inf")) == 0

0 commit comments

Comments
 (0)