[SPARK-53490][CONNECT][SQL] Fix Protobuf conversion in observed metrics

### What changes were proposed in this pull request?

This PR fixes a critical issue in the protobuf conversion of observed metrics in Spark Connect, specifically when dealing with complex data types like structs, arrays, and maps. The main changes include:
1. **Modified Observation class to store Row objects instead of Map[String, Any]**: Changed the internal promise type from `Promise[Map[String, Any]]` to `Promise[Row]` to preserve type information during protobuf serialization/deserialization.
2. **Enhanced protobuf conversion for complex types**:
   - Added proper handling for struct types by creating `GenericRowWithSchema` objects instead of tuples
   - Added support for map type conversion in `LiteralValueProtoConverter`
   - Improved data type inference with a new `getDataType` method that properly handles all literal types
3. **Fixed observed metrics**: Modified the observed metrics processing to include data type information in the protobuf conversion, ensuring that complex types are properly serialized and deserialized.

### Why are the changes needed?

The previous implementation had several issues:
1. **Data type loss**: Observed metrics were losing their original data types during Protobuf conversion, causing errors
2. **Struct handling problems**: The conversion logic didn't properly handle Row objects and struct types

### Does this PR introduce _any_ user-facing change?

Yes, this PR fixes a bug that was preventing users from successfully using observed metrics with complex data types (structs, arrays, maps) in Spark Connect. Users can now:
- Use `struct()` expressions in observed metrics and receive properly typed `GenericRowWithSchema` objects
- Use `array()` expressions in observed metrics and receive properly typed arrays
- Use `map()` expressions in observed metrics and receive properly typed maps

Previously, the code below would fail.
```scala
val observation = Observation("struct")
spark
  .range(10)
  .observe(observation, struct(count(lit(1)).as("rows"), max("id").as("maxid")).as("struct"))
  .collect()
observation.get
// Below is the error message:
"""
org.apache.spark.SparkUnsupportedOperationException: literal [10,9] not supported (yet).
org.apache.spark.sql.connect.common.LiteralValueProtoConverter$.toLiteralProtoBuilder(LiteralValueProtoConverter.scala:104)
org.apache.spark.sql.connect.common.LiteralValueProtoConverter$.toLiteralProto(LiteralValueProtoConverter.scala:203)
org.apache.spark.sql.connect.execution.SparkConnectPlanExecution$.$anonfun$createObservedMetricsResponse$2(SparkConnectPlanExecution.scala:571)
org.apache.spark.sql.connect.execution.SparkConnectPlanExecution$.$anonfun$createObservedMetricsResponse$2$adapted(SparkConnectPlanExecution.scala:570)
"""
```

### How was this patch tested?

`build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite -- -z SPARK-53490"`
`build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"`

### Was this patch authored or co-authored using generative AI tooling?

Generated-by: Cursor 1.5.9

Closes #52236 from heyihong/SPARK-53490.

Authored-by: Yihong He <heyihong.cn@gmail.com>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
index fa427fe..5e27198 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -58,12 +58,12 @@
 
   private val isRegistered = new AtomicBoolean()
 
-  private val promise = Promise[Map[String, Any]]()
+  private val promise = Promise[Row]()
 
   /**
    * Future holding the (yet to be completed) observation.
    */
-  val future: Future[Map[String, Any]] = promise.future
+  val future: Future[Row] = promise.future
 
   /**
    * (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish its
@@ -76,7 +76,10 @@
    *   interrupted while waiting
    */
   @throws[InterruptedException]
