Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 130 additions & 66 deletions datafusion/physical-expr/benches/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,29 @@
// under the License.

use arrow::array::builder::{Int32Builder, StringBuilder};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::array::{Array, ArrayRef, Int32Array};
use arrow::datatypes::{Field, Schema};
use arrow::record_batch::RecordBatch;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr, Column, Literal};
use datafusion_physical_expr::expressions::{case, col, lit, BinaryExpr};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::sync::Arc;

fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
fn make_x_cmp_y(
x: &Arc<dyn PhysicalExpr>,
op: Operator,
y: i32,
) -> Arc<dyn PhysicalExpr> {
Arc::new(BinaryExpr::new(Arc::clone(x), op, lit(y)))
}

fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
}

fn criterion_benchmark(c: &mut Criterion) {
// create input data
fn make_batch(row_count: usize, column_count: usize) -> RecordBatch {
let mut c1 = Int32Builder::new();
let mut c2 = StringBuilder::new();
let mut c3 = StringBuilder::new();
for i in 0..1000 {
c1.append_value(i);
for i in 0..row_count {
c1.append_value(i as i32);
if i % 7 == 0 {
c2.append_null();
} else {
Expand All @@ -54,69 +53,134 @@ fn criterion_benchmark(c: &mut Criterion) {
let c1 = Arc::new(c1.finish());
let c2 = Arc::new(c2.finish());
let c3 = Arc::new(c3.finish());
let schema = Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
Field::new("c3", DataType::Utf8, true),
]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap();

// use same predicate for all benchmarks
let predicate = Arc::new(BinaryExpr::new(
make_col("c1", 0),
Operator::LtEq,
make_lit_i32(500),
));
let mut columns: Vec<ArrayRef> = vec![c1, c2, c3];
for _ in 3..column_count {
columns.push(Arc::new(Int32Array::from_value(0, row_count)));
}

// CASE WHEN c1 <= 500 THEN 1 ELSE 0 END
c.bench_function("case_when: scalar or scalar", |b| {
let expr = Arc::new(
CaseExpr::try_new(
None,
vec![(predicate.clone(), make_lit_i32(1))],
Some(make_lit_i32(0)),
let fields = columns
.iter()
.enumerate()
.map(|(i, c)| {
Field::new(
format!("c{}", i + 1),
c.data_type().clone(),
c.is_nullable(),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});
})
.collect::<Vec<_>>();

// CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END
c.bench_function("case_when: column or null", |b| {
let expr = Arc::new(
CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None)
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(Arc::clone(&schema), columns).unwrap()
}

fn criterion_benchmark(c: &mut Criterion) {
run_benchmarks(c, &make_batch(8192, 3));
run_benchmarks(c, &make_batch(8192, 50));
run_benchmarks(c, &make_batch(8192, 100));
}

fn run_benchmarks(c: &mut Criterion, batch: &RecordBatch) {
let c1 = col("c1", &batch.schema()).unwrap();
let c2 = col("c2", &batch.schema()).unwrap();
let c3 = col("c3", &batch.schema()).unwrap();

c.bench_function(
format!(
"case_when {}x{}: CASE WHEN c1 <= 500 THEN 1 ELSE 0 END",
batch.num_rows(),
batch.num_columns()
)
.as_str(),
|b| {
let expr = Arc::new(
case(
None,
vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), lit(1))],
Some(lit(0)),
&batch.schema(),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
},
);

// CASE WHEN c1 <= 500 THEN c2 ELSE c3 END
c.bench_function("case_when: expr or expr", |b| {
let expr = Arc::new(
CaseExpr::try_new(
None,
vec![(predicate.clone(), make_col("c2", 1))],
Some(make_col("c3", 2)),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});
c.bench_function(
format!(
"case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END",
batch.num_rows(),
batch.num_columns()
)
.as_str(),
|b| {
let expr = Arc::new(
case(
None,
vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), Arc::clone(&c2))],
None,
&batch.schema(),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
},
);

c.bench_function(
format!(
"case_when {}x{}: CASE WHEN c1 <= 500 THEN c2 ELSE c3 END",
batch.num_rows(),
batch.num_columns()
)
.as_str(),
|b| {
let expr = Arc::new(
case(
None,
vec![(make_x_cmp_y(&c1, Operator::LtEq, 500), Arc::clone(&c2))],
Some(Arc::clone(&c3)),
&batch.schema(),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
},
);

// CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END
c.bench_function("case_when: CASE expr", |b| {
c.bench_function(
format!(
"case_when {}x{}: CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END",
batch.num_rows(),
batch.num_columns()
)
.as_str(),
|b| {
let expr = Arc::new(
case(
Some(Arc::clone(&c1)),
vec![(lit(1), Arc::clone(&c2)), (lit(2), Arc::clone(&c3))],
None,
&batch.schema(),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
},
);

c.bench_function(format!("case_when {}x{}: CASE WHEN c1 == 0 THEN 0 WHEN c1 == 1 THEN 1 ... WHEN c1 == n THEN n ELSE n + 1 END", batch.num_rows(), batch.num_columns()).as_str(), |b| {
let when_thens = (0..batch.num_rows() as i32).map(|i| (make_x_cmp_y(&c1, Operator::Eq, i), lit(i))).collect();
let expr = Arc::new(
CaseExpr::try_new(
Some(make_col("c1", 0)),
vec![
(make_lit_i32(1), make_col("c2", 1)),
(make_lit_i32(2), make_col("c3", 2)),
],
case(
None,
when_thens,
Some(lit(batch.num_rows() as i32)),
&batch.schema(),
)
.unwrap(),
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
b.iter(|| black_box(expr.evaluate(black_box(batch)).unwrap()))
});
}

Expand Down
18 changes: 15 additions & 3 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use crate::expressions::try_cast;
use crate::PhysicalExpr;
use crate::{expressions, PhysicalExpr};
use std::borrow::Cow;
use std::hash::Hash;
use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -600,8 +600,20 @@ pub fn case(
expr: Option<Arc<dyn PhysicalExpr>>,
when_thens: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
let case_expr = CaseExpr::try_new(expr, when_thens, else_expr)?;

match case_expr.eval_method {
EvalMethod::NoExpression
| EvalMethod::WithExpression
| EvalMethod::ExpressionOrExpression => {
expressions::projected(Arc::new(case_expr), input_schema)
}
EvalMethod::InfallibleExprOrNull | EvalMethod::ScalarOrScalar => {
Ok(Arc::new(case_expr))
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1381,7 +1393,7 @@ mod tests {
Ok((left, right))
}
}?;
case(expr, when_thens, else_expr)
case(expr, when_thens, else_expr, input_schema)
}

fn get_case_common_type(
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod literal;
mod negative;
mod no_op;
mod not;
mod projected;
mod try_cast;
mod unknown_column;

Expand All @@ -54,5 +55,6 @@ pub use literal::{lit, Literal};
pub use negative::{negative, NegativeExpr};
pub use no_op::NoOp;
pub use not::{not, NotExpr};
pub use projected::{projected, ProjectedExpr};
pub use try_cast::{try_cast, TryCastExpr};
pub use unknown_column::UnKnownColumn;
Loading
Loading