diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs index 83cc5cded8361..bb706aa614dbc 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs @@ -28,5 +28,6 @@ pub use native::Bitmap65536DistinctCountAccumulator; pub use native::Bitmap65536DistinctCountAccumulatorI16; pub use native::BoolArray256DistinctCountAccumulator; pub use native::BoolArray256DistinctCountAccumulatorI8; +pub use native::BooleanDistinctCountAccumulator; pub use native::FloatDistinctCountAccumulator; pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs index fb9cfb379a26e..c7b466d4f0e0c 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs @@ -27,13 +27,14 @@ use std::mem::size_of_val; use std::sync::Arc; use arrow::array::ArrayRef; +use arrow::array::BooleanArray; use arrow::array::PrimitiveArray; use arrow::array::types::ArrowPrimitiveType; use arrow::datatypes::DataType; use datafusion_common::hash_utils::RandomState; use datafusion_common::ScalarValue; -use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_expr_common::accumulator::Accumulator; @@ -518,3 +519,101 @@ impl Accumulator for Bitmap65536DistinctCountAccumulatorI16 { size_of_val(self) + 8192 } } + +/// Optimized COUNT DISTINCT accumulator for `Boolean` using two flags. +/// +/// Tracks whether `false` and `true` have been observed; nulls are skipped. +/// Result is always 0, 1, or 2. +#[derive(Debug)] +pub struct BooleanDistinctCountAccumulator { + has_seen_false: bool, + has_seen_true: bool, +} + +impl BooleanDistinctCountAccumulator { + pub fn new() -> Self { + Self { + has_seen_false: false, + has_seen_true: false, + } + } + + #[inline] + fn seen_both(&self) -> bool { + self.has_seen_false && self.has_seen_true + } + + #[inline] + fn count(&self) -> i64 { + (self.has_seen_false as u8 + self.has_seen_true as u8) as i64 + } + + /// Update flags from a `BooleanArray`, short-circuiting per-flag once set. + #[inline] + fn observe(&mut self, arr: &BooleanArray) { + if !self.has_seen_false && arr.has_false() { + self.has_seen_false = true; + } + if !self.has_seen_true && arr.has_true() { + self.has_seen_true = true; + } + } +} + +impl Default for BooleanDistinctCountAccumulator { + fn default() -> Self { + Self::new() + } +} + +impl Accumulator for BooleanDistinctCountAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + if values.is_empty() || self.seen_both() { + return Ok(()); + } + + let arr = as_boolean_array(&values[0])?; + self.observe(arr); + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + if states.is_empty() || self.seen_both() { + return Ok(()); + } + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if self.seen_both() { + return Ok(()); + } + if let Some(list) = maybe_list { + self.observe(as_boolean_array(&list)?); + }; + Ok(()) + }) + } + + fn state(&mut self) -> datafusion_common::Result> { + let mut values: Vec = Vec::with_capacity(2); + if self.has_seen_false { + values.push(false); + } + if self.has_seen_true { + values.push(true); + } + + let arr = Arc::new(BooleanArray::from(values)); + Ok(vec![ + SingleRowListArrayBuilder::new(arr).build_list_scalar(), + ]) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + Ok(ScalarValue::Int64(Some(self.count()))) + } + + fn size(&self) -> usize { + size_of_val(self) + } +} diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index cc42b6c22bdbe..df221dbce7154 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -42,6 +42,7 @@ use datafusion_expr::{ use datafusion_functions_aggregate_common::aggregate::count_distinct::{ Bitmap65536DistinctCountAccumulator, Bitmap65536DistinctCountAccumulatorI16, BoolArray256DistinctCountAccumulator, BoolArray256DistinctCountAccumulatorI8, + BooleanDistinctCountAccumulator, }; use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator; use datafusion_macros::user_doc; @@ -336,10 +337,13 @@ impl ApproxDistinct { } #[cold] -fn get_small_int_approx_accumulator( +fn get_fixed_domain_approx_accumulator( data_type: &DataType, ) -> Result> { match data_type { + DataType::Boolean => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: BooleanDistinctCountAccumulator::new(), + })), DataType::UInt8 => Ok(Box::new(ApproxDistinctBitmapWrapper { inner: BoolArray256DistinctCountAccumulator::new(), })), @@ -357,7 +361,10 @@ fn get_small_int_approx_accumulator( } #[cold] -fn get_small_int_state_field(name: &str, data_type: &DataType) -> Result> { +fn get_fixed_domain_state_field( + name: &str, + data_type: &DataType, +) -> Result> { Ok(vec![ Field::new_list( format_state_name(name, "approx_distinct"), @@ -392,9 +399,11 @@ impl AggregateUDFImpl for ApproxDistinct { ) .into(), ]), - DataType::UInt8 | DataType::Int8 | DataType::UInt16 | DataType::Int16 => { - get_small_int_state_field(args.name, data_type) - } + DataType::Boolean + | DataType::UInt8 + | DataType::Int8 + | DataType::UInt16 + | DataType::Int16 => get_fixed_domain_state_field(args.name, data_type), _ => Ok(vec![ Field::new( format_state_name(args.name, "hll_registers"), @@ -410,8 +419,12 @@ impl AggregateUDFImpl for ApproxDistinct { let data_type = acc_args.expr_fields[0].data_type(); let accumulator: Box = match data_type { - DataType::UInt8 | DataType::Int8 | DataType::UInt16 | DataType::Int16 => { - return get_small_int_approx_accumulator(data_type); + DataType::Boolean + | DataType::UInt8 + | DataType::Int8 + | DataType::UInt16 + | DataType::Int16 => { + return get_fixed_domain_approx_accumulator(data_type); } DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 25b69d16dd035..a18da0cde23b7 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1836,6 +1836,50 @@ SELECT approx_distinct(c14) AS a, approx_distinct(c15) AS b, approx_distinct(arr ---- 18 60 60 60 60 +# approx_distinct over Boolean: exact count via flag-pair accumulator (0..=2). +statement ok +CREATE TABLE approx_distinct_bool_test (g INT, b BOOLEAN) AS VALUES + (1, true), (1, true), (1, NULL), + (2, false), (2, false), + (3, true), (3, false), (3, NULL), (3, true), + (4, NULL), (4, NULL); + +query I +SELECT approx_distinct(b) FROM approx_distinct_bool_test WHERE g = 1; +---- +1 + +query I +SELECT approx_distinct(b) FROM approx_distinct_bool_test WHERE g = 2; +---- +1 + +query I +SELECT approx_distinct(b) FROM approx_distinct_bool_test WHERE g = 3; +---- +2 + +query I +SELECT approx_distinct(b) FROM approx_distinct_bool_test WHERE g = 4; +---- +0 + +query II +SELECT g, approx_distinct(b) FROM approx_distinct_bool_test GROUP BY g ORDER BY g; +---- +1 1 +2 1 +3 2 +4 0 + +query I +SELECT approx_distinct(b) FROM approx_distinct_bool_test; +---- +2 + +statement ok +DROP TABLE approx_distinct_bool_test; + ## This test executes the APPROX_PERCENTILE_CONT aggregation against the test ## data, asserting the estimated quantiles are ±5% their actual values. ##