Skip to content

Commit d167165

Browse files
committed
flt_pow impl
1 parent db6b84e commit d167165

File tree

3 files changed

+155
-1
lines changed

3 files changed

+155
-1
lines changed

quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ init_quad_binary_ops(PyObject *numpy)
224224
if (create_quad_binary_ufunc<quad_pow, ld_pow>(numpy, "power") < 0) {
225225
return -1;
226226
}
227+
// float_power uses the same implementation as power for floating-point types
228+
// The only difference is that float_power promotes integer inputs to float (quaddtype already float)
229+
if (create_quad_binary_ufunc<quad_pow, ld_pow>(numpy, "float_power") < 0) {
230+
return -1;
231+
}
227232
if (create_quad_binary_ufunc<quad_mod, ld_mod>(numpy, "mod") < 0) {
228233
return -1;
229234
}

quaddtype/release_tracker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
| negative |||
1818
| positive |||
1919
| power |||
20-
| float_power | | |
20+
| float_power | | |
2121
| remainder |||
2222
| mod |||
2323
| fmod | | |

quaddtype/tests/test_quaddtype.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,3 +1072,152 @@ def test_fill_function(func, args, expected):
10721072
assert len(arr) == len(expected)
10731073
for i, exp_val in enumerate(expected):
10741074
np.testing.assert_allclose(float(arr[i]), float(exp_val), rtol=1e-15, atol=1e-15)
1075+
1076+
@pytest.mark.parametrize("base,exponent", [
1077+
# Basic integer powers
1078+
(2.0, 3.0), (3.0, 2.0), (10.0, 5.0), (5.0, 10.0),
1079+
1080+
# Fractional powers
1081+
(4.0, 0.5), (9.0, 0.5), (27.0, 1.0/3.0), (16.0, 0.25),
1082+
(8.0, 2.0/3.0), (100.0, 0.5),
1083+
1084+
# Negative bases with integer exponents
1085+
(-2.0, 3.0), (-3.0, 2.0), (-2.0, 4.0), (-5.0, 3.0),
1086+
1087+
# Negative bases with fractional exponents (should return NaN)
1088+
(-1.0, 0.5), (-4.0, 0.5), (-1.0, 1.5), (-4.0, 1.5),
1089+
(-2.0, 0.25), (-8.0, 1.0/3.0), (-5.0, 2.5), (-10.0, 0.75),
1090+
(-1.0, -0.5), (-4.0, -1.5), (-2.0, -2.5),
1091+
1092+
# Zero base cases
1093+
(0.0, 0.0), (0.0, 1.0), (0.0, 2.0), (0.0, 10.0),
1094+
(0.0, 0.5), (0.0, -0.0),
1095+
1096+
# Negative zero base
1097+
(-0.0, 0.0), (-0.0, 1.0), (-0.0, 2.0), (-0.0, 3.0),
1098+
1099+
# Base of 1
1100+
(1.0, 0.0), (1.0, 1.0), (1.0, 100.0), (1.0, -100.0),
1101+
(1.0, float('inf')), (1.0, float('-inf')), (1.0, float('nan')),
1102+
1103+
# Base of -1
1104+
(-1.0, 0.0), (-1.0, 1.0), (-1.0, 2.0), (-1.0, 3.0),
1105+
(-1.0, float('inf')), (-1.0, float('-inf')),
1106+
1107+
# Exponent of 0
1108+
(2.0, 0.0), (100.0, 0.0), (-5.0, 0.0), (0.5, 0.0),
1109+
(float('inf'), 0.0), (float('-inf'), 0.0), (float('nan'), 0.0),
1110+
1111+
# Exponent of 1
1112+
(2.0, 1.0), (100.0, 1.0), (-5.0, 1.0), (0.5, 1.0),
1113+
(float('inf'), 1.0), (float('-inf'), 1.0),
1114+
1115+
# Negative exponents
1116+
(2.0, -1.0), (2.0, -2.0), (10.0, -3.0), (0.5, -1.0),
1117+
(4.0, -0.5), (9.0, -0.5),
1118+
1119+
# Infinity base
1120+
(float('inf'), 0.0), (float('inf'), 1.0), (float('inf'), 2.0),
1121+
(float('inf'), -1.0), (float('inf'), -2.0), (float('inf'), 0.5),
1122+
(float('inf'), float('inf')), (float('inf'), float('-inf')),
1123+
1124+
# Negative infinity base
1125+
(float('-inf'), 0.0), (float('-inf'), 1.0), (float('-inf'), 2.0),
1126+
(float('-inf'), 3.0), (float('-inf'), -1.0), (float('-inf'), -2.0),
1127+
(float('-inf'), float('inf')), (float('-inf'), float('-inf')),
1128+
1129+
# Infinity exponent
1130+
(2.0, float('inf')), (0.5, float('inf')), (1.5, float('inf')),
1131+
(2.0, float('-inf')), (0.5, float('-inf')), (1.5, float('-inf')),
1132+
(0.0, float('inf')), (0.0, float('-inf')),
1133+
1134+
# NaN cases
1135+
(float('nan'), 0.0), (float('nan'), 1.0), (float('nan'), 2.0),
1136+
(2.0, float('nan')), (0.0, float('nan')),
1137+
(float('nan'), float('nan')), (float('nan'), float('inf')),
1138+
(float('inf'), float('nan')),
1139+
1140+
# Small and large values
1141+
(1e-10, 2.0), (1e10, 2.0), (1e-10, 0.5), (1e10, 0.5),
1142+
(2.0, 100.0), (2.0, -100.0), (0.5, 100.0), (0.5, -100.0),
1143+
])
1144+
def test_float_power(base, exponent):
1145+
"""
1146+
Comprehensive test for float_power ufunc.
1147+
1148+
float_power differs from power in that it always promotes to floating point.
1149+
For floating-point dtypes like QuadPrecDType, it should behave identically to power.
1150+
"""
1151+
quad_base = QuadPrecision(str(base)) if not (np.isnan(base) or np.isinf(base)) else QuadPrecision(base)
1152+
quad_exp = QuadPrecision(str(exponent)) if not (np.isnan(exponent) or np.isinf(exponent)) else QuadPrecision(exponent)
1153+
1154+
float_base = np.float64(base)
1155+
float_exp = np.float64(exponent)
1156+
1157+
quad_result = np.float_power(quad_base, quad_exp)
1158+
float_result = np.float_power(float_base, float_exp)
1159+
1160+
# Handle NaN cases
1161+
if np.isnan(float_result):
1162+
assert np.isnan(float(quad_result)), \
1163+
f"Expected NaN for float_power({base}, {exponent}), got {float(quad_result)}"
1164+
return
1165+
1166+
# Handle infinity cases
1167+
if np.isinf(float_result):
1168+
assert np.isinf(float(quad_result)), \
1169+
f"Expected inf for float_power({base}, {exponent}), got {float(quad_result)}"
1170+
assert np.sign(float_result) == np.sign(float(quad_result)), \
1171+
f"Infinity sign mismatch for float_power({base}, {exponent})"
1172+
return
1173+
1174+
# For finite results
1175+
np.testing.assert_allclose(
1176+
float(quad_result), float_result,
1177+
rtol=1e-13, atol=1e-15,
1178+
err_msg=f"Value mismatch for float_power({base}, {exponent})"
1179+
)
1180+
1181+
# Check sign for zero results
1182+
if float_result == 0.0:
1183+
assert np.signbit(float_result) == np.signbit(quad_result), \
1184+
f"Zero sign mismatch for float_power({base}, {exponent})"
1185+
1186+
1187+
@pytest.mark.parametrize("base,exponent", [
1188+
# Test that float_power works with integer inputs (promotes to float)
1189+
(2, 3),
1190+
(4, 2),
1191+
(10, 5),
1192+
(-2, 3),
1193+
])
1194+
def test_float_power_integer_promotion(base, exponent):
1195+
"""
1196+
Test that float_power works with integer inputs and promotes them to QuadPrecDType.
1197+
This is the key difference from power - float_power always returns float types.
1198+
"""
1199+
# Create arrays with integer inputs
1200+
base_arr = np.array([base], dtype=QuadPrecDType())
1201+
exp_arr = np.array([exponent], dtype=QuadPrecDType())
1202+
1203+
result = np.float_power(base_arr, exp_arr)
1204+
1205+
# Result should be QuadPrecDType
1206+
assert result.dtype.name == "QuadPrecDType128"
1207+
1208+
# Check the value
1209+
expected = float(base) ** float(exponent)
1210+
np.testing.assert_allclose(float(result[0]), expected, rtol=1e-13)
1211+
1212+
1213+
def test_float_power_array():
1214+
"""Test float_power with arrays"""
1215+
bases = np.array([2.0, 4.0, 9.0, 16.0], dtype=QuadPrecDType())
1216+
exponents = np.array([3.0, 0.5, 2.0, 0.25], dtype=QuadPrecDType())
1217+
1218+
result = np.float_power(bases, exponents)
1219+
expected = np.array([8.0, 2.0, 81.0, 2.0], dtype=np.float64)
1220+
1221+
assert result.dtype.name == "QuadPrecDType128"
1222+
for i in range(len(result)):
1223+
np.testing.assert_allclose(float(result[i]), expected[i], rtol=1e-13)

0 commit comments

Comments
 (0)