From 97a72947f27b8ec32655c4cd51fb75ecde384983 Mon Sep 17 00:00:00 2001 From: pchintar <89355405+pchintar@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:39:13 +0530 Subject: [PATCH] Preserve integer precision in round() --- datafusion/functions/src/math/round.rs | 204 +++++++++++++++++- datafusion/spark/src/function/math/round.rs | 50 +++-- datafusion/sqllogictest/test_files/scalar.slt | 30 +++ datafusion/sqllogictest/test_files/select.slt | 11 +- .../test_files/spark/math/round.slt | 24 +++ 5 files changed, 289 insertions(+), 30 deletions(-) diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 78016c0f52f71..651d748da388e 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -17,13 +17,15 @@ use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef, AsArray}; use arrow::datatypes::DataType::{ - Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, + Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, Int8, Int16, Int32, + Int64, UInt8, UInt16, UInt32, UInt64, }; use arrow::datatypes::{ ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, - Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type, + Decimal256Type, DecimalType, Float32Type, Float64Type, Int8Type, Int16Type, + Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; use arrow::datatypes::{Field, FieldRef}; use arrow::error::ArrowError; @@ -120,6 +122,13 @@ fn calculate_new_precision_scale( } } +fn is_integer_data_type(data_type: &DataType) -> bool { + matches!( + data_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) +} + fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result { let out_of_range = |value: String| { datafusion_common::DataFusionError::Execution(format!( @@ -185,6 +194,7 @@ impl RoundFunc { vec![TypeSignatureClass::Integer], NativeType::Int32, ); + let integer = Coercion::new_exact(TypeSignatureClass::Integer); let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32())); let float64 = Coercion::new_implicit( TypeSignatureClass::Native(logical_float64()), @@ -199,6 +209,11 @@ impl RoundFunc { decimal_places.clone(), ]), TypeSignature::Coercible(vec![decimal]), + TypeSignature::Coercible(vec![ + integer.clone(), + decimal_places.clone(), + ]), + TypeSignature::Coercible(vec![integer]), TypeSignature::Coercible(vec![ float32.clone(), decimal_places.clone(), @@ -245,6 +260,7 @@ impl ScalarUDFImpl for RoundFunc { // extra precision to accommodate potential carry-over. let return_type = match input_type { + input_type if is_integer_data_type(input_type) => input_type.clone(), Float32 => Float32, Decimal32(precision, scale) => calculate_new_precision_scale::< Decimal32Type, @@ -308,6 +324,9 @@ impl ScalarUDFImpl for RoundFunc { }; match (value_scalar, args.return_type()) { + (value_scalar, return_type) if is_integer_data_type(return_type) => { + round_integer_scalar(value_scalar, return_type, dp) + } (ScalarValue::Float32(Some(v)), _) => { let rounded = round_float(*v, dp)?; Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) @@ -468,6 +487,11 @@ fn round_columnar( let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_)); let arr: ArrayRef = match (value_array.data_type(), return_type) { + (input_type, return_type) + if input_type == return_type && is_integer_data_type(return_type) => + { + round_integer_array(value_array.as_ref(), decimal_places, return_type)? + } (Float64, _) => { let result = calculate_binary_math::( value_array.as_ref(), @@ -630,6 +654,180 @@ fn round_columnar( } } +fn round_integer_value(value: i128, decimal_places: i32) -> Result { + if decimal_places >= 0 { + return Ok(value); + } + + let Some(factor) = 10_i128.checked_pow(decimal_places.unsigned_abs()) else { + return Ok(0); + }; + + let remainder = value % factor; + let threshold = factor / 2; + + if remainder >= threshold { + value + .checked_sub(remainder) + .and_then(|v| v.checked_add(factor)) + .ok_or_else(|| { + ArrowError::ComputeError("Overflow while rounding integer".to_string()) + }) + } else if remainder <= -threshold { + value + .checked_sub(remainder) + .and_then(|v| v.checked_sub(factor)) + .ok_or_else(|| { + ArrowError::ComputeError("Overflow while rounding integer".to_string()) + }) + } else { + value.checked_sub(remainder).ok_or_else(|| { + ArrowError::ComputeError("Overflow while rounding integer".to_string()) + }) + } +} + +fn round_integer_scalar( + value: &ScalarValue, + return_type: &DataType, + decimal_places: i32, +) -> Result { + let rounded = match value { + ScalarValue::Int8(Some(v)) => { + round_integer_value(i128::from(*v), decimal_places)? + } + ScalarValue::Int16(Some(v)) => { + round_integer_value(i128::from(*v), decimal_places)? + } + ScalarValue::Int32(Some(v)) => { + round_integer_value(i128::from(*v), decimal_places)? + } + ScalarValue::Int64(Some(v)) => { + round_integer_value(i128::from(*v), decimal_places)? + } + ScalarValue::UInt8(Some(v)) => { + round_integer_value(i128::from(*v), decimal_places)? + } + ScalarValue::UInt16(Some(v)) => { + round_integer_value(i128::from(*v), decimal_places)? + } + ScalarValue::UInt32(Some(v)) => { + round_integer_value(i128::from(*v), decimal_places)? + } + ScalarValue::UInt64(Some(v)) => { + round_integer_value(i128::from(*v), decimal_places)? + } + _ => { + return internal_err!( + "Unexpected datatype for integer round: {}", + value.data_type() + ); + } + }; + + let scalar = match return_type { + Int8 => ScalarValue::Int8(Some(i8::try_from(rounded).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding Int8".to_string()) + })?)), + Int16 => ScalarValue::Int16(Some(i16::try_from(rounded).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding Int16".to_string()) + })?)), + Int32 => ScalarValue::Int32(Some(i32::try_from(rounded).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding Int32".to_string()) + })?)), + Int64 => ScalarValue::Int64(Some(i64::try_from(rounded).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding Int64".to_string()) + })?)), + UInt8 => ScalarValue::UInt8(Some(u8::try_from(rounded).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding UInt8".to_string()) + })?)), + UInt16 => ScalarValue::UInt16(Some(u16::try_from(rounded).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding UInt16".to_string()) + })?)), + UInt32 => ScalarValue::UInt32(Some(u32::try_from(rounded).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding UInt32".to_string()) + })?)), + UInt64 => ScalarValue::UInt64(Some(u64::try_from(rounded).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding UInt64".to_string()) + })?)), + _ => { + return internal_err!( + "Unexpected return type for integer round: {return_type}" + ); + } + }; + + Ok(ColumnarValue::Scalar(scalar)) +} + +macro_rules! round_integer_array { + ($ARRAY:expr, $DP:expr, $RETURN_TYPE:expr, $ARRAY_TYPE:ty, $NATIVE:ty) => {{ + let array = $ARRAY.as_primitive::<$ARRAY_TYPE>(); + + let result = calculate_binary_math::<$ARRAY_TYPE, Int32Type, $ARRAY_TYPE, _>( + array, + $DP, + |v, dp| { + let rounded = round_integer_value(i128::from(v), dp)?; + <$NATIVE>::try_from(rounded).map_err(|_| { + ArrowError::ComputeError(format!( + "Overflow while rounding {}", + $RETURN_TYPE + )) + }) + }, + )?; + + Ok(result as ArrayRef) + }}; +} + +fn round_integer_array( + value_array: &dyn Array, + decimal_places: &ColumnarValue, + return_type: &DataType, +) -> Result { + match return_type { + Int8 => { + round_integer_array!(value_array, decimal_places, return_type, Int8Type, i8) + } + Int16 => { + round_integer_array!(value_array, decimal_places, return_type, Int16Type, i16) + } + Int32 => { + round_integer_array!(value_array, decimal_places, return_type, Int32Type, i32) + } + Int64 => { + round_integer_array!(value_array, decimal_places, return_type, Int64Type, i64) + } + UInt8 => { + round_integer_array!(value_array, decimal_places, return_type, UInt8Type, u8) + } + UInt16 => round_integer_array!( + value_array, + decimal_places, + return_type, + UInt16Type, + u16 + ), + UInt32 => round_integer_array!( + value_array, + decimal_places, + return_type, + UInt32Type, + u32 + ), + UInt64 => round_integer_array!( + value_array, + decimal_places, + return_type, + UInt64Type, + u64 + ), + _ => internal_err!("Unexpected return type for integer round: {return_type}"), + } +} + fn round_float(value: T, decimal_places: i32) -> Result where T: num_traits::Float, diff --git a/datafusion/spark/src/function/math/round.rs b/datafusion/spark/src/function/math/round.rs index 05745666183d3..93cc74e8878e0 100644 --- a/datafusion/spark/src/function/math/round.rs +++ b/datafusion/spark/src/function/math/round.rs @@ -462,18 +462,22 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result { - let array = array.as_primitive::(); - let result: PrimitiveArray = array.try_unary(|x| { - let v_i64 = i64::try_from(x).map_err(|_| { - (exec_err!( - "round: UInt64 value {x} exceeds i64::MAX and cannot be rounded" - ) as Result<(), _>) - .unwrap_err() + if scale >= 0 { + Ok(args[0].clone()) + } else { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.try_unary(|x| { + let v_i64 = i64::try_from(x).map_err(|_| { + (exec_err!( + "round: UInt64 value {x} exceeds i64::MAX and cannot be rounded" + ) as Result<(), _>) + .unwrap_err() + })?; + round_integer(v_i64, scale, enable_ansi_mode) + .map(|v| v as u64) })?; - round_integer(v_i64, scale, enable_ansi_mode) - .map(|v| v as u64) - })?; - Ok(ColumnarValue::Array(Arc::new(result))) + Ok(ColumnarValue::Array(Arc::new(result))) + } } // Float types @@ -588,16 +592,20 @@ fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result { - let v_i64 = i64::try_from(*v).map_err(|_| { - (exec_err!( - "round: UInt64 value {v} exceeds i64::MAX and cannot be rounded" - ) as Result<(), _>) - .unwrap_err() - })?; - let result = round_integer(v_i64, scale, enable_ansi_mode)?; - Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( - result as u64, - )))) + if scale >= 0 { + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some(*v)))) + } else { + let v_i64 = i64::try_from(*v).map_err(|_| { + (exec_err!( + "round: UInt64 value {v} exceeds i64::MAX and cannot be rounded" + ) as Result<(), _>) + .unwrap_err() + })?; + let result = round_integer(v_i64, scale, enable_ansi_mode)?; + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( + result as u64, + )))) + } } // Float scalars diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 9dbf8f16d85ab..784773bf9c46c 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -940,6 +940,36 @@ select round(a), round(b), round(c) from small_floats; 0 0 1 1 0 0 +# round int64 should preserve exact values above Float64 precision range +query TI +select arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'))), + round(arrow_cast(9007199254740993, 'Int64')); +---- +Int64 9007199254740993 + +# round int64 with positive decimal_places should preserve exact values above Float64 precision range +query TI +select arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'), 2)), + round(arrow_cast(9007199254740993, 'Int64'), 2); +---- +Int64 9007199254740993 + +# round int64 with negative decimal_places +query TI +select arrow_typeof(round(arrow_cast(125, 'Int64'), -1)), + round(arrow_cast(125, 'Int64'), -1); +---- +Int64 130 + +# round int64 with column decimal_places +query I +select round(v, dp) +from (values (arrow_cast(125, 'Int64'), 1), + (arrow_cast(125, 'Int64'), -1)) as t(v, dp); +---- +125 +130 + # round with too large # max Int32 is 2147483647 query error round decimal_places 2147483648 is out of supported i32 range diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index c7e5ed12fc0af..3a9ae30d04275 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1591,16 +1591,15 @@ WHERE CAST(ROUND(b) as INT) = a ORDER BY CAST(ROUND(b) as INT); ---- logical_plan -01)Sort: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) ASC NULLS LAST -02)--Filter: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a -03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a] +01)Sort: CAST(round(annotated_data_finite2.b) AS Int32) ASC NULLS LAST +02)--Filter: CAST(round(annotated_data_finite2.b) AS Int32) = annotated_data_finite2.a +03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(annotated_data_finite2.b) AS Int32) = annotated_data_finite2.a] physical_plan -01)SortPreservingMergeExec: [CAST(round(CAST(b@2 AS Float64)) AS Int32) ASC NULLS LAST] -02)--FilterExec: CAST(round(CAST(b@2 AS Float64)) AS Int32) = a@1 +01)SortPreservingMergeExec: [round(b@2) ASC NULLS LAST] +02)--FilterExec: round(b@2) = a@1 03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1, maintains_sort_order=true 04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], file_type=csv, has_header=true - statement ok drop table annotated_data_finite2; diff --git a/datafusion/sqllogictest/test_files/spark/math/round.slt b/datafusion/sqllogictest/test_files/spark/math/round.slt index 91c5bdf0506f5..49956846ac814 100644 --- a/datafusion/sqllogictest/test_files/spark/math/round.slt +++ b/datafusion/sqllogictest/test_files/spark/math/round.slt @@ -222,6 +222,18 @@ SELECT round(25::bigint, -1::int); ---- 30 +# round(bigint) should preserve exact values above Float64's exact integer range +query IT +SELECT round(arrow_cast(9007199254740993, 'Int64')), arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'))); +---- +9007199254740993 Int64 + +# round(bigint, positive scale) should also preserve exact values above Float64's exact integer range +query IT +SELECT round(arrow_cast(9007199254740993, 'Int64'), 2::int), arrow_typeof(round(arrow_cast(9007199254740993, 'Int64'), 2::int)); +---- +9007199254740993 Int64 + # round(smallint, -1) query I SELECT round(25::smallint, -1::int); @@ -268,6 +280,18 @@ SELECT round(arrow_cast(25, 'UInt64'), -1::int); ---- 30 +# round(uint64) should preserve exact values above Float64's exact integer range +query IT +SELECT round(arrow_cast(18446744073709551615, 'UInt64')), arrow_typeof(round(arrow_cast(18446744073709551615, 'UInt64'))); +---- +18446744073709551615 UInt64 + +# round(uint64, positive scale) should also preserve exact values above Float64's exact integer range +query IT +SELECT round(arrow_cast(18446744073709551615, 'UInt64'), 2::int), arrow_typeof(round(arrow_cast(18446744073709551615, 'UInt64'), 2::int)); +---- +18446744073709551615 UInt64 + # round(uint32, positive scale) — no-op for integers query I SELECT round(arrow_cast(42, 'UInt32'), 2::int);