[SPARK-47681][SQL][FOLLOWUP] Fix variant decimal handling

### What changes were proposed in this pull request?

There are two issues with the current variant decimal handling:
1. The precision and scale of the `BigDecimal` returned by `getDecimal` is not checked. Based on the variant spec, they must be within the corresponding limit for DECIMAL4/8/16. An out-of-range decimal can lead to failure in the downstream Spark operations.
2. The current `schema_of_variant` implementation doesn't correctly handle the case where precision is smaller than scale. Spark's `DecimalType` requires `precision >= scale`.

The Python side requires a similar fix for 1. During the fix, I found that Python error reporting was not correctly implemented (it was never tested either) and I also fixed it.

### Why are the changes needed?

They are bug fixes and are required to process decimals correctly.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46338 from chenhao-db/fix_variant_decimal.

Authored-by: Chenhao Li <chenhao.li@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
index e4e9cc8..84e3a45 100644
--- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
+++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
@@ -392,6 +392,13 @@
     return Double.longBitsToDouble(readLong(value, pos + 1, 8));
   }
 
+  // Check whether the precision and scale of the decimal are within the limit.
+  private static void checkDecimal(BigDecimal d, int maxPrecision) {
+    if (d.precision() > maxPrecision || d.scale() > maxPrecision) {
+      throw malformedVariant();
+    }
+  }
+
   // Get a decimal value from variant value `value[pos...]`.
   // Throw `MALFORMED_VARIANT` if the variant is malformed.
   public static BigDecimal getDecimal(byte[] value, int pos) {
@@ -399,14 +406,18 @@
     int basicType = value[pos] & BASIC_TYPE_MASK;
     int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
     if (basicType != PRIMITIVE) throw unexpectedType(Type.DECIMAL);
-    int scale = value[pos + 1];
+    // Interpret the scale byte as unsigned. If it is a negative byte, the unsigned value must be
+    // greater than `MAX_DECIMAL16_PRECISION` and will trigger an error in `checkDecimal`.
+    int scale = value[pos + 1] & 0xFF;
     BigDecimal result;
     switch (typeInfo) {
       case DECIMAL4:
         result = BigDecimal.valueOf(readLong(value, pos + 2, 4), scale);
+        checkDecimal(result, MAX_DECIMAL4_PRECISION);
         break;
       case DECIMAL8:
         result = BigDecimal.valueOf(readLong(value, pos + 2, 8), scale);
+        checkDecimal(result, MAX_DECIMAL8_PRECISION);
         break;
       case DECIMAL16:
         checkIndex(pos + 17, value.length);
@@ -417,6 +428,7 @@
           bytes[i] = value[pos + 17 - i];
         }
         result = new BigDecimal(new BigInteger(bytes), scale);
+        checkDecimal(result, MAX_DECIMAL16_PRECISION);
         break;
       default:
         throw unexpectedType(Type.DECIMAL);
diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json
index 7771791..906bf78 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -482,6 +482,11 @@
       "<arg1> and <arg2> should be of the same length, got <arg1_length> and <arg2_length>."
     ]
   },
