Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def __init__(
semantic_hash: Optional[int],
storage_level: Optional[StorageLevel],
ddl_string: Optional[str],
num_partitions: Optional[int],
):
self.schema = schema
self.explain_string = explain_string
Expand All @@ -552,6 +553,7 @@ def __init__(
self.semantic_hash = semantic_hash
self.storage_level = storage_level
self.ddl_string = ddl_string
self.num_partitions = num_partitions

@classmethod
def fromProto(cls, pb: Any) -> "AnalyzeResult":
Expand All @@ -567,6 +569,7 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
semantic_hash: Optional[int] = None
storage_level: Optional[StorageLevel] = None
ddl_string: Optional[str] = None
num_partitions: Optional[int] = None

if pb.HasField("schema"):
schema = types.proto_schema_to_pyspark_data_type(pb.schema.schema)
Expand Down Expand Up @@ -596,6 +599,8 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
storage_level = proto_to_storage_level(pb.get_storage_level.storage_level)
elif pb.HasField("json_to_ddl"):
ddl_string = pb.json_to_ddl.ddl_string
elif pb.HasField("get_num_partitions"):
num_partitions = pb.get_num_partitions.num_partitions
else:
raise SparkConnectException("No analyze result found!")

Expand All @@ -612,6 +617,7 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult":
semantic_hash,
storage_level,
ddl_string,
num_partitions,
)


