From b3f78561a66ac956901dffa4017ec312086187f1 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 2 Nov 2025 13:21:35 -0600 Subject: [PATCH] Refactor state management in `HashJoinExec` and use CASE expressions to evaluate pushed down filters only for the given partition. --- .../physical_optimizer/filter_pushdown/mod.rs | 10 +- .../physical-plan/src/joins/hash_join/exec.rs | 94 ++-- .../physical-plan/src/joins/hash_join/mod.rs | 1 + .../joins/hash_join/partitioned_hash_eval.rs | 158 +++++++ .../src/joins/hash_join/shared_bounds.rs | 407 +++++++++++------- .../src/joins/hash_join/stream.rs | 54 ++- .../physical-plan/src/repartition/mod.rs | 11 +- 7 files changed, 509 insertions(+), 226 deletions(-) create mode 100644 datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs index b91c1732260c..0b3f279ff718 100644 --- a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -278,7 +278,7 @@ async fn test_dynamic_filter_pushdown_through_hash_join_with_topk() { - SortExec: TopK(fetch=2), expr=[e@4 ASC], preserve_partitioning=[false], filter=[e@4 IS NULL OR e@4 < bb] - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)] - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= aa AND d@0 <= ab ] AND DynamicFilter [ e@1 IS NULL OR e@1 < bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 1 WHEN 0 THEN d@0 >= aa AND d@0 <= ab ELSE false END ] AND DynamicFilter [ e@1 IS NULL OR e@1 < bb ] " ); } @@ -1309,7 +1309,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb OR a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 12 WHEN 2 THEN a@0 >= ab AND a@0 <= ab AND b@1 >= bb AND b@1 <= bb WHEN 4 THEN a@0 >= aa AND a@0 <= aa AND b@1 >= ba AND b@1 <= ba ELSE false END ] " ); @@ -1326,7 +1326,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() { - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true - CoalesceBatchesExec: target_batch_size=8192 - RepartitionExec: partitioning=Hash([a@0, b@1], 12), input_partitions=1 - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, e], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 12 WHEN 0 THEN a@0 >= aa AND a@0 <= ab AND b@1 >= ba AND b@1 <= bb ELSE false END ] " ); @@ -1671,8 +1671,8 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() { - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, b@0)] - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true - HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ b@0 >= aa AND b@0 <= ab ] - - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ d@0 >= ca AND d@0 <= cb ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 1 WHEN 0 THEN b@0 >= aa AND b@0 <= ab ELSE false END ] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilter [ CASE hash_repartition % 1 WHEN 0 THEN d@0 >= ca AND d@0 <= cb ELSE false END ] " ); } diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index b5fe5ee5cda1..5b0a632c2a6c 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -26,7 +26,9 @@ use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; -use crate::joins::hash_join::shared_bounds::{ColumnBounds, SharedBoundsAccumulator}; +use crate::joins::hash_join::shared_bounds::{ + ColumnBounds, PartitionBounds, SharedBuildAccumulator, +}; use crate::joins::hash_join::stream::{ BuildSide, BuildSideInitialState, HashJoinStream, HashJoinStreamState, }; @@ -40,6 +42,7 @@ use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, }; +use crate::repartition::REPARTITION_HASH_SEED; use crate::spill::get_record_batch_memory_size; use crate::ExecutionPlanProperties; use crate::{ @@ -87,7 +90,8 @@ const HASH_JOIN_SEED: RandomState = /// HashTable and input data for the left (build side) of a join pub(super) struct JoinLeftData { /// The hash table with indices into `batch` - pub(super) hash_map: Box, + /// Arc is used to allow sharing with SharedBuildAccumulator for hash map pushdown + pub(super) hash_map: Arc, /// The input rows for the build side batch: RecordBatch, /// The build side on expressions values @@ -102,32 +106,13 @@ pub(super) struct JoinLeftData { /// This could hide potential out-of-memory issues, especially when upstream operators increase their memory consumption. /// The MemoryReservation ensures proper tracking of memory resources throughout the join operation's lifecycle. _reservation: MemoryReservation, - /// Bounds computed from the build side for dynamic filter pushdown - pub(super) bounds: Option>, + /// Bounds computed from the build side for dynamic filter pushdown. + /// If the partition is empty (no rows) this will be None. + /// If the partition has some rows this will be Some with the bounds for each join key column. + pub(super) bounds: Option, } impl JoinLeftData { - /// Create a new `JoinLeftData` from its parts - pub(super) fn new( - hash_map: Box, - batch: RecordBatch, - values: Vec, - visited_indices_bitmap: SharedBitmapBuilder, - probe_threads_counter: AtomicUsize, - reservation: MemoryReservation, - bounds: Option>, - ) -> Self { - Self { - hash_map, - batch, - values, - visited_indices_bitmap, - probe_threads_counter, - _reservation: reservation, - bounds, - } - } - /// return a reference to the hash map pub(super) fn hash_map(&self) -> &dyn JoinHashMapType { &*self.hash_map @@ -364,9 +349,9 @@ pub struct HashJoinExec { struct HashJoinExecDynamicFilter { /// Dynamic filter that we'll update with the results of the build side once that is done. filter: Arc, - /// Bounds accumulator to keep track of the min/max bounds on the join keys for each partition. + /// Build accumulator to collect build-side information (hash maps and/or bounds) from each partition. /// It is lazily initialized during execution to make sure we use the actual execution time partition counts. - bounds_accumulator: OnceLock>, + build_accumulator: OnceLock>, } impl fmt::Debug for HashJoinExec { @@ -977,8 +962,15 @@ impl ExecutionPlan for HashJoinExec { let batch_size = context.session_config().batch_size(); - // Initialize bounds_accumulator lazily with runtime partition counts (only if enabled) - let bounds_accumulator = enable_dynamic_filter_pushdown + // Initialize build_accumulator lazily with runtime partition counts (only if enabled) + // Use RepartitionExec's random state (seeds: 0,0,0,0) for partition routing + let repartition_random_state = RandomState::with_seeds( + REPARTITION_HASH_SEED[0], + REPARTITION_HASH_SEED[1], + REPARTITION_HASH_SEED[2], + REPARTITION_HASH_SEED[3], + ); + let build_accumulator = enable_dynamic_filter_pushdown .then(|| { self.dynamic_filter.as_ref().map(|df| { let filter = Arc::clone(&df.filter); @@ -987,13 +979,14 @@ impl ExecutionPlan for HashJoinExec { .iter() .map(|(_, right_expr)| Arc::clone(right_expr)) .collect::>(); - Some(Arc::clone(df.bounds_accumulator.get_or_init(|| { - Arc::new(SharedBoundsAccumulator::new_from_partition_mode( + Some(Arc::clone(df.build_accumulator.get_or_init(|| { + Arc::new(SharedBuildAccumulator::new_from_partition_mode( self.mode, self.left.as_ref(), self.right.as_ref(), filter, on_right, + repartition_random_state, )) }))) }) @@ -1036,7 +1029,7 @@ impl ExecutionPlan for HashJoinExec { batch_size, vec![], self.right.output_ordering().is_some(), - bounds_accumulator, + build_accumulator, self.mode, ))) } @@ -1197,7 +1190,7 @@ impl ExecutionPlan for HashJoinExec { cache: self.cache.clone(), dynamic_filter: Some(HashJoinExecDynamicFilter { filter: dynamic_filter, - bounds_accumulator: OnceLock::new(), + build_accumulator: OnceLock::new(), }), }); result = result.with_updated_node(new_node as Arc); @@ -1346,7 +1339,7 @@ impl BuildSideState { /// When `should_compute_bounds` is true, this function computes the min/max bounds /// for each join key column but does NOT update the dynamic filter. Instead, the /// bounds are stored in the returned `JoinLeftData` and later coordinated by -/// `SharedBoundsAccumulator` to ensure all partitions contribute their bounds +/// `SharedBuildAccumulator` to ensure all partitions contribute their bounds /// before updating the filter exactly once. /// /// # Returns @@ -1417,6 +1410,7 @@ async fn collect_left_input( // Use `u32` indices for the JoinHashMap when num_rows ≤ u32::MAX, otherwise use the // `u64` indice variant + // Arc is used instead of Box to allow sharing with SharedBuildAccumulator for hash map pushdown let mut hashmap: Box = if num_rows > u32::MAX as usize { let estimated_hashtable_size = estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; @@ -1452,15 +1446,15 @@ async fn collect_left_input( offset += batch.num_rows(); } // Merge all batches into a single batch, so we can directly index into the arrays - let single_batch = concat_batches(&schema, batches_iter)?; + let batch = concat_batches(&schema, batches_iter)?; // Reserve additional memory for visited indices bitmap and create shared builder let visited_indices_bitmap = if with_visited_indices_bitmap { - let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + let bitmap_size = bit_util::ceil(batch.num_rows(), 8); reservation.try_grow(bitmap_size)?; metrics.build_mem_used.add(bitmap_size); - let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); + let mut bitmap_buffer = BooleanBufferBuilder::new(batch.num_rows()); bitmap_buffer.append_n(num_rows, false); bitmap_buffer } else { @@ -1469,10 +1463,7 @@ async fn collect_left_input( let left_values = on_left .iter() - .map(|c| { - c.evaluate(&single_batch)? - .into_array(single_batch.num_rows()) - }) + .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows())) .collect::>>()?; // Compute bounds for dynamic filter if enabled @@ -1482,20 +1473,23 @@ async fn collect_left_input( .into_iter() .map(CollectLeftAccumulator::evaluate) .collect::>>()?; - Some(bounds) + Some(PartitionBounds::new(bounds)) } _ => None, }; - let data = JoinLeftData::new( - hashmap, - single_batch, - left_values.clone(), - Mutex::new(visited_indices_bitmap), - AtomicUsize::new(probe_threads_count), - reservation, + // Convert Box to Arc for sharing with SharedBuildAccumulator + let hash_map: Arc = hashmap.into(); + + let data = JoinLeftData { + hash_map, + batch, + values: left_values, + visited_indices_bitmap: Mutex::new(visited_indices_bitmap), + probe_threads_counter: AtomicUsize::new(probe_threads_count), + _reservation: reservation, bounds, - ); + }; Ok(data) } diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 7f1e5cae13a3..6c073e7a9cff 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -20,5 +20,6 @@ pub use exec::HashJoinExec; mod exec; +mod partitioned_hash_eval; mod shared_bounds; mod stream; diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs new file mode 100644 index 000000000000..527642ade07e --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Hash computation and hash table lookup expressions for dynamic filtering + +use std::{any::Any, fmt::Display, hash::Hash, sync::Arc}; + +use ahash::RandomState; +use arrow::{ + array::UInt64Array, + datatypes::{DataType, Schema}, +}; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::{ + DynHash, PhysicalExpr, PhysicalExprRef, +}; + +use crate::hash_utils::create_hashes; + +/// Physical expression that computes hash values for a set of columns +/// +/// This expression computes the hash of join key columns using a specific RandomState. +/// It returns a UInt64Array containing the hash values. +/// +/// This is used for: +/// - Computing routing hashes (with RepartitionExec's 0,0,0,0 seeds) +/// - Computing lookup hashes (with HashJoin's 'J','O','I','N' seeds) +pub(super) struct HashExpr { + /// Columns to hash + on_columns: Vec, + /// Random state for hashing + random_state: RandomState, + /// Description for display + description: String, +} + +impl HashExpr { + /// Create a new HashExpr + /// + /// # Arguments + /// * `on_columns` - Columns to hash + /// * `random_state` - RandomState for hashing + /// * `description` - Description for debugging (e.g., "hash_repartition", "hash_join") + pub(super) fn new( + on_columns: Vec, + random_state: RandomState, + description: String, + ) -> Self { + Self { + on_columns, + random_state, + description, + } + } +} + +impl std::fmt::Debug for HashExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let cols = self + .on_columns + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", "); + write!(f, "{}({})", self.description, cols) + } +} + +impl Hash for HashExpr { + fn hash(&self, state: &mut H) { + self.on_columns.dyn_hash(state); + self.description.hash(state); + } +} + +impl PartialEq for HashExpr { + fn eq(&self, other: &Self) -> bool { + self.on_columns == other.on_columns && self.description == other.description + } +} + +impl Eq for HashExpr {} + +impl Display for HashExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.description) + } +} + +impl PhysicalExpr for HashExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn children(&self) -> Vec<&Arc> { + self.on_columns.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(HashExpr::new( + children, + self.random_state.clone(), + self.description.clone(), + ))) + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::UInt64) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate( + &self, + batch: &arrow::record_batch::RecordBatch, + ) -> Result { + let num_rows = batch.num_rows(); + + // Evaluate columns + let keys_values = self + .on_columns + .iter() + .map(|c| c.evaluate(batch)?.into_array(num_rows)) + .collect::>>()?; + + // Compute hashes + let mut hashes_buffer = vec![0; num_rows]; + create_hashes(&keys_values, &self.random_state, &mut hashes_buffer)?; + + Ok(ColumnarValue::Array(Arc::new(UInt64Array::from( + hashes_buffer, + )))) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.description) + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 25f7a0de31ac..5b1009cc47d3 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -15,22 +15,25 @@ // specific language governing permissions and limitations // under the License. -//! Utilities for shared bounds. Used in dynamic filter pushdown in Hash Joins. +//! Utilities for shared build-side information. Used in dynamic filter pushdown in Hash Joins. // TODO: include the link to the Dynamic Filter blog post. use std::fmt; use std::sync::Arc; +use crate::joins::hash_join::partitioned_hash_eval::HashExpr; use crate::joins::PartitionMode; use crate::ExecutionPlan; use crate::ExecutionPlanProperties; +use ahash::RandomState; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::{lit, BinaryExpr, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::expressions::{ + lit, BinaryExpr, CaseExpr, DynamicFilterPhysicalExpr, +}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; -use itertools::Itertools; use parking_lot::Mutex; use tokio::sync::Barrier; @@ -54,23 +57,14 @@ impl ColumnBounds { /// This contains the min/max values computed from one partition's build-side data. #[derive(Debug, Clone)] pub(crate) struct PartitionBounds { - /// Partition identifier for debugging and determinism (not strictly necessary) - partition: usize, /// Min/max bounds for each join key column in this partition. /// Index corresponds to the join key expression index. column_bounds: Vec, } impl PartitionBounds { - pub(crate) fn new(partition: usize, column_bounds: Vec) -> Self { - Self { - partition, - column_bounds, - } - } - - pub(crate) fn len(&self) -> usize { - self.column_bounds.len() + pub(crate) fn new(column_bounds: Vec) -> Self { + Self { column_bounds } } pub(crate) fn get_column_bounds(&self, index: usize) -> Option<&ColumnBounds> { @@ -78,18 +72,69 @@ impl PartitionBounds { } } -/// Coordinates dynamic filter bounds collection across multiple partitions +/// Creates a bounds predicate from partition bounds. +/// +/// Returns `None` if no column bounds are available. +/// Returns a combined predicate (col >= min AND col <= max) for all columns with bounds. +fn create_bounds_predicate( + on_right: &[PhysicalExprRef], + bounds: &PartitionBounds, +) -> Option> { + let mut column_predicates = Vec::new(); + + for (col_idx, right_expr) in on_right.iter().enumerate() { + if let Some(column_bounds) = bounds.get_column_bounds(col_idx) { + // Create predicate: col >= min AND col <= max + let min_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::GtEq, + lit(column_bounds.min.clone()), + )) as Arc; + let max_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::LtEq, + lit(column_bounds.max.clone()), + )) as Arc; + let range_expr = Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) + as Arc; + column_predicates.push(range_expr); + } + } + + if column_predicates.is_empty() { + None + } else { + Some( + column_predicates + .into_iter() + .reduce(|acc, pred| { + Arc::new(BinaryExpr::new(acc, Operator::And, pred)) + as Arc + }) + .unwrap(), + ) + } +} + +/// Coordinates build-side information collection across multiple partitions /// -/// This structure ensures that dynamic filters are built with complete information from all -/// relevant partitions before being applied to probe-side scans. Incomplete filters would +/// This structure collects information from the build side (hash tables and/or bounds) and +/// ensures that dynamic filters are built with complete information from all relevant +/// partitions before being applied to probe-side scans. Incomplete filters would /// incorrectly eliminate valid join results. /// /// ## Synchronization Strategy /// -/// 1. Each partition computes bounds from its build-side data -/// 2. Bounds are stored in the shared vector -/// 3. A barrier tracks how many partitions have reported their bounds -/// 4. When the last partition reports, bounds are merged and the filter is updated exactly once +/// 1. Each partition computes information from its build-side data (hash maps and/or bounds) +/// 2. Information is stored in the shared state +/// 3. A barrier tracks how many partitions have reported +/// 4. When the last partition reports, information is merged and the filter is updated exactly once +/// +/// ## Hash Map vs Bounds +/// +/// - **Hash Maps (Partitioned mode)**: Collects Arc references to hash tables from each partition. +/// Creates a `PartitionedHashLookupPhysicalExpr` that routes rows to the correct partition's hash table. +/// - **Bounds (CollectLeft mode)**: Collects min/max bounds and creates range predicates. /// /// ## Partition Counting /// @@ -101,25 +146,57 @@ impl PartitionBounds { /// /// All fields use a single mutex to ensure correct coordination between concurrent /// partition executions. -pub(crate) struct SharedBoundsAccumulator { - /// Shared state protected by a single mutex to avoid ordering concerns - inner: Mutex, +pub(crate) struct SharedBuildAccumulator { + /// Build-side data protected by a single mutex to avoid ordering concerns + inner: Mutex, barrier: Barrier, /// Dynamic filter for pushdown to probe side dynamic_filter: Arc, - /// Right side join expressions needed for creating filter bounds + /// Right side join expressions needed for creating filter expressions on_right: Vec, + /// Random state for partitioning (RepartitionExec's hash function with 0,0,0,0 seeds) + /// Used for PartitionedHashLookupPhysicalExpr + repartition_random_state: RandomState, +} + +#[derive(Clone)] +pub(crate) enum PartitionBuildDataReport { + Partitioned { + partition_id: usize, + /// Bounds computed from this partition's build side. + /// If the partition is empty (no rows) this will be None. + bounds: Option, + }, + CollectLeft { + /// Bounds computed from the collected build side. + /// If the build side is empty (no rows) this will be None. + bounds: Option, + }, } -/// State protected by SharedBoundsAccumulator's mutex -struct SharedBoundsState { - /// Bounds from completed partitions. - /// Each element represents the column bounds computed by one partition. - bounds: Vec, +#[derive(Clone)] +struct PartitionedBuildData { + partition_id: usize, + bounds: PartitionBounds, } -impl SharedBoundsAccumulator { - /// Creates a new SharedBoundsAccumulator configured for the given partition mode +#[derive(Clone)] +struct CollectLeftBuildData { + bounds: PartitionBounds, +} + +/// Build-side data organized by partition mode +enum AccumulatedBuildData { + Partitioned { + partitions: Vec>, + }, + CollectLeft { + data: Option, + }, +} + +impl SharedBuildAccumulator { + /// Creates a new SharedBuildAccumulator configured for the given partition mode /// /// This method calculates how many times `collect_build_side` will be called based on the /// partition mode's execution pattern. This count is critical for determining when we have @@ -137,12 +214,12 @@ impl SharedBoundsAccumulator { /// `collect_build_side` once. Expected calls = number of build partitions. /// /// - **Auto**: Placeholder mode resolved during optimization. Uses 1 as safe default since - /// the actual mode will be determined and a new bounds_accumulator created before execution. + /// the actual mode will be determined and a new accumulator created before execution. /// /// ## Why This Matters /// /// We cannot build a partial filter from some partitions - it would incorrectly eliminate - /// valid join results. We must wait until we have complete bounds information from ALL + /// valid join results. We must wait until we have complete information from ALL /// relevant partitions before updating the dynamic filter. pub(crate) fn new_from_partition_mode( partition_mode: PartitionMode, @@ -150,6 +227,7 @@ impl SharedBoundsAccumulator { right_child: &dyn ExecutionPlan, dynamic_filter: Arc, on_right: Vec, + repartition_random_state: RandomState, ) -> Self { // Troubleshooting: If partition counts are incorrect, verify this logic matches // the actual execution pattern in collect_build_side() @@ -165,140 +243,171 @@ impl SharedBoundsAccumulator { // Default value, will be resolved during optimization (does not exist once `execute()` is called; will be replaced by one of the other two) PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), }; + + let mode_data = match partition_mode { + PartitionMode::Partitioned => AccumulatedBuildData::Partitioned { + partitions: vec![None; left_child.output_partitioning().partition_count()], + }, + PartitionMode::CollectLeft => AccumulatedBuildData::CollectLeft { + data: None, + }, + PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), + }; + Self { - inner: Mutex::new(SharedBoundsState { - bounds: Vec::with_capacity(expected_calls), - }), + inner: Mutex::new(mode_data), barrier: Barrier::new(expected_calls), dynamic_filter, on_right, + repartition_random_state, } } - /// Create a filter expression from individual partition bounds using OR logic. + /// Report build-side data from a partition /// - /// This creates a filter where each partition's bounds form a conjunction (AND) - /// of column range predicates, and all partitions are combined with OR. - /// - /// For example, with 2 partitions and 2 columns: - /// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col1 >= p0_min1 AND col1 <= p0_max1) - /// OR - /// (col0 >= p1_min0 AND col0 <= p1_max0 AND col1 >= p1_min1 AND col1 <= p1_max1)) - pub(crate) fn create_filter_from_partition_bounds( - &self, - bounds: &[PartitionBounds], - ) -> Result> { - if bounds.is_empty() { - return Ok(lit(true)); - } - - // Create a predicate for each partition - let mut partition_predicates = Vec::with_capacity(bounds.len()); - - for partition_bounds in bounds.iter().sorted_by_key(|b| b.partition) { - // Create range predicates for each join key in this partition - let mut column_predicates = Vec::with_capacity(partition_bounds.len()); - - for (col_idx, right_expr) in self.on_right.iter().enumerate() { - if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { - // Create predicate: col >= min AND col <= max - let min_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::GtEq, - lit(column_bounds.min.clone()), - )) as Arc; - let max_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::LtEq, - lit(column_bounds.max.clone()), - )) as Arc; - let range_expr = - Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) - as Arc; - column_predicates.push(range_expr); - } - } - - // Combine all column predicates for this partition with AND - if !column_predicates.is_empty() { - let partition_predicate = column_predicates - .into_iter() - .reduce(|acc, pred| { - Arc::new(BinaryExpr::new(acc, Operator::And, pred)) - as Arc - }) - .unwrap(); - partition_predicates.push(partition_predicate); - } - } - - // Combine all partition predicates with OR - let combined_predicate = partition_predicates - .into_iter() - .reduce(|acc, pred| { - Arc::new(BinaryExpr::new(acc, Operator::Or, pred)) - as Arc - }) - .unwrap_or_else(|| lit(true)); - - Ok(combined_predicate) - } - - /// Report bounds from a completed partition and update dynamic filter if all partitions are done - /// - /// This method coordinates the dynamic filter updates across all partitions. It stores the - /// bounds from the current partition, increments the completion counter, and when all - /// partitions have reported, creates an OR'd filter from individual partition bounds. - /// - /// This method is async and uses a [`tokio::sync::Barrier`] to wait for all partitions - /// to report their bounds. Once that occurs, the method will resolve for all callers and the - /// dynamic filter will be updated exactly once. - /// - /// # Note - /// - /// As barriers are reusable, it is likely an error to call this method more times than the - /// total number of partitions - as it can lead to pending futures that never resolve. We rely - /// on correct usage from the caller rather than imposing additional checks here. If this is a concern, - /// consider making the resulting future shared so the ready result can be reused. + /// This unified method handles both CollectLeft and Partitioned modes. When all partitions + /// have reported (barrier wait), the leader builds the appropriate filter expression: + /// - CollectLeft: Simple conjunction of bounds and membership check + /// - Partitioned: CASE expression routing to per-partition filters /// /// # Arguments - /// * `left_side_partition_id` - The identifier for the **left-side** partition reporting its bounds - /// * `partition_bounds` - The bounds computed by this partition (if any) + /// * `data` - Build data including hash map, pushdown strategy, and bounds /// /// # Returns - /// * `Result<()>` - Ok if successful, Err if filter update failed - pub(crate) async fn report_partition_bounds( + /// * `Result<()>` - Ok if successful, Err if filter update failed or mode mismatch + pub(crate) async fn report_build_data( &self, - left_side_partition_id: usize, - partition_bounds: Option>, + data: PartitionBuildDataReport, ) -> Result<()> { - // Store bounds in the accumulator - this runs once per partition - if let Some(bounds) = partition_bounds { + // Store data in the accumulator + { let mut guard = self.inner.lock(); - let should_push = if let Some(last_bound) = guard.bounds.last() { - // In `PartitionMode::CollectLeft`, all streams on the left side share the same partition id (0). - // Since this function can be called multiple times for that same partition, we must deduplicate - // by checking against the last recorded bound. - last_bound.partition != left_side_partition_id - } else { - true - }; - - if should_push { - guard - .bounds - .push(PartitionBounds::new(left_side_partition_id, bounds)); + match (data, &mut *guard) { + // Partitioned mode + ( + PartitionBuildDataReport::Partitioned { + partition_id, + bounds, + }, + AccumulatedBuildData::Partitioned { partitions }, + ) => { + if let Some(bounds) = bounds { + partitions[partition_id] = Some(PartitionedBuildData { + partition_id, + bounds, + }); + } + } + // CollectLeft mode (store once, deduplicate across partitions) + ( + PartitionBuildDataReport::CollectLeft { bounds }, + AccumulatedBuildData::CollectLeft { data }, + ) => { + match (bounds, data) { + (None, _) | (_, Some(_)) => { + // No bounds reported or already reported; do nothing + } + (Some(new_bounds), data) => { + // First report, store the bounds + *data = Some(CollectLeftBuildData { bounds: new_bounds }); + } + } + } + // Mismatched modes - should never happen + _ => { + return datafusion_common::internal_err!( + "Build data mode mismatch in report_build_data" + ); + } } } + // Wait for all partitions to report if self.barrier.wait().await.is_leader() { - // All partitions have reported, so we can update the filter + // All partitions have reported, so we can create and update the filter let inner = self.inner.lock(); - if !inner.bounds.is_empty() { - let filter_expr = - self.create_filter_from_partition_bounds(&inner.bounds)?; - self.dynamic_filter.update(filter_expr)?; + + match &*inner { + // CollectLeft: Simple conjunction of bounds and membership check + AccumulatedBuildData::CollectLeft { data } => { + if let Some(partition_data) = data { + // Create bounds check expression (if bounds available) + let Some(filter_expr) = create_bounds_predicate( + &self.on_right, + &partition_data.bounds, + ) else { + // No bounds available, nothing to update + return Ok(()); + }; + + self.dynamic_filter.update(filter_expr)?; + } + } + // Partitioned: CASE expression routing to per-partition filters + AccumulatedBuildData::Partitioned { partitions } => { + // Collect all partition data, skipping empty partitions + let partition_data: Vec<_> = + partitions.iter().filter_map(|p| p.as_ref()).collect(); + + if partition_data.is_empty() { + // All partitions are empty: no rows can match, skip the probe side entirely + self.dynamic_filter.update(lit(false))?; + return Ok(()); + } + + // Build a CASE expression that combines range checks AND membership checks + // CASE (hash_repartition(join_keys) % num_partitions) + // WHEN 0 THEN (col >= min_0 AND col <= max_0 AND ...) + // WHEN 1 THEN (col >= min_1 AND col <= max_1 AND ...) + // ... + // ELSE false + // END + + let num_partitions = partitions.len(); + + // Create base expression: hash_repartition(join_keys) % num_partitions + let routing_hash_expr = Arc::new(HashExpr::new( + self.on_right.clone(), + self.repartition_random_state.clone(), + "hash_repartition".to_string(), + )) + as Arc; + + let modulo_expr = Arc::new(BinaryExpr::new( + routing_hash_expr, + Operator::Modulo, + lit(ScalarValue::UInt64(Some(num_partitions as u64))), + )) as Arc; + + // Create WHEN branches for each partition + let when_then_branches: Vec<( + Arc, + Arc, + )> = partition_data + .into_iter() + .map(|pdata| -> Result<_> { + // WHEN partition_id + let when_expr = + lit(ScalarValue::UInt64(Some(pdata.partition_id as u64))); + + // Create bounds check expression for this partition (if bounds available) + let bounds_expr = + create_bounds_predicate(&self.on_right, &pdata.bounds) + .unwrap_or_else(|| lit(true)); // No bounds means all rows pass + + Ok((when_expr, bounds_expr)) + }) + .collect::>>()?; + + let case_expr = Arc::new(CaseExpr::try_new( + Some(modulo_expr), + when_then_branches, + Some(lit(false)), // ELSE false + )?) as Arc; + + self.dynamic_filter.update(case_expr)?; + } } } @@ -306,8 +415,8 @@ impl SharedBoundsAccumulator { } } -impl fmt::Debug for SharedBoundsAccumulator { +impl fmt::Debug for SharedBuildAccumulator { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SharedBoundsAccumulator") + write!(f, "SharedBuildAccumulator") } } diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 88c50c2eb2ce..c1468567f6c0 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -24,7 +24,9 @@ use std::sync::Arc; use std::task::Poll; use crate::joins::hash_join::exec::JoinLeftData; -use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; +use crate::joins::hash_join::shared_bounds::{ + PartitionBuildDataReport, SharedBuildAccumulator, +}; use crate::joins::utils::{ equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut, }; @@ -206,11 +208,11 @@ pub(super) struct HashJoinStream { hashes_buffer: Vec, /// Specifies whether the right side has an ordering to potentially preserve right_side_ordered: bool, - /// Shared bounds accumulator for coordinating dynamic filter updates (optional) - bounds_accumulator: Option>, - /// Optional future to signal when bounds have been reported by all partitions + /// Shared build accumulator for coordinating dynamic filter updates (collects hash maps and/or bounds, optional) + build_accumulator: Option>, + /// Optional future to signal when build information has been reported by all partitions /// and the dynamic filter has been updated - bounds_waiter: Option>, + build_waiter: Option>, /// Partitioning mode to use mode: PartitionMode, @@ -315,7 +317,7 @@ impl HashJoinStream { batch_size: usize, hashes_buffer: Vec, right_side_ordered: bool, - bounds_accumulator: Option>, + build_accumulator: Option>, mode: PartitionMode, ) -> Self { Self { @@ -334,8 +336,8 @@ impl HashJoinStream { batch_size, hashes_buffer, right_side_ordered, - bounds_accumulator, - bounds_waiter: None, + build_accumulator, + build_waiter: None, mode, } } @@ -370,12 +372,12 @@ impl HashJoinStream { } } - /// Optional step to wait until bounds have been reported by all partitions. - /// This state is only entered if a bounds accumulator is present. + /// Optional step to wait until build-side information (hash maps or bounds) has been reported by all partitions. + /// This state is only entered if a build accumulator is present. /// /// ## Why wait? /// - /// The dynamic filter is only built once all partitions have reported their bounds. + /// The dynamic filter is only built once all partitions have reported their information (hash maps or bounds). /// If we do not wait here, the probe-side scan may start before the filter is ready. /// This can lead to the probe-side scan missing the opportunity to apply the filter /// and skip reading unnecessary data. @@ -383,7 +385,7 @@ impl HashJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>>> { - if let Some(ref mut fut) = self.bounds_waiter { + if let Some(ref mut fut) = self.build_waiter { ready!(fut.get_shared(cx))?; } self.state = HashJoinStreamState::FetchProbeBatch; @@ -406,12 +408,13 @@ impl HashJoinStream { .get_shared(cx))?; build_timer.done(); - // Handle dynamic filter bounds accumulation + // Handle dynamic filter build-side information accumulation // // Dynamic filter coordination between partitions: - // Report bounds to the accumulator which will handle synchronization and filter updates - if let Some(ref bounds_accumulator) = self.bounds_accumulator { - let bounds_accumulator = Arc::clone(bounds_accumulator); + // Report hash maps (Partitioned mode) or bounds (CollectLeft mode) to the accumulator + // which will handle synchronization and filter updates + if let Some(ref build_accumulator) = self.build_accumulator { + let build_accumulator = Arc::clone(build_accumulator); let left_side_partition_id = match self.mode { PartitionMode::Partitioned => self.partition, @@ -419,11 +422,20 @@ impl HashJoinStream { PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), }; - let left_data_bounds = left_data.bounds.clone(); - self.bounds_waiter = Some(OnceFut::new(async move { - bounds_accumulator - .report_partition_bounds(left_side_partition_id, left_data_bounds) - .await + let build_data = match self.mode { + PartitionMode::Partitioned => PartitionBuildDataReport::Partitioned { + partition_id: left_side_partition_id, + bounds: left_data.bounds.clone(), + }, + PartitionMode::CollectLeft => PartitionBuildDataReport::CollectLeft { + bounds: left_data.bounds.clone(), + }, + PartitionMode::Auto => unreachable!( + "PartitionMode::Auto should not be present at execution time" + ), + }; + self.build_waiter = Some(OnceFut::new(async move { + build_accumulator.report_build_data(build_data).await })); self.state = HashJoinStreamState::WaitPartitionBoundsReport; } else { diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 8174f71c31af..f633dde665f8 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -315,6 +315,10 @@ enum BatchPartitionerState { }, } +/// Fixed seed used for hash repartitioning to ensure consistent behavior across +/// executions and runs. +pub const REPARTITION_HASH_SEED: [u64; 4] = [0u64; 4]; + impl BatchPartitioner { /// Create a new [`BatchPartitioner`] with the provided [`Partitioning`] /// @@ -331,7 +335,12 @@ impl BatchPartitioner { exprs, num_partitions, // Use fixed random hash - random_state: ahash::RandomState::with_seeds(0, 0, 0, 0), + random_state: ahash::RandomState::with_seeds( + REPARTITION_HASH_SEED[0], + REPARTITION_HASH_SEED[1], + REPARTITION_HASH_SEED[2], + REPARTITION_HASH_SEED[3], + ), hash_buffer: vec![], }, other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),