Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 201 additions & 3 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

use crate::utils::{calculate_binary_decimal_math, calculate_binary_math};

use arrow::array::ArrayRef;
use arrow::array::{Array, ArrayRef, AsArray};
use arrow::datatypes::DataType::{
Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, Int8, Int16, Int32,
Int64, UInt8, UInt16, UInt32, UInt64,
};
use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
Decimal256Type, DecimalType, Float32Type, Float64Type, Int8Type, Int16Type,
Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
};
use arrow::datatypes::{Field, FieldRef};
use arrow::error::ArrowError;
Expand Down Expand Up @@ -120,6 +122,13 @@ fn calculate_new_precision_scale<T: DecimalType>(
}
}

fn is_integer_data_type(data_type: &DataType) -> bool {
matches!(
data_type,
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
)
}

fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result<i32> {
let out_of_range = |value: String| {
datafusion_common::DataFusionError::Execution(format!(
Expand Down Expand Up @@ -185,6 +194,7 @@ impl RoundFunc {
vec![TypeSignatureClass::Integer],
NativeType::Int32,
);
let integer = Coercion::new_exact(TypeSignatureClass::Integer);
let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
let float64 = Coercion::new_implicit(
TypeSignatureClass::Native(logical_float64()),
Expand All @@ -199,6 +209,11 @@ impl RoundFunc {
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![decimal]),
TypeSignature::Coercible(vec![
integer.clone(),
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![integer]),
TypeSignature::Coercible(vec![
float32.clone(),
decimal_places.clone(),
Expand Down Expand Up @@ -245,6 +260,7 @@ impl ScalarUDFImpl for RoundFunc {
// extra precision to accommodate potential carry-over.
let return_type =
match input_type {
input_type if is_integer_data_type(input_type) => input_type.clone(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returns the integer type for all scales, but only non-negative scales are handled downstream. round(arrow_cast(125,'Int64'), -1) now fails, it returned 130.0 before this PR.

Float32 => Float32,
Decimal32(precision, scale) => calculate_new_precision_scale::<
Decimal32Type,
Expand Down Expand Up @@ -308,6 +324,9 @@ impl ScalarUDFImpl for RoundFunc {
};

match (value_scalar, args.return_type()) {
(value_scalar, return_type) if is_integer_data_type(return_type) => {
round_integer_scalar(value_scalar, return_type, dp)
}
(ScalarValue::Float32(Some(v)), _) => {
let rounded = round_float(*v, dp)?;
Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
Expand Down Expand Up @@ -468,6 +487,11 @@ fn round_columnar(
let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_));

let arr: ArrayRef = match (value_array.data_type(), return_type) {
(input_type, return_type)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The no-op fast path requires decimal_places to be a non-negative scalar literal. When the scale is a column (or a negative literal), this guard is false and there's no other integer arm, so it hits exec_err!

    CREATE TABLE t(v BIGINT, dp INT) AS VALUES (125,1),(125,-1);
    SELECT round(v, dp) FROM t;   -- errored; worked (Float64) before this PR

if input_type == return_type && is_integer_data_type(return_type) =>
{
round_integer_array(value_array.as_ref(), decimal_places, return_type)?
}
(Float64, _) => {
let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>(
value_array.as_ref(),
Expand Down Expand Up @@ -630,6 +654,180 @@ fn round_columnar(
}
}

fn round_integer_value(value: i128, decimal_places: i32) -> Result<i128, ArrowError> {
if decimal_places >= 0 {
return Ok(value);
}

let Some(factor) = 10_i128.checked_pow(decimal_places.unsigned_abs()) else {
return Ok(0);
};

let remainder = value % factor;
let threshold = factor / 2;

if remainder >= threshold {
value
.checked_sub(remainder)
.and_then(|v| v.checked_add(factor))
.ok_or_else(|| {
ArrowError::ComputeError("Overflow while rounding integer".to_string())
})
} else if remainder <= -threshold {
value
.checked_sub(remainder)
.and_then(|v| v.checked_sub(factor))
.ok_or_else(|| {
ArrowError::ComputeError("Overflow while rounding integer".to_string())
})
} else {
value.checked_sub(remainder).ok_or_else(|| {
ArrowError::ComputeError("Overflow while rounding integer".to_string())
})
}
}

fn round_integer_scalar(
value: &ScalarValue,
return_type: &DataType,
decimal_places: i32,
) -> Result<ColumnarValue> {
let rounded = match value {
ScalarValue::Int8(Some(v)) => {
round_integer_value(i128::from(*v), decimal_places)?
}
ScalarValue::Int16(Some(v)) => {
round_integer_value(i128::from(*v), decimal_places)?
}
ScalarValue::Int32(Some(v)) => {
round_integer_value(i128::from(*v), decimal_places)?
}
ScalarValue::Int64(Some(v)) => {
round_integer_value(i128::from(*v), decimal_places)?
}
ScalarValue::UInt8(Some(v)) => {
round_integer_value(i128::from(*v), decimal_places)?
}
ScalarValue::UInt16(Some(v)) => {
round_integer_value(i128::from(*v), decimal_places)?
}
ScalarValue::UInt32(Some(v)) => {
round_integer_value(i128::from(*v), decimal_places)?
}
ScalarValue::UInt64(Some(v)) => {
round_integer_value(i128::from(*v), decimal_places)?
}
_ => {
return internal_err!(
"Unexpected datatype for integer round: {}",
value.data_type()
);
}
};

let scalar = match return_type {
Int8 => ScalarValue::Int8(Some(i8::try_from(rounded).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding Int8".to_string())
})?)),
Int16 => ScalarValue::Int16(Some(i16::try_from(rounded).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding Int16".to_string())
})?)),
Int32 => ScalarValue::Int32(Some(i32::try_from(rounded).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding Int32".to_string())
})?)),
Int64 => ScalarValue::Int64(Some(i64::try_from(rounded).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding Int64".to_string())
})?)),
UInt8 => ScalarValue::UInt8(Some(u8::try_from(rounded).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding UInt8".to_string())
})?)),
UInt16 => ScalarValue::UInt16(Some(u16::try_from(rounded).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding UInt16".to_string())
})?)),
UInt32 => ScalarValue::UInt32(Some(u32::try_from(rounded).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding UInt32".to_string())
})?)),
UInt64 => ScalarValue::UInt64(Some(u64::try_from(rounded).map_err(|_| {
ArrowError::ComputeError("Overflow while rounding UInt64".to_string())
})?)),
_ => {
return internal_err!(
"Unexpected return type for integer round: {return_type}"
);
}
};

Ok(ColumnarValue::Scalar(scalar))
}

