Skip to content

Commit 76ebb61

Browse files
committed
Fix log(0.0::float8) should error, not return -inf
1 parent 77240f9 commit 76ebb61

2 files changed

Lines changed: 55 additions & 29 deletions

File tree

datafusion/functions/src/math/log.rs

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,17 @@ fn is_valid_integer_base(base: f64) -> bool {
106106
base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64
107107
}
108108

109+
#[inline]
110+
fn validate_log_value(value: f64) -> Result<(), ArrowError> {
111+
if value == 0.0 {
112+
Err(ArrowError::ComputeError(
113+
"cannot take logarithm of zero".to_string(),
114+
))
115+
} else {
116+
Ok(())
117+
}
118+
}
119+
109120
/// Calculate logarithm for Decimal32 values.
110121
/// For integer bases >= 2 with zero scale, return an exact integer log when the
111122
/// value is a perfect power of the base. Otherwise falls back to f64 computation.
@@ -121,7 +132,10 @@ fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> {
121132
return Ok(int_log as f64);
122133
}
123134
}
124-
decimal_to_f64(value, scale).map(|v| v.log(base))
135+
decimal_to_f64(value, scale).and_then(|v| {
136+
validate_log_value(v)?;
137+
Ok(v.log(base))
138+
})
125139
}
126140

127141
/// Calculate logarithm for Decimal64 values.
@@ -139,7 +153,10 @@ fn log_decimal64(value: i64, scale: i8, base: f64) -> Result<f64, ArrowError> {
139153
return Ok(int_log as f64);
140154
}
141155
}
142-
decimal_to_f64(value, scale).map(|v| v.log(base))
156+
decimal_to_f64(value, scale).and_then(|v| {
157+
validate_log_value(v)?;
158+
Ok(v.log(base))
159+
})
143160
}
144161

145162
/// Calculate logarithm for Decimal128 values.
@@ -157,7 +174,10 @@ fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError>
157174
return Ok(int_log as f64);
158175
}
159176
}
160-
decimal_to_f64(value, scale).map(|v| v.log(base))
177+
decimal_to_f64(value, scale).and_then(|v| {
178+
validate_log_value(v)?;
179+
Ok(v.log(base))
180+
})
161181
}
162182

163183
/// Convert a scaled decimal value to f64.
@@ -180,7 +200,9 @@ fn log_decimal256(value: i256, scale: i8, base: f64) -> Result<f64, ArrowError>
180200
ArrowError::ComputeError(format!("Cannot convert {value} to f64"))
181201
})?;
182202
let scale_factor = 10f64.powi(scale as i32);
183-
Ok((value_f64 / scale_factor).log(base))
203+
let value = value_f64 / scale_factor;
204+
validate_log_value(value)?;
205+
Ok(value.log(base))
184206
}
185207
}
186208
}
@@ -247,27 +269,33 @@ impl ScalarUDFImpl for LogFunc {
247269
let value = value.to_array(args.number_rows)?;
248270

249271
let output: ArrayRef = match value.data_type() {
250-
DataType::Float16 => {
251-
calculate_binary_math::<Float16Type, Float16Type, Float16Type, _>(
252-
&value,
253-
&base,
254-
|value, base| Ok(value.log(base)),
255-
)?
256-
}
257-
DataType::Float32 => {
258-
calculate_binary_math::<Float32Type, Float32Type, Float32Type, _>(
259-
&value,
260-
&base,
261-
|value, base| Ok(value.log(base)),
262-
)?
263-
}
264-
DataType::Float64 => {
265-
calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
266-
&value,
267-
&base,
268-
|value, base| Ok(value.log(base)),
269-
)?
270-
}
272+
DataType::Float16 => calculate_binary_math::<
273+
Float16Type,
274+
Float16Type,
275+
Float16Type,
276+
_,
277+
>(&value, &base, |value, base| {
278+
validate_log_value(value.to_f64())?;
279+
Ok(value.log(base))
280+
})?,
281+
DataType::Float32 => calculate_binary_math::<
282+
Float32Type,
283+
Float32Type,
284+
Float32Type,
285+
_,
286+
>(&value, &base, |value, base| {
287+
validate_log_value(value as f64)?;
288+
Ok(value.log(base))
289+
})?,
290+
DataType::Float64 => calculate_binary_math::<
291+
Float64Type,
292+
Float64Type,
293+
Float64Type,
294+
_,
295+
>(&value, &base, |value, base| {
296+
validate_log_value(value)?;
297+
Ok(value.log(base))
298+
})?,
271299
DataType::Decimal32(_, scale) => {
272300
calculate_binary_math::<Decimal32Type, Float64Type, Float64Type, _>(
273301
&value,

datafusion/sqllogictest/test_files/scalar.slt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -625,10 +625,8 @@ select log(2, 2.0/3) a, log(10, 2.0/3) b;
625625

626626
# log scalar ops with zero edgecases
627627
# please see https://github.com/apache/datafusion/pull/5245#issuecomment-1426828382
628-
query RR rowsort
629-
select log(0) a, log(1, 64) b;
630-
----
631-
-Infinity Infinity
628+
query error cannot take logarithm of zero
629+
select log(0) a;
632630

633631
# log with columns #1
634632
query RRR rowsort

0 commit comments

Comments
 (0)