Skip to content

Commit dcf3e4a

Browse files
committed
Crude hack to introduce type coercion for hash join keys
Remove this after rebasing on top of commit ac2e5d1 "Support type coercion for equijoin (apache#4666)". It was first released at DF 16.0 ARROW-11838: fix offset buffer in golden file (#60)
1 parent 400fa0d commit dcf3e4a

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

datafusion/core/src/physical_plan/hash_join.rs

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ use crate::arrow::datatypes::TimeUnit;
7676
use crate::execution::context::TaskContext;
7777
use crate::physical_plan::coalesce_batches::concat_batches;
7878
use crate::physical_plan::PhysicalExpr;
79+
use datafusion_expr::binary_rule::coerce_types;
80+
use datafusion_expr::Operator;
81+
use datafusion_physical_expr::expressions::try_cast;
7982
use log::debug;
8083
use std::fmt;
8184

@@ -295,7 +298,32 @@ impl ExecutionPlan for HashJoinExec {
295298
partition: usize,
296299
context: Arc<TaskContext>,
297300
) -> Result<SendableRecordBatchStream> {
298-
let on_left = self.on.iter().map(|on| on.0.clone()).collect::<Vec<_>>();
301+
// This is a hacky way to support type coercion for join expressions
302+
// Without this it would panic later, in build_join_indexes => equal_rows, when it would try to downcast both sides to same primitive type
303+
// TODO Remove this after rebasing on top of commit ac2e5d15 "Support type coercion for equijoin (#4666)". It was first released at DF 16.0
304+
305+
// TODO Rewrite it with iterators on modern toolchain, `impl FromIterator<(AE, BE)> for (A, B)` is not available ATM
306+
let mut on_left = Vec::with_capacity(self.on.len());
307+
let mut on_right = Vec::with_capacity(self.on.len());
308+
for on in &self.on {
309+
let l = Arc::new(on.0.clone());
310+
let r = Arc::new(on.1.clone());
311+
312+
let lt = l.data_type(&self.left.schema())?;
313+
let rt = r.data_type(&self.right.schema())?;
314+
let res_type = coerce_types(&lt, &Operator::Eq, &rt)?;
315+
316+
let left_cast = try_cast(l, &self.left.schema(), res_type.clone())?;
317+
let right_cast = try_cast(r, &self.right.schema(), res_type)?;
318+
319+
on_left.push(left_cast);
320+
on_right.push(right_cast);
321+
}
322+
323+
// Make them immutable
324+
let on_left = on_left;
325+
let on_right = on_right;
326+
299327
// we only want to compute the build side once for PartitionMode::CollectLeft
300328
let left_data = {
301329
match self.mode {
@@ -414,7 +442,6 @@ impl ExecutionPlan for HashJoinExec {
414442
// over the right that uses this information to issue new batches.
415443

416444
let right_stream = self.right.execute(partition, context.clone()).await?;
417-
let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();
418445

419446
let num_rows = left_data.1.num_rows();
420447
let visited_left_side = match self.join_type {
@@ -473,7 +500,7 @@ impl ExecutionPlan for HashJoinExec {
473500
/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`,
474501
/// assuming that the [RecordBatch] corresponds to the `index`th
475502
fn update_hash(
476-
on: &[Column],
503+
on: &[Arc<dyn PhysicalExpr>],
477504
batch: &RecordBatch,
478505
hash_map: &mut JoinHashMap,
479506
offset: usize,
@@ -512,9 +539,9 @@ struct HashJoinStream {
512539
/// Input schema
513540
schema: Arc<Schema>,
514541
/// columns from the left
515-
on_left: Vec<Column>,
542+
on_left: Vec<Arc<dyn PhysicalExpr>>,
516543
/// columns from the right used to compute the hash
517-
on_right: Vec<Column>,
544+
on_right: Vec<Arc<dyn PhysicalExpr>>,
518545
/// type of the join
519546
join_type: JoinType,
520547
/// information from the left
@@ -539,8 +566,8 @@ struct HashJoinStream {
539566
impl HashJoinStream {
540567
fn new(
541568
schema: Arc<Schema>,
542-
on_left: Vec<Column>,
543-
on_right: Vec<Column>,
569+
on_left: Vec<Arc<dyn PhysicalExpr>>,
570+
on_right: Vec<Arc<dyn PhysicalExpr>>,
544571
join_type: JoinType,
545572
left_data: JoinLeftData,
546573
right: SendableRecordBatchStream,
@@ -624,8 +651,8 @@ fn build_batch_from_indices(
624651
fn build_batch(
625652
batch: &RecordBatch,
626653
left_data: &JoinLeftData,
627-
on_left: &[Column],
628-
on_right: &[Column],
654+
on_left: &[Arc<dyn PhysicalExpr>],
655+
on_right: &[Arc<dyn PhysicalExpr>],
629656
join_type: JoinType,
630657
schema: &Schema,
631658
column_indices: &[ColumnIndex],
@@ -691,8 +718,8 @@ fn build_join_indexes(
691718
left_data: &JoinLeftData,
692719
right: &RecordBatch,
693720
join_type: JoinType,
694-
left_on: &[Column],
695-
right_on: &[Column],
721+
left_on: &[Arc<dyn PhysicalExpr>],
722+
right_on: &[Arc<dyn PhysicalExpr>],
696723
random_state: &RandomState,
697724
null_equals_null: &bool,
698725
) -> Result<(UInt64Array, UInt32Array)> {
@@ -2002,8 +2029,8 @@ mod tests {
20022029
&left_data,
20032030
&right,
20042031
JoinType::Inner,
2005-
&[Column::new("a", 0)],
2006-
&[Column::new("a", 0)],
2032+
&[Arc::new(Column::new("a", 0))],
2033+
&[Arc::new(Column::new("a", 0))],
20072034
&random_state,
20082035
&false,
20092036
)?;

0 commit comments

Comments
 (0)