-  def get: Map[String, Any] = SparkThreadUtils.awaitResult(future, Duration.Inf)
+  def get: Map[String, Any] = {
+    val row = SparkThreadUtils.awaitResult(future, Duration.Inf)
+    row.getValuesMap(row.schema.map(_.name))
+  }
 
   /**
    * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish its
@@ -99,7 +102,8 @@
    */
   @throws[InterruptedException]
   private[sql] def getOrEmpty: Map[String, Any] = {
-    Try(SparkThreadUtils.awaitResult(future, 100.millis)).getOrElse(Map.empty)
+    val row = getRowOrEmpty.getOrElse(Row.empty)
+    row.getValuesMap(row.schema.map(_.name))
   }
 
   /**
@@ -118,8 +122,17 @@
    *   `true` if all waiting threads were notified, `false` if otherwise.
    */
   private[sql] def setMetricsAndNotify(metrics: Row): Boolean = {
-    val metricsMap = metrics.getValuesMap(metrics.schema.map(_.name))
-    promise.trySuccess(metricsMap)
+    promise.trySuccess(metrics)
+  }
+
+  /**
+   * Get the observed metrics as a Row.
+   *
+   * @return
+   *   the observed metrics as a `Row`, or None if the metrics are not available.
+   */
+  private[sql] def getRowOrEmpty: Option[Row] = {
+    Try(SparkThreadUtils.awaitResult(future, 100.millis)).toOption
   }
 }
 
diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index 2f21665..e221300 100644
--- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -1749,6 +1749,42 @@
     val nullRows = nullResult.filter(_.getAs[Long]("id") >= 5)
     assert(nullRows.forall(_.getAs[Int]("actual_p_id") == 0))
   }
+
+  test("SPARK-53490: struct type in observed metrics") {
+    val observation = Observation("struct")
+    spark
+      .range(10)
+      .observe(observation, struct(count(lit(1)).as("rows"), max("id").as("maxid")).as("struct"))
+      .collect()
+    val expectedSchema =
+      StructType(Seq(StructField("rows", LongType), StructField("maxid", LongType)))
+    val expectedValue = new GenericRowWithSchema(Array(10, 9), expectedSchema)
+    assert(observation.get.size === 1)
+    assert(observation.get.contains("struct"))
+    assert(observation.get("struct") === expectedValue)
+  }
+
+  test("SPARK-53490: array type in observed metrics") {
+    val observation = Observation("array")
+    spark
+      .range(10)
+      .observe(observation, array(count(lit(1))).as("array"))
+      .collect()
+    assert(observation.get.size === 1)
+    assert(observation.get.contains("array"))
+    assert(observation.get("array") === Array(10))
+  }
+
+  test("SPARK-53490: map type in observed metrics") {
+    val observation = Observation("map")
+    spark
+      .range(10)
+      .observe(observation, map(lit("count"), count(lit(1))).as("map"))
+      .collect()
+    assert(observation.get.size === 1)
+    assert(observation.get.contains("map"))
+    assert(observation.get("map") === Map("count" -> 10))
+  }
 }
 
 private[sql] case class ClassData(a: String, b: Int)
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index a7a8c97..ef55edd 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -383,7 +383,7 @@
     (0 until metric.getKeysCount).foreach { i =>
       val key = metric.getKeys(i)
       val value = LiteralValueProtoConverter.toScalaValue(metric.getValues(i))
-      schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass))
+      schema = schema.add(key, LiteralValueProtoConverter.getDataType(metric.getValues(i)))
       values += value
     }
     new GenericRowWithSchema(values.result(), schema)
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index 4ff555c..5e45fbb 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -109,7 +109,7 @@
     ArrayType(toCatalystType(t.getElementType), t.getContainsNull)
   }
 
