blob: 858d8f78bef673ae9c0292155cb20d328018df42 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
import java.time.{Duration, Period, ZoneId, ZoneOffset}
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import org.apache.commons.codec.digest.DigestUtils
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.types.{ArrayType, StructType, _}
import org.apache.spark.unsafe.types.UTF8String
class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val random = new scala.util.Random
implicit def stringToUTF8Str(str: String): UTF8String = UTF8String.fromString(str)
test("md5") {
checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))),
"902fbdd2b1df0c4f70b4a5d23525e932")
checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
"6ac1e56bc78f031059be7be854522c4c")
checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType)
}
test("sha1") {
checkEvaluation(Sha1(Literal("ABC".getBytes(StandardCharsets.UTF_8))),
"3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
"5d211bad8f4ee70e16c7d343a838fc344a1ed961")
checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
checkEvaluation(Sha1(Literal("".getBytes(StandardCharsets.UTF_8))),
"da39a3ee5e6b4b0d3255bfef95601890afd80709")
checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType)
}
test("sha2") {
checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)), Literal(256)),
DigestUtils.sha256Hex("ABC"))
checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
// unsupported bit length
checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
checkEvaluation(Sha2(Literal("ABC".getBytes(StandardCharsets.UTF_8)),
Literal.create(null, IntegerType)), null)
checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
}
test("crc32") {
checkEvaluation(Crc32(Literal("ABC".getBytes(StandardCharsets.UTF_8))), 2743272264L)
checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
2180413220L)
checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}
def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
// Note : All expected hashes need to be computed using Hive 1.2.1
val actual = HiveHashFunction.hash(input, dataType, seed = 0)
withClue(s"hash mismatch for input = `$input` of type `$dataType`.") {
assert(actual == expected)
}
}
def checkHiveHashForIntegralType(dataType: DataType): Unit = {
// corner cases
checkHiveHash(null, dataType, 0)
checkHiveHash(1, dataType, 1)
checkHiveHash(0, dataType, 0)
checkHiveHash(-1, dataType, -1)
checkHiveHash(Int.MaxValue, dataType, Int.MaxValue)
checkHiveHash(Int.MinValue, dataType, Int.MinValue)
// random values
for (_ <- 0 until 10) {
val input = random.nextInt()
checkHiveHash(input, dataType, input)
}
}
test("hive-hash for null") {
checkHiveHash(null, NullType, 0)
}
test("hive-hash for boolean") {
checkHiveHash(true, BooleanType, 1)
checkHiveHash(false, BooleanType, 0)
}
test("hive-hash for byte") {
checkHiveHashForIntegralType(ByteType)
}
test("hive-hash for short") {
checkHiveHashForIntegralType(ShortType)
}
test("hive-hash for int") {
checkHiveHashForIntegralType(IntegerType)
}
test("hive-hash for long") {
checkHiveHash(1L, LongType, 1L)
checkHiveHash(0L, LongType, 0L)
checkHiveHash(-1L, LongType, 0L)
checkHiveHash(Long.MaxValue, LongType, -2147483648)
// Hive's fails to parse this.. but the hashing function itself can handle this input
checkHiveHash(Long.MinValue, LongType, -2147483648)
for (_ <- 0 until 10) {
val input = random.nextLong()
checkHiveHash(input, LongType, ((input >>> 32) ^ input).toInt)
}
}
test("hive-hash for float") {
checkHiveHash(0F, FloatType, 0)
checkHiveHash(0.0F, FloatType, 0)
checkHiveHash(1.1F, FloatType, 1066192077L)
checkHiveHash(-1.1F, FloatType, -1081291571)
checkHiveHash(99999999.99999999999F, FloatType, 1287568416L)
checkHiveHash(Float.MaxValue, FloatType, 2139095039)
checkHiveHash(Float.MinValue, FloatType, -8388609)
}
test("hive-hash for double") {
checkHiveHash(0, DoubleType, 0)
checkHiveHash(0.0, DoubleType, 0)
checkHiveHash(1.1, DoubleType, -1503133693)
checkHiveHash(-1.1, DoubleType, 644349955)
checkHiveHash(1000000000.000001, DoubleType, 1104006509)
checkHiveHash(1000000000.0000000000000000000000001, DoubleType, 1104006501)
checkHiveHash(9999999999999999999.9999999999999999999, DoubleType, 594568676)
checkHiveHash(Double.MaxValue, DoubleType, -2146435072)
checkHiveHash(Double.MinValue, DoubleType, 1048576)
}
test("hive-hash for string") {
checkHiveHash(UTF8String.fromString("apache spark"), StringType, 1142704523L)
checkHiveHash(UTF8String.fromString("!@#$%^&*()_+=-"), StringType, -613724358L)
checkHiveHash(UTF8String.fromString("abcdefghijklmnopqrstuvwxyz"), StringType, 958031277L)
checkHiveHash(UTF8String.fromString("AbCdEfGhIjKlMnOpQrStUvWxYz012"), StringType, -648013852L)
// scalastyle:off nonascii
checkHiveHash(UTF8String.fromString("数据砖头"), StringType, -898686242L)
checkHiveHash(UTF8String.fromString("नमस्ते"), StringType, 2006045948L)
// scalastyle:on nonascii
}
test("hive-hash for date type") {
def checkHiveHashForDateType(dateString: String, expected: Long): Unit = {
checkHiveHash(
DateTimeUtils.stringToDate(UTF8String.fromString(dateString), ZoneOffset.UTC).get,
DateType,
expected)
}
// basic case
checkHiveHashForDateType("2017-01-01", 17167)
// boundary cases
checkHiveHashForDateType("0000-01-01", -719528)
checkHiveHashForDateType("9999-12-31", 2932896)
// epoch
checkHiveHashForDateType("1970-01-01", 0)
// before epoch
checkHiveHashForDateType("1800-01-01", -62091)
// Invalid input: bad date string. Hive returns 0 for such cases
intercept[NoSuchElementException](checkHiveHashForDateType("0-0-0", 0))
intercept[NoSuchElementException](checkHiveHashForDateType("-1212-01-01", 0))
intercept[NoSuchElementException](checkHiveHashForDateType("2016-99-99", 0))
// Invalid input: Empty string. Hive returns 0 for this case
intercept[NoSuchElementException](checkHiveHashForDateType("", 0))
// Invalid input: February 30th for a leap year. Hive supports this but Spark doesn't
intercept[NoSuchElementException](checkHiveHashForDateType("2016-02-30", 16861))
}
test("hive-hash for timestamp type") {
def checkHiveHashForTimestampType(
timestamp: String,
expected: Long,
zoneId: ZoneId = ZoneOffset.UTC): Unit = {
checkHiveHash(
DateTimeUtils.stringToTimestamp(UTF8String.fromString(timestamp), zoneId).get,
TimestampType,
expected)
}
// basic case
checkHiveHashForTimestampType("2017-02-24 10:56:29", 1445725271)
// with higher precision
checkHiveHashForTimestampType("2017-02-24 10:56:29.111111", 1353936655)
// with different timezone
checkHiveHashForTimestampType("2017-02-24 10:56:29", 1445732471,
DateTimeUtils.getZoneId("US/Pacific"))
// boundary cases
checkHiveHashForTimestampType("0001-01-01 00:00:00", 1645969984)
checkHiveHashForTimestampType("9999-01-01 00:00:00", -1081818240)
// epoch
checkHiveHashForTimestampType("1970-01-01 00:00:00", 0)
// before epoch
checkHiveHashForTimestampType("1800-01-01 03:12:45", -267420885)
// Invalid input: bad timestamp string. Hive returns 0 for such cases
intercept[NoSuchElementException](checkHiveHashForTimestampType("0-0-0 0:0:0", 0))
intercept[NoSuchElementException](checkHiveHashForTimestampType("-99-99-99 99:99:45", 0))
intercept[NoSuchElementException](checkHiveHashForTimestampType("555555-55555-5555", 0))
// Invalid input: Empty string. Hive returns 0 for this case
intercept[NoSuchElementException](checkHiveHashForTimestampType("", 0))
// Invalid input: February 30th is a leap year. Hive supports this but Spark doesn't
intercept[NoSuchElementException](checkHiveHashForTimestampType("2016-02-30 00:00:00", 0))
// Invalid input: Hive accepts upto 9 decimal place precision but Spark uses upto 6
intercept[TestFailedException](checkHiveHashForTimestampType("2017-02-24 10:56:29.11111111", 0))
}
test("hive-hash for CalendarInterval type") {
def checkHiveHashForIntervalType(interval: String, expected: Long): Unit = {
checkHiveHash(IntervalUtils.stringToInterval(UTF8String.fromString(interval)),
CalendarIntervalType, expected)
}
// ----- MICROSEC -----
// basic case
checkHiveHashForIntervalType("interval 1 microsecond", 24273)
// negative
checkHiveHashForIntervalType("interval -1 microsecond", 22273)
// edge / boundary cases
checkHiveHashForIntervalType("interval 0 microsecond", 23273)
checkHiveHashForIntervalType("interval 999 microsecond", 1022273)
checkHiveHashForIntervalType("interval -999 microsecond", -975727)
// ----- MILLISEC -----
// basic case
checkHiveHashForIntervalType("interval 1 millisecond", 1023273)
// negative
checkHiveHashForIntervalType("interval -1 millisecond", -976727)
// edge / boundary cases
checkHiveHashForIntervalType("interval 0 millisecond", 23273)
checkHiveHashForIntervalType("interval 999 millisecond", 999023273)
checkHiveHashForIntervalType("interval -999 millisecond", -998976727)
// ----- SECOND -----
// basic case
checkHiveHashForIntervalType("interval 1 second", 23310)
// negative
checkHiveHashForIntervalType("interval -1 second", 23273)
// edge / boundary cases
checkHiveHashForIntervalType("interval 0 second", 23273)
checkHiveHashForIntervalType("interval 2147483647 second", -2147460412)
checkHiveHashForIntervalType("interval -2147483648 second", -2147460412)
// Out of range for both Hive and Spark
// Hive throws an exception. Spark overflows and returns wrong output
// checkHiveHashForIntervalType("interval 9999999999 second", 0)
// ----- MINUTE -----
// basic cases
checkHiveHashForIntervalType("interval 1 minute", 25493)
// negative
checkHiveHashForIntervalType("interval -1 minute", 25456)
// edge / boundary cases
checkHiveHashForIntervalType("interval 0 minute", 23273)
checkHiveHashForIntervalType("interval 2147483647 minute", 21830)
checkHiveHashForIntervalType("interval -2147483648 minute", 22163)
// Out of range for both Hive and Spark
// Hive throws an exception. Spark overflows and returns wrong output
// checkHiveHashForIntervalType("interval 9999999999 minute", 0)
// ----- HOUR -----
// basic case
checkHiveHashForIntervalType("interval 1 hour", 156473)
// negative
checkHiveHashForIntervalType("interval -1 hour", 156436)
// edge / boundary cases
checkHiveHashForIntervalType("interval 0 hour", 23273)
checkHiveHashForIntervalType("interval 2147483647 hour", -62308)
checkHiveHashForIntervalType("interval -2147483648 hour", -43327)
// Out of range for both Hive and Spark
// Hive throws an exception. Spark overflows and returns wrong output
// checkHiveHashForIntervalType("interval 9999999999 hour", 0)
// ----- DAY -----
// basic cases
checkHiveHashForIntervalType("interval 1 day", 3220073)
// negative
checkHiveHashForIntervalType("interval -1 day", 3220036)
// edge / boundary cases
checkHiveHashForIntervalType("interval 0 day", 23273)
checkHiveHashForIntervalType("interval 106751991 day", -451506760)
checkHiveHashForIntervalType("interval -106751991 day", -451514123)
// Hive supports `day` for a longer range but Spark's range is smaller
// The check for range is done at the parser level so this does not fail in Spark
// checkHiveHashForIntervalType("interval -2147483648 day", -1575127)
// checkHiveHashForIntervalType("interval 2147483647 day", -4767228)
// Out of range for both Hive and Spark
// Hive throws an exception. Spark overflows and returns wrong output
// checkHiveHashForIntervalType("interval 9999999999 day", 0)
// ----- MIX -----
checkHiveHashForIntervalType("interval 0 day 0 hour", 23273)
checkHiveHashForIntervalType("interval 0 day 0 hour 0 minute", 23273)
checkHiveHashForIntervalType("interval 0 day 0 hour 0 minute 0 second", 23273)
checkHiveHashForIntervalType("interval 0 day 0 hour 0 minute 0 second 0 millisecond", 23273)
checkHiveHashForIntervalType(
"interval 0 day 0 hour 0 minute 0 second 0 millisecond 0 microsecond", 23273)
checkHiveHashForIntervalType("interval 6 day 15 hour", 21202073)
checkHiveHashForIntervalType("interval 5 day 4 hour 8 minute", 16557833)
checkHiveHashForIntervalType("interval -23 day 56 hour -1111113 minute 9898989 second",
-2128468593)
checkHiveHashForIntervalType("interval 66 day 12 hour 39 minute 23 second 987 millisecond",
1199697904)
checkHiveHashForIntervalType(
"interval 66 day 12 hour 39 minute 23 second 987 millisecond 123 microsecond", 1199820904)
}
test("hive-hash for array") {
// empty array
checkHiveHash(
input = new GenericArrayData(Array[Int]()),
dataType = ArrayType(IntegerType, containsNull = false),
expected = 0)
// basic case
checkHiveHash(
input = new GenericArrayData(Array(1, 10000, Int.MaxValue)),
dataType = ArrayType(IntegerType, containsNull = false),
expected = -2147172688L)
// with negative values
checkHiveHash(
input = new GenericArrayData(Array(-1L, 0L, 999L, Int.MinValue.toLong)),
dataType = ArrayType(LongType, containsNull = false),
expected = -2147452680L)
// with nulls only
val arrayTypeWithNull = ArrayType(IntegerType, containsNull = true)
checkHiveHash(
input = new GenericArrayData(Array(null, null)),
dataType = arrayTypeWithNull,
expected = 0)
// mix with null
checkHiveHash(
input = new GenericArrayData(Array(-12221, 89, null, 767)),
dataType = arrayTypeWithNull,
expected = -363989515)
// nested with array
checkHiveHash(
input = new GenericArrayData(
Array(
new GenericArrayData(Array(1234L, -9L, 67L)),
new GenericArrayData(Array(null, null)),
new GenericArrayData(Array(55L, -100L, -2147452680L))
)),
dataType = ArrayType(ArrayType(LongType)),
expected = -1007531064)
// nested with map
checkHiveHash(
input = new GenericArrayData(
Array(
new ArrayBasedMapData(
new GenericArrayData(Array(-99, 1234)),
new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
new ArrayBasedMapData(
new GenericArrayData(Array(67)),
new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
)),
dataType = ArrayType(MapType(IntegerType, StringType)),
expected = 1139205955)
}
test("hive-hash for map") {
val mapType = MapType(IntegerType, StringType)
// empty map
checkHiveHash(
input = new ArrayBasedMapData(new GenericArrayData(Array()), new GenericArrayData(Array())),
dataType = mapType,
expected = 0)
// basic case
checkHiveHash(
input = new ArrayBasedMapData(
new GenericArrayData(Array(1, 2)),
new GenericArrayData(Array(UTF8String.fromString("foo"), UTF8String.fromString("bar")))),
dataType = mapType,
expected = 198872)
// with null value
checkHiveHash(
input = new ArrayBasedMapData(
new GenericArrayData(Array(55, -99)),
new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))),
dataType = mapType,
expected = 1142704473)
// nesting (only values can be nested as keys have to be primitive datatype)
val nestedMapType = MapType(IntegerType, MapType(IntegerType, StringType))
checkHiveHash(
input = new ArrayBasedMapData(
new GenericArrayData(Array(1, -100)),
new GenericArrayData(
Array(
new ArrayBasedMapData(
new GenericArrayData(Array(-99, 1234)),
new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
new ArrayBasedMapData(
new GenericArrayData(Array(67)),
new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
))),
dataType = nestedMapType,
expected = -1142817416)
}
test("hive-hash for struct") {
// basic
val row = new GenericInternalRow(Array[Any](1, 2, 3))
checkHiveHash(
input = row,
dataType =
new StructType()
.add("col1", IntegerType)
.add("col2", IntegerType)
.add("col3", IntegerType),
expected = 1026)
// mix of several datatypes
val structType = new StructType()
.add("null", NullType)
.add("boolean", BooleanType)
.add("byte", ByteType)
.add("short", ShortType)
.add("int", IntegerType)
.add("long", LongType)
.add("arrayOfString", arrayOfString)
.add("mapOfString", mapOfString)
val rowValues = new ArrayBuffer[Any]()
rowValues += null
rowValues += true
rowValues += 1
rowValues += 2
rowValues += Int.MaxValue
rowValues += Long.MinValue
rowValues += new GenericArrayData(Array(
UTF8String.fromString("apache spark"),
UTF8String.fromString("hello world")
))
rowValues += new ArrayBasedMapData(
new GenericArrayData(Array(UTF8String.fromString("project"), UTF8String.fromString("meta"))),
new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))
)
val row2 = new GenericInternalRow(rowValues.toArray)
checkHiveHash(
input = row2,
dataType = structType,
expected = -2119012447)
}
private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)
private val arrayOfNull = ArrayType(NullType)
private val mapOfString = MapType(StringType, StringType)
private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
testHash(
new StructType()
.add("null", NullType)
.add("boolean", BooleanType)
.add("byte", ByteType)
.add("short", ShortType)
.add("int", IntegerType)
.add("long", LongType)
.add("float", FloatType)
.add("double", DoubleType)
.add("bigDecimal", DecimalType.SYSTEM_DEFAULT)
.add("smallDecimal", DecimalType.USER_DEFAULT)
.add("string", StringType)
.add("binary", BinaryType)
.add("date", DateType)
.add("timestamp", TimestampType)
.add("udt", new ExamplePointUDT))
testHash(
new StructType()
.add("arrayOfNull", arrayOfNull)
.add("arrayOfString", arrayOfString)
.add("arrayOfArrayOfString", ArrayType(arrayOfString))
.add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))
.add("arrayOfStruct", ArrayType(structOfString))
.add("arrayOfUDT", arrayOfUDT))
testHash(
new StructType()
.add("structOfString", structOfString)
.add("structOfStructOfString", new StructType().add("struct", structOfString))
.add("structOfArray", new StructType().add("array", arrayOfString))
.add("structOfUDT", structOfUDT))
test("hive-hash for decimal") {
def checkHiveHashForDecimal(
input: String,
precision: Int,
scale: Int,
expected: Long): Unit = {
val decimalType = DataTypes.createDecimalType(precision, scale)
val decimal = {
val value = Decimal.apply(new java.math.BigDecimal(input))
if (value.changePrecision(precision, scale)) value else null
}
checkHiveHash(decimal, decimalType, expected)
}
checkHiveHashForDecimal("18", 38, 0, 558)
checkHiveHashForDecimal("-18", 38, 0, -558)
checkHiveHashForDecimal("-18", 38, 12, -558)
checkHiveHashForDecimal("18446744073709001000", 38, 19, 0)
checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0)
checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057)
checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057)
checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656)
checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656)
checkHiveHashForDecimal("00000.00000000000", 38, 34, 0)
checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0)
checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974)
checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252)
checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234)
checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136)
checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136)
checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136)
checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252)
checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582)
checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666)
checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666)
}
test("SPARK-18207: Compute hash for a lot of expressions") {
def checkResult(schema: StructType, input: InternalRow): Unit = {
val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
BoundReference(i, f.dataType, true)
}
val murmur3HashExpr = Murmur3Hash(exprs, 42)
val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr))
val murmursHashEval = Murmur3Hash(exprs, 42).eval(input)
assert(murmur3HashPlan(input).getInt(0) == murmursHashEval)
val xxHash64Expr = XxHash64(exprs, 42)
val xxHash64Plan = GenerateMutableProjection.generate(Seq(xxHash64Expr))
val xxHash64Eval = XxHash64(exprs, 42).eval(input)
assert(xxHash64Plan(input).getLong(0) == xxHash64Eval)
val hiveHashExpr = HiveHash(exprs)
val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr))
val hiveHashEval = HiveHash(exprs).eval(input)
assert(hiveHashPlan(input).getInt(0) == hiveHashEval)
}
val N = 1000
val wideRow = new GenericInternalRow(
Seq.tabulate(N)(i => UTF8String.fromString(i.toString)).toArray[Any])
val schema = StructType((1 to N).map(i => StructField(i.toString, StringType)))
checkResult(schema, wideRow)
val nestedRow = InternalRow(wideRow)
val nestedSchema = new StructType().add("nested", schema)
checkResult(nestedSchema, nestedRow)
}
test("SPARK-22284: Compute hash for nested structs") {
val M = 80
val N = 10
val L = M * N
val O = 50
val seed = 42
val wideRow = new GenericInternalRow(Seq.tabulate(O)(k =>
new GenericInternalRow(Seq.tabulate(M)(j =>
new GenericInternalRow(Seq.tabulate(N)(i =>
new GenericInternalRow(Array[Any](
UTF8String.fromString((k * L + j * N + i).toString))))
.toArray[Any])).toArray[Any])).toArray[Any])
val inner = new StructType(
(0 until N).map(_ => StructField("structOfString", structOfString)).toArray)
val outer = new StructType(
(0 until M).map(_ => StructField("structOfStructOfString", inner)).toArray)
val schema = new StructType(
(0 until O).map(_ => StructField("structOfStructOfStructOfString", outer)).toArray)
val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
BoundReference(i, f.dataType, true)
}
val murmur3HashExpr = Murmur3Hash(exprs, 42)
val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr))
val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow)
assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval)
}
test("SPARK-30633: xxHash with different type seeds") {
val literal = Literal.create(42L, LongType)
val longSeeds = Seq(
Long.MinValue,
Integer.MIN_VALUE.toLong - 1L,
0L,
Integer.MAX_VALUE.toLong + 1L,
Long.MaxValue
)
for (seed <- longSeeds) {
checkEvaluation(XxHash64(Seq(literal), seed), XxHash64(Seq(literal), seed).eval())
}
val intSeeds = Seq(
Integer.MIN_VALUE,
0,
Integer.MAX_VALUE
)
for (seed <- intSeeds) {
checkEvaluation(XxHash64(Seq(literal), seed), XxHash64(Seq(literal), seed).eval())
}
checkEvaluation(XxHash64(Seq(literal), 100), XxHash64(Seq(literal), 100L).eval())
checkEvaluation(XxHash64(Seq(literal), 100L), XxHash64(Seq(literal), 100).eval())
}
test("SPARK-35113: HashExpression support DayTimeIntervalType/YearMonthIntervalType") {
val dayTime = Literal.create(Duration.ofSeconds(1237123123), DayTimeIntervalType)
val yearMonth = Literal.create(Period.ofMonths(1234), YearMonthIntervalType)
checkEvaluation(Murmur3Hash(Seq(dayTime), 10), -428664612)
checkEvaluation(Murmur3Hash(Seq(yearMonth), 10), -686520021)
checkEvaluation(XxHash64(Seq(dayTime), 10), 8228802290839366895L)
checkEvaluation(XxHash64(Seq(yearMonth), 10), -1774215319882784110L)
checkEvaluation(HiveHash(Seq(dayTime)), 743331816)
checkEvaluation(HiveHash(Seq(yearMonth)), 1234)
}
private def testHash(inputSchema: StructType): Unit = {
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
val toRow = RowEncoder(inputSchema).createSerializer()
val seed = scala.util.Random.nextInt()
test(s"murmur3/xxHash64/hive hash: ${inputSchema.simpleString}") {
for (_ <- 1 to 10) {
val input = toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow]
val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map {
case (value, dt) => Literal.create(value, dt)
}
// Only test the interpreted version has same result with codegen version.
checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval())
checkEvaluation(XxHash64(literals, seed), XxHash64(literals, seed).eval())
checkEvaluation(HiveHash(literals), HiveHash(literals).eval())
}
}
val longSeed = Math.abs(seed).toLong + Integer.MAX_VALUE.toLong
test(s"SPARK-30633: xxHash64 with long seed: ${inputSchema.simpleString}") {
for (_ <- 1 to 10) {
val input = toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow]
val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map {
case (value, dt) => Literal.create(value, dt)
}
// Only test the interpreted version has same result with codegen version.
checkEvaluation(XxHash64(literals, longSeed), XxHash64(literals, longSeed).eval())
}
}
}
}