+  "MALFORMED_VARIANT" : {
+    "message" : [
+      "Variant binary is malformed. Please check the data source is valid."
+    ]
+  },
   "MASTER_URL_NOT_SET": {
     "message": [
       "A master URL must be set in your configuration."
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index 7d45adb..40eded6 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -1577,6 +1577,19 @@
         # check repr
         self.assertEqual(str(variants[0]), str(eval(repr(variants[0]))))
 
+        metadata = bytes([1, 0, 0])
+        self.assertEqual(str(VariantVal(bytes([32, 0, 1, 0, 0, 0]), metadata)), "1")
+        self.assertEqual(str(VariantVal(bytes([32, 1, 2, 0, 0, 0]), metadata)), "0.2")
+        self.assertEqual(str(VariantVal(bytes([32, 2, 3, 0, 0, 0]), metadata)), "0.03")
+        self.assertEqual(str(VariantVal(bytes([32, 0, 1, 0, 0, 0]), metadata)), "1")
+        self.assertEqual(str(VariantVal(bytes([32, 0, 255, 201, 154, 59]), metadata)), "999999999")
+        self.assertRaises(
+            PySparkValueError, lambda: str(VariantVal(bytes([32, 0, 0, 202, 154, 59]), metadata))
+        )
+        self.assertRaises(
+            PySparkValueError, lambda: str(VariantVal(bytes([32, 10, 1, 0, 0, 0]), metadata))
+        )
+
     def test_from_ddl(self):
         self.assertEqual(DataType.fromDDL("long"), LongType())
         self.assertEqual(
diff --git a/python/pyspark/sql/variant_utils.py b/python/pyspark/sql/variant_utils.py
index 1ee1395..95084fc 100644
--- a/python/pyspark/sql/variant_utils.py
+++ b/python/pyspark/sql/variant_utils.py
@@ -115,6 +115,13 @@
     )
     EPOCH_NTZ = datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0)
 
+    MAX_DECIMAL4_PRECISION = 9
+    MAX_DECIMAL4_VALUE = 10**MAX_DECIMAL4_PRECISION
+    MAX_DECIMAL8_PRECISION = 18
+    MAX_DECIMAL8_VALUE = 10**MAX_DECIMAL8_PRECISION
+    MAX_DECIMAL16_PRECISION = 38
+    MAX_DECIMAL16_VALUE = 10**MAX_DECIMAL16_PRECISION
+
     @classmethod
     def to_json(cls, value: bytes, metadata: bytes, zone_id: str = "UTC") -> str:
         """
@@ -142,7 +149,7 @@
     @classmethod
     def _check_index(cls, pos: int, length: int) -> None:
         if pos < 0 or pos >= length:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
 
     @classmethod
     def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]:
@@ -162,14 +169,14 @@
         offset_size = ((metadata[0] >> 6) & 0x3) + 1
         dict_size = cls._read_long(metadata, 1, offset_size, signed=False)
         if id >= dict_size:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         string_start = 1 + (dict_size + 2) * offset_size
         offset = cls._read_long(metadata, 1 + (id + 1) * offset_size, offset_size, signed=False)
         next_offset = cls._read_long(
             metadata, 1 + (id + 2) * offset_size, offset_size, signed=False
         )
         if offset > next_offset:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         cls._check_index(string_start + next_offset - 1, len(metadata))
         return metadata[string_start + offset : (string_start + next_offset)].decode("utf-8")
 
@@ -180,7 +187,7 @@
         if basic_type != VariantUtils.PRIMITIVE or (
             type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE
         ):
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         return type_info == VariantUtils.TRUE
 
     @classmethod
@@ -188,7 +195,7 @@
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.PRIMITIVE:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         if type_info == VariantUtils.INT1:
             return cls._read_long(value, pos + 1, 1, signed=True)
         elif type_info == VariantUtils.INT2:
@@ -197,25 +204,25 @@
             return cls._read_long(value, pos + 1, 4, signed=True)
         elif type_info == VariantUtils.INT8:
             return cls._read_long(value, pos + 1, 8, signed=True)
-        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
 
     @classmethod
     def _get_date(cls, value: bytes, pos: int) -> datetime.date:
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.PRIMITIVE:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         if type_info == VariantUtils.DATE:
             days_since_epoch = cls._read_long(value, pos + 1, 4, signed=True)
             return datetime.date.fromordinal(VariantUtils.EPOCH.toordinal() + days_since_epoch)
-        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
 
     @classmethod
     def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.datetime:
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.PRIMITIVE:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         if type_info == VariantUtils.TIMESTAMP_NTZ:
             microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True)
             return VariantUtils.EPOCH_NTZ + datetime.timedelta(
@@ -226,7 +233,7 @@
             return (
                 VariantUtils.EPOCH + datetime.timedelta(microseconds=microseconds_since_epoch)
             ).astimezone(ZoneInfo(zone_id))
-        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
 
     @classmethod
     def _get_string(cls, value: bytes, pos: int) -> str:
@@ -245,39 +252,51 @@
                 length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
             cls._check_index(start + length - 1, len(value))
             return value[start : start + length].decode("utf-8")
-        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
 
     @classmethod
     def _get_double(cls, value: bytes, pos: int) -> float:
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.PRIMITIVE:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         if type_info == VariantUtils.FLOAT:
             cls._check_index(pos + 4, len(value))
             return struct.unpack("<f", value[pos + 1 : pos + 5])[0]
         elif type_info == VariantUtils.DOUBLE:
             cls._check_index(pos + 8, len(value))
             return struct.unpack("<d", value[pos + 1 : pos + 9])[0]
-        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
+
+    @classmethod
+    def _check_decimal(cls, unscaled: int, scale: int, max_unscaled: int, max_scale: int) -> None:
+        # max_unscaled == 10**max_scale, but we pass a literal parameter to avoid redundant
+        # computation.
+        if unscaled >= max_unscaled or unscaled <= -max_unscaled or scale > max_scale:
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
 
     @classmethod
     def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.PRIMITIVE:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         scale = value[pos + 1]
         unscaled = 0
         if type_info == VariantUtils.DECIMAL4:
             unscaled = cls._read_long(value, pos + 2, 4, signed=True)
+            cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL4_VALUE, cls.MAX_DECIMAL4_PRECISION)
         elif type_info == VariantUtils.DECIMAL8:
             unscaled = cls._read_long(value, pos + 2, 8, signed=True)
+            cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL8_VALUE, cls.MAX_DECIMAL8_PRECISION)
         elif type_info == VariantUtils.DECIMAL16:
             cls._check_index(pos + 17, len(value))
             unscaled = int.from_bytes(value[pos + 2 : pos + 18], byteorder="little", signed=True)
+            cls._check_decimal(
+                unscaled, scale, cls.MAX_DECIMAL16_VALUE, cls.MAX_DECIMAL16_PRECISION
+            )
         else:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale))
 
     @classmethod
@@ -285,7 +304,7 @@
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.BINARY:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         start = pos + 1 + VariantUtils.U32_SIZE
         length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
         cls._check_index(start + length - 1, len(value))
@@ -331,7 +350,7 @@
             return datetime.datetime
         elif type_info == VariantUtils.LONG_STR:
             return str
-        raise PySparkValueError(error_class="MALFORMED_VARIANT")
+        raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
 
     @classmethod
     def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) -> str:
@@ -419,7 +438,7 @@
         elif variant_type == datetime.datetime:
             return cls._get_timestamp(value, pos, zone_id)
         else:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
 
     @classmethod
     def _handle_object(
@@ -432,7 +451,7 @@
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.OBJECT:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         large_size = ((type_info >> 4) & 0x1) != 0
         size_bytes = VariantUtils.U32_SIZE if large_size else 1
         num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
@@ -461,7 +480,7 @@
         cls._check_index(pos, len(value))
         basic_type, type_info = cls._get_type_info(value, pos)
         if basic_type != VariantUtils.ARRAY:
-            raise PySparkValueError(error_class="MALFORMED_VARIANT")
+            raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
         large_size = ((type_info >> 2) & 0x1) != 0
         size_bytes = VariantUtils.U32_SIZE if large_size else 1
         num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index 43c561e..3dbc724 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -680,7 +680,8 @@
     case Type.DOUBLE => DoubleType
     case Type.DECIMAL =>
       val d = v.getDecimal
-      DecimalType(d.precision(), d.scale())
+      // Spark doesn't allow `DecimalType` to have `precision < scale`.
+      DecimalType(d.precision().max(d.scale()), d.scale())
     case Type.DATE => DateType
     case Type.TIMESTAMP => TimestampType
     case Type.TIMESTAMP_NTZ => TimestampNTZType
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index d001f0ec..f4a6a14 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -58,6 +58,9 @@
     check(Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0), emptyMetadata)
     // DECIMAL16 only has 15 byte content.
     check(Array(primitiveHeader(DECIMAL16)) ++ Array.fill(16)(0.toByte), emptyMetadata)
+    // 1e38 has a precision of 39. Even if it still fits into 16 bytes, it is not a valid decimal.
+    check(Array[Byte](primitiveHeader(DECIMAL16), 0) ++
+      BigDecimal(1e38).toBigInt.toByteArray.reverse, emptyMetadata)
     // Short string content too short.
     check(Array(shortStrHeader(2), 'x'), emptyMetadata)
     // Long string length too short (requires 4 bytes).
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
index c8d267f..3964bf3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala
@@ -157,6 +157,7 @@
     check("null", "VOID")
     check("1", "BIGINT")
     check("1.0", "DECIMAL(1,0)")
+    check("0.01", "DECIMAL(2,2)")
     check("1E0", "DOUBLE")
     check("true", "BOOLEAN")
     check("\"2000-01-01\"", "STRING")