Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKC-693 Support for Spark 3.3 #1351

Merged
merged 3 commits into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ trait SparkCassandraITSpecBase

def getCassandraScan(plan: SparkPlan): CassandraScan = {
plan.collectLeaves.collectFirst{
case BatchScanExec(_, cassandraScan: CassandraScan, _) => cassandraScan
case BatchScanExec(_, cassandraScan: CassandraScan, _, _) => cassandraScan
}.getOrElse(throw new IllegalArgumentException("No Cassandra Scan Found"))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ trait SaiBaseSpec extends Matchers with SparkCassandraITSpecBase {

def findCassandraScan(plan: SparkPlan): CassandraScan = {
plan match {
case BatchScanExec(_, scan: CassandraScan, _) => scan
case BatchScanExec(_, scan: CassandraScan, _, _) => scan
case filter: FilterExec => findCassandraScan(filter.child)
case project: ProjectExec => findCassandraScan(project.child)
case _ => throw new NoSuchElementException("RowDataSourceScanExec was not found in the given plan")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,10 @@ class CassandraDataSourceSpec extends SparkCassandraITFlatSpecBase with DefaultC
if (pushDown)
withClue(s"Given Dataframe plan does not contain CassandraInJoin in its predecessors.\n${df.queryExecution.sparkPlan.toString()}") {
df.queryExecution.executedPlan.collectLeaves().collectFirst{
case a@BatchScanExec(_, _: CassandraInJoin, _) => a
case a@BatchScanExec(_, _: CassandraInJoin, _, _) => a
case b@AdaptiveSparkPlanExec(_, _, _, _, _) =>
b.executedPlan.collectLeaves().collectFirst{
case a@BatchScanExec(_, _: CassandraInJoin, _) => a
case a@BatchScanExec(_, _: CassandraInJoin, _, _) => a
}
} shouldBe defined
}
Expand All @@ -288,7 +288,7 @@ class CassandraDataSourceSpec extends SparkCassandraITFlatSpecBase with DefaultC
private def assertOnAbsenceOfCassandraInJoin(df: DataFrame): Unit =
withClue(s"Given Dataframe plan contains CassandraInJoin in its predecessors.\n${df.queryExecution.sparkPlan.toString()}") {
df.queryExecution.executedPlan.collectLeaves().collectFirst{
case a@BatchScanExec(_, _: CassandraInJoin, _) => a
case a@BatchScanExec(_, _: CassandraInJoin, _, _) => a
} shouldBe empty
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
object CatalystUtil {

def findCassandraScan(sparkPlan: SparkPlan): Option[CassandraScan] = {
sparkPlan.collectFirst{ case BatchScanExec(_, scan: CassandraScan, _) => scan}
sparkPlan.collectFirst{ case BatchScanExec(_, scan: CassandraScan, _, _) => scan}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class CassandraCatalog extends CatalogPlugin
.asJava
}

override def dropNamespace(namespace: Array[String]): Boolean = {
override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = {
checkNamespace(namespace)
val keyspace = getKeyspaceMeta(connector, namespace)
val dropResult = connector.withSessionDo(session => session.execute(SchemaBuilder.dropKeyspace(keyspace.getName).asCql()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ import com.datastax.spark.connector.{ColumnRef, RowCountRef, TTL, WriteTime}
import org.apache.spark.SparkConf
import org.apache.spark.sql.cassandra.CassandraSourceRelation.{AdditionalCassandraPushDownRulesParam, InClauseToJoinWithTableConversionThreshold}
import org.apache.spark.sql.cassandra.{AnalyzedPredicates, Auto, BasicCassandraPredicatePushDown, CassandraPredicateRules, CassandraSourceRelation, DsePredicateRules, DseSearchOptimizationSetting, InClausePredicateRules, Off, On, SolrConstants, SolrPredicateRules, TimeUUIDPredicateRules}
import org.apache.spark.sql.connector.expressions.{Expression, Expressions}
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning}
import org.apache.spark.sql.sources.{EqualTo, Filter, In}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -307,7 +308,7 @@ case class CassandraScan(
}

override def outputPartitioning(): Partitioning = {
CassandraPartitioning(tableDef.partitionKey.map(_.columnName).toArray, inputPartitions.length)
new CassandraPartitioning(tableDef.partitionKey.map(_.columnName).map(Expressions.identity).toArray, inputPartitions.length)
}

override def description(): String = {
Expand All @@ -317,17 +318,7 @@ case class CassandraScan(
}
}

case class CassandraPartitioning(partitionKeys: Array[String], numPartitions: Int) extends Partitioning {

/*
Currently we only satisfy distributions which rely on all partition key values having identical
values. In the future we may be able to support some other distributions but Spark doesn't have
means to support those atm 3.0
*/
override def satisfy(distribution: Distribution): Boolean = distribution match {
case cD: ClusteredDistribution => partitionKeys.forall(cD.clusteredColumns.contains)
case _ => false
}
class CassandraPartitioning(keys: Array[Expression], numPartitions: Int) extends KeyGroupedPartitioning(keys, numPartitions) {
}

case class CassandraInJoin(
Expand Down Expand Up @@ -359,7 +350,7 @@ case class CassandraInJoin(
}

override def outputPartitioning(): Partitioning = {
CassandraPartitioning(tableDef.partitionKey.map(_.columnName).toArray, numPartitions)
new CassandraPartitioning(tableDef.partitionKey.map(_.columnName).map(Expressions.identity).toArray, numPartitions)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ object CassandraSourceRelation extends Logging {
oldPlan.transform {
case ds@DataSourceV2Relation(_: CassandraTable, _, _, _, options) =>
ds.copy(options = applyDirectJoinSetting(options, directJoinSetting))
case ds@DataSourceV2ScanRelation(_: CassandraTable, scan: CassandraScan, _) =>
case ds@DataSourceV2ScanRelation(_: CassandraTable, scan: CassandraScan, _, _) =>
ds.copy(scan = scan.copy(consolidatedConf = applyDirectJoinSetting(scan.consolidatedConf, directJoinSetting)))
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ case class CassandraDirectJoinStrategy(spark: SparkSession) extends Strategy wit
val conf = spark.sqlContext.conf

override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, _, left, right, _)
if hasValidDirectJoin(joinType, leftKeys, rightKeys, condition, left, right) =>

val (otherBranch, joinTargetBranch, buildType) = {
Expand All @@ -46,7 +46,7 @@ case class CassandraDirectJoinStrategy(spark: SparkSession) extends Strategy wit
val cassandraScanExec = getScanExec(dataSourceOptimizedPlan).get

joinTargetBranch match {
case PhysicalOperation(attributes, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _)) =>
case PhysicalOperation(attributes, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _, _)) =>
val directJoin =
CassandraDirectJoinExec(
leftKeys,
Expand Down Expand Up @@ -147,7 +147,7 @@ object CassandraDirectJoinStrategy extends Logging {
*/
def getScanExec(plan: SparkPlan): Option[BatchScanExec] = {
plan.collectFirst {
case exec @ BatchScanExec(_, _: CassandraScan, _) => exec
case exec @ BatchScanExec(_, _: CassandraScan, _, _) => exec
}
}

Expand All @@ -170,7 +170,7 @@ object CassandraDirectJoinStrategy extends Logging {
def getDSV2CassandraRelation(plan: LogicalPlan): Option[DataSourceV2ScanRelation] = {
val children = plan.collectLeaves()
if (children.length == 1) {
plan.collectLeaves().collectFirst { case ds @ DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _) => ds }
plan.collectLeaves().collectFirst { case ds @ DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _, _) => ds }
} else {
None
}
Expand All @@ -183,7 +183,7 @@ object CassandraDirectJoinStrategy extends Logging {
def getCassandraTable(plan: LogicalPlan): Option[CassandraTable] = {
val children = plan.collectLeaves()
if (children.length == 1) {
children.collectFirst { case DataSourceV2ScanRelation(DataSourceV2Relation(table: CassandraTable, _, _, _, _), _, _) => table }
children.collectFirst { case DataSourceV2ScanRelation(DataSourceV2Relation(table: CassandraTable, _, _, _, _), _, _, _) => table }
} else {
None
}
Expand All @@ -192,7 +192,7 @@ object CassandraDirectJoinStrategy extends Logging {
def getCassandraScan(plan: LogicalPlan): Option[CassandraScan] = {
val children = plan.collectLeaves()
if (children.length == 1) {
plan.collectLeaves().collectFirst { case DataSourceV2ScanRelation(_: DataSourceV2Relation, cs: CassandraScan, _) => cs }
plan.collectLeaves().collectFirst { case DataSourceV2ScanRelation(_: DataSourceV2Relation, cs: CassandraScan, _, _) => cs }
} else {
None
}
Expand All @@ -204,8 +204,8 @@ object CassandraDirectJoinStrategy extends Logging {
*/
def hasCassandraChild[T <: QueryPlan[T]](plan: T): Boolean = {
plan.children.size == 1 && plan.children.exists {
case DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _) => true
case BatchScanExec(_, _: CassandraScan, _) => true
case DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), _, _, _) => true
case BatchScanExec(_, _: CassandraScan, _, _) => true
case _ => false
}
}
Expand Down Expand Up @@ -235,7 +235,7 @@ object CassandraDirectJoinStrategy extends Logging {
def reorderPlan(plan: SparkPlan, directJoin: CassandraDirectJoinExec): SparkPlan = {
val reordered = plan match {
//This may be the only node in the Plan
case BatchScanExec(_, _: CassandraScan, _) => directJoin
case BatchScanExec(_, _: CassandraScan, _, _) => directJoin
// Plan has children
case normalPlan => normalPlan.transform {
case penultimate if hasCassandraChild(penultimate) =>
Expand Down Expand Up @@ -292,7 +292,7 @@ object CassandraDirectJoinStrategy extends Logging {
plan match {
case PhysicalOperation(
attributes, _,
DataSourceV2ScanRelation(DataSourceV2Relation(cassandraTable: CassandraTable, _, _, _, _), _, _)) =>
DataSourceV2ScanRelation(DataSourceV2Relation(cassandraTable: CassandraTable, _, _, _, _), _, _, _)) =>

val joinKeysExprId = joinKeys.collect{ case attributeReference: AttributeReference => attributeReference.exprId }

Expand Down Expand Up @@ -324,7 +324,7 @@ object CassandraDirectJoinStrategy extends Logging {
*/
def containsSafePlans(plan: LogicalPlan): Boolean = {
plan match {
case PhysicalOperation(_, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), scan: CassandraScan, _))
case PhysicalOperation(_, _, DataSourceV2ScanRelation(DataSourceV2Relation(_: CassandraTable, _, _, _, _), scan: CassandraScan, _, _))
if getDirectJoinSetting(scan.consolidatedConf) != AlwaysOff => true
case _ => false
}
Expand Down
2 changes: 1 addition & 1 deletion project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ object Versions {
// and install in a local Maven repository. This is all done automatically, however it will work
// only on Unix/OSX operating system. Windows users have to build and install Spark manually if the
// desired version is not yet published into a public Maven repository.
val ApacheSpark = "3.2.1"
val ApacheSpark = "3.3.1"
val SparkJetty = "9.3.27.v20190418"
tobiasstadler marked this conversation as resolved.
Show resolved Hide resolved
val SolrJ = "8.3.0"

Expand Down