Skip to content

Commit d20b6d1

Browse files
authored
fix: unparsing left/ right semi/mark join (apache#15212)
* fix: unparse semi/mark join * recursive * fix use * update * stackoverflow * update stack size * update test * fix test * format * refine * refine ci based on goldmedal's suggestion
1 parent 3dd3d73 commit d20b6d1

File tree

5 files changed

+280
-22
lines changed

5 files changed

+280
-22
lines changed

.github/workflows/extended.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ jobs:
8181
- name: Run tests (excluding doctests)
8282
env:
8383
RUST_BACKTRACE: 1
84-
run: cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace,extended_tests
84+
run: cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace,extended_tests,recursive_protection
8585
- name: Verify Working Directory Clean
8686
run: git diff --exit-code
8787
- name: Cleanup

datafusion/sql/src/unparser/ast.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
// under the License.
1717

1818
use core::fmt;
19+
use std::ops::ControlFlow;
1920

20-
use sqlparser::ast;
2121
use sqlparser::ast::helpers::attached_token::AttachedToken;
22+
use sqlparser::ast::{self, visit_expressions_mut};
2223

2324
#[derive(Clone)]
2425
pub struct QueryBuilder {
@@ -176,6 +177,37 @@ impl SelectBuilder {
176177
self.lateral_views = value;
177178
self
178179
}
180+
181+
/// Replaces the selection with a new value.
182+
///
183+
/// This function is used to replace a specific expression within the selection.
184+
/// Unlike the `selection` method which combines existing and new selections with AND,
185+
/// this method searches for and replaces occurrences of a specific expression.
186+
///
187+
/// This method is primarily used to modify LEFT MARK JOIN expressions.
188+
/// When processing a LEFT MARK JOIN, we need to replace the placeholder expression
189+
/// with the actual join condition in the selection clause.
190+
///
191+
/// # Arguments
192+
///
193+
/// * `existing_expr` - The expression to replace
194+
/// * `value` - The new expression to set as the selection
195+
pub fn replace_mark(
196+
&mut self,
197+
existing_expr: &ast::Expr,
198+
value: &ast::Expr,
199+
) -> &mut Self {
200+
if let Some(selection) = &mut self.selection {
201+
visit_expressions_mut(selection, |expr| {
202+
if expr == existing_expr {
203+
*expr = value.clone();
204+
}
205+
ControlFlow::<()>::Continue(())
206+
});
207+
}
208+
self
209+
}
210+
179211
pub fn selection(&mut self, value: Option<ast::Expr>) -> &mut Self {
180212
// With filter pushdown optimization, the LogicalPlan can have filters defined as part of `TableScan` and `Filter` nodes.
181213
// To avoid overwriting one of the filters, we combine the existing filter with the additional filter.

datafusion/sql/src/unparser/expr.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ impl Unparser<'_> {
9494
Ok(root_expr)
9595
}
9696

97+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
9798
fn expr_to_sql_inner(&self, expr: &Expr) -> Result<ast::Expr> {
9899
match expr {
99100
Expr::InList(InList {
@@ -674,7 +675,7 @@ impl Unparser<'_> {
674675
}
675676
}
676677

677-
fn col_to_sql(&self, col: &Column) -> Result<ast::Expr> {
678+
pub fn col_to_sql(&self, col: &Column) -> Result<ast::Expr> {
678679
if let Some(table_ref) = &col.relation {
679680
let mut id = if self.dialect.full_qualified_col() {
680681
table_ref.to_vec()

datafusion/sql/src/unparser/plan.rs

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ impl Unparser<'_> {
322322
}
323323
}
324324

325+
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
325326
fn select_to_sql_recursively(
326327
&self,
327328
plan: &LogicalPlan,
@@ -566,14 +567,20 @@ impl Unparser<'_> {
566567
}
567568
LogicalPlan::Join(join) => {
568569
let mut table_scan_filters = vec![];
570+
let (left_plan, right_plan) = match join.join_type {
571+
JoinType::RightSemi | JoinType::RightAnti => {
572+
(&join.right, &join.left)
573+
}
574+
_ => (&join.left, &join.right),
575+
};
569576

570577
let left_plan =
571-
match try_transform_to_simple_table_scan_with_filters(&join.left)? {
578+
match try_transform_to_simple_table_scan_with_filters(left_plan)? {
572579
Some((plan, filters)) => {
573580
table_scan_filters.extend(filters);
574581
Arc::new(plan)
575582
}
576-
None => Arc::clone(&join.left),
583+
None => Arc::clone(left_plan),
577584
};
578585

579586
self.select_to_sql_recursively(
@@ -584,12 +591,12 @@ impl Unparser<'_> {
584591
)?;
585592

586593
let right_plan =
587-
match try_transform_to_simple_table_scan_with_filters(&join.right)? {
594+
match try_transform_to_simple_table_scan_with_filters(right_plan)? {
588595
Some((plan, filters)) => {
589596
table_scan_filters.extend(filters);
590597
Arc::new(plan)
591598
}
592-
None => Arc::clone(&join.right),
599+
None => Arc::clone(right_plan),
593600
};
594601

595602
let mut right_relation = RelationBuilder::default();
@@ -641,19 +648,70 @@ impl Unparser<'_> {
641648
&mut right_relation,
642649
)?;
643650

644-
let Ok(Some(relation)) = right_relation.build() else {
645-
return internal_err!("Failed to build right relation");
646-
};
647-
648-
let ast_join = ast::Join {
649-
relation,
650-
global: false,
651-
join_operator: self
652-
.join_operator_to_sql(join.join_type, join_constraint)?,
651+
match join.join_type {
652+
JoinType::LeftSemi
653+
| JoinType::LeftAnti
654+
| JoinType::LeftMark
655+
| JoinType::RightSemi
656+
| JoinType::RightAnti => {
657+
let mut query_builder = QueryBuilder::default();
658+
let mut from = TableWithJoinsBuilder::default();
659+
let mut exists_select: SelectBuilder = SelectBuilder::default();
660+
from.relation(right_relation);
661+
exists_select.push_from(from);
662+
if let Some(filter) = &join.filter {
663+
exists_select.selection(Some(self.expr_to_sql(filter)?));
664+
}
665+
for (left, right) in &join.on {
666+
exists_select.selection(Some(
667+
self.expr_to_sql(&left.clone().eq(right.clone()))?,
668+
));
669+
}
670+
exists_select.projection(vec![ast::SelectItem::UnnamedExpr(
671+
ast::Expr::Value(ast::Value::Number("1".to_string(), false)),
672+
)]);
673+
query_builder.body(Box::new(SetExpr::Select(Box::new(
674+
exists_select.build()?,
675+
))));
676+
677+
let negated = match join.join_type {
678+
JoinType::LeftSemi
679+
| JoinType::RightSemi
680+
| JoinType::LeftMark => false,
681+
JoinType::LeftAnti | JoinType::RightAnti => true,
682+
_ => unreachable!(),
683+
};
684+
let exists_expr = ast::Expr::Exists {
685+
subquery: Box::new(query_builder.build()?),
686+
negated,
687+
};
688+
if join.join_type == JoinType::LeftMark {
689+
let (table_ref, _) = right_plan.schema().qualified_field(0);
690+
let column = self
691+
.col_to_sql(&Column::new(table_ref.cloned(), "mark"))?;
692+
select.replace_mark(&column, &exists_expr);
693+
} else {
694+
select.selection(Some(exists_expr));
695+
}
696+
}
697+
JoinType::Inner
698+
| JoinType::Left
699+
| JoinType::Right
700+
| JoinType::Full => {
701+
let Ok(Some(relation)) = right_relation.build() else {
702+
return internal_err!("Failed to build right relation");
703+
};
704+
let ast_join = ast::Join {
705+
relation,
706+
global: false,
707+
join_operator: self
708+
.join_operator_to_sql(join.join_type, join_constraint)?,
709+
};
710+
let mut from = select.pop_from().unwrap();
711+
from.push_join(ast_join);
712+
select.push_from(from);
713+
}
653714
};
654-
let mut from = select.pop_from().unwrap();
655-
from.push_join(ast_join);
656-
select.push_from(from);
657715

658716
Ok(())
659717
}

0 commit comments

Comments
 (0)