diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala index 7150c81ad64ec..253e23c51e181 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression, VariableReference} -import org.apache.spark.sql.catalyst.plans.logical.{CreateView, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{CreateView, CTEInChildren, LogicalPlan, WithCTE} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.errors.QueryCompilationErrors @@ -59,8 +59,8 @@ class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch] private def apply0( plan: LogicalPlan, - referredTempVars: Option[mutable.ArrayBuffer[Seq[String]]] = None): LogicalPlan = - plan.resolveOperatorsUpWithPruning(_.containsAnyPattern( + referredTempVars: Option[mutable.ArrayBuffer[Seq[String]]] = None): LogicalPlan = { + val resolved = plan.resolveOperatorsUpWithPruning(_.containsAnyPattern( UNRESOLVED_IDENTIFIER, PLAN_WITH_UNRESOLVED_IDENTIFIER)) { case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved => @@ -82,6 +82,13 @@ class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch] IdentifierResolution.evalIdentifierExpr(e.identifierExpr), e.otherExprs) } } + // When `PlanWithUnresolvedIdentifier` materializes into a `CTEInChildren` (e.g. + // `InsertIntoStatement`) inside an outer `WithCTE`, push the CTE defs into the command's + // children - restoring the invariant from `CTESubstitution.withCTEDefs`. + resolved.resolveOperatorsUpWithPruning(_.containsPattern(CTE)) { + case WithCTE(c: CTEInChildren, cteDefs) => c.withCTEDefs(cteDefs) + } + } private def collectTemporaryVariablesInLogicalPlan(child: LogicalPlan): Seq[Seq[String]] = { def collectTempVars(child: LogicalPlan): Seq[Seq[String]] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index d6b22431e854e..e114d9b22fe86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -22,7 +22,7 @@ import java.time.{Instant, LocalDate, LocalDateTime, ZoneId} import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.Limit +import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, Limit, WithCTE} import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.functions.{array, call_function, lit, map, map_from_arrays, map_from_entries, str_to_map, struct} @@ -2460,4 +2460,53 @@ class ParametersSuite extends SharedSparkSession { spark.sql("SELECT 1", Array.empty[Any]), Row(1)) } + + test("WITH ... INSERT OVERWRITE TABLE IDENTIFIER(:p) SELECT ... FROM cte") { + withTable("t_cte_overwrite") { + sql("CREATE TABLE t_cte_overwrite (a INT) USING PARQUET") + sql("INSERT INTO t_cte_overwrite VALUES (10)") + spark.sql( + """WITH transformation AS (SELECT 1 AS a) + |INSERT OVERWRITE TABLE IDENTIFIER(:tname) + |SELECT * FROM transformation""".stripMargin, + Map("tname" -> "t_cte_overwrite")) + checkAnswer(spark.table("t_cte_overwrite"), Row(1)) + } + } + + test("WITH ... INSERT INTO IDENTIFIER(:p) SELECT ... FROM cte") { + withTable("t_cte_into") { + sql("CREATE TABLE t_cte_into (a INT) USING PARQUET") + spark.sql( + """WITH transformation AS (SELECT 7 AS a) + |INSERT INTO IDENTIFIER(:tname) + |SELECT * FROM transformation""".stripMargin, + Map("tname" -> "t_cte_into")) + checkAnswer(spark.table("t_cte_into"), Row(7)) + } + } + + test("Analyzed plan does not leave WithCTE wrapping a CTEInChildren " + + "when IDENTIFIER(:p) is the INSERT target") { + // After analysis, the WithCTE must be pushed into the InsertIntoStatement's query child + // (CTEInChildren placement), not left wrapping the command. The wrapped shape produces an + // orphan CTERelationRef whose CTERelationDef is in a now-detached WithCTE; any downstream + // pass that re-analyses the subtree below the command then trips InlineCTE.buildCTEMap with + // NoSuchElementException: key not found. + withTable("t_cte_shape") { + sql("CREATE TABLE t_cte_shape (a INT) USING PARQUET") + val df = spark.sql( + """WITH transformation AS (SELECT 1 AS a) + |INSERT INTO IDENTIFIER(:tname) + |SELECT * FROM transformation""".stripMargin, + Map("tname" -> "t_cte_shape")) + val analyzed = df.queryExecution.analyzed + analyzed match { + case WithCTE(_: CTEInChildren, _) => + fail(s"WithCTE must be pushed into the CTEInChildren's children, not left " + + s"wrapping the command. Analyzed plan:\n$analyzed") + case _ => // expected + } + } + } }