From ba897a707f68d5213b882b4a536752c1345e52d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Wed, 7 May 2025 22:52:55 +0200 Subject: [PATCH 1/7] Add PrimitiveDistinctCountGroupsAccumulator --- datafusion/functions-aggregate/Cargo.toml | 1 + datafusion/functions-aggregate/src/count.rs | 331 +++++++++++++++++++- 2 files changed, 330 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index ec6e6b633bb8..b2db621e5184 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -49,6 +49,7 @@ datafusion-macros = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } +hashbrown = {workspace = true } log = { workspace = true } paste = "1.0.14" diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 2d995b4a4179..3e1091f11e4a 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,13 +16,20 @@ // under the License. use ahash::RandomState; +use arrow::array::{ArrowPrimitiveType, ListArray}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::stats::Precision; use datafusion_expr::expr::WindowFunction; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions; +use hashbrown::hash_table::Entry; +use hashbrown::HashTable; use std::collections::HashSet; use std::fmt::Debug; +use std::hash::Hash; use std::mem::{size_of, size_of_val}; use std::ops::BitAnd; use std::sync::Arc; @@ -347,15 +354,108 @@ impl AggregateUDFImpl for Count { // groups accumulator only supports `COUNT(c1)`, not // `COUNT(c1, c2)`, etc if args.is_distinct { - return false; + return args.exprs.len() == 1 + && args.exprs[0].data_type(args.schema).unwrap().is_primitive(); } args.exprs.len() == 1 } fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { + if args.is_distinct { + if args.exprs.len() > 1 { + return not_impl_err!("COUNT DISTINCT with multiple arguments"); + } + + let data_type = &args.exprs[0].data_type(args.schema)?; + return Ok(match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int8Type, + >::new(data_type.clone())), + DataType::Int16 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int16Type, + >::new(data_type.clone())), + DataType::Int32 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int32Type, + >::new(data_type.clone())), + DataType::Int64 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int64Type, + >::new(data_type.clone())), + DataType::UInt8 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt8Type, + >::new(data_type.clone())), + DataType::UInt16 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt16Type, + >::new(data_type.clone())), + DataType::UInt32 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt32Type, + >::new(data_type.clone())), + DataType::UInt64 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt64Type, + >::new(data_type.clone())), + DataType::Decimal128(_, _) => Box::new( + PrimitiveDistinctCountGroupsAccumulator::::new( + data_type.clone(), + ), + ), + DataType::Decimal256(_, _) => Box::new( + PrimitiveDistinctCountGroupsAccumulator::::new( + data_type.clone(), + ), + ), + + DataType::Date32 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Date32Type, + >::new(data_type.clone())), + DataType::Date64 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Date64Type, + >::new(data_type.clone())), + DataType::Time32(TimeUnit::Millisecond) => { + Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Time32MillisecondType, + >::new(data_type.clone())) + } + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountGroupsAccumulator::::new( + data_type.clone(), + ), + ), + DataType::Time64(TimeUnit::Microsecond) => { + Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Time64MicrosecondType, + >::new(data_type.clone())) + } + DataType::Time64(TimeUnit::Nanosecond) => { + Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Time64NanosecondType, + >::new(data_type.clone())) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Box::new(PrimitiveDistinctCountGroupsAccumulator::< + TimestampMicrosecondType, + >::new(data_type.clone())) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + Box::new(PrimitiveDistinctCountGroupsAccumulator::< + TimestampMillisecondType, + >::new(data_type.clone())) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + Box::new(PrimitiveDistinctCountGroupsAccumulator::< + TimestampNanosecondType, + >::new(data_type.clone())) + } + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountGroupsAccumulator::::new( + data_type.clone(), + ), + ), + _ => unimplemented!("COUNT DISTINCT for {data_type:?}"), + }); + } // instantiate specialized accumulator Ok(Box::new(CountGroupsAccumulator::new())) } @@ -751,7 +851,234 @@ impl Accumulator for DistinctCountAccumulator { } } } +/// A specialized GroupsAccumulator for count distinct operations with primitive types +/// This is more efficient than the general DistinctCountGroupsAccumulator for primitive types +#[derive(Debug)] +pub struct PrimitiveDistinctCountGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + /// One HashSet per group to track distinct values + distinct_sets: Vec>, + data_type: DataType, + random_state: RandomState, +} + +impl PrimitiveDistinctCountGroupsAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + pub fn new(data_type: DataType) -> Self { + Self { + distinct_sets: vec![], + data_type, + random_state: RandomState::new(), + } + } + + fn ensure_sets(&mut self, total_num_groups: usize) { + if self.distinct_sets.len() < total_num_groups { + self.distinct_sets + .resize_with(total_num_groups, HashTable::default); + } + } + + fn add_value_to_set(&mut self, val: T::Native, group_idx: usize) { + // let val = data[row_idx]; + let hash = self.random_state.hash_one(val); + let entry = + self.distinct_sets[group_idx].entry(hash, |&(v, _)| val == v, |&(_, h)| h); + if let Entry::Vacant(v) = entry { + v.insert((val, hash)); + } + } +} + +impl GroupsAccumulator for PrimitiveDistinctCountGroupsAccumulator +where + T: ArrowPrimitiveType + Send + Debug, + T::Native: Eq + Hash, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "COUNT DISTINCT expects a single argument"); + self.ensure_sets(total_num_groups); + + let array = as_primitive_array::(&values[0])?; + let data = array.values(); + + // Implement a manual iteration rather than using accumulate_indices with a closure + // that needs row_index + match (array.logical_nulls(), opt_filter) { + (None, None) => { + // No nulls, no filter - process all rows + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + self.add_value_to_set(data[row_idx], group_idx); + } + } + (Some(nulls), None) => { + // Has nulls, no filter + for (row_idx, (&group_idx, is_valid)) in + group_indices.iter().zip(nulls.iter()).enumerate() + { + if is_valid { + self.add_value_to_set(data[row_idx], group_idx); + } + } + } + (None, Some(filter)) => { + // No nulls, has filter + for (row_idx, (&group_idx, filter_value)) in + group_indices.iter().zip(filter.iter()).enumerate() + { + if let Some(true) = filter_value { + self.add_value_to_set(data[row_idx], group_idx); + } + } + } + (Some(nulls), Some(filter)) => { + // Has nulls and filter + let iter = filter + .iter() + .zip(group_indices.iter()) + .zip(nulls.iter()) + .enumerate(); + + for (row_idx, ((filter_value, &group_idx), is_valid)) in iter { + if is_valid && filter_value == Some(true) { + self.add_value_to_set(data[row_idx], group_idx); + } + } + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let distinct_sets = emit_to.take_needed(&mut self.distinct_sets); + + let counts = distinct_sets + .iter() + .map(|set| set.len() as i64) + .collect::>(); + + Ok(Arc::new(Int64Array::from(counts))) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!( + values.len(), + 1, + "COUNT DISTINCT merge expects a single state array" + ); + self.ensure_sets(total_num_groups); + + let list_array = as_list_array(&values[0])?; + + // For each group in the incoming batch + for (i, &group_idx) in group_indices.iter().enumerate() { + if i < list_array.len() { + let inner_array = list_array.value(i); + if !inner_array.is_empty() { + // Get the primitive array from the list and extend our set with its values + let primitive_array = as_primitive_array::(&inner_array)?; + for v in primitive_array.values() { + self.add_value_to_set(*v, group_idx); + } + } + } + } + Ok(()) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let distinct_sets = emit_to.take_needed(&mut self.distinct_sets); + + let mut offsets = Vec::with_capacity(distinct_sets.len() + 1); + offsets.push(0); + let mut values = + Vec::with_capacity(distinct_sets.iter().map(|set| set.len()).sum()); + + // Create the values array by flattening all sets + for set in distinct_sets { + values.extend(set.into_iter().map(|x| x.0)); + offsets.push(values.len() as i32); + } + // Create the primitive array from the flattened values + let values_array = Arc::new( + PrimitiveArray::::new(values.into(), None) + .with_data_type(self.data_type.clone()), + ) as ArrayRef; + + // Create list array with the offsets + let offset_buffer = OffsetBuffer::new(ScalarBuffer::from(offsets)); + let list_array = ListArray::new( + Arc::new(Field::new_list_field(self.data_type.clone(), true)), + offset_buffer, + values_array, + None, + ); + + Ok(vec![Arc::new(list_array) as _]) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // For a single distinct value per row, create a list array with that value + assert_eq!(values.len(), 1, "COUNT DISTINCT expects a single argument"); + let values = ArrayRef::clone(&values[0]); + + let offsets = + OffsetBuffer::new(ScalarBuffer::from_iter(0..values.len() as i32 + 1)); + let nulls = filtered_null_mask(opt_filter, &values); + let list_array = ListArray::new( + Arc::new(Field::new_list_field(values.data_type().clone(), true)), + offsets, + values, + nulls, + ); + + Ok(vec![Arc::new(list_array)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + let mut total_size = std::mem::size_of::(); + + // Size of vector container + total_size += std::mem::size_of::>>(); + + // Size of actual sets and their contents + for set in &self.distinct_sets { + let set_size = std::mem::size_of::>() + + set.capacity() * std::mem::size_of::(); + total_size += set_size; + } + + total_size + } +} #[cfg(test)] mod tests { use super::*; From 66a0b6334d8fa74b3f63780fb7ef8a1dd99b80c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Wed, 7 May 2025 22:59:18 +0200 Subject: [PATCH 2/7] Add PrimitiveDistinctCountGroupsAccumulator --- Cargo.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.lock b/Cargo.lock index ba5136b32706..37887f472501 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2292,6 +2292,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-expr-common", "half", + "hashbrown 0.14.5", "log", "paste", "rand 0.8.5", From ee694846e418cb3461e577f0482a6c9fa2c5d074 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Thu, 8 May 2025 00:15:42 +0200 Subject: [PATCH 3/7] Cleanup --- datafusion/functions-aggregate/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index b2db621e5184..c3462c69448d 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -49,7 +49,7 @@ datafusion-macros = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } -hashbrown = {workspace = true } +hashbrown = { workspace = true } log = { workspace = true } paste = "1.0.14" From 580e262f3cdaf31f69325ab16fe9e401c1fc4b43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Thu, 8 May 2025 00:55:33 +0200 Subject: [PATCH 4/7] Cleanup --- datafusion/functions-aggregate/src/count.rs | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 3e1091f11e4a..75e6679609d6 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -19,6 +19,7 @@ use ahash::RandomState; use arrow::array::{ArrowPrimitiveType, ListArray}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::hash_utils::HashValue; use datafusion_common::stats::Precision; use datafusion_expr::expr::WindowFunction; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; @@ -29,7 +30,6 @@ use hashbrown::hash_table::Entry; use hashbrown::HashTable; use std::collections::HashSet; use std::fmt::Debug; -use std::hash::Hash; use std::mem::{size_of, size_of_val}; use std::ops::BitAnd; use std::sync::Arc; @@ -396,6 +396,15 @@ impl AggregateUDFImpl for Count { DataType::UInt64 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< UInt64Type, >::new(data_type.clone())), + DataType::Float32 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Float32Type, + >::new(data_type.clone())), + DataType::Float64 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Float64Type, + >::new(data_type.clone())), + DataType::Float16 => Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Float16Type, + >::new(data_type.clone())), DataType::Decimal128(_, _) => Box::new( PrimitiveDistinctCountGroupsAccumulator::::new( data_type.clone(), @@ -857,7 +866,7 @@ impl Accumulator for DistinctCountAccumulator { pub struct PrimitiveDistinctCountGroupsAccumulator where T: ArrowPrimitiveType + Send, - T::Native: Eq + Hash, + T::Native: PartialEq + HashValue, { /// One HashSet per group to track distinct values distinct_sets: Vec>, @@ -868,7 +877,7 @@ where impl PrimitiveDistinctCountGroupsAccumulator where T: ArrowPrimitiveType + Send, - T::Native: Eq + Hash, + T::Native: PartialEq + HashValue, { pub fn new(data_type: DataType) -> Self { Self { @@ -887,7 +896,7 @@ where fn add_value_to_set(&mut self, val: T::Native, group_idx: usize) { // let val = data[row_idx]; - let hash = self.random_state.hash_one(val); + let hash = val.hash_one(&self.random_state); let entry = self.distinct_sets[group_idx].entry(hash, |&(v, _)| val == v, |&(_, h)| h); if let Entry::Vacant(v) = entry { @@ -899,7 +908,7 @@ where impl GroupsAccumulator for PrimitiveDistinctCountGroupsAccumulator where T: ArrowPrimitiveType + Send + Debug, - T::Native: Eq + Hash, + T::Native: PartialEq + HashValue, { fn update_batch( &mut self, From 81d04080880d277186fd78f19ba3656e112a3a0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Thu, 8 May 2025 01:10:16 +0200 Subject: [PATCH 5/7] Cleanup --- datafusion/functions-aggregate/src/count.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 75e6679609d6..adfef1fc971a 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -18,6 +18,7 @@ use ahash::RandomState; use arrow::array::{ArrowPrimitiveType, ListArray}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType}; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::hash_utils::HashValue; use datafusion_common::stats::Precision; @@ -462,6 +463,17 @@ impl AggregateUDFImpl for Count { data_type.clone(), ), ), + DataType::Interval(IntervalUnit::YearMonth) => { + Box::new(PrimitiveDistinctCountGroupsAccumulator::< + IntervalYearMonthType, + >::new(data_type.clone())) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Box::new(PrimitiveDistinctCountGroupsAccumulator::< + IntervalMonthDayNanoType, + >::new(data_type.clone())) + } + _ => unimplemented!("COUNT DISTINCT for {data_type:?}"), }); } From 050e585bb05b330ab4f6824002a21525e97d2d72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Thu, 8 May 2025 08:21:07 +0200 Subject: [PATCH 6/7] Cleanup --- datafusion/functions-aggregate/src/count.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index adfef1fc971a..c7412f1231a5 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -18,7 +18,9 @@ use ahash::RandomState; use arrow::array::{ArrowPrimitiveType, ListArray}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; -use arrow::datatypes::{IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType}; +use arrow::datatypes::{ + IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, +}; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::hash_utils::HashValue; use datafusion_common::stats::Precision; From 148c487aae4cf23ba3685df8ec81b696b37bf9b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Thu, 8 May 2025 08:50:28 +0200 Subject: [PATCH 7/7] Cleanup --- datafusion/functions-aggregate/src/count.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index c7412f1231a5..adfef1fc971a 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -18,9 +18,7 @@ use ahash::RandomState; use arrow::array::{ArrowPrimitiveType, ListArray}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; -use arrow::datatypes::{ - IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, -}; +use arrow::datatypes::{IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType}; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::hash_utils::HashValue; use datafusion_common::stats::Precision;