Skip to content

Commit 8dca5fc

Browse files
committed
[SPARK-53656][SQL] Refactor MemoryStream to use SparkSession instead of SQLContext
1 parent 71c67b0 commit 8dca5fc

22 files changed

+106
-86
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import javax.annotation.concurrent.GuardedBy
2424
import scala.collection.mutable.ListBuffer
2525

2626
import org.apache.spark.internal.Logging
27-
import org.apache.spark.sql.{Encoder, SQLContext}
27+
import org.apache.spark.sql.{Encoder, SparkSession}
2828
import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
3030
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
@@ -47,28 +47,28 @@ object MemoryStream {
4747
protected val currentBlockId = new AtomicInteger(0)
4848
protected val memoryStreamId = new AtomicInteger(0)
4949

50-
def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
51-
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
50+
def apply[A : Encoder](implicit sparkSession: SparkSession): MemoryStream[A] =
51+
new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
5252

53-
def apply[A : Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] =
54-
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, Some(numPartitions))
53+
def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession): MemoryStream[A] =
54+
new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, Some(numPartitions))
5555
}
5656

5757
/**
5858
* A base class for memory stream implementations. Supports adding data and resetting.
5959
*/
60-
abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends SparkDataStream {
60+
abstract class MemoryStreamBase[A : Encoder](sparkSession: SparkSession) extends SparkDataStream {
6161
val encoder = encoderFor[A]
6262
protected val attributes = toAttributes(encoder.schema)
6363

6464
protected lazy val toRow: ExpressionEncoder.Serializer[A] = encoder.createSerializer()
6565

6666
def toDS(): Dataset[A] = {
67-
Dataset[A](sqlContext.sparkSession, logicalPlan)
67+
Dataset[A](sparkSession, logicalPlan)
6868
}
6969

7070
def toDF(): DataFrame = {
71-
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
71+
Dataset.ofRows(sparkSession, logicalPlan)
7272
}
7373

7474
def addData(data: A*): OffsetV2 = {
@@ -156,9 +156,9 @@ class MemoryStreamScanBuilder(stream: MemoryStreamBase[_]) extends ScanBuilder w
156156
*/
157157
case class MemoryStream[A : Encoder](
158158
id: Int,
159-
sqlContext: SQLContext,
159+
sparkSession: SparkSession,
160160
numPartitions: Option[Int] = None)
161-
extends MemoryStreamBase[A](sqlContext)
161+
extends MemoryStreamBase[A](sparkSession)
162162
with MicroBatchStream
163163
with SupportsTriggerAvailableNow
164164
with Logging {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.json4s.jackson.Serialization
2727

2828
import org.apache.spark.{SparkEnv, TaskContext}
2929
import org.apache.spark.rpc.RpcEndpointRef
30-
import org.apache.spark.sql.{Encoder, SQLContext}
30+
import org.apache.spark.sql.{Encoder, SparkSession}
3131
import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
3333
import org.apache.spark.sql.connector.read.InputPartition
@@ -44,8 +44,11 @@ import org.apache.spark.util.RpcUtils
4444
* ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at
4545
* the specified offset within the list, or null if that offset doesn't yet have a record.
4646
*/
47-
class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2)
48-
extends MemoryStreamBase[A](sqlContext) with ContinuousStream {
47+
class ContinuousMemoryStream[A : Encoder](
48+
id: Int,
49+
sparkSession: SparkSession,
50+
numPartitions: Int = 2)
51+
extends MemoryStreamBase[A](sparkSession) with ContinuousStream {
4952

5053
private implicit val formats: Formats = Serialization.formats(NoTypeHints)
5154

@@ -112,11 +115,11 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
112115
object ContinuousMemoryStream {
113116
protected val memoryStreamId = new AtomicInteger(0)
114117

115-
def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
116-
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
118+
def apply[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] =
119+
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
117120

118-
def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
119-
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1)
121+
def singlePartition[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] =
122+
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1)
120123
}
121124

122125
/**

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,7 @@ class DatasetSuite extends QueryTest
10121012
assert(err.getMessage.contains("An Observation can be used with a Dataset only once"))
10131013

10141014
// streaming datasets are not supported
1015-
val streamDf = new MemoryStream[Int](0, sqlContext).toDF()
1015+
val streamDf = new MemoryStream[Int](0, spark).toDF()
10161016
val streamObservation = Observation("stream")
10171017
val streamErr = intercept[IllegalArgumentException] {
10181018
streamDf.observe(streamObservation, avg($"value").cast("int").as("avg_val"))

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
6767

6868
class MemoryStreamCapture[A: Encoder](
6969
id: Int,
70-
sqlContext: SQLContext,
70+
sparkSession: SparkSession,
7171
numPartitions: Option[Int] = None)
72-
extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
72+
extends MemoryStream[A](id, sparkSession, numPartitions = numPartitions) {
7373

7474
val commits = new ListBuffer[streaming.Offset]()
7575
val commitThreads = new ListBuffer[Thread]()
@@ -136,7 +136,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
136136
test("async WAL commits recovery") {
137137
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
138138

139-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
139+
val inputData = new MemoryStream[Int](id = 0, spark)
140140
val ds = inputData.toDF()
141141

142142
var index = 0
@@ -204,7 +204,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
204204
}
205205

206206
test("async WAL commits turn on and off") {
207-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
207+
val inputData = new MemoryStream[Int](id = 0, spark)
208208
val ds = inputData.toDS()
209209

210210
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -308,7 +308,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
308308
}
309309

310310
test("Fail with once trigger") {
311-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
311+
val inputData = new MemoryStream[Int](id = 0, spark)
312312
val ds = inputData.toDF()
313313

314314
val e = intercept[IllegalArgumentException] {
@@ -323,7 +323,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
323323

324324
test("Fail with available now trigger") {
325325

326-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
326+
val inputData = new MemoryStream[Int](id = 0, spark)
327327
val ds = inputData.toDF()
328328

329329
val e = intercept[IllegalArgumentException] {
@@ -339,7 +339,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
339339
test("switching between async wal commit enabled and trigger once") {
340340
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
341341

342-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
342+
val inputData = new MemoryStream[Int](id = 0, spark)
343343
val ds = inputData.toDF()
344344

345345
var index = 0
@@ -500,7 +500,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
500500
test("switching between async wal commit enabled and available now") {
501501
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
502502

503-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
503+
val inputData = new MemoryStream[Int](id = 0, spark)
504504
val ds = inputData.toDF()
505505

506506
var index = 0
@@ -669,7 +669,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
669669
}
670670

671671
def testAsyncWriteErrorsAlreadyExists(path: String): Unit = {
672-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
672+
val inputData = new MemoryStream[Int](id = 0, spark)
673673
val ds = inputData.toDS()
674674
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
675675

@@ -720,7 +720,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
720720
}
721721

722722
def testAsyncWriteErrorsPermissionsIssue(path: String): Unit = {
723-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
723+
val inputData = new MemoryStream[Int](id = 0, spark)
724724
val ds = inputData.toDS()
725725
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
726726
val commitDir = new File(checkpointLocation + path)
@@ -778,7 +778,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
778778

779779
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
780780

781-
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
781+
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
782782

783783
val ds = inputData.toDF()
784784

@@ -852,7 +852,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
852852
}
853853

854854
test("interval commits and recovery") {
855-
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
855+
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
856856
val ds = inputData.toDS()
857857

858858
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -934,7 +934,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
934934
}
935935

936936
test("recovery when first offset is not zero and not commit log entries") {
937-
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
937+
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
938938
val ds = inputData.toDS()
939939

940940
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -961,7 +961,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
961961
/**
962962
* start new stream
963963
*/
964-
val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
964+
val inputData2 = new MemoryStreamCapture[Int](id = 0, spark)
965965
val ds2 = inputData2.toDS()
966966
testStream(ds2, extraOptions = Map(
967967
ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
@@ -995,7 +995,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
995995
}
996996

997997
test("recovery non-contiguous log") {
998-
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
998+
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
999999
val ds = inputData.toDS()
10001000

10011001
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -1088,7 +1088,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
10881088
}
10891089

10901090
test("Fail on pipelines using unsupported sinks") {
1091-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
1091+
val inputData = new MemoryStream[Int](id = 0, spark)
10921092
val ds = inputData.toDF()
10931093

10941094
val e = intercept[IllegalArgumentException] {
@@ -1109,7 +1109,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
11091109

11101110
withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "false") {
11111111
withTempDir { checkpointLocation =>
1112-
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
1112+
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
11131113
val ds = inputData.toDS()
11141114

11151115
val clock = new StreamManualClock
@@ -1243,7 +1243,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
12431243
test("with async log purging") {
12441244
withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") {
12451245
withTempDir { checkpointLocation =>
1246-
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
1246+
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
12471247
val ds = inputData.toDS()
12481248

12491249
val clock = new StreamManualClock
@@ -1381,7 +1381,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
13811381
}
13821382

13831383
test("test multiple gaps in offset and commit logs") {
1384-
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
1384+
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
13851385
val ds = inputData.toDS()
13861386

13871387
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -1427,7 +1427,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
14271427
/**
14281428
* start new stream
14291429
*/
1430-
val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
1430+
val inputData2 = new MemoryStreamCapture[Int](id = 0, spark)
14311431
val ds2 = inputData2.toDS()
14321432
testStream(ds2, extraOptions = Map(
14331433
ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
@@ -1460,7 +1460,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
14601460
}
14611461

14621462
test("recovery when gaps exist in offset and commit log") {
1463-
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
1463+
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
14641464
val ds = inputData.toDS()
14651465

14661466
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -1494,7 +1494,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
14941494
/**
14951495
* start new stream
14961496
*/
1497-
val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
1497+
val inputData2 = new MemoryStreamCapture[Int](id = 0, spark)
14981498
val ds2 = inputData2.toDS()
14991499
testStream(ds2, extraOptions = Map(
15001500
ASYNC_PROGRESS_TRACKING_ENABLED -> "true",

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter with Match
5454
test("async log purging") {
5555
withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") {
5656
withTempDir { checkpointLocation =>
57-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
57+
val inputData = new MemoryStream[Int](id = 0, spark)
5858
val ds = inputData.toDS()
5959
testStream(ds)(
6060
StartStream(checkpointLocation = checkpointLocation.getCanonicalPath),
@@ -99,7 +99,7 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter with Match
9999
test("error notifier test") {
100100
withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") {
101101
withTempDir { checkpointLocation =>
102-
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
102+
val inputData = new MemoryStream[Int](id = 0, spark)
103103
val ds = inputData.toDS()
104104
val e = intercept[StreamingQueryException] {
105105

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
123123
test("query stop deactivates related store providers") {
124124
var coordRef: StateStoreCoordinatorRef = null
125125
try {
126-
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
126+
implicit val spark: SparkSession = SparkSession.builder().sparkContext(sc).getOrCreate()
127127
SparkSession.setActiveSession(spark)
128128
import spark.implicits._
129129
coordRef = spark.streams.stateStoreCoordinator
130-
implicit val sqlContext = spark.sqlContext
131130
spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")
132131

133132
// Start a query and run a batch to load state stores
@@ -254,7 +253,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
254253
) {
255254
case (coordRef, spark) =>
256255
import spark.implicits._
257-
implicit val sqlContext = spark.sqlContext
256+
implicit val sparkSession: SparkSession = spark
258257
val inputData = MemoryStream[Int]
259258
val query = setUpStatefulQuery(inputData, "query")
260259
// Add, commit, and wait multiple times to force snapshot versions and time difference
@@ -290,7 +289,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
290289
) {
291290
case (coordRef, spark) =>
292291
import spark.implicits._
293-
implicit val sqlContext = spark.sqlContext
292+
implicit val sparkSession: SparkSession = spark
294293
// Start a join query and run some data to force snapshot uploads
295294
val input1 = MemoryStream[Int]
296295
val input2 = MemoryStream[Int]
@@ -333,7 +332,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
333332
) {
334333
case (coordRef, spark) =>
335334
import spark.implicits._
336-
implicit val sqlContext = spark.sqlContext
335+
implicit val sparkSession: SparkSession = spark
337336
// Start and run two queries together with some data to force snapshot uploads
338337
val input1 = MemoryStream[Int]
339338
val input2 = MemoryStream[Int]
@@ -400,7 +399,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
400399
) {
401400
case (coordRef, spark) =>
402401
import spark.implicits._
403-
implicit val sqlContext = spark.sqlContext
402+
implicit val sparkSession: SparkSession = spark
404403
// Start a query and run some data to force snapshot uploads
405404
val inputData = MemoryStream[Int]
406405
val query = setUpStatefulQuery(inputData, "query")
@@ -444,7 +443,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
444443
) {
445444
case (coordRef, spark) =>
446445
import spark.implicits._
447-
implicit val sqlContext = spark.sqlContext
446+
implicit val sparkSession: SparkSession = spark
448447
// Start a query and run some data to force snapshot uploads
449448
val inputData = MemoryStream[Int]
450449
val query = setUpStatefulQuery(inputData, "query")

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,9 +1206,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
12061206
test("SPARK-21145: Restarted queries create new provider instances") {
12071207
try {
12081208
val checkpointLocation = Utils.createTempDir().getAbsoluteFile
1209-
val spark = SparkSession.builder().master("local[2]").getOrCreate()
1209+
implicit val spark: SparkSession = SparkSession.builder().master("local[2]").getOrCreate()
12101210
SparkSession.setActiveSession(spark)
1211-
implicit val sqlContext = spark.sqlContext
12121211
spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")
12131212
import spark.implicits._
12141213
val inputData = MemoryStream[Int]

0 commit comments

Comments
 (0)