Skip to content

Commit 98359f9

Browse files
committed
feat: Support PERCENTILE_CONT planning
1 parent dcf3e4a commit 98359f9

File tree

13 files changed

+311
-12
lines changed

13 files changed

+311
-12
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ cranelift-module = { version = "0.82.0", optional = true }
4444
ordered-float = "2.10"
4545
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "a03d4eef5640e05dddf99fc2357ad6d58b5337cb", features = ["arrow"], optional = true }
4646
pyo3 = { version = "0.16", optional = true }
47-
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
47+
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "3a3a7e582f51576c4d2ac2350512564633fe02dd" }

datafusion/core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ pin-project-lite= "^0.2.7"
7979
pyo3 = { version = "0.16", optional = true }
8080
rand = "0.8"
8181
smallvec = { version = "1.6", features = ["union"] }
82-
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
82+
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "3a3a7e582f51576c4d2ac2350512564633fe02dd" }
8383
tempfile = "3"
8484
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] }
8585
tokio-stream = "0.1"

datafusion/core/src/physical_plan/aggregates.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ pub fn create_aggregate_expr(
239239
.to_string(),
240240
));
241241
}
242+
(AggregateFunction::PercentileCont, _) => {
243+
Arc::new(expressions::PercentileCont::new(
244+
// Pass in the desired percentile expr
245+
name,
246+
coerced_phy_exprs,
247+
return_type,
248+
)?)
249+
}
242250
(AggregateFunction::ApproxMedian, false) => {
243251
Arc::new(expressions::ApproxMedian::new(
244252
coerced_phy_exprs[0].clone(),

datafusion/core/src/sql/planner.rs

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ use datafusion_expr::expr::GroupingSet;
5656
use sqlparser::ast::{
5757
ArrayAgg, BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr,
5858
Fetch, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
59-
ObjectName, Offset as SQLOffset, Query, Select, SelectItem, SetExpr, SetOperator,
60-
ShowStatementFilter, TableFactor, TableWithJoins, TrimWhereField, UnaryOperator,
61-
Value, Values as SQLValues,
59+
ObjectName, Offset as SQLOffset, PercentileCont, Query, Select, SelectItem, SetExpr,
60+
SetOperator, ShowStatementFilter, TableFactor, TableWithJoins, TrimWhereField,
61+
UnaryOperator, Value, Values as SQLValues,
6262
};
6363
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
6464
use sqlparser::ast::{ObjectType, OrderByExpr, Statement};
@@ -1437,22 +1437,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
14371437

14381438
let order_by_rex = order_by
14391439
.into_iter()
1440-
.map(|e| self.order_by_to_sort_expr(e, plan.schema()))
1440+
.map(|e| self.order_by_to_sort_expr(e, plan.schema(), true))
14411441
.collect::<Result<Vec<_>>>()?;
14421442

14431443
LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build()
14441444
}
14451445

14461446
/// convert sql OrderByExpr to Expr::Sort
1447-
fn order_by_to_sort_expr(&self, e: OrderByExpr, schema: &DFSchema) -> Result<Expr> {
1447+
fn order_by_to_sort_expr(
1448+
&self,
1449+
e: OrderByExpr,
1450+
schema: &DFSchema,
1451+
parse_indexes: bool,
1452+
) -> Result<Expr> {
14481453
let OrderByExpr {
14491454
asc,
14501455
expr,
14511456
nulls_first,
14521457
} = e;
14531458

14541459
let expr = match expr {
1455-
SQLExpr::Value(Value::Number(v, _)) => {
1460+
SQLExpr::Value(Value::Number(v, _)) if parse_indexes => {
14561461
let field_index = v
14571462
.parse::<usize>()
14581463
.map_err(|err| DataFusionError::Plan(err.to_string()))?;
@@ -2310,7 +2315,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
23102315
let order_by = window
23112316
.order_by
23122317
.into_iter()
2313-
.map(|e| self.order_by_to_sort_expr(e, schema))
2318+
.map(|e| self.order_by_to_sort_expr(e, schema, true))
23142319
.collect::<Result<Vec<_>>>()?;
23152320
let window_frame = window
23162321
.window_frame
@@ -2438,6 +2443,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
24382443

24392444
SQLExpr::ArrayAgg(array_agg) => self.parse_array_agg(array_agg, schema),
24402445

2446+
SQLExpr::PercentileCont(percentile_cont) => self.parse_percentile_cont(percentile_cont, schema),
2447+
24412448
_ => Err(DataFusionError::NotImplemented(format!(
24422449
"Unsupported ast node {:?} in sqltorel",
24432450
sql
@@ -2491,6 +2498,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
24912498
})
24922499
}
24932500

2501+
fn parse_percentile_cont(
2502+
&self,
2503+
percentile_cont: PercentileCont,
2504+
input_schema: &DFSchema,
2505+
) -> Result<Expr> {
2506+
let PercentileCont { expr, within_group } = percentile_cont;
2507+
2508+
// Some dialects have special syntax for percentile_cont. DataFusion only supports it like a function.
2509+
let expr = self.sql_expr_to_logical_expr(*expr, input_schema)?;
2510+
let (order_by_expr, asc, nulls_first) =
2511+
match self.order_by_to_sort_expr(*within_group, input_schema, false)? {
2512+
Expr::Sort {
2513+
expr,
2514+
asc,
2515+
nulls_first,
2516+
} => (expr, asc, nulls_first),
2517+
_ => {
2518+
return Err(DataFusionError::Internal(
2519+
"PercentileCont expected Sort expression in ORDER BY".to_string(),
2520+
))
2521+
}
2522+
};
2523+
let asc_expr = Expr::Literal(ScalarValue::Boolean(Some(asc)));
2524+
let nulls_first_expr = Expr::Literal(ScalarValue::Boolean(Some(nulls_first)));
2525+
2526+
let args = vec![expr, *order_by_expr, asc_expr, nulls_first_expr];
2527+
// next, aggregate built-ins
2528+
let fun = aggregates::AggregateFunction::PercentileCont;
2529+
2530+
Ok(Expr::AggregateFunction {
2531+
fun,
2532+
distinct: false,
2533+
args,
2534+
})
2535+
}
2536+
24942537
fn function_args_to_expr(
24952538
&self,
24962539
args: Vec<FunctionArg>,
@@ -4130,6 +4173,15 @@ mod tests {
41304173
quick_test(sql, expected);
41314174
}
41324175

4176+
#[test]
4177+
fn select_percentile_cont() {
4178+
let sql = "SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY age) FROM person";
4179+
let expected = "Projection: #PERCENTILECONT(Float64(0.5),person.age,Boolean(true),Boolean(false))\
4180+
\n Aggregate: groupBy=[[]], aggr=[[PERCENTILECONT(Float64(0.5), #person.age, Boolean(true), Boolean(false))]]\
4181+
\n TableScan: person projection=None";
4182+
quick_test(sql, expected);
4183+
}
4184+
41334185
#[test]
41344186
fn select_scalar_func() {
41354187
let sql = "SELECT sqrt(age) FROM person";

datafusion/expr/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ path = "src/lib.rs"
3838
ahash = { version = "0.7", default-features = false }
3939
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "a03d4eef5640e05dddf99fc2357ad6d58b5337cb", features = ["prettyprint"] }
4040
datafusion-common = { path = "../common", version = "7.0.0" }
41-
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
41+
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "3a3a7e582f51576c4d2ac2350512564633fe02dd" }

datafusion/expr/src/aggregate_function.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ pub enum AggregateFunction {
8484
ApproxPercentileCont,
8585
/// Approximate continuous percentile function with weight
8686
ApproxPercentileContWithWeight,
87+
/// Continuous percentile function
88+
PercentileCont,
8789
/// ApproxMedian
8890
ApproxMedian,
8991
/// BoolAnd
@@ -124,6 +126,7 @@ impl FromStr for AggregateFunction {
124126
"approx_percentile_cont_with_weight" => {
125127
AggregateFunction::ApproxPercentileContWithWeight
126128
}
129+
"percentile_cont" => AggregateFunction::PercentileCont,
127130
"approx_median" => AggregateFunction::ApproxMedian,
128131
"bool_and" => AggregateFunction::BoolAnd,
129132
"bool_or" => AggregateFunction::BoolOr,
@@ -178,6 +181,7 @@ pub fn return_type(
178181
AggregateFunction::ApproxPercentileContWithWeight => {
179182
Ok(coerced_data_types[0].clone())
180183
}
184+
AggregateFunction::PercentileCont => Ok(coerced_data_types[1].clone()),
181185
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
182186
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => Ok(DataType::Boolean),
183187
}
@@ -324,6 +328,33 @@ pub fn coerce_types(
324328
}
325329
Ok(input_types.to_vec())
326330
}
331+
AggregateFunction::PercentileCont => {
332+
if !matches!(input_types[0], DataType::Float64) {
333+
return Err(DataFusionError::Plan(format!(
334+
"The percentile argument for {:?} must be Float64, not {:?}.",
335+
agg_fun, input_types[0]
336+
)));
337+
}
338+
if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) {
339+
return Err(DataFusionError::Plan(format!(
340+
"The function {:?} does not support inputs of type {:?}.",
341+
agg_fun, input_types[1]
342+
)));
343+
}
344+
if !matches!(input_types[2], DataType::Boolean) {
345+
return Err(DataFusionError::Plan(format!(
346+
"The asc argument for {:?} must be Boolean, not {:?}.",
347+
agg_fun, input_types[2]
348+
)));
349+
}
350+
if !matches!(input_types[3], DataType::Boolean) {
351+
return Err(DataFusionError::Plan(format!(
352+
"The nulls_first argument for {:?} must be Boolean, not {:?}.",
353+
agg_fun, input_types[3]
354+
)));
355+
}
356+
Ok(input_types.to_vec())
357+
}
327358
AggregateFunction::ApproxMedian => {
328359
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
329360
return Err(DataFusionError::Plan(format!(
@@ -395,6 +426,21 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
395426
.collect(),
396427
Volatility::Immutable,
397428
),
429+
AggregateFunction::PercentileCont => Signature::one_of(
430+
// Accept a float64 percentile paired with any numeric value, plus bool values
431+
NUMERICS
432+
.iter()
433+
.map(|t| {
434+
TypeSignature::Exact(vec![
435+
DataType::Float64,
436+
t.clone(),
437+
DataType::Boolean,
438+
DataType::Boolean,
439+
])
440+
})
441+
.collect(),
442+
Volatility::Immutable,
443+
),
398444
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
399445
Signature::exact(vec![DataType::Boolean], Volatility::Immutable)
400446
}

datafusion/physical-expr/src/expressions/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ mod not;
4949
mod nth_value;
5050
mod nullif;
5151
mod outer_column;
52+
mod percentile_cont;
5253
mod rank;
5354
mod row_number;
5455
mod stats;
@@ -95,6 +96,7 @@ pub use not::{not, NotExpr};
9596
pub use nth_value::NthValue;
9697
pub use nullif::nullif_func;
9798
pub use outer_column::OuterColumn;
99+
pub use percentile_cont::PercentileCont;
98100
pub use rank::{dense_rank, percent_rank, rank};
99101
pub use row_number::RowNumber;
100102
pub use stats::StatsType;

0 commit comments

Comments
 (0)