Skip to content

Commit 50e00b7

Browse files
vicennialdongjoon-hyun
authored andcommitted
[SPARK-51425][CONNECT] Add client API to set custom operation_id
### What changes were proposed in this pull request? Adds an additional optional parameter to the Scala/Python APIs to allow a user to explicitly set an `operation_id`. ### Why are the changes needed? The Spark Connect [protocol](https://github.com/apache/spark/blob/44e751f243a5a7be8e64c306e40ddc1502f72710/sql/connect/common/src/main/protobuf/spark/connect/base.proto#L318) allows the client to set an optional operation ID. However, there is no API that lets a user set this explicitly (although the client does set it in the case of Reaatchable Execution). ### Does this PR introduce _any_ user-facing change? Yes. Scala usage: ```scala client.execute(plan, operationId = Some("10a4c38e-7e87-40ee-9d6f-60ff0751e63b)) ``` Python usage: ```python req = client._execute_plan_request_with_metadata(operation_id="10a4c38e-7e87-40ee-9d6f-60ff0751e63b") # continue using the req as usual ``` ### How was this patch tested? New unit tests ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50191 from vicennial/customOpId. Authored-by: vicennial <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent a5b4d81 commit 50e00b7

File tree

4 files changed

+62
-2
lines changed

4 files changed

+62
-2
lines changed

python/pyspark/sql/connect/client/core.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,9 @@ def token(self) -> Optional[str]:
12281228
"""
12291229
return self._builder.token
12301230

1231-
def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
1231+
def _execute_plan_request_with_metadata(
1232+
self, operation_id: Optional[str] = None
1233+
) -> pb2.ExecutePlanRequest:
12321234
req = pb2.ExecutePlanRequest(
12331235
session_id=self._session_id,
12341236
client_type=self._builder.userAgent,
@@ -1238,6 +1240,15 @@ def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
12381240
req.client_observed_server_side_session_id = self._server_session_id
12391241
if self._user_id:
12401242
req.user_context.user_id = self._user_id
1243+
if operation_id is not None:
1244+
try:
1245+
uuid.UUID(operation_id, version=4)
1246+
except ValueError as ve:
1247+
raise PySparkValueError(
1248+
errorClass="INVALID_OPERATION_UUID_ID",
1249+
messageParameters={"arg_name": "operation_id", "origin": str(ve)},
1250+
)
1251+
req.operation_id = operation_id
12411252
return req
12421253

12431254
def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:

python/pyspark/sql/tests/connect/client/test_client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
137137
self.req = req
138138
resp = proto.ExecutePlanResponse()
139139
resp.session_id = self._session_id
140+
resp.operation_id = req.operation_id
140141

141142
pdf = pd.DataFrame(data={"col1": [1, 2]})
142143
schema = pa.Schema.from_pandas(pdf)
@@ -255,6 +256,16 @@ def test_channel_builder_with_session(self):
255256
client = SparkConnectClient(chan)
256257
self.assertEqual(client._session_id, chan.session_id)
257258

259+
def test_custom_operation_id(self):
260+
client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False)
261+
mock = MockService(client._session_id)
262+
client._stub = mock
263+
req = client._execute_plan_request_with_metadata(
264+
operation_id="10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
265+
)
266+
for resp in client._stub.ExecutePlan(req, metadata=None):
267+
assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
268+
258269

259270
@unittest.skipIf(not should_test_connect, connect_requirement_message)
260271
class SparkConnectClientReattachTestCase(unittest.TestCase):

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,26 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
635635
observer.onNext(proto.AddArtifactsRequest.newBuilder().build())
636636
observer.onCompleted()
637637
}
638+
639+
test("client can set a custom operation id for ExecutePlan requests") {
640+
startDummyServer(0)
641+
client = SparkConnectClient
642+
.builder()
643+
.connectionString(s"sc://localhost:${server.getPort}")
644+
.enableReattachableExecute()
645+
.build()
646+
647+
val plan = buildPlan("select * from range(10000000)")
648+
val dummyUUID = "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
649+
val iter = client.execute(plan, operationId = Some(dummyUUID))
650+
val reattachableIter =
651+
ExecutePlanResponseReattachableIterator.fromIterator(iter)
652+
assert(reattachableIter.operationId == dummyUUID)
653+
while (reattachableIter.hasNext) {
654+
val resp = reattachableIter.next()
655+
assert(resp.getOperationId == dummyUUID)
656+
}
657+
}
638658
}
639659

640660
class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase {

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,25 @@ private[sql] class SparkConnectClient(
110110
bstub.analyzePlan(request)
111111
}
112112

113+
private def isValidUUID(uuid: String): Boolean = {
114+
try {
115+
UUID.fromString(uuid)
116+
true
117+
} catch {
118+
case _: IllegalArgumentException => false
119+
}
120+
}
121+
113122
/**
114123
* Execute the plan and return response iterator.
115124
*
116125
* It returns CloseableIterator. For resource management it is better to close it once you are
117126
* done. If you don't close it, it and the underlying data will be cleaned up once the iterator
118127
* is garbage collected.
119128
*/
120-
def execute(plan: proto.Plan): CloseableIterator[proto.ExecutePlanResponse] = {
129+
def execute(
130+
plan: proto.Plan,
131+
operationId: Option[String] = None): CloseableIterator[proto.ExecutePlanResponse] = {
121132
artifactManager.uploadAllClassFileArtifacts()
122133
val request = proto.ExecutePlanRequest
123134
.newBuilder()
@@ -127,6 +138,13 @@ private[sql] class SparkConnectClient(
127138
.setClientType(userAgent)
128139
.addAllTags(tags.get.toSeq.asJava)
129140
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
141+
operationId.foreach { opId =>
142+
require(
143+
isValidUUID(opId),
144+
s"Invalid operationId: $opId. The id must be an UUID string of " +
145+
"the format `00112233-4455-6677-8899-aabbccddeeff`")
146+
request.setOperationId(opId)
147+
}
130148
if (configuration.useReattachableExecute) {
131149
bstub.executePlanReattachable(request.build())
132150
} else {

0 commit comments

Comments
 (0)