Expand Down Expand Up @@ -1440,6 +1446,8 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
req.get_storage_level.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
elif method == "json_to_ddl":
req.json_to_ddl.json_string = cast(str, kwargs.get("json_string"))
elif method == "get_num_partitions":
req.get_num_partitions.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
else:
raise PySparkValueError(
errorClass="UNSUPPORTED_OPERATION",
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,14 @@ def inputFiles(self) -> List[str]:
assert result is not None
return result

def getNumPartitions(self) -> int:
query = self._plan.to_proto(self._session.client)
result = self._session.client._analyze(
method="get_num_partitions", plan=query
).num_partitions
assert result is not None
return result

def to(self, schema: StructType) -> ParentDataFrame:
assert schema is not None
res = DataFrame(
Expand Down
396 changes: 200 additions & 196 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

52 changes: 51 additions & 1 deletion python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,23 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["relation", b"relation"]
) -> None: ...

class GetNumPartitions(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

PLAN_FIELD_NUMBER: builtins.int
@property
def plan(self) -> global___Plan:
"""(Required) The logical plan to be analyzed."""
def __init__(
self,
*,
plan: global___Plan | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["plan", b"plan"]
) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["plan", b"plan"]) -> None: ...

class JsonToDDL(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down Expand Up @@ -594,6 +611,7 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
UNPERSIST_FIELD_NUMBER: builtins.int
GET_STORAGE_LEVEL_FIELD_NUMBER: builtins.int
JSON_TO_DDL_FIELD_NUMBER: builtins.int
GET_NUM_PARTITIONS_FIELD_NUMBER: builtins.int
session_id: builtins.str
"""(Required)

Expand Down Expand Up @@ -644,6 +662,8 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
def get_storage_level(self) -> global___AnalyzePlanRequest.GetStorageLevel: ...
@property
def json_to_ddl(self) -> global___AnalyzePlanRequest.JsonToDDL: ...
@property
def get_num_partitions(self) -> global___AnalyzePlanRequest.GetNumPartitions: ...
def __init__(
self,
*,
Expand All @@ -665,6 +685,7 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
unpersist: global___AnalyzePlanRequest.Unpersist | None = ...,
get_storage_level: global___AnalyzePlanRequest.GetStorageLevel | None = ...,
json_to_ddl: global___AnalyzePlanRequest.JsonToDDL | None = ...,
get_num_partitions: global___AnalyzePlanRequest.GetNumPartitions | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -683,6 +704,8 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
b"ddl_parse",
"explain",
b"explain",
"get_num_partitions",
b"get_num_partitions",
"get_storage_level",
b"get_storage_level",
"input_files",
Expand Down Expand Up @@ -728,6 +751,8 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
b"ddl_parse",
"explain",
b"explain",
"get_num_partitions",
b"get_num_partitions",
"get_storage_level",
b"get_storage_level",
"input_files",
Expand Down Expand Up @@ -788,6 +813,7 @@ class AnalyzePlanRequest(google.protobuf.message.Message):
"unpersist",
"get_storage_level",
"json_to_ddl",
"get_num_partitions",
]
| None
): ...
Expand All @@ -797,7 +823,7 @@ global___AnalyzePlanRequest = AnalyzePlanRequest
class AnalyzePlanResponse(google.protobuf.message.Message):
"""Response to performing analysis of the query. Contains relevant metadata to be able to
reason about the performance.
Next ID: 16
Next ID: 18
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor
Expand Down Expand Up @@ -985,6 +1011,21 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["storage_level", b"storage_level"]
) -> None: ...

class GetNumPartitions(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

NUM_PARTITIONS_FIELD_NUMBER: builtins.int
num_partitions: builtins.int
"""The number of partitions in the physical execution plan."""
def __init__(
self,
*,
num_partitions: builtins.int = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["num_partitions", b"num_partitions"]
) -> None: ...

class JsonToDDL(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down Expand Up @@ -1015,6 +1056,7 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
UNPERSIST_FIELD_NUMBER: builtins.int
GET_STORAGE_LEVEL_FIELD_NUMBER: builtins.int
JSON_TO_DDL_FIELD_NUMBER: builtins.int
GET_NUM_PARTITIONS_FIELD_NUMBER: builtins.int
session_id: builtins.str
server_side_session_id: builtins.str
"""Server-side generated idempotency key that the client can use to assert that the server side
Expand Down Expand Up @@ -1048,6 +1090,8 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
def get_storage_level(self) -> global___AnalyzePlanResponse.GetStorageLevel: ...
@property
def json_to_ddl(self) -> global___AnalyzePlanResponse.JsonToDDL: ...
@property
def get_num_partitions(self) -> global___AnalyzePlanResponse.GetNumPartitions: ...
def __init__(
self,
*,
Expand All @@ -1067,6 +1111,7 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
unpersist: global___AnalyzePlanResponse.Unpersist | None = ...,
get_storage_level: global___AnalyzePlanResponse.GetStorageLevel | None = ...,
json_to_ddl: global___AnalyzePlanResponse.JsonToDDL | None = ...,
get_num_partitions: global___AnalyzePlanResponse.GetNumPartitions | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -1075,6 +1120,8 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
b"ddl_parse",
"explain",
b"explain",
"get_num_partitions",
b"get_num_partitions",
"get_storage_level",
b"get_storage_level",
"input_files",
Expand Down Expand Up @@ -1110,6 +1157,8 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
b"ddl_parse",
"explain",
b"explain",
"get_num_partitions",
b"get_num_partitions",
"get_storage_level",
b"get_storage_level",
"input_files",
Expand Down Expand Up @@ -1160,6 +1209,7 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
"unpersist",
"get_storage_level",
"json_to_ddl",
"get_num_partitions",
]
| None
): ...
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,10 @@ def test_is_streaming(self):
self.assertFalse(self.connect.read.table(self.tbl_name).isStreaming)
self.assertFalse(self.connect.sql("SELECT 1 AS X LIMIT 0").isStreaming)

def test_get_num_partitions(self):
self.assertEqual(self.connect.range(10).repartition(4).getNumPartitions(), 4)
self.assertEqual(self.connect.range(10).coalesce(1).getNumPartitions(), 1)

def test_input_files(self):
# SPARK-41216: Test input files
tmpPath = tempfile.mkdtemp()
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/test_connect_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_dataframe_compatibility(self):
expected_missing_connect_properties = {"sql_ctx"}
expected_missing_classic_properties = {"is_cached"}
expected_missing_connect_methods = set()
expected_missing_classic_methods = set()
expected_missing_classic_methods = {"getNumPartitions"}
self.check_compatibility(
ClassicDataFrame,
ConnectDataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,8 @@ class ClientE2ETestSuite
assert(!df.isStreaming)
assert(df.toString.contains("[id: bigint]"))
assert(df.inputFiles.isEmpty)
assert(df.repartition(4).getNumPartitions === 4)
assert(df.coalesce(1).getNumPartitions === 1)
}

test("Dataset schema") {
Expand Down
14 changes: 13 additions & 1 deletion sql/connect/common/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ message AnalyzePlanRequest {
Unpersist unpersist = 15;
GetStorageLevel get_storage_level = 16;
JsonToDDL json_to_ddl = 18;
GetNumPartitions get_num_partitions = 19;
}

message Schema {
Expand Down Expand Up @@ -221,6 +222,11 @@ message AnalyzePlanRequest {
Relation relation = 1;
}

message GetNumPartitions {
// (Required) The logical plan to be analyzed.
Plan plan = 1;
}

message JsonToDDL {
// (Required) The JSON formatted string to be converted to DDL.
string json_string = 1;
Expand All @@ -229,7 +235,7 @@ message AnalyzePlanRequest {

// Response to performing analysis of the query. Contains relevant metadata to be able to
// reason about the performance.
// Next ID: 16
// Next ID: 18
message AnalyzePlanResponse {
string session_id = 1;
// Server-side generated idempotency key that the client can use to assert that the server side
Expand All @@ -251,6 +257,7 @@ message AnalyzePlanResponse {
Unpersist unpersist = 13;
GetStorageLevel get_storage_level = 14;
JsonToDDL json_to_ddl = 16;
GetNumPartitions get_num_partitions = 17;
}

message Schema {
Expand Down Expand Up @@ -303,6 +310,11 @@ message AnalyzePlanResponse {
StorageLevel storage_level = 1;
}

message GetNumPartitions {
// The number of partitions in the physical execution plan.
int32 num_partitions = 1;
}

message JsonToDDL {
string ddl_string = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,17 @@ class Dataset[T] private[sql] (
.getIsStreaming
.getIsStreaming

/**
* Returns the number of partitions of this Dataset.
*
* @group basic
* @since 4.2.0
*/
def getNumPartitions: Int = sparkSession
.analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.GET_NUM_PARTITIONS)
.getGetNumPartitions
.getNumPartitions

/** @inheritdoc */
// scalastyle:off println
def show(numRows: Int, truncate: Boolean): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,13 @@ private[sql] class SparkConnectClient(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION =>
builder.setSparkVersion(proto.AnalyzePlanRequest.SparkVersion.newBuilder().build())
case proto.AnalyzePlanRequest.AnalyzeCase.GET_NUM_PARTITIONS =>
assert(maybeCompressedPlan.isDefined)
builder.setGetNumPartitions(
proto.AnalyzePlanRequest.GetNumPartitions
.newBuilder()
.setPlan(maybeCompressedPlan.get)
.build())
case other => throw new IllegalArgumentException(s"Unknown Analyze request $other")
}
analyze(builder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,16 @@ class RequestDecompressionInterceptor extends ServerInterceptor with Logging {
.build())
(req, Seq(size))

case proto.AnalyzePlanRequest.AnalyzeCase.GET_NUM_PARTITIONS =>
val (req, size) = decompress(
request,
request.getGetNumPartitions.getPlan,
p =>
request.toBuilder
.setGetNumPartitions(request.getGetNumPartitions.toBuilder.setPlan(p))
.build())
(req, Seq(size))

case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
val (req, size) = decompress(
request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ private[connect] class SparkConnectAnalyzeHandler(
.setDdlString(ddl)
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.GET_NUM_PARTITIONS =>
val rel = transformRelationPlan(request.getGetNumPartitions.getPlan)
val numPartitions =
getDataFrameWithoutExecuting(rel).queryExecution.executedPlan.execute().getNumPartitions
builder.setGetNumPartitions(
proto.AnalyzePlanResponse.GetNumPartitions
.newBuilder()
.setNumPartitions(numPartitions)
.build())

// NOTE: When adding a new AnalyzePlanRequest case here, also update
// RequestDecompressionInterceptor.decompressAnalyzePlanRequest() to handle
// this case. The interceptor has a default case that throws UnsupportedOperationException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,29 @@ class SparkConnectServiceSuite
val response6 = handler.process(request6, sparkSessionHolder)
assert(response6.hasInputFiles)
assert(response6.getInputFiles.getFilesCount === 0)

val repartitionPlan = proto.Plan
.newBuilder()
.setRoot(
proto.Relation
.newBuilder()
.setRepartition(
proto.Repartition
.newBuilder()
.setInput(plan.getRoot)
.setNumPartitions(4)
.setShuffle(true)
.build())
.build())
.build()
val request7 = proto.AnalyzePlanRequest
.newBuilder()
.setGetNumPartitions(
proto.AnalyzePlanRequest.GetNumPartitions.newBuilder().setPlan(repartitionPlan).build())
.build()
val response7 = handler.process(request7, sparkSessionHolder)
assert(response7.hasGetNumPartitions)
assert(response7.getGetNumPartitions.getNumPartitions === 4)
}
}

Expand Down