Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 119 additions & 117 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

40 changes: 39 additions & 1 deletion python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1094,24 +1094,33 @@ class ExecutePlanRequest(google.protobuf.message.Message):

REATTACH_OPTIONS_FIELD_NUMBER: builtins.int
RESULT_CHUNKING_OPTIONS_FIELD_NUMBER: builtins.int
ACCEPT_RESPONSE_OPTIONS_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def reattach_options(self) -> global___ReattachOptions: ...
@property
def result_chunking_options(self) -> global___ResultChunkingOptions: ...
@property
def accept_response_options(self) -> global___AcceptResponseOptions:
"""Options to describe what responses (e.g. using a new field in the response)
can be accepted.
"""
@property
def extension(self) -> google.protobuf.any_pb2.Any:
"""Extension type for request options"""
def __init__(
self,
*,
reattach_options: global___ReattachOptions | None = ...,
result_chunking_options: global___ResultChunkingOptions | None = ...,
accept_response_options: global___AcceptResponseOptions | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"accept_response_options",
b"accept_response_options",
"extension",
b"extension",
"reattach_options",
Expand All @@ -1125,6 +1134,8 @@ class ExecutePlanRequest(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
"accept_response_options",
b"accept_response_options",
"extension",
b"extension",
"reattach_options",
Expand All @@ -1138,7 +1149,12 @@ class ExecutePlanRequest(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["request_option", b"request_option"]
) -> (
typing_extensions.Literal["reattach_options", "result_chunking_options", "extension"]
typing_extensions.Literal[
"reattach_options",
"result_chunking_options",
"accept_response_options",
"extension",
]
| None
): ...

Expand Down Expand Up @@ -3049,6 +3065,28 @@ class ResultChunkingOptions(google.protobuf.message.Message):

global___ResultChunkingOptions = ResultChunkingOptions

class AcceptResponseOptions(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

ACCEPT_LITERAL_DATA_TYPE_FIELD_FIELD_NUMBER: builtins.int
accept_literal_data_type_field: builtins.bool
"""When true, the client indicates it can handle Literal messages in responses
that include the data_type field.
"""
def __init__(
self,
*,
accept_literal_data_type_field: builtins.bool = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"accept_literal_data_type_field", b"accept_literal_data_type_field"
],
) -> None: ...

global___AcceptResponseOptions = AcceptResponseOptions

class ReattachExecuteRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down
9 changes: 9 additions & 0 deletions sql/connect/common/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@ message ExecutePlanRequest {
oneof request_option {
ReattachOptions reattach_options = 1;
ResultChunkingOptions result_chunking_options = 2;
// Options to describe what responses (e.g. using a new field in the response)
// can be accepted.
AcceptResponseOptions accept_response_options = 3;
// Extension type for request options
google.protobuf.Any extension = 999;
}
Expand Down Expand Up @@ -846,6 +849,12 @@ message ResultChunkingOptions {
optional int64 preferred_arrow_chunk_size = 2;
}

message AcceptResponseOptions {
// When true, the client indicates it can handle Literal messages in responses
// that include the data_type field.
bool accept_literal_data_type_field = 1;
}

message ReattachExecuteRequest {
// (Required)
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.addAllTags(tags.get.toSeq.asJava)
.addRequestOptions(SparkConnectClient.ACCEPT_RESPONSE_OPTIONS)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
operationId.foreach { opId =>
require(
Expand Down Expand Up @@ -425,6 +426,12 @@ object SparkConnectClient {
private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] =
Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER)

private val ACCEPT_RESPONSE_OPTIONS = proto.ExecutePlanRequest.RequestOption
.newBuilder()
.setAcceptResponseOptions(
proto.AcceptResponseOptions.newBuilder().setAcceptLiteralDataTypeField(true).build())
.build()

// for internal tests
private[sql] def apply(channel: ManagedChannel): SparkConnectClient = {
new SparkConnectClient(Configuration(), channel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
executeHolder.sessionHolder.sessionId,
executeHolder.sessionHolder.serverSessionId,
executeHolder.allObservationAndPlanIds,
observedMetrics ++ accumulatedInPython))
observedMetrics ++ accumulatedInPython,
executeHolder.acceptLiteralDataTypeFieldInResponses))
}

// State transition should be atomic to prevent a situation in which a client of reattachable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.{toLiteralProtoWithOptions, ToLiteralProtoOptions}
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_ARROW_MAX_BATCH_SIZE, CONNECT_SESSION_RESULT_CHUNKING_MAX_CHUNK_SIZE}
import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner}
import org.apache.spark.sql.connect.service.ExecuteHolder
Expand Down Expand Up @@ -331,7 +331,8 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
sessionId,
sessionHolder.serverSessionId,
observationAndPlanIds,
observedMetrics))
observedMetrics,
executeHolder.acceptLiteralDataTypeFieldInResponses))
} else None
}
}
Expand All @@ -352,17 +353,21 @@ object SparkConnectPlanExecution {
sessionId: String,
serverSessionId: String,
observationAndPlanIds: Map[String, Long],
metrics: Map[String, Seq[(Option[String], Any, Option[DataType])]]): ExecutePlanResponse = {
metrics: Map[String, Seq[(Option[String], Any, Option[DataType])]],
acceptLiteralDataTypeFieldInResponses: Boolean): ExecutePlanResponse = {
val toLiteralProtoOptions =
ToLiteralProtoOptions(useDeprecatedDataTypeFields = !acceptLiteralDataTypeFieldInResponses)
val observedMetrics = metrics.map { case (name, values) =>
val metrics = ExecutePlanResponse.ObservedMetrics
.newBuilder()
.setName(name)
values.foreach { case (keyOpt, value, dataTypeOpt) =>
dataTypeOpt match {
case Some(dataType) =>
metrics.addValues(toLiteralProto(value, dataType))
metrics.addValues(
toLiteralProtoWithOptions(value, Some(dataType), toLiteralProtoOptions))
case None =>
metrics.addValues(toLiteralProto(value))
metrics.addValues(toLiteralProtoWithOptions(value, None, toLiteralProtoOptions))
}
keyOpt.foreach(metrics.addKeys)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ private[connect] class ExecuteHolder(
}
}

/**
* If the client can handle Literal messages in responses that include the data_type field.
*/
lazy val acceptLiteralDataTypeFieldInResponses: Boolean = {
request.getRequestOptionsList.asScala.exists { option =>
option.getAcceptResponseOptions.getAcceptLiteralDataTypeField
}
}

val responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse] =
new ExecuteResponseObserver[proto.ExecutePlanResponse](this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,33 +135,30 @@ package object dsl {
.build()
}

def proto_min(e: Expression): Expression =
private def unresolvedFunction(functionName: String, e: Expression): Expression =
Expression
.newBuilder()
.setUnresolvedFunction(
Expression.UnresolvedFunction.newBuilder().setFunctionName("min").addArguments(e))
Expression.UnresolvedFunction
.newBuilder()
.setFunctionName(functionName)
.addArguments(e))
.build()

def proto_struct(e: Expression): Expression =
unresolvedFunction("struct", e)

def proto_min(e: Expression): Expression =
unresolvedFunction("min", e)

def proto_max(e: Expression): Expression =
Expression
.newBuilder()
.setUnresolvedFunction(
Expression.UnresolvedFunction.newBuilder().setFunctionName("max").addArguments(e))
.build()
unresolvedFunction("max", e)

def proto_sum(e: Expression): Expression =
Expression
.newBuilder()
.setUnresolvedFunction(
Expression.UnresolvedFunction.newBuilder().setFunctionName("sum").addArguments(e))
.build()
unresolvedFunction("sum", e)

def proto_explode(e: Expression): Expression =
Expression
.newBuilder()
.setUnresolvedFunction(
Expression.UnresolvedFunction.newBuilder().setFunctionName("explode").addArguments(e))
.build()
unresolvedFunction("explode", e)

/**
* Create an unresolved function from name parts.
Expand Down
Loading