-  private def toCatalystStructType(t: proto.DataType.Struct): StructType = {
+  private[common] def toCatalystStructType(t: proto.DataType.Struct): StructType = {
     val fields = t.getFieldsList.asScala.toSeq.map { protoField =>
       val metadata = if (protoField.hasMetadata) {
         Metadata.fromJson(protoField.getMetadata)
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
index ca785ff..0edb865 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
@@ -30,12 +30,13 @@
 import com.google.protobuf.ByteString
 
 import org.apache.spark.connect.proto
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
-import org.apache.spark.util.SparkClassUtils
 
 object LiteralValueProtoConverter {
 
@@ -223,52 +224,51 @@
       val sb = builder.getStructBuilder
       val fields = structType.fields
 
-      scalaValue match {
+      val iter = scalaValue match {
         case p: Product =>
-          val iter = p.productIterator
-          var idx = 0
-          if (options.useDeprecatedDataTypeFields) {
-            while (idx < structType.size) {
-              val field = fields(idx)
-              // For backward compatibility, we need the data type for each field.
-              val literalProto = toLiteralProtoBuilderInternal(
-                iter.next(),
-                field.dataType,
-                options,
-                needDataType = true).build()
-              sb.addElements(literalProto)
-              idx += 1
-            }
-            sb.setStructType(toConnectProtoType(structType))
-          } else {
-            while (idx < structType.size) {
-              val field = fields(idx)
-              val literalProto =
-                toLiteralProtoBuilderInternal(iter.next(), field.dataType, options, needDataType)
-                  .build()
-              sb.addElements(literalProto)
-
-              if (needDataType) {
-                val fieldBuilder = sb.getDataTypeStructBuilder
-                  .addFieldsBuilder()
-                  .setName(field.name)
-                  .setNullable(field.nullable)
-
-                if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
-                  fieldBuilder.setDataType(toConnectProtoType(field.dataType))
-                }
-
-                // Set metadata if available
-                if (field.metadata != Metadata.empty) {
-                  fieldBuilder.setMetadata(field.metadata.json)
-                }
-              }
-
-              idx += 1
-            }
-          }
+          p.productIterator
+        case r: Row =>
+          r.toSeq.iterator
         case other =>
-          throw new IllegalArgumentException(s"literal $other not supported (yet).")
+          throw new IllegalArgumentException(
+            s"literal ${other.getClass.getName}($other) not supported (yet).")
+      }
+
+      var idx = 0
+      if (options.useDeprecatedDataTypeFields) {
+        while (idx < structType.size) {
+          val field = fields(idx)
+          val literalProto =
+            toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
+          sb.addElements(literalProto)
+          idx += 1
+        }
+        sb.setStructType(toConnectProtoType(structType))
+      } else {
+        val dataTypeStruct = proto.DataType.Struct.newBuilder()
+        while (idx < structType.size) {
+          val field = fields(idx)
+          val literalProto =
+            toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options)
+          sb.addElements(literalProto)
+
+          val fieldBuilder = dataTypeStruct
+            .addFieldsBuilder()
+            .setName(field.name)
+            .setNullable(field.nullable)
+
+          if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
+            fieldBuilder.setDataType(toConnectProtoType(field.dataType))
+          }
+
+          // Set metadata if available
+          if (field.metadata != Metadata.empty) {
+            fieldBuilder.setMetadata(field.metadata.json)
+          }
+
+          idx += 1
+        }
+        sb.setDataTypeStruct(dataTypeStruct.build())
       }
 
       sb
@@ -721,23 +721,12 @@
   private def toScalaStructInternal(
       struct: proto.Expression.Literal.Struct,
       structType: proto.DataType.Struct): Any = {
-    def toTuple[A <: Object](data: Seq[A]): Product = {
-      try {
-        val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
-        tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product]
-      } catch {
-        case _: Exception =>
-          throw InvalidPlanInput(s"Unsupported Literal: ${data.mkString("Array(", ", ", ")")})")
-      }
-    }
-
-    val size = struct.getElementsCount
-    val structData = Seq.tabulate(size) { i =>
+    val structData = Array.tabulate(struct.getElementsCount) { i =>
       val element = struct.getElements(i)
       val dataType = structType.getFields(i).getDataType
-      getConverter(dataType)(element).asInstanceOf[Object]
+      getConverter(dataType)(element)
     }
-    toTuple(structData)
+    new GenericRowWithSchema(structData, DataTypeProtoConverter.toCatalystStructType(structType))
   }
 
   def getProtoStructType(struct: proto.Expression.Literal.Struct): proto.DataType.Struct = {
@@ -759,4 +748,77 @@
   def toScalaStruct(struct: proto.Expression.Literal.Struct): Any = {
     toScalaStructInternal(struct, getProtoStructType(struct))
   }
+
+  def getDataType(lit: proto.Expression.Literal): DataType = {
+    lit.getLiteralTypeCase match {
+      case proto.Expression.Literal.LiteralTypeCase.NULL =>
+        DataTypeProtoConverter.toCatalystType(lit.getNull)
+      case proto.Expression.Literal.LiteralTypeCase.BINARY =>
+        BinaryType
+      case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
+        BooleanType
+      case proto.Expression.Literal.LiteralTypeCase.BYTE =>
+        ByteType
+      case proto.Expression.Literal.LiteralTypeCase.SHORT =>
+        ShortType
+      case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
+        IntegerType
+      case proto.Expression.Literal.LiteralTypeCase.LONG =>
+        LongType
+      case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
+        FloatType
+      case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
+        DoubleType
+      case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
+        val decimal = Decimal.apply(lit.getDecimal.getValue)
+        var precision = decimal.precision
+        if (lit.getDecimal.hasPrecision) {
+          precision = math.max(precision, lit.getDecimal.getPrecision)
+        }
+        var scale = decimal.scale
+        if (lit.getDecimal.hasScale) {
+          scale = math.max(scale, lit.getDecimal.getScale)
+        }
+        DecimalType(math.max(precision, scale), scale)
+      case proto.Expression.Literal.LiteralTypeCase.STRING =>
+        StringType
+      case proto.Expression.Literal.LiteralTypeCase.DATE =>
+        DateType
+      case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
+        TimestampType
+      case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
+        TimestampNTZType
+      case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
+        CalendarIntervalType
+      case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
+        YearMonthIntervalType()
+      case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
+        DayTimeIntervalType()
+      case proto.Expression.Literal.LiteralTypeCase.TIME =>
+        var precision = TimeType.DEFAULT_PRECISION
+        if (lit.getTime.hasPrecision) {
+          precision = lit.getTime.getPrecision
+        }
+        TimeType(precision)
+      case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+        DataTypeProtoConverter.toCatalystType(
+          proto.DataType.newBuilder
+            .setArray(LiteralValueProtoConverter.getProtoArrayType(lit.getArray))
+            .build())
+      case proto.Expression.Literal.LiteralTypeCase.MAP =>
+        DataTypeProtoConverter.toCatalystType(
+          proto.DataType.newBuilder
+            .setMap(LiteralValueProtoConverter.getProtoMapType(lit.getMap))
+            .build())
+      case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
+        DataTypeProtoConverter.toCatalystType(
+          proto.DataType.newBuilder
+            .setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct))
+            .build())
+      case _ =>
+        throw InvalidPlanInput(
+          s"Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
+            s"(${lit.getLiteralTypeCase.getNumber})")
+    }
+  }
 }
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 7c4ad7d..38ed252 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -31,6 +31,7 @@
 import org.apache.spark.sql.connect.planner.InvalidInputErrors
 import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService}
 import org.apache.spark.sql.connect.utils.ErrorUtils
