From 920f29a454913003d298dc8f02d0960e9d1f945e Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 28 Apr 2022 14:53:33 +0200 Subject: [PATCH 1/3] Implementing math power function for SQL (#2324) * Implementing POWER function * Delete pv.yaml * Delete build-ballista-docker.sh * Delete ballista.dockerfile * aligining with latest upstream changes * Readding docker files * Formatting * Leaving only 64bit types * Adding tests, remove type conversion * fix for cast * Update functions.rs (cherry picked from commit c3c02cf5e881c9c5020ad87417716dd932e69a69) Can drop this after rebase on commit c3c02cf "Implementing math power function for SQL (#2324)", first released in 8.0.0 # Conflicts: # datafusion/core/src/logical_plan/mod.rs # datafusion/core/src/physical_plan/functions.rs # datafusion/core/tests/sql/functions.rs # datafusion/cube_ext/Cargo.toml # datafusion/expr/src/built_in_function.rs # datafusion/expr/src/function.rs # datafusion/proto/proto/datafusion.proto # datafusion/proto/src/from_proto.rs # datafusion/proto/src/to_proto.rs # dev/docker/ballista.dockerfile --- datafusion/core/src/logical_plan/mod.rs | 12 +- .../core/src/physical_plan/functions.rs | 4 + datafusion/core/tests/sql/functions.rs | 126 ++++++++++++++++++ datafusion/expr/src/built_in_function.rs | 4 + datafusion/expr/src/expr_fn.rs | 1 + datafusion/expr/src/function.rs | 12 ++ .../physical-expr/src/math_expressions.rs | 58 +++++++- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/from_proto.rs | 11 +- datafusion/proto/src/to_proto.rs | 1 + 10 files changed, 220 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index f1badd1e4198..03b780e96468 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -48,12 +48,12 @@ pub use expr::{ count, count_distinct, create_udaf, create_udf, create_udtf, date_part, date_trunc, digest, exp, exprlist_to_fields, exprlist_to_fields_from_schema, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, - max, md5, min, now, now_expr, nullif, octet_length, or, pi, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, - sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, - to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, - trim, trunc, unalias, upper, when, Column, Expr, ExprSchema, GroupingSet, Like, - Literal, + max, md5, min, now, now_expr, nullif, octet_length, or, pi, power, random, + regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, + sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, + substr, sum, tan, to_hex, to_timestamp_micros, to_timestamp_millis, + to_timestamp_seconds, translate, trim, trunc, unalias, upper, when, Column, Expr, + ExprSchema, GroupingSet, Like, Literal, }; pub use expr_rewriter::{ normalize_col, normalize_cols, replace_col, replace_col_to_expr, diff --git a/datafusion/core/src/physical_plan/functions.rs b/datafusion/core/src/physical_plan/functions.rs index 7256467155c6..7fbc6a219214 100644 --- a/datafusion/core/src/physical_plan/functions.rs +++ b/datafusion/core/src/physical_plan/functions.rs @@ -312,6 +312,10 @@ pub fn create_physical_fun( BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc), + BuiltinScalarFunction::Power => { + Arc::new(|args| make_scalar_function(math_expressions::power)(args)) + } + BuiltinScalarFunction::Pi => Arc::new(math_expressions::pi), // string functions BuiltinScalarFunction::MakeArray => Arc::new(array_expressions::array), diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index cfb16169da7b..5d234d4b658c 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -555,6 +555,132 @@ async fn case_builtin_math_expression() { } } +#[tokio::test] +async fn test_power() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("i32", DataType::Int16, true), + Field::new("i64", DataType::Int64, true), + Field::new("f32", DataType::Float32, true), + Field::new("f64", DataType::Float64, true), + ])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int16Array::from(vec![ + Some(2), + Some(5), + Some(0), + Some(-14), + None, + ])), + Arc::new(Int64Array::from(vec![ + Some(2), + Some(5), + Some(0), + Some(-14), + None, + ])), + Arc::new(Float32Array::from(vec![ + Some(1.0), + Some(2.5), + Some(0.0), + Some(-14.5), + None, + ])), + Arc::new(Float64Array::from(vec![ + Some(1.0), + Some(2.5), + Some(0.0), + Some(-14.5), + None, + ])), + ], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let ctx = SessionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = r"SELECT power(i32, exp_i) as power_i32, + power(i64, exp_f) as power_i64, + power(f32, exp_i) as power_f32, + power(f64, exp_f) as power_f64, + power(2, 3) as power_int_scalar, + power(2.5, 3.0) as power_float_scalar + FROM (select test.*, 3 as exp_i, 3.0 as exp_f from test) a"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+-----------+-----------+-----------+------------------+--------------------+", + "| power_i32 | power_i64 | power_f32 | power_f64 | power_int_scalar | power_float_scalar |", + "+-----------+-----------+-----------+-----------+------------------+--------------------+", + "| 8 | 8 | 1 | 1 | 8 | 15.625 |", + "| 125 | 125 | 15.625 | 15.625 | 8 | 15.625 |", + "| 0 | 0 | 0 | 0 | 8 | 15.625 |", + "| -2744 | -2744 | -3048.625 | -3048.625 | 8 | 15.625 |", + "| | | | | 8 | 15.625 |", + "+-----------+-----------+-----------+-----------+------------------+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + //dbg!(actual[0].schema().fields()); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_i32") + .unwrap() + .data_type() + .to_owned(), + DataType::Int64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_i64") + .unwrap() + .data_type() + .to_owned(), + DataType::Float64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_f32") + .unwrap() + .data_type() + .to_owned(), + DataType::Float64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_f64") + .unwrap() + .data_type() + .to_owned(), + DataType::Float64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_int_scalar") + .unwrap() + .data_type() + .to_owned(), + DataType::Int64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_float_scalar") + .unwrap() + .data_type() + .to_owned(), + DataType::Float64 + ); + + Ok(()) +} + // #[tokio::test] // async fn case_sensitive_identifiers_aggregates() { // let ctx = SessionContext::new(); diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 7cfb3e4e9b08..1805942f6fa4 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -54,6 +54,8 @@ pub enum BuiltinScalarFunction { Log10, /// log2 Log2, + /// power + Power, /// pi Pi, /// round @@ -196,6 +198,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Log10 => Volatility::Immutable, BuiltinScalarFunction::Log2 => Volatility::Immutable, + BuiltinScalarFunction::Power => Volatility::Immutable, BuiltinScalarFunction::Pi => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, BuiltinScalarFunction::Signum => Volatility::Immutable, @@ -284,6 +287,7 @@ impl FromStr for BuiltinScalarFunction { "log" => BuiltinScalarFunction::Log, "log10" => BuiltinScalarFunction::Log10, "log2" => BuiltinScalarFunction::Log2, + "power" => BuiltinScalarFunction::Power, "pi" => BuiltinScalarFunction::Pi, "round" => BuiltinScalarFunction::Round, "signum" => BuiltinScalarFunction::Signum, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8d6f0d08e65a..77fcb70f43fd 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -266,6 +266,7 @@ unary_scalar_expr!(Log2, log2); unary_scalar_expr!(Log10, log10); unary_scalar_expr!(Ln, ln); unary_scalar_expr!(NullIf, nullif); +scalar_expr!(Power, power, base, exponent); // string functions scalar_expr!(Ascii, ascii, string); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index e6cdfa428f7b..f947e44fe871 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -244,6 +244,11 @@ pub fn return_type( } }), + BuiltinScalarFunction::Power => match &input_expr_types[0] { + DataType::Int64 => Ok(DataType::Int64), + _ => Ok(DataType::Float64), + }, + BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin @@ -550,6 +555,13 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { ), BuiltinScalarFunction::Pi => Signature::exact(vec![], fun.volatility()), BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()), + BuiltinScalarFunction::Power => Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), + ], + fun.volatility(), + ), BuiltinScalarFunction::Log => Signature::one_of( vec![ TypeSignature::Exact(vec![DataType::Float64]), diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index fea5902f3bef..d5cf7853054a 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -17,12 +17,14 @@ //! Math expressions -use arrow::array::{Float32Array, Float64Array}; +use arrow::array::ArrayRef; +use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; use rand::{thread_rng, Rng}; +use std::any::type_name; use std::iter; use std::sync::Arc; @@ -86,6 +88,33 @@ macro_rules! math_unary_function { }; } +macro_rules! downcast_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} + +macro_rules! make_function_inputs2 { + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE>() + }}; +} + math_unary_function!("sqrt", sqrt); math_unary_function!("sin", sin); math_unary_function!("cos", cos); @@ -131,6 +160,33 @@ pub fn random(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +pub fn power(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "base", + "exponent", + Float64Array, + { f64::powf } + )) as ArrayRef), + + DataType::Int64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "base", + "exponent", + Int64Array, + { i64::pow } + )) as ArrayRef), + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function power", + other + ))), + } +} + #[cfg(test)] mod tests { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d159e1511f87..37ee5ae20251 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -190,6 +190,7 @@ enum ScalarFunction { Upper=62; Coalesce=63; // Upstream + Power=64; CurrentDate=70; Pi=80; // Cubesql diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 893e011f9c3b..b042958803dc 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -25,9 +25,9 @@ use datafusion::{ logical_plan::{ abs, acos, ascii, asin, atan, ceil, character_length, chr, concat_expr, concat_ws_expr, cos, digest, exp, floor, left, ln, log10, log2, now_expr, nullif, - pi, random, regexp_replace, repeat, replace, reverse, right, round, signum, sin, - split_part, sqrt, starts_with, strpos, substr, tan, to_hex, to_timestamp_micros, - to_timestamp_millis, to_timestamp_seconds, translate, trunc, + pi, power, random, regexp_replace, repeat, replace, reverse, right, round, + signum, sin, split_part, sqrt, starts_with, strpos, substr, tan, to_hex, + to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, trunc, window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits}, Column, DFField, DFSchema, DFSchemaRef, Expr, Like, Operator, }, @@ -430,6 +430,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Translate => Self::Translate, ScalarFunction::RegexpMatch => Self::RegexpMatch, ScalarFunction::Coalesce => Self::Coalesce, + ScalarFunction::Power => Self::Power, ScalarFunction::Pi => Self::Pi, // Cube SQL ScalarFunction::UtcTimestamp => Self::UtcTimestamp, @@ -1232,6 +1233,10 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::Power => Ok(power( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::Pi => Ok(pi()), _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 90bc1b3e050d..b64aa6bafb7d 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -1079,6 +1079,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::RegexpMatch => Self::RegexpMatch, BuiltinScalarFunction::Coalesce => Self::Coalesce, + BuiltinScalarFunction::Power => Self::Power, BuiltinScalarFunction::Pi => Self::Pi, // Cube SQL BuiltinScalarFunction::UtcTimestamp => Self::UtcTimestamp, From 459bd0e880d26d376caa45d844f60c7544716fd4 Mon Sep 17 00:00:00 2001 From: Mikhail Cheshkov Date: Mon, 12 May 2025 19:12:25 +0200 Subject: [PATCH 2/3] Support round() function with two parameters (#5807) Can drop this after rebase on commit 771c20c "Support round() function with two parameters (#5807)", first released in 22.0.0 --- .../core/src/optimizer/projection_drop_out.rs | 19 ++-- .../core/src/physical_plan/functions.rs | 4 +- datafusion/expr/src/expr_fn.rs | 5 +- .../physical-expr/src/math_expressions.rs | 106 +++++++++++++++++- datafusion/proto/src/from_proto.rs | 7 +- 5 files changed, 128 insertions(+), 13 deletions(-) diff --git a/datafusion/core/src/optimizer/projection_drop_out.rs b/datafusion/core/src/optimizer/projection_drop_out.rs index 479c9ca917f2..61e38db19078 100644 --- a/datafusion/core/src/optimizer/projection_drop_out.rs +++ b/datafusion/core/src/optimizer/projection_drop_out.rs @@ -631,7 +631,7 @@ mod tests { )? .project_with_alias( vec![ - round(col("id")).alias("first"), + round(vec![col("id")]).alias("first"), col("n").alias("second"), lit(2).alias("third"), ], @@ -649,7 +649,7 @@ mod tests { // select * from (select id first, a second, 2 third from (select round(a) id, 1 num from table) a) x; let plan = LogicalPlanBuilder::from(table_scan) .project_with_alias( - vec![round(col("a")).alias("id"), lit(1).alias("n")], + vec![round(vec![col("a")]).alias("id"), lit(1).alias("n")], Some("a".to_string()), )? .project_with_alias( @@ -748,7 +748,7 @@ mod tests { )? .project_with_alias( vec![ - round(col("id")).alias("first"), + round(vec![col("id")]).alias("first"), col("n").alias("second"), lit(2).alias("third"), ], @@ -826,7 +826,10 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project_with_alias(vec![col("a").alias("id")], Some("a".to_string()))? .project_with_alias( - vec![round(col("id")).alias("first"), lit(2).alias("second")], + vec![ + round(vec![col("id")]).alias("first"), + lit(2).alias("second"), + ], Some("b".to_string()), )? .sort(vec![col("first")])? @@ -1019,7 +1022,7 @@ mod tests { .project_with_alias(vec![col("a").alias("num")], Some("a".to_string()))? .project_with_alias(vec![col("num")], Some("b".to_string()))? .filter(col("num").gt(lit(0)))? - .aggregate(vec![round(col("num"))], Vec::::new())? + .aggregate(vec![round(vec![col("num")])], Vec::::new())? .project(vec![col("Round(b.num)")])? .sort(vec![col("Round(b.num)")])? .build()?; @@ -1044,7 +1047,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan.clone()) .project_with_alias(vec![col("a").alias("num")], Some("a".to_string()))? .project_with_alias(vec![col("num")], Some("b".to_string()))? - .aggregate(vec![round(col("num"))], Vec::::new())? + .aggregate(vec![round(vec![col("num")])], Vec::::new())? .project(vec![col("Round(b.num)")])? .sort(vec![col("Round(b.num)")])? .build()?; @@ -1061,7 +1064,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project_with_alias(vec![col("a").alias("num")], Some("a".to_string()))? .project_with_alias(vec![col("num")], Some("b".to_string()))? - .aggregate(vec![round(col("num"))], Vec::::new())? + .aggregate(vec![round(vec![col("num")])], Vec::::new())? .project(vec![col("Round(b.num)")])? .sort(vec![col("Round(b.num)")])? .project_with_alias(vec![col("Round(b.num)")], Some("x".to_string()))? @@ -1099,7 +1102,7 @@ mod tests { .project_with_alias(vec![col("num")], Some("x".to_string()))? .join(&right, JoinType::Left, (vec!["num"], vec!["a"]))? .project_with_alias( - vec![col("num"), col("a"), round(col("num"))], + vec![col("num"), col("a"), round(vec![col("num")])], Some("b".to_string()), )? .build()?; diff --git a/datafusion/core/src/physical_plan/functions.rs b/datafusion/core/src/physical_plan/functions.rs index 7fbc6a219214..3a3a7f0ad913 100644 --- a/datafusion/core/src/physical_plan/functions.rs +++ b/datafusion/core/src/physical_plan/functions.rs @@ -306,7 +306,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), BuiltinScalarFunction::Random => Arc::new(math_expressions::random), - BuiltinScalarFunction::Round => Arc::new(math_expressions::round), + BuiltinScalarFunction::Round => { + Arc::new(|args| make_scalar_function(math_expressions::round)(args)) + } BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum), BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin), BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 77fcb70f43fd..be392cdcf758 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -257,7 +257,7 @@ unary_scalar_expr!(Atan, atan); unary_scalar_expr!(Floor, floor); unary_scalar_expr!(Ceil, ceil); unary_scalar_expr!(Now, now); -unary_scalar_expr!(Round, round); +nary_scalar_expr!(Round, round); unary_scalar_expr!(Trunc, trunc); unary_scalar_expr!(Abs, abs); unary_scalar_expr!(Signum, signum); @@ -418,7 +418,8 @@ mod test { test_unary_scalar_expr!(Floor, floor); test_unary_scalar_expr!(Ceil, ceil); test_unary_scalar_expr!(Now, now); - test_unary_scalar_expr!(Round, round); + test_nary_scalar_expr!(Round, round, input); + test_nary_scalar_expr!(Round, round, input, decimal_places); test_unary_scalar_expr!(Trunc, trunc); test_unary_scalar_expr!(Abs, abs); test_unary_scalar_expr!(Signum, signum); diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index d5cf7853054a..f7eb90f66e1d 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -113,6 +113,18 @@ macro_rules! make_function_inputs2 { }) .collect::<$ARRAY_TYPE>() }}; + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE1>() + }}; } math_unary_function!("sqrt", sqrt); @@ -124,7 +136,6 @@ math_unary_function!("acos", acos); math_unary_function!("atan", atan); math_unary_function!("floor", floor); math_unary_function!("ceil", ceil); -math_unary_function!("round", round); math_unary_function!("trunc", trunc); math_unary_function!("abs", abs); math_unary_function!("signum", signum); @@ -160,6 +171,59 @@ pub fn random(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +/// Round SQL function +pub fn round(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "round function requires one or two arguments, got {}", + args.len() + ))); + } + + let mut decimal_places = + &(Arc::new(Int64Array::from_value(0, args[0].len())) as ArrayRef); + + if args.len() == 2 { + decimal_places = &args[1]; + } + + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float64Array, + Int64Array, + { + |value: f64, decimal_places: i64| { + (value * 10.0_f64.powi(decimal_places.try_into().unwrap())).round() + / 10.0_f64.powi(decimal_places.try_into().unwrap()) + } + } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float32Array, + Int64Array, + { + |value: f32, decimal_places: i64| { + (value * 10.0_f32.powi(decimal_places.try_into().unwrap())).round() + / 10.0_f32.powi(decimal_places.try_into().unwrap()) + } + } + )) as ArrayRef), + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function round" + ))), + } +} + pub fn power(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Float64 => Ok(Arc::new(make_function_inputs2!( @@ -202,4 +266,44 @@ mod tests { assert_eq!(floats.len(), 1); assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); } + + #[test] + fn test_round_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![125.2345; 10])), // input + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = result + .as_any() + .downcast_ref::() + .expect("failed to initialize function round"); + + let expected = Float32Array::from(vec![ + 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, + ]); + + assert_eq!(floats, &expected); + } + + #[test] + fn test_round_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345; 10])), // input + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = result + .as_any() + .downcast_ref::() + .expect("failed to initialize function round"); + + let expected = Float64Array::from(vec![ + 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, + ]); + + assert_eq!(floats, &expected); + } } diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index b042958803dc..3b6445d1452d 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -1089,7 +1089,12 @@ pub fn parse_expr( ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], registry)?)), ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], registry)?)), ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry)?)), - ScalarFunction::Round => Ok(round(parse_expr(&args[0], registry)?)), + ScalarFunction::Round => Ok(round( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], registry)?)), ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], registry)?)), ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], registry)?)), From ce4e9b322138d2f2ee2093c2d5f49afbec5439ed Mon Sep 17 00:00:00 2001 From: Mikhail Cheshkov Date: Mon, 12 May 2025 23:53:54 +0200 Subject: [PATCH 3/3] Bump actions/cache to v3 Can drop this after rebase on commit 49e072a "Update actions (#2678)", first released in 9.0.0 --- .github/workflows/rust.yml | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 89456cd77f59..168da334276a 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -40,14 +40,14 @@ jobs: steps: - uses: actions/checkout@v2 - name: Cache Cargo - uses: actions/cache@v2 + uses: actions/cache@v3 with: # these represent dependencies downloaded by cargo # and thus do not depend on the OS, arch nor rust version. path: /github/home/.cargo key: cargo-cache- - name: Cache Rust dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: # these represent compiled steps of both dependencies and arrow # and thus are specific for a particular OS, arch and rust version. @@ -103,13 +103,13 @@ jobs: with: submodules: true - name: Cache Cargo - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /github/home/.cargo # this key equals the ones on `linux-build-lib` for re-use key: cargo-cache- - name: Cache Rust dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /github/home/target # this key equals the ones on `linux-build-lib` for re-use @@ -250,13 +250,13 @@ jobs: # with: # submodules: true # - name: Cache Cargo - # uses: actions/cache@v2 + # uses: actions/cache@v3 # with: # path: /github/home/.cargo # # this key equals the ones on `linux-build-lib` for re-use # key: cargo-cache- # - name: Cache Rust dependencies - # uses: actions/cache@v2 + # uses: actions/cache@v3 # with: # path: /github/home/target # # this key equals the ones on `linux-build-lib` for re-use @@ -315,13 +315,13 @@ jobs: with: submodules: true - name: Cache Cargo - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /github/home/.cargo # this key equals the ones on `linux-build-lib` for re-use key: cargo-cache- - name: Cache Rust dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /github/home/target # this key equals the ones on `linux-build-lib` for re-use @@ -360,13 +360,13 @@ jobs: with: submodules: true - name: Cache Cargo - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /github/home/.cargo # this key equals the ones on `linux-build-lib` for re-use key: cargo-cache- - name: Cache Rust dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /github/home/target # this key equals the ones on `linux-build-lib` for re-use @@ -419,13 +419,13 @@ jobs: # with: # submodules: true # - name: Cache Cargo -# uses: actions/cache@v2 +# uses: actions/cache@v3 # with: # path: /github/home/.cargo # # this key equals the ones on `linux-build-lib` for re-use # key: cargo-cache- # - name: Cache Rust dependencies -# uses: actions/cache@v2 +# uses: actions/cache@v3 # with: # path: /github/home/target # # this key equals the ones on `linux-build-lib` for re-use @@ -466,13 +466,13 @@ jobs: # with: # submodules: true # - name: Cache Cargo -# uses: actions/cache@v2 +# uses: actions/cache@v3 # with: # path: /home/runner/.cargo # # this key is not equal because the user is different than on a container (runner vs github) # key: cargo-coverage-cache- # - name: Cache Rust dependencies -# uses: actions/cache@v2 +# uses: actions/cache@v3 # with: # path: /home/runner/target # # this key is not equal because coverage uses different compilation flags.