[SPARK-53490][CONNECT][SQL] Fix Protobuf conversion in observed metrics
### What changes were proposed in this pull request?
This PR fixes a critical issue in the protobuf conversion of observed metrics in Spark Connect, specifically when dealing with complex data types like structs, arrays, and maps. The main changes include:
1. **Modified Observation class to store Row objects instead of Map[String, Any]**: Changed the internal promise type from `Promise[Map[String, Any]]` to `Promise[Row]` to preserve type information during protobuf serialization/deserialization.
2. **Enhanced protobuf conversion for complex types**:
- Added proper handling for struct types by creating `GenericRowWithSchema` objects instead of tuples
- Added support for map type conversion in `LiteralValueProtoConverter`
- Improved data type inference with a new `getDataType` method that properly handles all literal types
3. **Fixed observed metrics**: Modified the observed metrics processing to include data type information in the protobuf conversion, ensuring that complex types are properly serialized and deserialized.
### Why are the changes needed?
The previous implementation had several issues:
1. **Data type loss**: Observed metrics were losing their original data types during Protobuf conversion, causing errors
2. **Struct handling problems**: The conversion logic didn't properly handle Row objects and struct types
### Does this PR introduce _any_ user-facing change?
Yes, this PR fixes a bug that was preventing users from successfully using observed metrics with complex data types (structs, arrays, maps) in Spark Connect. Users can now:
- Use `struct()` expressions in observed metrics and receive properly typed `GenericRowWithSchema` objects
- Use `array()` expressions in observed metrics and receive properly typed arrays
- Use `map()` expressions in observed metrics and receive properly typed maps
Previously, the code below would fail.
```scala
val observation = Observation("struct")
spark
.range(10)
.observe(observation, struct(count(lit(1)).as("rows"), max("id").as("maxid")).as("struct"))
.collect()
observation.get
// Below is the error message:
"""
org.apache.spark.SparkUnsupportedOperationException: literal [10,9] not supported (yet).
org.apache.spark.sql.connect.common.LiteralValueProtoConverter$.toLiteralProtoBuilder(LiteralValueProtoConverter.scala:104)
org.apache.spark.sql.connect.common.LiteralValueProtoConverter$.toLiteralProto(LiteralValueProtoConverter.scala:203)
org.apache.spark.sql.connect.execution.SparkConnectPlanExecution$.$anonfun$createObservedMetricsResponse$2(SparkConnectPlanExecution.scala:571)
org.apache.spark.sql.connect.execution.SparkConnectPlanExecution$.$anonfun$createObservedMetricsResponse$2$adapted(SparkConnectPlanExecution.scala:570)
"""
```
### How was this patch tested?
`build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite -- -z SPARK-53490"`
`build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"`
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Cursor 1.5.9
Closes #52236 from heyihong/SPARK-53490.
Authored-by: Yihong He <heyihong.cn@gmail.com>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
index fa427fe..5e27198 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -58,12 +58,12 @@
private val isRegistered = new AtomicBoolean()
- private val promise = Promise[Map[String, Any]]()
+ private val promise = Promise[Row]()
/**
* Future holding the (yet to be completed) observation.
*/
- val future: Future[Map[String, Any]] = promise.future
+ val future: Future[Row] = promise.future
/**
* (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish its
@@ -76,7 +76,10 @@
* interrupted while waiting
*/
@throws[InterruptedException]
- def get: Map[String, Any] = SparkThreadUtils.awaitResult(future, Duration.Inf)
+ def get: Map[String, Any] = {
+ val row = SparkThreadUtils.awaitResult(future, Duration.Inf)
+ row.getValuesMap(row.schema.map(_.name))
+ }
/**
* (Java-specific) Get the observed metrics. This waits for the observed dataset to finish its
@@ -99,7 +102,8 @@
*/
@throws[InterruptedException]
private[sql] def getOrEmpty: Map[String, Any] = {
- Try(SparkThreadUtils.awaitResult(future, 100.millis)).getOrElse(Map.empty)
+ val row = getRowOrEmpty.getOrElse(Row.empty)
+ row.getValuesMap(row.schema.map(_.name))
}
/**
@@ -118,8 +122,17 @@
* `true` if all waiting threads were notified, `false` if otherwise.
*/
private[sql] def setMetricsAndNotify(metrics: Row): Boolean = {
- val metricsMap = metrics.getValuesMap(metrics.schema.map(_.name))
- promise.trySuccess(metricsMap)
+ promise.trySuccess(metrics)
+ }
+
+ /**
+ * Get the observed metrics as a Row.
+ *
+ * @return
+ * the observed metrics as a `Row`, or None if the metrics are not available.
+ */
+ private[sql] def getRowOrEmpty: Option[Row] = {
+ Try(SparkThreadUtils.awaitResult(future, 100.millis)).toOption
}
}
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index 2f21665..e221300 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -1749,6 +1749,42 @@
val nullRows = nullResult.filter(_.getAs[Long]("id") >= 5)
assert(nullRows.forall(_.getAs[Int]("actual_p_id") == 0))
}
+
+ test("SPARK-53490: struct type in observed metrics") {
+ val observation = Observation("struct")
+ spark
+ .range(10)
+ .observe(observation, struct(count(lit(1)).as("rows"), max("id").as("maxid")).as("struct"))
+ .collect()
+ val expectedSchema =
+ StructType(Seq(StructField("rows", LongType), StructField("maxid", LongType)))
+ val expectedValue = new GenericRowWithSchema(Array(10, 9), expectedSchema)
+ assert(observation.get.size === 1)
+ assert(observation.get.contains("struct"))
+ assert(observation.get("struct") === expectedValue)
+ }
+
+ test("SPARK-53490: array type in observed metrics") {
+ val observation = Observation("array")
+ spark
+ .range(10)
+ .observe(observation, array(count(lit(1))).as("array"))
+ .collect()
+ assert(observation.get.size === 1)
+ assert(observation.get.contains("array"))
+ assert(observation.get("array") === Array(10))
+ }
+
+ test("SPARK-53490: map type in observed metrics") {
+ val observation = Observation("map")
+ spark
+ .range(10)
+ .observe(observation, map(lit("count"), count(lit(1))).as("map"))
+ .collect()
+ assert(observation.get.size === 1)
+ assert(observation.get.contains("map"))
+ assert(observation.get("map") === Map("count" -> 10))
+ }
}
private[sql] case class ClassData(a: String, b: Int)
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index a7a8c97..ef55edd 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -383,7 +383,7 @@
(0 until metric.getKeysCount).foreach { i =>
val key = metric.getKeys(i)
val value = LiteralValueProtoConverter.toScalaValue(metric.getValues(i))
- schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
+ schema = schema.add(key, LiteralValueProtoConverter.getDataType(metric.getValues(i)))
values += value
}
new GenericRowWithSchema(values.result(), schema)
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index 4ff555c..5e45fbb 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -109,7 +109,7 @@
ArrayType(toCatalystType(t.getElementType), t.getContainsNull)
}
- private def toCatalystStructType(t: proto.DataType.Struct): StructType = {
+ private[common] def toCatalystStructType(t: proto.DataType.Struct): StructType = {
val fields = t.getFieldsList.asScala.toSeq.map { protoField =>
val metadata = if (protoField.hasMetadata) {
Metadata.fromJson(protoField.getMetadata)
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
index ca785ff..0edb865 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
@@ -30,12 +30,13 @@
import com.google.protobuf.ByteString
import org.apache.spark.connect.proto
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
-import org.apache.spark.util.SparkClassUtils
object LiteralValueProtoConverter {
@@ -223,52 +224,51 @@
val sb = builder.getStructBuilder
val fields = structType.fields
- scalaValue match {
+ val iter = scalaValue match {
case p: Product =>
- val iter = p.productIterator
- var idx = 0
- if (options.useDeprecatedDataTypeFields) {
- while (idx < structType.size) {
- val field = fields(idx)
- // For backward compatibility, we need the data type for each field.
- val literalProto = toLiteralProtoBuilderInternal(
- iter.next(),
- field.dataType,
- options,
- needDataType = true).build()
- sb.addElements(literalProto)
- idx += 1
- }
- sb.setStructType(toConnectProtoType(structType))
- } else {
- while (idx < structType.size) {
- val field = fields(idx)
- val literalProto =
- toLiteralProtoBuilderInternal(iter.next(), field.dataType, options, needDataType)
- .build()
- sb.addElements(literalProto)
-
- if (needDataType) {
- val fieldBuilder = sb.getDataTypeStructBuilder
- .addFieldsBuilder()
- .setName(field.name)
- .setNullable(field.nullable)
-
- if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
- fieldBuilder.setDataType(toConnectProtoType(field.dataType))
- }
-
- // Set metadata if available
- if (field.metadata != Metadata.empty) {
- fieldBuilder.setMetadata(field.metadata.json)
- }
- }
-
- idx += 1
- }
- }
+ p.productIterator
+ case r: Row =>
+ r.toSeq.iterator
case other =>
- throw new IllegalArgumentException(s"literal $other not supported (yet).")
+ throw new IllegalArgumentException(
+ s"literal ${other.getClass.getName}($other) not supported (yet).")
+ }
+
+ var idx = 0
+ if (options.useDeprecatedDataTypeFields) {
+ while (idx < structType.size) {
+ val field = fields(idx)
+ val literalProto =
+ toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
+ sb.addElements(literalProto)
+ idx += 1
+ }
+ sb.setStructType(toConnectProtoType(structType))
+ } else {
+ val dataTypeStruct = proto.DataType.Struct.newBuilder()
+ while (idx < structType.size) {
+ val field = fields(idx)
+ val literalProto =
+ toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
+ sb.addElements(literalProto)
+
+ val fieldBuilder = dataTypeStruct
+ .addFieldsBuilder()
+ .setName(field.name)
+ .setNullable(field.nullable)
+
+ if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
+ fieldBuilder.setDataType(toConnectProtoType(field.dataType))
+ }
+
+ // Set metadata if available
+ if (field.metadata != Metadata.empty) {
+ fieldBuilder.setMetadata(field.metadata.json)
+ }
+
+ idx += 1
+ }
+ sb.setDataTypeStruct(dataTypeStruct.build())
}
sb
@@ -721,23 +721,12 @@
private def toScalaStructInternal(
struct: proto.Expression.Literal.Struct,
structType: proto.DataType.Struct): Any = {
- def toTuple[A <: Object](data: Seq[A]): Product = {
- try {
- val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
- tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product]
- } catch {
- case _: Exception =>
- throw InvalidPlanInput(s"Unsupported Literal: ${data.mkString("Array(", ", ", ")")})")
- }
- }
-
- val size = struct.getElementsCount
- val structData = Seq.tabulate(size) { i =>
+ val structData = Array.tabulate(struct.getElementsCount) { i =>
val element = struct.getElements(i)
val dataType = structType.getFields(i).getDataType
- getConverter(dataType)(element).asInstanceOf[Object]
+ getConverter(dataType)(element)
}
- toTuple(structData)
+ new GenericRowWithSchema(structData, DataTypeProtoConverter.toCatalystStructType(structType))
}
def getProtoStructType(struct: proto.Expression.Literal.Struct): proto.DataType.Struct = {
@@ -759,4 +748,77 @@
def toScalaStruct(struct: proto.Expression.Literal.Struct): Any = {
toScalaStructInternal(struct, getProtoStructType(struct))
}
+
+ def getDataType(lit: proto.Expression.Literal): DataType = {
+ lit.getLiteralTypeCase match {
+ case proto.Expression.Literal.LiteralTypeCase.NULL =>
+ DataTypeProtoConverter.toCatalystType(lit.getNull)
+ case proto.Expression.Literal.LiteralTypeCase.BINARY =>
+ BinaryType
+ case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
+ BooleanType
+ case proto.Expression.Literal.LiteralTypeCase.BYTE =>
+ ByteType
+ case proto.Expression.Literal.LiteralTypeCase.SHORT =>
+ ShortType
+ case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
+ IntegerType
+ case proto.Expression.Literal.LiteralTypeCase.LONG =>
+ LongType
+ case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
+ FloatType
+ case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
+ DoubleType
+ case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
+ val decimal = Decimal.apply(lit.getDecimal.getValue)
+ var precision = decimal.precision
+ if (lit.getDecimal.hasPrecision) {
+ precision = math.max(precision, lit.getDecimal.getPrecision)
+ }
+ var scale = decimal.scale
+ if (lit.getDecimal.hasScale) {
+ scale = math.max(scale, lit.getDecimal.getScale)
+ }
+ DecimalType(math.max(precision, scale), scale)
+ case proto.Expression.Literal.LiteralTypeCase.STRING =>
+ StringType
+ case proto.Expression.Literal.LiteralTypeCase.DATE =>
+ DateType
+ case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
+ TimestampType
+ case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
+ TimestampNTZType
+ case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
+ CalendarIntervalType
+ case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
+ YearMonthIntervalType()
+ case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
+ DayTimeIntervalType()
+ case proto.Expression.Literal.LiteralTypeCase.TIME =>
+ var precision = TimeType.DEFAULT_PRECISION
+ if (lit.getTime.hasPrecision) {
+ precision = lit.getTime.getPrecision
+ }
+ TimeType(precision)
+ case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+ DataTypeProtoConverter.toCatalystType(
+ proto.DataType.newBuilder
+ .setArray(LiteralValueProtoConverter.getProtoArrayType(lit.getArray))
+ .build())
+ case proto.Expression.Literal.LiteralTypeCase.MAP =>
+ DataTypeProtoConverter.toCatalystType(
+ proto.DataType.newBuilder
+ .setMap(LiteralValueProtoConverter.getProtoMapType(lit.getMap))
+ .build())
+ case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
+ DataTypeProtoConverter.toCatalystType(
+ proto.DataType.newBuilder
+ .setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct))
+ .build())
+ case _ =>
+ throw InvalidPlanInput(
+ s"Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
+ s"(${lit.getLiteralTypeCase.getNumber})")
+ }
+ }
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 7c4ad7d..38ed252 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -31,6 +31,7 @@
import org.apache.spark.sql.connect.planner.InvalidInputErrors
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService}
import org.apache.spark.sql.connect.utils.ErrorUtils
+import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils
/**
@@ -227,21 +228,22 @@
executeHolder.request.getPlan.getDescriptorForType)
}
- val observedMetrics: Map[String, Seq[(Option[String], Any)]] = {
+ val observedMetrics: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
executeHolder.observations.map { case (name, observation) =>
- val values = observation.getOrEmpty.map { case (key, value) =>
- (Some(key), value)
- }.toSeq
+ val values =
+ observation.getRowOrEmpty
+ .map(SparkConnectPlanExecution.toObservedMetricsValues(_))
+ .getOrElse(Seq.empty)
name -> values
}.toMap
}
- val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = {
+ val accumulatedInPython: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator =>
accumulator.synchronized {
val value = accumulator.value.asScala.toSeq
if (value.nonEmpty) {
accumulator.reset()
- Some("__python_accumulator__" -> value.map(value => (None, value)))
+ Some("__python_accumulator__" -> value.map(value => (None, value, None)))
} else {
None
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 388fd6a..0e89681 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -27,6 +27,7 @@
import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
@@ -38,7 +39,7 @@
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ThreadUtils
/**
@@ -281,11 +282,7 @@
dataframe: DataFrame): Option[ExecutePlanResponse] = {
val observedMetrics = dataframe.queryExecution.observedMetrics.collect {
case (name, row) if !executeHolder.observations.contains(name) =>
- val values = if (row.schema == null) {
- (0 until row.length).map { i => (None, row(i)) }
- } else {
- (0 until row.length).map { i => (Some(row.schema.fieldNames(i)), row(i)) }
- }
+ val values = SparkConnectPlanExecution.toObservedMetricsValues(row)
name -> values
}
if (observedMetrics.nonEmpty) {
@@ -301,18 +298,34 @@
}
object SparkConnectPlanExecution {
+
+ def toObservedMetricsValues(row: Row): Seq[(Option[String], Any, Option[DataType])] = {
+ if (row.schema == null) {
+ (0 until row.length).map { i => (None, row(i), None) }
+ } else {
+ (0 until row.length).map { i =>
+ (Some(row.schema.fieldNames(i)), row(i), Some(row.schema(i).dataType))
+ }
+ }
+ }
+
def createObservedMetricsResponse(
sessionId: String,
serverSessionId: String,
observationAndPlanIds: Map[String, Long],
- metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = {
+ metrics: Map[String, Seq[(Option[String], Any, Option[DataType])]]): ExecutePlanResponse = {
val observedMetrics = metrics.map { case (name, values) =>
val metrics = ExecutePlanResponse.ObservedMetrics
.newBuilder()
.setName(name)
- values.foreach { case (key, value) =>
- metrics.addValues(toLiteralProto(value))
- key.foreach(metrics.addKeys)
+ values.foreach { case (keyOpt, value, dataTypeOpt) =>
+ dataTypeOpt match {
+ case Some(dataType) =>
+ metrics.addValues(toLiteralProto(value, dataType))
+ case None =>
+ metrics.addValues(toLiteralProto(value))
+ }
+ keyOpt.foreach(metrics.addKeys)
}
observationAndPlanIds.get(name).foreach(metrics.setPlanId)
metrics.build()
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
index 4b308e0..be7d672 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
@@ -19,7 +19,7 @@
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
-import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter}
+import org.apache.spark.sql.connect.common.{InvalidPlanInput, LiteralValueProtoConverter}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -32,107 +32,87 @@
* Expression
*/
def toCatalystExpression(lit: proto.Expression.Literal): expressions.Literal = {
+ val dataType = LiteralValueProtoConverter.getDataType(lit)
lit.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.NULL =>
- expressions.Literal(null, DataTypeProtoConverter.toCatalystType(lit.getNull))
+ expressions.Literal(null, dataType)
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
- expressions.Literal(lit.getBinary.toByteArray, BinaryType)
+ expressions.Literal(lit.getBinary.toByteArray, dataType)
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
- expressions.Literal(lit.getBoolean, BooleanType)
+ expressions.Literal(lit.getBoolean, dataType)
case proto.Expression.Literal.LiteralTypeCase.BYTE =>
- expressions.Literal(lit.getByte.toByte, ByteType)
+ expressions.Literal(lit.getByte.toByte, dataType)
case proto.Expression.Literal.LiteralTypeCase.SHORT =>
- expressions.Literal(lit.getShort.toShort, ShortType)
+ expressions.Literal(lit.getShort.toShort, dataType)
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
- expressions.Literal(lit.getInteger, IntegerType)
+ expressions.Literal(lit.getInteger, dataType)
case proto.Expression.Literal.LiteralTypeCase.LONG =>
- expressions.Literal(lit.getLong, LongType)
+ expressions.Literal(lit.getLong, dataType)
case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
- expressions.Literal(lit.getFloat, FloatType)
+ expressions.Literal(lit.getFloat, dataType)
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
- expressions.Literal(lit.getDouble, DoubleType)
+ expressions.Literal(lit.getDouble, dataType)
case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
- val decimal = Decimal.apply(lit.getDecimal.getValue)
- var precision = decimal.precision
- if (lit.getDecimal.hasPrecision) {
- precision = math.max(precision, lit.getDecimal.getPrecision)
- }
- var scale = decimal.scale
- if (lit.getDecimal.hasScale) {
- scale = math.max(scale, lit.getDecimal.getScale)
- }
- expressions.Literal(decimal, DecimalType(math.max(precision, scale), scale))
+ expressions.Literal(Decimal.apply(lit.getDecimal.getValue), dataType)
case proto.Expression.Literal.LiteralTypeCase.STRING =>
- expressions.Literal(UTF8String.fromString(lit.getString), StringType)
+ expressions.Literal(UTF8String.fromString(lit.getString), dataType)
case proto.Expression.Literal.LiteralTypeCase.DATE =>
- expressions.Literal(lit.getDate, DateType)
+ expressions.Literal(lit.getDate, dataType)
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
- expressions.Literal(lit.getTimestamp, TimestampType)
+ expressions.Literal(lit.getTimestamp, dataType)
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
- expressions.Literal(lit.getTimestampNtz, TimestampNTZType)
+ expressions.Literal(lit.getTimestampNtz, dataType)
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
val interval = new CalendarInterval(
lit.getCalendarInterval.getMonths,
lit.getCalendarInterval.getDays,
lit.getCalendarInterval.getMicroseconds)
- expressions.Literal(interval, CalendarIntervalType)
+ expressions.Literal(interval, dataType)
case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
- expressions.Literal(lit.getYearMonthInterval, YearMonthIntervalType())
+ expressions.Literal(lit.getYearMonthInterval, dataType)
case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
- expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
+ expressions.Literal(lit.getDayTimeInterval, dataType)
case proto.Expression.Literal.LiteralTypeCase.TIME =>
var precision = TimeType.DEFAULT_PRECISION
if (lit.getTime.hasPrecision) {
precision = lit.getTime.getPrecision
}
- expressions.Literal(lit.getTime.getNano, TimeType(precision))
+ expressions.Literal(lit.getTime.getNano, dataType)
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
val arrayData = LiteralValueProtoConverter.toScalaArray(lit.getArray)
- val dataType = DataTypeProtoConverter.toCatalystType(
- proto.DataType.newBuilder
- .setArray(LiteralValueProtoConverter.getProtoArrayType(lit.getArray))
- .build())
expressions.Literal.create(arrayData, dataType)
case proto.Expression.Literal.LiteralTypeCase.MAP =>
val mapData = LiteralValueProtoConverter.toScalaMap(lit.getMap)
- val dataType = DataTypeProtoConverter.toCatalystType(
- proto.DataType.newBuilder
- .setMap(LiteralValueProtoConverter.getProtoMapType(lit.getMap))
- .build())
expressions.Literal.create(mapData, dataType)
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
val structData = LiteralValueProtoConverter.toScalaStruct(lit.getStruct)
- val dataType = DataTypeProtoConverter.toCatalystType(
- proto.DataType.newBuilder
- .setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct))
- .build())
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
expressions.Literal(convert(structData), dataType)
case _ =>
throw InvalidPlanInput(
- s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
- s"(${lit.getLiteralTypeCase.name})")
+ s"Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
+ s"(${lit.getLiteralTypeCase.getNumber})")
}
}
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
index dfde32c..80c185e 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
@@ -20,6 +20,8 @@
import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.ToLiteralProtoOptions
@@ -51,7 +53,7 @@
}
}
- Seq(
+ Seq[(Any, DataType)](
(
(1, "string", true),
StructType(
@@ -76,7 +78,6 @@
"b",
StructType(Seq(StructField("c", IntegerType), StructField("d", IntegerType))))))),
(Array(true, false, true), ArrayType(BooleanType)),
- (Array(1.toByte, 2.toByte, 3.toByte), ArrayType(ByteType)),
(Array(1.toShort, 2.toShort, 3.toShort), ArrayType(ShortType)),
(Array(1, 2, 3), ArrayType(IntegerType)),
(Array(1L, 2L, 3L), ArrayType(LongType)),
@@ -87,15 +88,16 @@
(
Array(Array(Array(Array(Array(Array(1, 2, 3)))))),
ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType))))))),
- (Array(Map(1 -> 2)), ArrayType(MapType(IntegerType, IntegerType))),
(Map[String, String]("1" -> "2", "3" -> "4"), MapType(StringType, StringType)),
(Map[String, Boolean]("1" -> true, "2" -> false), MapType(StringType, BooleanType)),
(Map[Int, Int](), MapType(IntegerType, IntegerType)),
(Map(1 -> 2, 3 -> 4, 5 -> 6), MapType(IntegerType, IntegerType))).zipWithIndex.foreach {
case ((v, t), idx) =>
+ val convert = CatalystTypeConverters.createToCatalystConverter(t)
+ val expected = expressions.Literal(convert(v), t)
test(s"complex proto value and catalyst value conversion #$idx") {
- assertResult(v)(
- LiteralValueProtoConverter.toScalaValue(
+ assertResult(expected)(
+ LiteralExpressionProtoConverter.toCatalystExpression(
LiteralValueProtoConverter.toLiteralProtoWithOptions(
v,
Some(t),
@@ -103,8 +105,8 @@
}
test(s"complex proto value and catalyst value conversion #$idx - backward compatibility") {
- assertResult(v)(
- LiteralValueProtoConverter.toScalaValue(
+ assertResult(expected)(
+ LiteralExpressionProtoConverter.toCatalystExpression(
LiteralValueProtoConverter.toLiteralProtoWithOptions(
v,
Some(t),
@@ -186,12 +188,12 @@
val result = LiteralValueProtoConverter.toScalaStruct(structProto.getStruct)
val resultType = LiteralValueProtoConverter.getProtoStructType(structProto.getStruct)
- // Verify the result is a tuple with correct values
- assert(result.isInstanceOf[Product])
- val product = result.asInstanceOf[Product]
- assert(product.productArity == 2)
- assert(product.productElement(0) == 1)
- assert(product.productElement(1) == "test")
+ // Verify the result is a GenericRowWithSchema with correct values
+ assert(result.isInstanceOf[GenericRowWithSchema])
+ val row = result.asInstanceOf[GenericRowWithSchema]
+ assert(row.length == 2)
+ assert(row.get(0) == 1)
+ assert(row.get(1) == "test")
// Verify the returned struct type matches the original
assert(resultType.getFieldsCount == 2)