@@ -76,6 +76,9 @@ use crate::arrow::datatypes::TimeUnit;
76
76
use crate :: execution:: context:: TaskContext ;
77
77
use crate :: physical_plan:: coalesce_batches:: concat_batches;
78
78
use crate :: physical_plan:: PhysicalExpr ;
79
+ use datafusion_expr:: binary_rule:: coerce_types;
80
+ use datafusion_expr:: Operator ;
81
+ use datafusion_physical_expr:: expressions:: try_cast;
79
82
use log:: debug;
80
83
use std:: fmt;
81
84
@@ -295,7 +298,32 @@ impl ExecutionPlan for HashJoinExec {
295
298
partition : usize ,
296
299
context : Arc < TaskContext > ,
297
300
) -> Result < SendableRecordBatchStream > {
298
- let on_left = self . on . iter ( ) . map ( |on| on. 0 . clone ( ) ) . collect :: < Vec < _ > > ( ) ;
301
+ // This is a hacky way to support type coercion for join expressions
302
+ // Without this it would panic later, in build_join_indexes => equal_rows, when it would try to downcast both sides to same primitive type
303
+ // TODO Remove this after rebasing on top of commit ac2e5d15 "Support type coercion for equijoin (#4666)". It was first released at DF 16.0
304
+
305
+ // TODO Rewrite it with iterators on modern toolchain, `impl FromIterator<(AE, BE)> for (A, B)` is not available ATM
306
+ let mut on_left = Vec :: with_capacity ( self . on . len ( ) ) ;
307
+ let mut on_right = Vec :: with_capacity ( self . on . len ( ) ) ;
308
+ for on in & self . on {
309
+ let l = Arc :: new ( on. 0 . clone ( ) ) ;
310
+ let r = Arc :: new ( on. 1 . clone ( ) ) ;
311
+
312
+ let lt = l. data_type ( & self . left . schema ( ) ) ?;
313
+ let rt = r. data_type ( & self . right . schema ( ) ) ?;
314
+ let res_type = coerce_types ( & lt, & Operator :: Eq , & rt) ?;
315
+
316
+ let left_cast = try_cast ( l, & self . left . schema ( ) , res_type. clone ( ) ) ?;
317
+ let right_cast = try_cast ( r, & self . right . schema ( ) , res_type) ?;
318
+
319
+ on_left. push ( left_cast) ;
320
+ on_right. push ( right_cast) ;
321
+ }
322
+
323
+ // Make them immutable
324
+ let on_left = on_left;
325
+ let on_right = on_right;
326
+
299
327
// we only want to compute the build side once for PartitionMode::CollectLeft
300
328
let left_data = {
301
329
match self . mode {
@@ -414,7 +442,6 @@ impl ExecutionPlan for HashJoinExec {
414
442
// over the right that uses this information to issue new batches.
415
443
416
444
let right_stream = self . right . execute ( partition, context. clone ( ) ) . await ?;
417
- let on_right = self . on . iter ( ) . map ( |on| on. 1 . clone ( ) ) . collect :: < Vec < _ > > ( ) ;
418
445
419
446
let num_rows = left_data. 1 . num_rows ( ) ;
420
447
let visited_left_side = match self . join_type {
@@ -473,7 +500,7 @@ impl ExecutionPlan for HashJoinExec {
473
500
/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`,
474
501
/// assuming that the [RecordBatch] corresponds to the `index`th
475
502
fn update_hash (
476
- on : & [ Column ] ,
503
+ on : & [ Arc < dyn PhysicalExpr > ] ,
477
504
batch : & RecordBatch ,
478
505
hash_map : & mut JoinHashMap ,
479
506
offset : usize ,
@@ -512,9 +539,9 @@ struct HashJoinStream {
512
539
/// Input schema
513
540
schema : Arc < Schema > ,
514
541
/// columns from the left
515
- on_left : Vec < Column > ,
542
+ on_left : Vec < Arc < dyn PhysicalExpr > > ,
516
543
/// columns from the right used to compute the hash
517
- on_right : Vec < Column > ,
544
+ on_right : Vec < Arc < dyn PhysicalExpr > > ,
518
545
/// type of the join
519
546
join_type : JoinType ,
520
547
/// information from the left
@@ -539,8 +566,8 @@ struct HashJoinStream {
539
566
impl HashJoinStream {
540
567
fn new (
541
568
schema : Arc < Schema > ,
542
- on_left : Vec < Column > ,
543
- on_right : Vec < Column > ,
569
+ on_left : Vec < Arc < dyn PhysicalExpr > > ,
570
+ on_right : Vec < Arc < dyn PhysicalExpr > > ,
544
571
join_type : JoinType ,
545
572
left_data : JoinLeftData ,
546
573
right : SendableRecordBatchStream ,
@@ -624,8 +651,8 @@ fn build_batch_from_indices(
624
651
fn build_batch (
625
652
batch : & RecordBatch ,
626
653
left_data : & JoinLeftData ,
627
- on_left : & [ Column ] ,
628
- on_right : & [ Column ] ,
654
+ on_left : & [ Arc < dyn PhysicalExpr > ] ,
655
+ on_right : & [ Arc < dyn PhysicalExpr > ] ,
629
656
join_type : JoinType ,
630
657
schema : & Schema ,
631
658
column_indices : & [ ColumnIndex ] ,
@@ -691,8 +718,8 @@ fn build_join_indexes(
691
718
left_data : & JoinLeftData ,
692
719
right : & RecordBatch ,
693
720
join_type : JoinType ,
694
- left_on : & [ Column ] ,
695
- right_on : & [ Column ] ,
721
+ left_on : & [ Arc < dyn PhysicalExpr > ] ,
722
+ right_on : & [ Arc < dyn PhysicalExpr > ] ,
696
723
random_state : & RandomState ,
697
724
null_equals_null : & bool ,
698
725
) -> Result < ( UInt64Array , UInt32Array ) > {
@@ -2002,8 +2029,8 @@ mod tests {
2002
2029
& left_data,
2003
2030
& right,
2004
2031
JoinType :: Inner ,
2005
- & [ Column :: new ( "a" , 0 ) ] ,
2006
- & [ Column :: new ( "a" , 0 ) ] ,
2032
+ & [ Arc :: new ( Column :: new ( "a" , 0 ) ) ] ,
2033
+ & [ Arc :: new ( Column :: new ( "a" , 0 ) ) ] ,
2007
2034
& random_state,
2008
2035
& false ,
2009
2036
) ?;
0 commit comments