[SPARK-35113][SQL] Support ANSI intervals in the Hash expression
### What changes were proposed in this pull request?
Support ANSI interval in HashExpression and add UT
### Why are the changes needed?
Support ANSI interval in HashExpression
### Does this PR introduce _any_ user-facing change?
User can pass ANSI interval in HashExpression function
### How was this patch tested?
Added UT
Closes #32259 from AngersZhuuuu/SPARK-35113.
Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index f23c1e5..f3a8274 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -238,24 +238,30 @@
* is not exposed to users and should only be set inside spark SQL.
*
* The hash value for an expression depends on its type and seed:
- * - null: seed
- * - boolean: turn boolean into int, 1 for true, 0 for false, and then use murmur3 to
- * hash this int with seed.
- * - byte, short, int: use murmur3 to hash the input as int with seed.
- * - long: use murmur3 to hash the long input with seed.
- * - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it.
- * - double: turn it into long: java.lang.Double.doubleToLongBits(input), and hash it.
- * - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long and hash
- * it. Else, turn it into bytes and hash it.
- * - calendar interval: hash `microseconds` first, and use the result as seed to hash `months`.
- * - binary: use murmur3 to hash the bytes with seed.
- * - string: get the bytes of string and hash it.
- * - array: The `result` starts with seed, then use `result` as seed, recursively
- * calculate hash value for each element, and assign the element hash value
- * to `result`.
- * - struct: The `result` starts with seed, then use `result` as seed, recursively
- * calculate hash value for each field, and assign the field hash value to
- * `result`.
+ * - null: seed
+ * - boolean: turn boolean into int, 1 for true, 0 for false,
+ * and then use murmur3 to hash this int with seed.
+ * - byte, short, int: use murmur3 to hash the input as int with seed.
+ * - long: use murmur3 to hash the long input with seed.
+ * - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it.
+ * - double: turn it into long: java.lang.Double.doubleToLongBits(input),
+ * and hash it.
+ * - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long
+ * and hash it. Else, turn it into bytes and hash it.
+ * - calendar interval: hash `microseconds` first, and use the result as seed
+ * to hash `months`.
+ * - interval day to second: it store long value of `microseconds`, use murmur3 to hash the long
+ * input with seed.
+ * - interval year to month: it store int value of `months`, use murmur3 to hash the int
+ * input with seed.
+ * - binary: use murmur3 to hash the bytes with seed.
+ * - string: get the bytes of string and hash it.
+ * - array: The `result` starts with seed, then use `result` as seed, recursively
+ * calculate hash value for each element, and assign the element hash
+ * value to `result`.
+ * - struct: The `result` starts with seed, then use `result` as seed, recursively
+ * calculate hash value for each field, and assign the field hash value
+ * to `result`.
*
* Finally we aggregate the hash values for each expression by the same way of struct.
*/
@@ -475,6 +481,8 @@
case DoubleType => genHashDouble(input, result)
case d: DecimalType => genHashDecimal(ctx, d, input, result)
case CalendarIntervalType => genHashCalendarInterval(input, result)
+ case DayTimeIntervalType => genHashLong(input, result)
+ case YearMonthIntervalType => genHashInt(input, result)
case BinaryType => genHashBytes(input, result)
case StringType => genHashString(input, result)
case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index af6e5a3..858d8f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
-import java.time.{ZoneId, ZoneOffset}
+import java.time.{Duration, Period, ZoneId, ZoneOffset}
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
@@ -697,6 +697,17 @@
checkEvaluation(XxHash64(Seq(literal), 100L), XxHash64(Seq(literal), 100).eval())
}
+ test("SPARK-35113: HashExpression support DayTimeIntervalType/YearMonthIntervalType") {
+ val dayTime = Literal.create(Duration.ofSeconds(1237123123), DayTimeIntervalType)
+ val yearMonth = Literal.create(Period.ofMonths(1234), YearMonthIntervalType)
+ checkEvaluation(Murmur3Hash(Seq(dayTime), 10), -428664612)
+ checkEvaluation(Murmur3Hash(Seq(yearMonth), 10), -686520021)
+ checkEvaluation(XxHash64(Seq(dayTime), 10), 8228802290839366895L)
+ checkEvaluation(XxHash64(Seq(yearMonth), 10), -1774215319882784110L)
+ checkEvaluation(HiveHash(Seq(dayTime)), 743331816)
+ checkEvaluation(HiveHash(Seq(yearMonth)), 1234)
+ }
+
private def testHash(inputSchema: StructType): Unit = {
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
val toRow = RowEncoder(inputSchema).createSerializer()