-
Notifications
You must be signed in to change notification settings - Fork 2.1k
fix: Preserve integer values in round() for large Int64 and UInt64 inputs #22697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,13 +17,15 @@ | |
|
|
||
| use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; | ||
|
|
||
| use arrow::array::ArrayRef; | ||
| use arrow::array::{Array, ArrayRef, AsArray}; | ||
| use arrow::datatypes::DataType::{ | ||
| Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, | ||
| Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, Int8, Int16, Int32, | ||
| Int64, UInt8, UInt16, UInt32, UInt64, | ||
| }; | ||
| use arrow::datatypes::{ | ||
| ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, | ||
| Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type, | ||
| Decimal256Type, DecimalType, Float32Type, Float64Type, Int8Type, Int16Type, | ||
| Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, | ||
| }; | ||
| use arrow::datatypes::{Field, FieldRef}; | ||
| use arrow::error::ArrowError; | ||
|
|
@@ -120,6 +122,13 @@ fn calculate_new_precision_scale<T: DecimalType>( | |
| } | ||
| } | ||
|
|
||
| fn is_integer_data_type(data_type: &DataType) -> bool { | ||
| matches!( | ||
| data_type, | ||
| Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | ||
| ) | ||
| } | ||
|
|
||
| fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result<i32> { | ||
| let out_of_range = |value: String| { | ||
| datafusion_common::DataFusionError::Execution(format!( | ||
|
|
@@ -185,6 +194,7 @@ impl RoundFunc { | |
| vec![TypeSignatureClass::Integer], | ||
| NativeType::Int32, | ||
| ); | ||
| let integer = Coercion::new_exact(TypeSignatureClass::Integer); | ||
| let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32())); | ||
| let float64 = Coercion::new_implicit( | ||
| TypeSignatureClass::Native(logical_float64()), | ||
|
|
@@ -199,6 +209,11 @@ impl RoundFunc { | |
| decimal_places.clone(), | ||
| ]), | ||
| TypeSignature::Coercible(vec![decimal]), | ||
| TypeSignature::Coercible(vec![ | ||
| integer.clone(), | ||
| decimal_places.clone(), | ||
| ]), | ||
| TypeSignature::Coercible(vec![integer]), | ||
| TypeSignature::Coercible(vec![ | ||
| float32.clone(), | ||
| decimal_places.clone(), | ||
|
|
@@ -245,6 +260,7 @@ impl ScalarUDFImpl for RoundFunc { | |
| // extra precision to accommodate potential carry-over. | ||
| let return_type = | ||
| match input_type { | ||
| input_type if is_integer_data_type(input_type) => input_type.clone(), | ||
| Float32 => Float32, | ||
| Decimal32(precision, scale) => calculate_new_precision_scale::< | ||
| Decimal32Type, | ||
|
|
@@ -308,6 +324,9 @@ impl ScalarUDFImpl for RoundFunc { | |
| }; | ||
|
|
||
| match (value_scalar, args.return_type()) { | ||
| (value_scalar, return_type) if is_integer_data_type(return_type) => { | ||
| round_integer_scalar(value_scalar, return_type, dp) | ||
| } | ||
| (ScalarValue::Float32(Some(v)), _) => { | ||
| let rounded = round_float(*v, dp)?; | ||
| Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) | ||
|
|
@@ -468,6 +487,11 @@ fn round_columnar( | |
| let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_)); | ||
|
|
||
| let arr: ArrayRef = match (value_array.data_type(), return_type) { | ||
| (input_type, return_type) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The no-op fast path requires CREATE TABLE t(v BIGINT, dp INT) AS VALUES (125,1),(125,-1);
SELECT round(v, dp) FROM t; -- errored; worked (Float64) before this PR |
||
| if input_type == return_type && is_integer_data_type(return_type) => | ||
| { | ||
| round_integer_array(value_array.as_ref(), decimal_places, return_type)? | ||
| } | ||
| (Float64, _) => { | ||
| let result = calculate_binary_math::<Float64Type, Int32Type, Float64Type, _>( | ||
| value_array.as_ref(), | ||
|
|
@@ -630,6 +654,180 @@ fn round_columnar( | |
| } | ||
| } | ||
|
|
||
| fn round_integer_value(value: i128, decimal_places: i32) -> Result<i128, ArrowError> { | ||
| if decimal_places >= 0 { | ||
| return Ok(value); | ||
| } | ||
|
|
||
| let Some(factor) = 10_i128.checked_pow(decimal_places.unsigned_abs()) else { | ||
| return Ok(0); | ||
| }; | ||
|
|
||
| let remainder = value % factor; | ||
| let threshold = factor / 2; | ||
|
|
||
| if remainder >= threshold { | ||
| value | ||
| .checked_sub(remainder) | ||
| .and_then(|v| v.checked_add(factor)) | ||
| .ok_or_else(|| { | ||
| ArrowError::ComputeError("Overflow while rounding integer".to_string()) | ||
| }) | ||
| } else if remainder <= -threshold { | ||
| value | ||
| .checked_sub(remainder) | ||
| .and_then(|v| v.checked_sub(factor)) | ||
| .ok_or_else(|| { | ||
| ArrowError::ComputeError("Overflow while rounding integer".to_string()) | ||
| }) | ||
| } else { | ||
| value.checked_sub(remainder).ok_or_else(|| { | ||
| ArrowError::ComputeError("Overflow while rounding integer".to_string()) | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| fn round_integer_scalar( | ||
| value: &ScalarValue, | ||
| return_type: &DataType, | ||
| decimal_places: i32, | ||
| ) -> Result<ColumnarValue> { | ||
| let rounded = match value { | ||
| ScalarValue::Int8(Some(v)) => { | ||
| round_integer_value(i128::from(*v), decimal_places)? | ||
| } | ||
| ScalarValue::Int16(Some(v)) => { | ||
| round_integer_value(i128::from(*v), decimal_places)? | ||
| } | ||
| ScalarValue::Int32(Some(v)) => { | ||
| round_integer_value(i128::from(*v), decimal_places)? | ||
| } | ||
| ScalarValue::Int64(Some(v)) => { | ||
| round_integer_value(i128::from(*v), decimal_places)? | ||
| } | ||
| ScalarValue::UInt8(Some(v)) => { | ||
| round_integer_value(i128::from(*v), decimal_places)? | ||
| } | ||
| ScalarValue::UInt16(Some(v)) => { | ||
| round_integer_value(i128::from(*v), decimal_places)? | ||
| } | ||
| ScalarValue::UInt32(Some(v)) => { | ||
| round_integer_value(i128::from(*v), decimal_places)? | ||
| } | ||
| ScalarValue::UInt64(Some(v)) => { | ||
| round_integer_value(i128::from(*v), decimal_places)? | ||
| } | ||
| _ => { | ||
| return internal_err!( | ||
| "Unexpected datatype for integer round: {}", | ||
| value.data_type() | ||
| ); | ||
| } | ||
| }; | ||
|
|
||
| let scalar = match return_type { | ||
| Int8 => ScalarValue::Int8(Some(i8::try_from(rounded).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding Int8".to_string()) | ||
| })?)), | ||
| Int16 => ScalarValue::Int16(Some(i16::try_from(rounded).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding Int16".to_string()) | ||
| })?)), | ||
| Int32 => ScalarValue::Int32(Some(i32::try_from(rounded).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding Int32".to_string()) | ||
| })?)), | ||
| Int64 => ScalarValue::Int64(Some(i64::try_from(rounded).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding Int64".to_string()) | ||
| })?)), | ||
| UInt8 => ScalarValue::UInt8(Some(u8::try_from(rounded).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding UInt8".to_string()) | ||
| })?)), | ||
| UInt16 => ScalarValue::UInt16(Some(u16::try_from(rounded).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding UInt16".to_string()) | ||
| })?)), | ||
| UInt32 => ScalarValue::UInt32(Some(u32::try_from(rounded).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding UInt32".to_string()) | ||
| })?)), | ||
| UInt64 => ScalarValue::UInt64(Some(u64::try_from(rounded).map_err(|_| { | ||
| ArrowError::ComputeError("Overflow while rounding UInt64".to_string()) | ||
| })?)), | ||
| _ => { | ||
| return internal_err!( | ||
| "Unexpected return type for integer round: {return_type}" | ||
| ); | ||
| } | ||
| }; | ||
|
|
||
| Ok(ColumnarValue::Scalar(scalar)) | ||
| } | ||
|
|
||
| macro_rules! round_integer_array { | ||
| ($ARRAY:expr, $DP:expr, $RETURN_TYPE:expr, $ARRAY_TYPE:ty, $NATIVE:ty) => {{ | ||
| let array = $ARRAY.as_primitive::<$ARRAY_TYPE>(); | ||
|
|
||
| let result = calculate_binary_math::<$ARRAY_TYPE, Int32Type, $ARRAY_TYPE, _>( | ||
| array, | ||
| $DP, | ||
| |v, dp| { | ||
| let rounded = round_integer_value(i128::from(v), dp)?; | ||
| <$NATIVE>::try_from(rounded).map_err(|_| { | ||
| ArrowError::ComputeError(format!( | ||
| "Overflow while rounding {}", | ||
| $RETURN_TYPE | ||
| )) | ||
| }) | ||
| }, | ||
| )?; | ||
|
|
||
| Ok(result as ArrayRef) | ||
| }}; | ||
| } | ||
|
|
||
| fn round_integer_array( | ||
| value_array: &dyn Array, | ||
| decimal_places: &ColumnarValue, | ||
| return_type: &DataType, | ||
| ) -> Result<ArrayRef> { | ||
| match return_type { | ||
| Int8 => { | ||
| round_integer_array!(value_array, decimal_places, return_type, Int8Type, i8) | ||
| } | ||
| Int16 => { | ||
| round_integer_array!(value_array, decimal_places, return_type, Int16Type, i16) | ||
| } | ||
| Int32 => { | ||
| round_integer_array!(value_array, decimal_places, return_type, Int32Type, i32) | ||
| } | ||
| Int64 => { | ||
| round_integer_array!(value_array, decimal_places, return_type, Int64Type, i64) | ||
| } | ||
| UInt8 => { | ||
| round_integer_array!(value_array, decimal_places, return_type, UInt8Type, u8) | ||
| } | ||
| UInt16 => round_integer_array!( | ||
| value_array, | ||
| decimal_places, | ||
| return_type, | ||
| UInt16Type, | ||
| u16 | ||
| ), | ||
| UInt32 => round_integer_array!( | ||
| value_array, | ||
| decimal_places, | ||
| return_type, | ||
| UInt32Type, | ||
| u32 | ||
| ), | ||
| UInt64 => round_integer_array!( | ||
| value_array, | ||
| decimal_places, | ||
| return_type, | ||
| UInt64Type, | ||
| u64 | ||
| ), | ||
| _ => internal_err!("Unexpected return type for integer round: {return_type}"), | ||
| } | ||
| } | ||
|
|
||
| fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError> | ||
| where | ||
| T: num_traits::Float, | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tests only cover non-negative scalar scale. There's no coverage for negative scale, column scale, or negative scale on a column |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
returns the integer type for all scales, but only non-negative scales are handled downstream.
round(arrow_cast(125,'Int64'), -1)now fails, it returned130.0before this PR.