macro_rules! round_integer_array {
($ARRAY:expr, $DP:expr, $RETURN_TYPE:expr, $ARRAY_TYPE:ty, $NATIVE:ty) => {{
let array = $ARRAY.as_primitive::<$ARRAY_TYPE>();

let result = calculate_binary_math::<$ARRAY_TYPE, Int32Type, $ARRAY_TYPE, _>(
array,
$DP,
|v, dp| {
let rounded = round_integer_value(i128::from(v), dp)?;
<$NATIVE>::try_from(rounded).map_err(|_| {
ArrowError::ComputeError(format!(
"Overflow while rounding {}",
$RETURN_TYPE
))
})
},
)?;

Ok(result as ArrayRef)
}};
}

fn round_integer_array(
value_array: &dyn Array,
decimal_places: &ColumnarValue,
return_type: &DataType,
) -> Result<ArrayRef> {
match return_type {
Int8 => {
round_integer_array!(value_array, decimal_places, return_type, Int8Type, i8)
}
Int16 => {
round_integer_array!(value_array, decimal_places, return_type, Int16Type, i16)
}
Int32 => {
round_integer_array!(value_array, decimal_places, return_type, Int32Type, i32)
}
Int64 => {
round_integer_array!(value_array, decimal_places, return_type, Int64Type, i64)
}
UInt8 => {
round_integer_array!(value_array, decimal_places, return_type, UInt8Type, u8)
}
UInt16 => round_integer_array!(
value_array,
decimal_places,
return_type,
UInt16Type,
u16
),
UInt32 => round_integer_array!(
value_array,
decimal_places,
return_type,
UInt32Type,
u32
),
UInt64 => round_integer_array!(
value_array,
decimal_places,
return_type,
UInt64Type,
u64
),
_ => internal_err!("Unexpected return type for integer round: {return_type}"),
}
}

fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError>
where
T: num_traits::Float,
Expand Down
50 changes: 29 additions & 21 deletions datafusion/spark/src/function/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,18 +462,22 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result<Columna
impl_integer_array_round!(array, UInt32Type, scale, enable_ansi_mode)
}
DataType::UInt64 => {
let array = array.as_primitive::<UInt64Type>();
let result: PrimitiveArray<UInt64Type> = array.try_unary(|x| {
let v_i64 = i64::try_from(x).map_err(|_| {
(exec_err!(
"round: UInt64 value {x} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
if scale >= 0 {
Ok(args[0].clone())
} else {
let array = array.as_primitive::<UInt64Type>();
let result: PrimitiveArray<UInt64Type> = array.try_unary(|x| {
let v_i64 = i64::try_from(x).map_err(|_| {
(exec_err!(
"round: UInt64 value {x} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
})?;
round_integer(v_i64, scale, enable_ansi_mode)
.map(|v| v as u64)
})?;
round_integer(v_i64, scale, enable_ansi_mode)
.map(|v| v as u64)
})?;
Ok(ColumnarValue::Array(Arc::new(result)))
Ok(ColumnarValue::Array(Arc::new(result)))
}
}

// Float types
Expand Down Expand Up @@ -588,16 +592,20 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result<Columna
Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(result))))
}
ScalarValue::UInt64(Some(v)) => {
let v_i64 = i64::try_from(*v).map_err(|_| {
(exec_err!(
"round: UInt64 value {v} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
})?;
let result = round_integer(v_i64, scale, enable_ansi_mode)?;
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(
result as u64,
))))
if scale >= 0 {
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(*v))))
} else {
let v_i64 = i64::try_from(*v).map_err(|_| {
(exec_err!(
"round: UInt64 value {v} exceeds i64::MAX and cannot be rounded"
) as Result<(), _>)
.unwrap_err()
})?;
let result = round_integer(v_i64, scale, enable_ansi_mode)?;
Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(
result as u64,
))))
}
}

// Float scalars
Expand Down
30 changes: 30 additions & 0 deletions datafusion/sqllogictest/test_files/scalar.slt
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests only cover non-negative scalar scale. There's no coverage for negative scale, column scale, or negative scale on a column

Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,36 @@ select round(a), round(b), round(c) from small_floats;
0 0 1
1 0 0

# round int64 should preserve exact values above Float64 precision range
query TI
select arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'))),
round(arrow_cast(9007199254740993, 'Int64'));
----
Int64 9007199254740993

# round int64 with positive decimal_places should preserve exact values above Float64 precision range
query TI
select arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'), 2)),
round(arrow_cast(9007199254740993, 'Int64'), 2);
----
Int64 9007199254740993

# round int64 with negative decimal_places
query TI
select arrow_typeof(round(arrow_cast(125, 'Int64'), -1)),
round(arrow_cast(125, 'Int64'), -1);
----
Int64 130

# round int64 with column decimal_places
query I
select round(v, dp)
from (values (arrow_cast(125, 'Int64'), 1),
(arrow_cast(125, 'Int64'), -1)) as t(v, dp);
----
125
130

# round with too large
# max Int32 is 2147483647
query error round decimal_places 2147483648 is out of supported i32 range
Expand Down
11 changes: 5 additions & 6 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1591,16 +1591,15 @@ WHERE CAST(ROUND(b) as INT) = a
ORDER BY CAST(ROUND(b) as INT);
----
logical_plan
01)Sort: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) ASC NULLS LAST
02)--Filter: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a
03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a]
01)Sort: CAST(round(annotated_data_finite2.b) AS Int32) ASC NULLS LAST
02)--Filter: CAST(round(annotated_data_finite2.b) AS Int32) = annotated_data_finite2.a
03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(annotated_data_finite2.b) AS Int32) = annotated_data_finite2.a]
physical_plan
01)SortPreservingMergeExec: [CAST(round(CAST(b@2 AS Float64)) AS Int32) ASC NULLS LAST]
02)--FilterExec: CAST(round(CAST(b@2 AS Float64)) AS Int32) = a@1
01)SortPreservingMergeExec: [round(b@2) ASC NULLS LAST]
02)--FilterExec: round(b@2) = a@1
03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1, maintains_sort_order=true
04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true


statement ok
drop table annotated_data_finite2;

Expand Down
Loading
Loading