diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 010df564a948..7a33aa95c56b 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -23,8 +23,9 @@ use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate, + SlicesIterator, }; -use arrow::datatypes::{DataType, Schema, UInt32Type}; +use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode}; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; @@ -246,13 +247,26 @@ fn is_cheap_and_infallible(expr: &Arc) -> bool { } /// Creates a [FilterPredicate] from a boolean array. -fn create_filter(predicate: &BooleanArray) -> FilterPredicate { +fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate { let mut filter_builder = FilterBuilder::new(predicate); - // Always optimize the filter since we use them multiple times. - filter_builder = filter_builder.optimize(); + if optimize { + // Always optimize the filter since we use them multiple times. + filter_builder = filter_builder.optimize(); + } filter_builder.build() } +fn multiple_arrays(data_type: &DataType) -> bool { + match data_type { + DataType::Struct(fields) => { + fields.len() > 1 + || fields.len() == 1 && multiple_arrays(fields[0].data_type()) + } + DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(), + _ => false, + } +} + // This should be removed when https://github.com/apache/arrow-rs/pull/8693 // is merged and becomes available. fn filter_record_batch( @@ -290,6 +304,84 @@ fn filter_array( filter.filter(array) } +fn merge( + mask: &BooleanArray, + truthy: ColumnarValue, + falsy: ColumnarValue, +) -> std::result::Result { + let (truthy, truthy_is_scalar) = match truthy { + ColumnarValue::Array(a) => (a, false), + ColumnarValue::Scalar(s) => (s.to_array()?, true), + }; + let (falsy, falsy_is_scalar) = match falsy { + ColumnarValue::Array(a) => (a, false), + ColumnarValue::Scalar(s) => (s.to_array()?, true), + }; + + if truthy_is_scalar && falsy_is_scalar { + return zip(mask, &Scalar::new(truthy), &Scalar::new(falsy)); + } + + let falsy = falsy.to_data(); + let truthy = truthy.to_data(); + + let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len()); + + // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to + // fill with falsy values + + // keep track of how much is filled + let mut filled = 0; + let mut falsy_offset = 0; + let mut truthy_offset = 0; + + SlicesIterator::new(mask).for_each(|(start, end)| { + // the gap needs to be filled with falsy values + if start > filled { + if falsy_is_scalar { + for _ in filled..start { + // Copy the first item from the 'falsy' array into the output buffer. + mutable.extend(1, 0, 1); + } + } else { + let falsy_length = start - filled; + let falsy_end = falsy_offset + falsy_length; + mutable.extend(1, falsy_offset, falsy_end); + falsy_offset = falsy_end; + } + } + // fill with truthy values + if truthy_is_scalar { + for _ in start..end { + // Copy the first item from the 'truthy' array into the output buffer. + mutable.extend(0, 0, 1); + } + } else { + let truthy_length = end - start; + let truthy_end = truthy_offset + truthy_length; + mutable.extend(0, truthy_offset, truthy_end); + truthy_offset = truthy_end; + } + filled = end; + }); + // the remaining part is falsy + if filled < mask.len() { + if falsy_is_scalar { + for _ in filled..mask.len() { + // Copy the first item from the 'falsy' array into the output buffer. + mutable.extend(1, 0, 1); + } + } else { + let falsy_length = mask.len() - filled; + let falsy_end = falsy_offset + falsy_length; + mutable.extend(1, falsy_offset, falsy_end); + } + } + + let data = mutable.freeze(); + Ok(make_array(data)) +} + /// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from /// those values. /// @@ -342,7 +434,7 @@ fn filter_array( /// └───────────┘ └─────────┘ └─────────┘ /// values indices result /// ``` -fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result { +fn merge_n(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result { #[cfg(debug_assertions)] for ix in indices { if let Some(index) = ix.index() { @@ -647,7 +739,7 @@ impl ResultBuilder { } Partial { arrays, indices } => { // Merge partial results into a single array. - Ok(ColumnarValue::Array(merge(&arrays, &indices)?)) + Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?)) } Complete(v) => { // If we have a complete result, we can just return it. @@ -723,6 +815,26 @@ impl CaseExpr { } impl CaseBody { + fn data_type(&self, input_schema: &Schema) -> Result { + // since all then results have the same data type, we can choose any one as the + // return data type except for the null. + let mut data_type = DataType::Null; + for i in 0..self.when_then_expr.len() { + data_type = self.when_then_expr[i].1.data_type(input_schema)?; + if !data_type.equals_datatype(&DataType::Null) { + break; + } + } + // if all then results are null, we use data type of else expr instead if possible. + if data_type.equals_datatype(&DataType::Null) { + if let Some(e) = &self.else_expr { + data_type = e.data_type(input_schema)?; + } + } + + Ok(data_type) + } + /// See [CaseExpr::case_when_with_expr]. fn case_when_with_expr( &self, @@ -767,7 +879,7 @@ impl CaseBody { result_builder.add_branch_result(&remainder_rows, nulls_value)?; } else { // Filter out the null rows and evaluate the else expression for those - let nulls_filter = create_filter(¬(&base_not_nulls)?); + let nulls_filter = create_filter(¬(&base_not_nulls)?, true); let nulls_batch = filter_record_batch(&remainder_batch, &nulls_filter)?; let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; @@ -782,7 +894,7 @@ impl CaseBody { } // Remove the null rows from the remainder batch - let not_null_filter = create_filter(&base_not_nulls); + let not_null_filter = create_filter(&base_not_nulls, true); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?); remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?; @@ -802,8 +914,7 @@ impl CaseBody { compare_with_eq(&a, &base_values, base_value_is_nested) } ColumnarValue::Scalar(s) => { - let scalar = Scalar::new(s.to_array()?); - compare_with_eq(&scalar, &base_values, base_value_is_nested) + compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested) } }?; @@ -829,7 +940,7 @@ impl CaseBody { // for the current branch // Still no need to call `prep_null_mask_filter` since `create_filter` will already do // this unconditionally. - let then_filter = create_filter(&when_value); + let then_filter = create_filter(&when_value, true); let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; let then_rows = filter_array(&remainder_rows, &then_filter)?; @@ -852,7 +963,7 @@ impl CaseBody { not(&prep_null_mask_filter(&when_value)) } }?; - let next_filter = create_filter(&next_selection); + let next_filter = create_filter(&next_selection, true); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); remainder_rows = filter_array(&remainder_rows, &next_filter)?; @@ -918,7 +1029,7 @@ impl CaseBody { // for the current branch // Still no need to call `prep_null_mask_filter` since `create_filter` will already do // this unconditionally. - let then_filter = create_filter(when_value); + let then_filter = create_filter(when_value, true); let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; let then_rows = filter_array(&remainder_rows, &then_filter)?; @@ -941,7 +1052,7 @@ impl CaseBody { not(&prep_null_mask_filter(when_value)) } }?; - let next_filter = create_filter(&next_selection); + let next_filter = create_filter(&next_selection, true); remainder_batch = Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); remainder_rows = filter_array(&remainder_rows, &next_filter)?; @@ -964,24 +1075,39 @@ impl CaseBody { &self, batch: &RecordBatch, when_value: &BooleanArray, - return_type: &DataType, ) -> Result { - let then_value = self.when_then_expr[0] - .1 - .evaluate_selection(batch, when_value)? - .into_array(batch.num_rows())?; + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => { + // `prep_null_mask_filter` is required to ensure null is treated as false + Cow::Owned(prep_null_mask_filter(when_value)) + } + }; + + let optimize_filter = batch.num_columns() > 1 + || (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type())); + + let when_filter = create_filter(&when_value, optimize_filter); + let then_batch = filter_record_batch(batch, &when_filter)?; + let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?; + + let else_selection = not(&when_value)?; + let else_filter = create_filter(&else_selection, optimize_filter); + let else_batch = filter_record_batch(batch, &else_filter)?; - // evaluate else expression on the values not covered by when_value - let remainder = not(when_value)?; - let e = self.else_expr.as_ref().unwrap(); // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + let e = self.else_expr.as_ref().unwrap(); + let return_type = self.data_type(&batch.schema())?; + let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) .unwrap_or_else(|_| Arc::clone(e)); - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) + let else_value = else_expr.evaluate(&else_batch)?; + + Ok(ColumnarValue::Array(merge( + &when_value, + then_value, + else_value, + )?)) } } @@ -1113,11 +1239,12 @@ impl CaseExpr { batch: &RecordBatch, projected: &ProjectedCaseBody, ) -> Result { - let return_type = self.data_type(&batch.schema())?; - // evaluate when condition on batch let when_value = self.body.when_then_expr[0].0.evaluate(batch)?; - let when_value = when_value.into_array(batch.num_rows())?; + // `num_rows == 1` is intentional to avoid expanding scalars. + // If the `when_value` is effectively a scalar, the 'all true' and 'all false' checks + // below will avoid incorrectly using the scalar as a merge/zip mask. + let when_value = when_value.into_array(1)?; let when_value = as_boolean_array(&when_value).map_err(|e| { DataFusionError::Context( "WHEN expression did not return a BooleanArray".to_string(), @@ -1125,29 +1252,21 @@ impl CaseExpr { ) })?; - // For the true and false/null selection vectors, bypass `evaluate_selection` and merging - // results. This avoids materializing the array for the other branch which we will discard - // entirely anyway. let true_count = when_value.true_count(); - if true_count == batch.num_rows() { - return self.body.when_then_expr[0].1.evaluate(batch); + if true_count == when_value.len() { + // All input rows are true, just call the 'then' expression + self.body.when_then_expr[0].1.evaluate(batch) } else if true_count == 0 { - return self.body.else_expr.as_ref().unwrap().evaluate(batch); - } - - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - - if projected.projection.len() < batch.num_columns() { + // All input rows are false/null, just call the 'else' expression + self.body.else_expr.as_ref().unwrap().evaluate(batch) + } else if projected.projection.len() < batch.num_columns() { + // The case expressions do not use all the columns of the input batch. + // Project first to reduce time spent filtering. let projected_batch = batch.project(&projected.projection)?; - projected - .body - .expr_or_expr(&projected_batch, &when_value, &return_type) + projected.body.expr_or_expr(&projected_batch, when_value) } else { - self.body.expr_or_expr(batch, &when_value, &return_type) + // All columns are used in the case expressions, so there is no need to project. + self.body.expr_or_expr(batch, when_value) } } } @@ -1159,23 +1278,7 @@ impl PhysicalExpr for CaseExpr { } fn data_type(&self, input_schema: &Schema) -> Result { - // since all then results have the same data type, we can choose any one as the - // return data type except for the null. - let mut data_type = DataType::Null; - for i in 0..self.body.when_then_expr.len() { - data_type = self.body.when_then_expr[i].1.data_type(input_schema)?; - if !data_type.equals_datatype(&DataType::Null) { - break; - } - } - // if all then results are null, we use data type of else expr instead if possible. - if data_type.equals_datatype(&DataType::Null) { - if let Some(e) = &self.body.else_expr { - data_type = e.data_type(input_schema)?; - } - } - - Ok(data_type) + self.body.data_type(input_schema) } fn nullable(&self, input_schema: &Schema) -> Result { @@ -2140,7 +2243,7 @@ mod tests { } #[test] - fn test_merge() { + fn test_merge_n() { let a1 = StringArray::from(vec![Some("A")]).to_data(); let a2 = StringArray::from(vec![Some("B")]).to_data(); let a3 = StringArray::from(vec![Some("C"), Some("D")]).to_data(); @@ -2154,7 +2257,7 @@ mod tests { PartialResultIndex::try_new(2).unwrap(), ]; - let merged = merge(&[a1, a2, a3], &indices).unwrap(); + let merged = merge_n(&[a1, a2, a3], &indices).unwrap(); let merged = merged.as_string::(); assert_eq!(merged.len(), indices.len()); @@ -2169,4 +2272,24 @@ mod tests { assert!(merged.is_valid(5)); assert_eq!(merged.value(5), "D"); } + + #[test] + fn test_merge() { + let a1 = Arc::new(StringArray::from(vec![Some("A"), Some("C")])); + let a2 = Arc::new(StringArray::from(vec![Some("B")])); + + let mask = BooleanArray::from(vec![true, false, true]); + + let merged = + merge(&mask, ColumnarValue::Array(a1), ColumnarValue::Array(a2)).unwrap(); + let merged = merged.as_string::(); + + assert_eq!(merged.len(), mask.len()); + assert!(merged.is_valid(0)); + assert_eq!(merged.value(0), "A"); + assert!(merged.is_valid(1)); + assert_eq!(merged.value(1), "B"); + assert!(merged.is_valid(2)); + assert_eq!(merged.value(2), "C"); + } }