blob: f3a82743182a9ca859eb80eba8b467ae6e3b3f26 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.math.{BigDecimal, RoundingMode}
import java.security.{MessageDigest, NoSuchAlgorithmException}
import java.util.concurrent.TimeUnit._
import java.util.zip.CRC32
import scala.annotation.tailrec
import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines all the expressions for hashing.
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* A function that calculates an MD5 128-bit checksum and returns it as a hex string
* For input of type [[BinaryType]]
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns an MD5 128-bit checksum as a hex string of `expr`.",
examples = """
Examples:
> SELECT _FUNC_('Spark');
8cde774d6f7333752ed72cacddb05126
""",
since = "1.5.0",
group = "hash_funcs")
case class Md5(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType)
protected override def nullSafeEval(input: Any): Any =
UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]]))
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c =>
s"UTF8String.fromString(${classOf[DigestUtils].getName}.md5Hex($c))")
}
override protected def withNewChildInternal(newChild: Expression): Md5 = copy(child = newChild)
}
/**
* A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512)
* and returns it as a hex string. The first argument is the string or binary to be hashed. The
* second argument indicates the desired bit length of the result, which must have a value of 224,
* 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If
* asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or
* the hash length is not one of the permitted values, the return value is NULL.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of `expr`.
SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.
""",
examples = """
Examples:
> SELECT _FUNC_('Spark', 256);
529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b
""",
since = "1.5.0",
group = "hash_funcs")
// scalastyle:on line.size.limit
case class Sha2(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def dataType: DataType = StringType
override def nullable: Boolean = true
override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val bitLength = input2.asInstanceOf[Int]
val input = input1.asInstanceOf[Array[Byte]]
bitLength match {
case 224 =>
// DigestUtils doesn't support SHA-224 now
try {
val md = MessageDigest.getInstance("SHA-224")
md.update(input)
UTF8String.fromBytes(md.digest())
} catch {
// SHA-224 is not supported on the system, return null
case noa: NoSuchAlgorithmException => null
}
case 256 | 0 =>
UTF8String.fromString(DigestUtils.sha256Hex(input))
case 384 =>
UTF8String.fromString(DigestUtils.sha384Hex(input))
case 512 =>
UTF8String.fromString(DigestUtils.sha512Hex(input))
case _ => null
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val digestUtils = classOf[DigestUtils].getName
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
if ($eval2 == 224) {
try {
java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224");
md.update($eval1);
${ev.value} = UTF8String.fromBytes(md.digest());
} catch (java.security.NoSuchAlgorithmException e) {
${ev.isNull} = true;
}
} else if ($eval2 == 256 || $eval2 == 0) {
${ev.value} =
UTF8String.fromString($digestUtils.sha256Hex($eval1));
} else if ($eval2 == 384) {
${ev.value} =
UTF8String.fromString($digestUtils.sha384Hex($eval1));
} else if ($eval2 == 512) {
${ev.value} =
UTF8String.fromString($digestUtils.sha512Hex($eval1));
} else {
${ev.isNull} = true;
}
"""
})
}
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Sha2 =
copy(left = newLeft, right = newRight)
}
/**
* A function that calculates a sha1 hash value and returns it as a hex string
* For input of type [[BinaryType]] or [[StringType]]
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns a sha1 hash value as a hex string of the `expr`.",
examples = """
Examples:
> SELECT _FUNC_('Spark');
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c
""",
since = "1.5.0",
group = "hash_funcs")
case class Sha1(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType)
protected override def nullSafeEval(input: Any): Any =
UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]]))
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c =>
s"UTF8String.fromString(${classOf[DigestUtils].getName}.sha1Hex($c))"
)
}
override protected def withNewChildInternal(newChild: Expression): Sha1 = copy(child = newChild)
}
/**
* A function that computes a cyclic redundancy check value and returns it as a bigint
* For input of type [[BinaryType]]
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns a cyclic redundancy check value of the `expr` as a bigint.",
examples = """
Examples:
> SELECT _FUNC_('Spark');
1557323817
""",
since = "1.5.0",
group = "hash_funcs")
case class Crc32(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = LongType
override def inputTypes: Seq[DataType] = Seq(BinaryType)
protected override def nullSafeEval(input: Any): Any = {
val checksum = new CRC32
checksum.update(input.asInstanceOf[Array[Byte]], 0, input.asInstanceOf[Array[Byte]].length)
checksum.getValue
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val CRC32 = "java.util.zip.CRC32"
val checksum = ctx.freshName("checksum")
nullSafeCodeGen(ctx, ev, value => {
s"""
$CRC32 $checksum = new $CRC32();
$checksum.update($value, 0, $value.length);
${ev.value} = $checksum.getValue();
"""
})
}
override protected def withNewChildInternal(newChild: Expression): Crc32 = copy(child = newChild)
}
/**
* A function that calculates hash value for a group of expressions. Note that the `seed` argument
* is not exposed to users and should only be set inside spark SQL.
*
* The hash value for an expression depends on its type and seed:
* - null: seed
* - boolean: turn boolean into int, 1 for true, 0 for false,
* and then use murmur3 to hash this int with seed.
* - byte, short, int: use murmur3 to hash the input as int with seed.
* - long: use murmur3 to hash the long input with seed.
* - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it.
* - double: turn it into long: java.lang.Double.doubleToLongBits(input),
* and hash it.
* - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long
* and hash it. Else, turn it into bytes and hash it.
* - calendar interval: hash `microseconds` first, and use the result as seed
* to hash `months`.
* - interval day to second: it store long value of `microseconds`, use murmur3 to hash the long
* input with seed.
* - interval year to month: it store int value of `months`, use murmur3 to hash the int
* input with seed.
* - binary: use murmur3 to hash the bytes with seed.
* - string: get the bytes of string and hash it.
* - array: The `result` starts with seed, then use `result` as seed, recursively
* calculate hash value for each element, and assign the element hash
* value to `result`.
* - struct: The `result` starts with seed, then use `result` as seed, recursively
* calculate hash value for each field, and assign the field hash value
* to `result`.
*
* Finally we aggregate the hash values for each expression by the same way of struct.
*/
abstract class HashExpression[E] extends Expression {
/** Seed of the HashExpression. */
val seed: E
override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = false
private def hasMapType(dt: DataType): Boolean = {
dt.existsRecursively(_.isInstanceOf[MapType])
}
override def checkInputDataTypes(): TypeCheckResult = {
if (children.length < 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least one argument")
} else if (children.exists(child => hasMapType(child.dataType)) &&
!SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE)) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName cannot contain elements of MapType. In Spark, same maps " +
"may have different hashcode, thus hash expressions are prohibited on MapType elements." +
s" To restore previous behavior set ${SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE.key} " +
"to true.")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
override def eval(input: InternalRow = null): Any = {
var hash = seed
var i = 0
val len = children.length
while (i < len) {
hash = computeHash(children(i).eval(input), children(i).dataType, hash)
i += 1
}
hash
}
protected def computeHash(value: Any, dataType: DataType, seed: E): E
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.isNull = FalseLiteral
val childrenHash = children.map { child =>
val childGen = child.genCode(ctx)
childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, ev.value, ctx)
}
}
val hashResultType = CodeGenerator.javaType(dataType)
val typedSeed = if (dataType.sameType(LongType)) s"${seed}L" else s"$seed"
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
extraArguments = Seq(hashResultType -> ev.value),
returnType = hashResultType,
makeSplitFunction = body =>
s"""
|$body
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
code"""
|$hashResultType ${ev.value} = $typedSeed;
|$codes
""".stripMargin)
}
protected def nullSafeElementHash(
input: String,
index: String,
nullable: Boolean,
elementType: DataType,
result: String,
ctx: CodegenContext): String = {
val element = ctx.freshName("element")
val jt = CodeGenerator.javaType(elementType)
ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
s"""
final $jt $element = ${CodeGenerator.getValue(input, elementType, index)};
${computeHash(element, elementType, result, ctx)}
"""
}
}
protected def genHashInt(i: String, result: String): String =
s"$result = $hasherClassName.hashInt($i, $result);"
protected def genHashLong(l: String, result: String): String =
s"$result = $hasherClassName.hashLong($l, $result);"
protected def genHashBytes(b: String, result: String): String = {
val offset = "Platform.BYTE_ARRAY_OFFSET"
s"$result = $hasherClassName.hashUnsafeBytes($b, $offset, $b.length, $result);"
}
protected def genHashBoolean(input: String, result: String): String =
genHashInt(s"$input ? 1 : 0", result)
protected def genHashFloat(input: String, result: String): String =
genHashInt(s"Float.floatToIntBits($input)", result)
protected def genHashDouble(input: String, result: String): String =
genHashLong(s"Double.doubleToLongBits($input)", result)
protected def genHashDecimal(
ctx: CodegenContext,
d: DecimalType,
input: String,
result: String): String = {
if (d.precision <= Decimal.MAX_LONG_DIGITS) {
genHashLong(s"$input.toUnscaledLong()", result)
} else {
val bytes = ctx.freshName("bytes")
s"""
|final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
|${genHashBytes(bytes, result)}
""".stripMargin
}
}
protected def genHashTimestamp(t: String, result: String): String = genHashLong(t, result)
protected def genHashCalendarInterval(input: String, result: String): String = {
val microsecondsHash = s"$hasherClassName.hashLong($input.microseconds, $result)"
s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);"
}
protected def genHashString(input: String, result: String): String = {
val baseObject = s"$input.getBaseObject()"
val baseOffset = s"$input.getBaseOffset()"
val numBytes = s"$input.numBytes()"
s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);"
}
protected def genHashForMap(
ctx: CodegenContext,
input: String,
result: String,
keyType: DataType,
valueType: DataType,
valueContainsNull: Boolean): String = {
val index = ctx.freshName("index")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
s"""
final ArrayData $keys = $input.keyArray();
final ArrayData $values = $input.valueArray();
for (int $index = 0; $index < $input.numElements(); $index++) {
${nullSafeElementHash(keys, index, false, keyType, result, ctx)}
${nullSafeElementHash(values, index, valueContainsNull, valueType, result, ctx)}
}
"""
}
protected def genHashForArray(
ctx: CodegenContext,
input: String,
result: String,
elementType: DataType,
containsNull: Boolean): String = {
val index = ctx.freshName("index")
s"""
for (int $index = 0; $index < $input.numElements(); $index++) {
${nullSafeElementHash(input, index, containsNull, elementType, result, ctx)}
}
"""
}
protected def genHashForStruct(
ctx: CodegenContext,
input: String,
result: String,
fields: Array[StructField]): String = {
val tmpInput = ctx.freshName("input")
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx)
}
val hashResultType = CodeGenerator.javaType(dataType)
val code = ctx.splitExpressions(
expressions = fieldsHash,
funcName = "computeHashForStruct",
arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result),
returnType = hashResultType,
makeSplitFunction = body =>
s"""
|$body
|return $result;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
s"""
|final InternalRow $tmpInput = $input;
|$code
""".stripMargin
}
@tailrec
private def computeHashWithTailRec(
input: String,
dataType: DataType,
result: String,
ctx: CodegenContext): String = dataType match {
case NullType => ""
case BooleanType => genHashBoolean(input, result)
case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result)
case LongType => genHashLong(input, result)
case TimestampType => genHashTimestamp(input, result)
case FloatType => genHashFloat(input, result)
case DoubleType => genHashDouble(input, result)
case d: DecimalType => genHashDecimal(ctx, d, input, result)
case CalendarIntervalType => genHashCalendarInterval(input, result)
case DayTimeIntervalType => genHashLong(input, result)
case YearMonthIntervalType => genHashInt(input, result)
case BinaryType => genHashBytes(input, result)
case StringType => genHashString(input, result)
case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull)
case MapType(kt, vt, valueContainsNull) =>
genHashForMap(ctx, input, result, kt, vt, valueContainsNull)
case StructType(fields) => genHashForStruct(ctx, input, result, fields)
case udt: UserDefinedType[_] => computeHashWithTailRec(input, udt.sqlType, result, ctx)
}
protected def computeHash(
input: String,
dataType: DataType,
result: String,
ctx: CodegenContext): String = computeHashWithTailRec(input, dataType, result, ctx)
protected def hasherClassName: String
}
/**
* Base class for interpreted hash functions.
*/
abstract class InterpretedHashFunction {
protected def hashInt(i: Int, seed: Long): Long
protected def hashLong(l: Long, seed: Long): Long
protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long
/**
* Computes hash of a given `value` of type `dataType`. The caller needs to check the validity
* of input `value`.
*/
def hash(value: Any, dataType: DataType, seed: Long): Long = {
value match {
case null => seed
case b: Boolean => hashInt(if (b) 1 else 0, seed)
case b: Byte => hashInt(b, seed)
case s: Short => hashInt(s, seed)
case i: Int => hashInt(i, seed)
case l: Long => hashLong(l, seed)
case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed)
case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed)
case d: Decimal =>
val precision = dataType.asInstanceOf[DecimalType].precision
if (precision <= Decimal.MAX_LONG_DIGITS) {
hashLong(d.toUnscaledLong, seed)
} else {
val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray
hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed)
}
case c: CalendarInterval => hashInt(c.months, hashInt(c.days, hashLong(c.microseconds, seed)))
case a: Array[Byte] =>
hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed)
case s: UTF8String =>
hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed)
case array: ArrayData =>
val elementType = dataType match {
case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType
case ArrayType(et, _) => et
}
var result = seed
var i = 0
while (i < array.numElements()) {
result = hash(array.get(i, elementType), elementType, result)
i += 1
}
result
case map: MapData =>
val (kt, vt) = dataType match {
case udt: UserDefinedType[_] =>
val mapType = udt.sqlType.asInstanceOf[MapType]
mapType.keyType -> mapType.valueType
case MapType(kt, vt, _) => kt -> vt
}
val keys = map.keyArray()
val values = map.valueArray()
var result = seed
var i = 0
while (i < map.numElements()) {
result = hash(keys.get(i, kt), kt, result)
result = hash(values.get(i, vt), vt, result)
i += 1
}
result
case struct: InternalRow =>
val types: Array[DataType] = dataType match {
case udt: UserDefinedType[_] =>
udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray
case StructType(fields) => fields.map(_.dataType)
}
var result = seed
var i = 0
val len = struct.numFields
while (i < len) {
result = hash(struct.get(i, types(i)), types(i), result)
i += 1
}
result
}
}
}
/**
* A MurMur3 Hash expression.
*
* We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle
* and bucketing have same data distribution.
*/
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.",
examples = """
Examples:
> SELECT _FUNC_('Spark', array(123), 2);
-1321691492
""",
since = "2.0.0",
group = "hash_funcs")
case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] {
def this(arguments: Seq[Expression]) = this(arguments, 42)
override def dataType: DataType = IntegerType
override def prettyName: String = "hash"
override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName
override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
Murmur3HashFunction.hash(value, dataType, seed).toInt
}
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Murmur3Hash =
copy(children = newChildren)
}
object Murmur3HashFunction extends InterpretedHashFunction {
override protected def hashInt(i: Int, seed: Long): Long = {
Murmur3_x86_32.hashInt(i, seed.toInt)
}
override protected def hashLong(l: Long, seed: Long): Long = {
Murmur3_x86_32.hashLong(l, seed.toInt)
}
override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt)
}
}
/**
* A xxHash64 64-bit hash expression.
*/
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2, ...) - Returns a 64-bit hash value of the arguments.",
examples = """
Examples:
> SELECT _FUNC_('Spark', array(123), 2);
5602566077635097486
""",
since = "3.0.0",
group = "hash_funcs")
case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpression[Long] {
def this(arguments: Seq[Expression]) = this(arguments, 42L)
override def dataType: DataType = LongType
override def prettyName: String = "xxhash64"
override protected def hasherClassName: String = classOf[XXH64].getName
override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = {
XxHash64Function.hash(value, dataType, seed)
}
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): XxHash64 =
copy(children = newChildren)
}
object XxHash64Function extends InterpretedHashFunction {
override protected def hashInt(i: Int, seed: Long): Long = XXH64.hashInt(i, seed)
override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed)
override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
XXH64.hashUnsafeBytes(base, offset, len, seed)
}
}
/**
* Simulates Hive's hashing function from Hive v1.2.1 at
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
*
* We should use this hash function for both shuffle and bucket of Hive tables, so that
* we can guarantee shuffle and bucketing have same data distribution
*/
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.",
since = "2.2.0",
group = "hash_funcs")
case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
override val seed = 0
override def dataType: DataType = IntegerType
override def prettyName: String = "hive-hash"
override protected def hasherClassName: String = classOf[HiveHasher].getName
override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
HiveHashFunction.hash(value, dataType, this.seed).toInt
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.isNull = FalseLiteral
val childHash = ctx.freshName("childHash")
val childrenHash = children.map { child =>
val childGen = child.genCode(ctx)
val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, childHash, ctx)
}
s"""
|${childGen.code}
|$childHash = 0;
|$codeToComputeHash
|${ev.value} = (31 * ${ev.value}) + $childHash;
""".stripMargin
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value),
returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
|${CodeGenerator.JAVA_INT} $childHash = 0;
|$body
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
code"""
|${CodeGenerator.JAVA_INT} ${ev.value} = $seed;
|${CodeGenerator.JAVA_INT} $childHash = 0;
|$codes
""".stripMargin)
}
override def eval(input: InternalRow = null): Int = {
var hash = seed
var i = 0
val len = children.length
while (i < len) {
hash = (31 * hash) + computeHash(children(i).eval(input), children(i).dataType, hash)
i += 1
}
hash
}
override protected def genHashInt(i: String, result: String): String =
s"$result = $hasherClassName.hashInt($i);"
override protected def genHashLong(l: String, result: String): String =
s"$result = $hasherClassName.hashLong($l);"
override protected def genHashBytes(b: String, result: String): String =
s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);"
override protected def genHashDecimal(
ctx: CodegenContext,
d: DecimalType,
input: String,
result: String): String = {
s"""
$result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal(
$input.toJavaBigDecimal()).hashCode();"""
}
override protected def genHashCalendarInterval(input: String, result: String): String = {
s"""
$result = (int)
${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashCalendarInterval($input);
"""
}
override protected def genHashTimestamp(input: String, result: String): String =
s"""
$result = (int) ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashTimestamp($input);
"""
override protected def genHashString(input: String, result: String): String = {
val baseObject = s"$input.getBaseObject()"
val baseOffset = s"$input.getBaseOffset()"
val numBytes = s"$input.numBytes()"
s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);"
}
override protected def genHashForArray(
ctx: CodegenContext,
input: String,
result: String,
elementType: DataType,
containsNull: Boolean): String = {
val index = ctx.freshName("index")
val childResult = ctx.freshName("childResult")
s"""
int $childResult = 0;
for (int $index = 0; $index < $input.numElements(); $index++) {
$childResult = 0;
${nullSafeElementHash(input, index, containsNull, elementType, childResult, ctx)};
$result = (31 * $result) + $childResult;
}
"""
}
override protected def genHashForMap(
ctx: CodegenContext,
input: String,
result: String,
keyType: DataType,
valueType: DataType,
valueContainsNull: Boolean): String = {
val index = ctx.freshName("index")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val keyResult = ctx.freshName("keyResult")
val valueResult = ctx.freshName("valueResult")
s"""
final ArrayData $keys = $input.keyArray();
final ArrayData $values = $input.valueArray();
int $keyResult = 0;
int $valueResult = 0;
for (int $index = 0; $index < $input.numElements(); $index++) {
$keyResult = 0;
${nullSafeElementHash(keys, index, false, keyType, keyResult, ctx)}
$valueResult = 0;
${nullSafeElementHash(values, index, valueContainsNull, valueType, valueResult, ctx)}
$result += $keyResult ^ $valueResult;
}
"""
}
override protected def genHashForStruct(
ctx: CodegenContext,
input: String,
result: String,
fields: Array[StructField]): String = {
val tmpInput = ctx.freshName("input")
val childResult = ctx.freshName("childResult")
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
val computeFieldHash = nullSafeElementHash(
tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx)
s"""
|$childResult = 0;
|$computeFieldHash
|$result = (31 * $result) + $childResult;
""".stripMargin
}
val code = ctx.splitExpressions(
expressions = fieldsHash,
funcName = "computeHashForStruct",
arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result),
returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
|${CodeGenerator.JAVA_INT} $childResult = 0;
|$body
|return $result;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
s"""
|final InternalRow $tmpInput = $input;
|${CodeGenerator.JAVA_INT} $childResult = 0;
|$code
""".stripMargin
}
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): HiveHash =
copy(children = newChildren)
}
object HiveHashFunction extends InterpretedHashFunction {
override protected def hashInt(i: Int, seed: Long): Long = {
HiveHasher.hashInt(i)
}
override protected def hashLong(l: Long, seed: Long): Long = {
HiveHasher.hashLong(l)
}
override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = {
HiveHasher.hashUnsafeBytes(base, offset, len)
}
private val HIVE_DECIMAL_MAX_PRECISION = 38
private val HIVE_DECIMAL_MAX_SCALE = 38
// Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize()
def normalizeDecimal(input: BigDecimal): BigDecimal = {
if (input == null) return null
def trimDecimal(input: BigDecimal) = {
var result = input
if (result.compareTo(BigDecimal.ZERO) == 0) {
// Special case for 0, because java doesn't strip zeros correctly on that number.
result = BigDecimal.ZERO
} else {
result = result.stripTrailingZeros
if (result.scale < 0) {
// no negative scale decimals
result = result.setScale(0)
}
}
result
}
var result = trimDecimal(input)
val intDigits = result.precision - result.scale
if (intDigits > HIVE_DECIMAL_MAX_PRECISION) {
return null
}
val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE,
Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale))
if (result.scale > maxScale) {
result = result.setScale(maxScale, RoundingMode.HALF_UP)
// Trimming is again necessary, because rounding may introduce new trailing 0's.
result = trimDecimal(result)
}
result
}
/**
* Mimics TimestampWritable.hashCode() in Hive
*/
def hashTimestamp(timestamp: Long): Long = {
val timestampInSeconds = MICROSECONDS.toSeconds(timestamp)
val nanoSecondsPortion = (timestamp % MICROS_PER_SECOND) * NANOS_PER_MICROS
var result = timestampInSeconds
result <<= 30 // the nanosecond part fits in 30 bits
result |= nanoSecondsPortion
((result >>> 32) ^ result).toInt
}
/**
* Hive allows input intervals to be defined using units below but the intervals
* have to be from the same category:
* - year, month (stored as HiveIntervalYearMonth)
* - day, hour, minute, second, nanosecond (stored as HiveIntervalDayTime)
*
* e.g. (INTERVAL '30' YEAR + INTERVAL '-23' DAY) fails in Hive
*
* This method mimics HiveIntervalDayTime.hashCode() in Hive.
*
* Two differences wrt Hive due to how intervals are stored in Spark vs Hive:
*
* - If the `INTERVAL` is backed as HiveIntervalYearMonth in Hive, then this method will not
* produce Hive compatible result. The reason being Spark's representation of calendar does not
* have such categories based on the interval and is unified.
*
* - Spark's [[CalendarInterval]] has precision upto microseconds but Hive's
* HiveIntervalDayTime can store data with precision upto nanoseconds. So, any input intervals
* with nanosecond values will lead to wrong output hashes (i.e. non adherent with Hive output)
*/
def hashCalendarInterval(calendarInterval: CalendarInterval): Long = {
val totalMicroSeconds = calendarInterval.days * MICROS_PER_DAY + calendarInterval.microseconds
val totalSeconds = totalMicroSeconds / MICROS_PER_SECOND.toInt
val result: Int = (17 * 37) + (totalSeconds ^ totalSeconds >> 32).toInt
val nanoSeconds = (totalMicroSeconds - (totalSeconds * MICROS_PER_SECOND.toInt)).toInt * 1000
(result * 37) + nanoSeconds
}
override def hash(value: Any, dataType: DataType, seed: Long): Long = {
value match {
case null => 0
case array: ArrayData =>
val elementType = dataType match {
case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType
case ArrayType(et, _) => et
}
var result = 0
var i = 0
val length = array.numElements()
while (i < length) {
result = (31 * result) + hash(array.get(i, elementType), elementType, 0).toInt
i += 1
}
result
case map: MapData =>
val (kt, vt) = dataType match {
case udt: UserDefinedType[_] =>
val mapType = udt.sqlType.asInstanceOf[MapType]
mapType.keyType -> mapType.valueType
case MapType(_kt, _vt, _) => _kt -> _vt
}
val keys = map.keyArray()
val values = map.valueArray()
var result = 0
var i = 0
val length = map.numElements()
while (i < length) {
result += hash(keys.get(i, kt), kt, 0).toInt ^ hash(values.get(i, vt), vt, 0).toInt
i += 1
}
result
case struct: InternalRow =>
val types: Array[DataType] = dataType match {
case udt: UserDefinedType[_] =>
udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray
case StructType(fields) => fields.map(_.dataType)
}
var result = 0
var i = 0
val length = struct.numFields
while (i < length) {
result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt
i += 1
}
result
case d: Decimal => normalizeDecimal(d.toJavaBigDecimal).hashCode()
case timestamp: Long if dataType.isInstanceOf[TimestampType] => hashTimestamp(timestamp)
case calendarInterval: CalendarInterval => hashCalendarInterval(calendarInterval)
case _ => super.hash(value, dataType, 0)
}
}
}