diff --git a/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogMetricsSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogMetricsSpec.scala new file mode 100644 index 000000000..63ceb40fc --- /dev/null +++ b/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogMetricsSpec.scala @@ -0,0 +1,79 @@ +/* + * Copyright DataStax, Inc. + * + * Please see the included license file for details. + */ + +package com.datastax.spark.connector.datasource + +import scala.collection.mutable +import com.datastax.spark.connector._ +import com.datastax.spark.connector.cluster.DefaultCluster +import com.datastax.spark.connector.cql.CassandraConnector +import org.scalatest.BeforeAndAfterEach +import com.datastax.spark.connector.datasource.CassandraCatalog +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import com.datastax.spark.connector.cql.CassandraConnector +import org.apache.spark.sql.SparkSession + + +class CassandraCatalogMetricsSpec extends SparkCassandraITFlatSpecBase with DefaultCluster with BeforeAndAfterEach { + + override lazy val conn = CassandraConnector(defaultConf) + + override lazy val spark = SparkSession.builder() + .config(sparkConf + // Enable Codahale/Dropwizard metrics + .set("spark.metrics.conf.executor.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource") + .set("spark.metrics.conf.driver.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource") + .set("spark.sql.sources.useV1SourceList", "") + .set("spark.sql.defaultCatalog", "cassandra") + .set("spark.sql.catalog.cassandra", classOf[CassandraCatalog].getCanonicalName) + ) + .withExtensions(new CassandraSparkExtensions).getOrCreate().newSession() + + override def beforeClass { + conn.withSessionDo { session => + session.execute(s"CREATE KEYSPACE IF NOT EXISTS $ks WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }") + session.execute(s"CREATE TABLE IF NOT EXISTS $ks.leftjoin (key INT, x INT, PRIMARY KEY (key))") + for (i <- 1 to 1000 * 10) { + session.execute(s"INSERT INTO $ks.leftjoin (key, x) values ($i, $i)") + } + } + } + + var readRowCount: Long = 0 + var readByteCount: Long = 0 + + it should "update Codahale read metrics for SELECT queries" in { + val df = spark.sql(s"SELECT x FROM $ks.leftjoin LIMIT 2") + val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter => + val tc = org.apache.spark.TaskContext.get() + val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc) + Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount)) + } + + val metrics = metricsRDD.collect() + readRowCount = metrics.map(_._1).sum - readRowCount + readByteCount = metrics.map(_._2).sum - readByteCount + + assert(readRowCount > 0) + assert(readByteCount == readRowCount * 4) // 4 bytes per INT result + } + + it should "update Codahale read metrics for COUNT queries" in { + val df = spark.sql(s"SELECT COUNT(*) FROM $ks.leftjoin") + val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter => + val tc = org.apache.spark.TaskContext.get() + val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc) + Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount)) + } + + val metrics = metricsRDD.collect() + readRowCount = metrics.map(_._1).sum - readRowCount + readByteCount = metrics.map(_._2).sum - readByteCount + + assert(readRowCount > 0) + assert(readByteCount == readRowCount * 8) // 8 bytes per COUNT result + } +} diff --git a/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogTableReadSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogTableReadSpec.scala index cd154a703..8fc31ca0f 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogTableReadSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/datasource/CassandraCatalogTableReadSpec.scala @@ -60,17 +60,17 @@ class CassandraCatalogTableReadSpec extends CassandraCatalogSpecBase { it should "handle count pushdowns" in { setupBasicTable() val request = spark.sql(s"""SELECT COUNT(*) from $defaultKs.$testTable""") - val reader = request + var factory = request .queryExecution .executedPlan .collectFirst { - case batchScanExec: BatchScanExec=> batchScanExec.readerFactory.createReader(EmptyInputPartition) - case adaptiveSparkPlanExec: AdaptiveSparkPlanExec => adaptiveSparkPlanExec.executedPlan.collectLeaves().collectFirst{ - case batchScanExec: BatchScanExec=> batchScanExec.readerFactory.createReader(EmptyInputPartition) - }.get + case batchScanExec: BatchScanExec=> batchScanExec.readerFactory + case adaptiveSparkPlanExec: AdaptiveSparkPlanExec => adaptiveSparkPlanExec.executedPlan.collectLeaves().collectFirst{ + case batchScanExec: BatchScanExec=> batchScanExec.readerFactory + }.get } - reader.get.isInstanceOf[CassandraCountPartitionReader] should be (true) + factory.get.asInstanceOf[CassandraScanPartitionReaderFactory].isCountQuery should be (true) request.collect()(0).get(0) should be (101) } diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala index 8af63070c..72b9c3a67 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraInJoinReaderFactory.scala @@ -12,6 +12,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.sources.In import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.metrics.InputMetricsUpdater +import org.apache.spark.TaskContext import scala.util.{Failure, Success} @@ -62,6 +64,7 @@ abstract class CassandraBaseInJoinReader( protected val maybeRateLimit = JoinHelper.maybeRateLimit(readConf) protected val requestsPerSecondRateLimiter = JoinHelper.requestsPerSecondRateLimiter(readConf) + protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf) protected def pairWithRight(left: CassandraRow): SettableFuture[Iterator[(CassandraRow, InternalRow)]] = { val resultFuture = SettableFuture.create[Iterator[(CassandraRow, InternalRow)]] val leftSide = Iterator.continually(left) @@ -69,9 +72,10 @@ abstract class CassandraBaseInJoinReader( queryExecutor.executeAsync(bsb.bind(left).executeAs(readConf.executeAs)).onComplete { case Success(rs) => val resultSet = new PrefetchingResultSetIterator(rs) + val iteratorWithMetrics = resultSet.map(metricsUpdater.updateMetrics) /* This is a much less than ideal place to actually rate limit, we are buffering these futures this means we will most likely exceed our threshold*/ - val throttledIterator = resultSet.map(maybeRateLimit) + val throttledIterator = iteratorWithMetrics.map(maybeRateLimit) val rightSide = throttledIterator.map(rowReader.read(_, rowMetadata)) resultFuture.set(leftSide.zip(rightSide)) case Failure(throwable) => @@ -103,6 +107,7 @@ abstract class CassandraBaseInJoinReader( override def get(): InternalRow = currentRow override def close(): Unit = { + metricsUpdater.finish() session.close() } } diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala index 6bea2c2f4..f44ce4011 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraScanPartitionReaderFactory.scala @@ -12,6 +12,8 @@ import com.datastax.spark.connector.util.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.metrics.InputMetricsUpdater +import org.apache.spark.TaskContext case class CassandraScanPartitionReaderFactory( connector: CassandraConnector, @@ -20,10 +22,12 @@ case class CassandraScanPartitionReaderFactory( readConf: ReadConf, queryParts: CqlQueryParts) extends PartitionReaderFactory { + def isCountQuery: Boolean = queryParts.selectedColumnRefs.contains(RowCountRef) + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val cassandraPartition = partition.asInstanceOf[CassandraPartition[Any, _ <: Token[Any]]] - if (queryParts.selectedColumnRefs.contains(RowCountRef)) { + if (isCountQuery) { //Count Pushdown CassandraCountPartitionReader( connector, @@ -61,6 +65,8 @@ abstract class CassandraPartitionReaderBase protected val rowIterator = getIterator() protected var lastRow: InternalRow = InternalRow() + protected val metricsUpdater = InputMetricsUpdater(TaskContext.get(), readConf) + override def next(): Boolean = { if (rowIterator.hasNext) { lastRow = rowIterator.next() @@ -73,6 +79,7 @@ abstract class CassandraPartitionReaderBase override def get(): InternalRow = lastRow override def close(): Unit = { + metricsUpdater.finish() scanner.close() } @@ -107,7 +114,8 @@ abstract class CassandraPartitionReaderBase tokenRanges.iterator.flatMap { range => val scanResult = ScanHelper.fetchTokenRange(scanner, tableDef, queryParts, range, readConf.consistencyLevel, readConf.fetchSizeInRows) val meta = scanResult.metadata - scanResult.rows.map(rowReader.read(_, meta)) + val iteratorWithMetrics = scanResult.rows.map(metricsUpdater.updateMetrics) + iteratorWithMetrics.map(rowReader.read(_, meta)) } } diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala index e1bb329a9..bbbc465f1 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CasssandraDriverDataWriterFactory.scala @@ -7,6 +7,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory, WriterCommitMessage} import org.apache.spark.sql.types.StructType +import org.apache.spark.metrics.OutputMetricsUpdater +import org.apache.spark.TaskContext case class CassandraDriverDataWriterFactory( connector: CassandraConnector, @@ -36,22 +38,31 @@ case class CassandraDriverDataWriter( private val columns = SomeColumns(inputSchema.fieldNames.map(name => ColumnName(name)): _*) - private val writer = + private val metricsUpdater = OutputMetricsUpdater(TaskContext.get(), writeConf) + + private val asycWriter = TableWriter(connector, tableDef, columns, writeConf, false)(unsafeRowWriterFactory) .getAsyncWriter() + private val writer = asycWriter.copy( + successHandler = Some(metricsUpdater.batchFinished(success = true, _, _, _)), + failureHandler = Some(metricsUpdater.batchFinished(success = false, _, _, _))) + override def write(record: InternalRow): Unit = writer.write(record) override def commit(): WriterCommitMessage = { + metricsUpdater.finish() writer.close() CassandraCommitMessage() } override def abort(): Unit = { + metricsUpdater.finish() writer.close() } override def close(): Unit = { + metricsUpdater.finish() //Our proxy Session Handler handles double closes by ignoring them so this is fine writer.close() }