+import org.apache.spark.sql.types.DataType
 import org.apache.spark.util.Utils
 
 /**
@@ -227,21 +228,22 @@
             executeHolder.request.getPlan.getDescriptorForType)
       }
 
-      val observedMetrics: Map[String, Seq[(Option[String], Any)]] = {
+      val observedMetrics: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
         executeHolder.observations.map { case (name, observation) =>
-          val values = observation.getOrEmpty.map { case (key, value) =>
-            (Some(key), value)
-          }.toSeq
+          val values =
+            observation.getRowOrEmpty
+              .map(SparkConnectPlanExecution.toObservedMetricsValues(_))
+              .getOrElse(Seq.empty)
           name -> values
         }.toMap
       }
-      val accumulatedInPython: Map[String, Seq[(Option[String], Any)]] = {
+      val accumulatedInPython: Map[String, Seq[(Option[String], Any, Option[DataType])]] = {
         executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator =>
           accumulator.synchronized {
             val value = accumulator.value.asScala.toSeq
             if (value.nonEmpty) {
               accumulator.reset()
-              Some("__python_accumulator__" -> value.map(value => (None, value)))
+              Some("__python_accumulator__" -> value.map(value => (None, value, None)))
             } else {
               None
             }
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 388fd6a..0e89681 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -27,6 +27,7 @@
 import org.apache.spark.SparkEnv
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.ExecutePlanResponse
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.classic.{DataFrame, Dataset}
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
@@ -38,7 +39,7 @@
 import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
 import org.apache.spark.util.ThreadUtils
 
 /**
@@ -281,11 +282,7 @@
       dataframe: DataFrame): Option[ExecutePlanResponse] = {
     val observedMetrics = dataframe.queryExecution.observedMetrics.collect {
       case (name, row) if !executeHolder.observations.contains(name) =>
-        val values = if (row.schema == null) {
-          (0 until row.length).map { i => (None, row(i)) }
-        } else {
-          (0 until row.length).map { i => (Some(row.schema.fieldNames(i)), row(i)) }
-        }
+        val values = SparkConnectPlanExecution.toObservedMetricsValues(row)
         name -> values
     }
     if (observedMetrics.nonEmpty) {
@@ -301,18 +298,34 @@
 }
 
 object SparkConnectPlanExecution {
+
+  def toObservedMetricsValues(row: Row): Seq[(Option[String], Any, Option[DataType])] = {
+    if (row.schema == null) {
+      (0 until row.length).map { i => (None, row(i), None) }
+    } else {
+      (0 until row.length).map { i =>
+        (Some(row.schema.fieldNames(i)), row(i), Some(row.schema(i).dataType))
+      }
+    }
+  }
+
   def createObservedMetricsResponse(
       sessionId: String,
       serverSessionId: String,
       observationAndPlanIds: Map[String, Long],
-      metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = {
+      metrics: Map[String, Seq[(Option[String], Any, Option[DataType])]]): ExecutePlanResponse = {
     val observedMetrics = metrics.map { case (name, values) =>
       val metrics = ExecutePlanResponse.ObservedMetrics
         .newBuilder()
         .setName(name)
-      values.foreach { case (key, value) =>
-        metrics.addValues(toLiteralProto(value))
-        key.foreach(metrics.addKeys)
+      values.foreach { case (keyOpt, value, dataTypeOpt) =>
+        dataTypeOpt match {
+          case Some(dataType) =>
+            metrics.addValues(toLiteralProto(value, dataType))
+          case None =>
+            metrics.addValues(toLiteralProto(value))
+        }
+        keyOpt.foreach(metrics.addKeys)
       }
       observationAndPlanIds.get(name).foreach(metrics.setPlanId)
       metrics.build()
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
index 4b308e0..be7d672 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
@@ -19,7 +19,7 @@
 
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
-import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter}
+import org.apache.spark.sql.connect.common.{InvalidPlanInput, LiteralValueProtoConverter}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
@@ -32,107 +32,87 @@
    *   Expression
    */
   def toCatalystExpression(lit: proto.Expression.Literal): expressions.Literal = {
+    val dataType = LiteralValueProtoConverter.getDataType(lit)
     lit.getLiteralTypeCase match {
       case proto.Expression.Literal.LiteralTypeCase.NULL =>
-        expressions.Literal(null, DataTypeProtoConverter.toCatalystType(lit.getNull))
+        expressions.Literal(null, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.BINARY =>
-        expressions.Literal(lit.getBinary.toByteArray, BinaryType)
+        expressions.Literal(lit.getBinary.toByteArray, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
-        expressions.Literal(lit.getBoolean, BooleanType)
+        expressions.Literal(lit.getBoolean, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.BYTE =>
-        expressions.Literal(lit.getByte.toByte, ByteType)
+        expressions.Literal(lit.getByte.toByte, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.SHORT =>
-        expressions.Literal(lit.getShort.toShort, ShortType)
+        expressions.Literal(lit.getShort.toShort, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
-        expressions.Literal(lit.getInteger, IntegerType)
+        expressions.Literal(lit.getInteger, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.LONG =>
-        expressions.Literal(lit.getLong, LongType)
+        expressions.Literal(lit.getLong, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
-        expressions.Literal(lit.getFloat, FloatType)
+        expressions.Literal(lit.getFloat, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
-        expressions.Literal(lit.getDouble, DoubleType)
+        expressions.Literal(lit.getDouble, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
-        val decimal = Decimal.apply(lit.getDecimal.getValue)
-        var precision = decimal.precision
-        if (lit.getDecimal.hasPrecision) {
-          precision = math.max(precision, lit.getDecimal.getPrecision)
-        }
-        var scale = decimal.scale
-        if (lit.getDecimal.hasScale) {
-          scale = math.max(scale, lit.getDecimal.getScale)
-        }
-        expressions.Literal(decimal, DecimalType(math.max(precision, scale), scale))
+        expressions.Literal(Decimal.apply(lit.getDecimal.getValue), dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.STRING =>
-        expressions.Literal(UTF8String.fromString(lit.getString), StringType)
+        expressions.Literal(UTF8String.fromString(lit.getString), dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.DATE =>
-        expressions.Literal(lit.getDate, DateType)
+        expressions.Literal(lit.getDate, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
-        expressions.Literal(lit.getTimestamp, TimestampType)
+        expressions.Literal(lit.getTimestamp, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
-        expressions.Literal(lit.getTimestampNtz, TimestampNTZType)
+        expressions.Literal(lit.getTimestampNtz, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
         val interval = new CalendarInterval(
           lit.getCalendarInterval.getMonths,
           lit.getCalendarInterval.getDays,
           lit.getCalendarInterval.getMicroseconds)
-        expressions.Literal(interval, CalendarIntervalType)
+        expressions.Literal(interval, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
-        expressions.Literal(lit.getYearMonthInterval, YearMonthIntervalType())
+        expressions.Literal(lit.getYearMonthInterval, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
-        expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
+        expressions.Literal(lit.getDayTimeInterval, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.TIME =>
         var precision = TimeType.DEFAULT_PRECISION
         if (lit.getTime.hasPrecision) {
           precision = lit.getTime.getPrecision
         }
-        expressions.Literal(lit.getTime.getNano, TimeType(precision))
+        expressions.Literal(lit.getTime.getNano, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
         val arrayData = LiteralValueProtoConverter.toScalaArray(lit.getArray)
-        val dataType = DataTypeProtoConverter.toCatalystType(
-          proto.DataType.newBuilder
-            .setArray(LiteralValueProtoConverter.getProtoArrayType(lit.getArray))
-            .build())
         expressions.Literal.create(arrayData, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.MAP =>
         val mapData = LiteralValueProtoConverter.toScalaMap(lit.getMap)
-        val dataType = DataTypeProtoConverter.toCatalystType(
-          proto.DataType.newBuilder
-            .setMap(LiteralValueProtoConverter.getProtoMapType(lit.getMap))
-            .build())
         expressions.Literal.create(mapData, dataType)
 
       case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
         val structData = LiteralValueProtoConverter.toScalaStruct(lit.getStruct)
-        val dataType = DataTypeProtoConverter.toCatalystType(
-          proto.DataType.newBuilder
-            .setStruct(LiteralValueProtoConverter.getProtoStructType(lit.getStruct))
-            .build())
         val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
         expressions.Literal(convert(structData), dataType)
 
       case _ =>
         throw InvalidPlanInput(
-          s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
-            s"(${lit.getLiteralTypeCase.name})")
+          s"Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
+            s"(${lit.getLiteralTypeCase.getNumber})")
     }
   }
 }
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
index dfde32c..80c185e 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
@@ -20,6 +20,8 @@
 import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
 
 import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.connect.common.InvalidPlanInput
 import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
 import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.ToLiteralProtoOptions
@@ -51,7 +53,7 @@
     }
   }
 
-  Seq(
+  Seq[(Any, DataType)](
     (
       (1, "string", true),
       StructType(
@@ -76,7 +78,6 @@
             "b",
             StructType(Seq(StructField("c", IntegerType), StructField("d", IntegerType))))))),
     (Array(true, false, true), ArrayType(BooleanType)),
-    (Array(1.toByte, 2.toByte, 3.toByte), ArrayType(ByteType)),
     (Array(1.toShort, 2.toShort, 3.toShort), ArrayType(ShortType)),
     (Array(1, 2, 3), ArrayType(IntegerType)),
     (Array(1L, 2L, 3L), ArrayType(LongType)),
@@ -87,15 +88,16 @@
     (
       Array(Array(Array(Array(Array(Array(1, 2, 3)))))),
       ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType))))))),
-    (Array(Map(1 -> 2)), ArrayType(MapType(IntegerType, IntegerType))),
     (Map[String, String]("1" -> "2", "3" -> "4"), MapType(StringType, StringType)),
     (Map[String, Boolean]("1" -> true, "2" -> false), MapType(StringType, BooleanType)),
     (Map[Int, Int](), MapType(IntegerType, IntegerType)),
     (Map(1 -> 2, 3 -> 4, 5 -> 6), MapType(IntegerType, IntegerType))).zipWithIndex.foreach {
     case ((v, t), idx) =>
+      val convert = CatalystTypeConverters.createToCatalystConverter(t)
+      val expected = expressions.Literal(convert(v), t)
       test(s"complex proto value and catalyst value conversion #$idx") {
-        assertResult(v)(
-          LiteralValueProtoConverter.toScalaValue(
+        assertResult(expected)(
+          LiteralExpressionProtoConverter.toCatalystExpression(
             LiteralValueProtoConverter.toLiteralProtoWithOptions(
               v,
               Some(t),
@@ -103,8 +105,8 @@
       }
 
       test(s"complex proto value and catalyst value conversion #$idx - backward compatibility") {
-        assertResult(v)(
-          LiteralValueProtoConverter.toScalaValue(
+        assertResult(expected)(
+          LiteralExpressionProtoConverter.toCatalystExpression(
             LiteralValueProtoConverter.toLiteralProtoWithOptions(
               v,
               Some(t),
@@ -186,12 +188,12 @@
     val result = LiteralValueProtoConverter.toScalaStruct(structProto.getStruct)
     val resultType = LiteralValueProtoConverter.getProtoStructType(structProto.getStruct)
 
-    // Verify the result is a tuple with correct values
-    assert(result.isInstanceOf[Product])
-    val product = result.asInstanceOf[Product]
-    assert(product.productArity == 2)
-    assert(product.productElement(0) == 1)
-    assert(product.productElement(1) == "test")
+    // Verify the result is a GenericRowWithSchema with correct values
+    assert(result.isInstanceOf[GenericRowWithSchema])
+    val row = result.asInstanceOf[GenericRowWithSchema]
+    assert(row.length == 2)
+    assert(row.get(0) == 1)
+    assert(row.get(1) == "test")
 
     // Verify the returned struct type matches the original
     assert(resultType.getFieldsCount == 2)