| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.sql.catalyst.expressions |
| |
| import java.math.{BigDecimal, RoundingMode} |
| import java.security.{MessageDigest, NoSuchAlgorithmException} |
| import java.util.concurrent.TimeUnit._ |
| import java.util.zip.CRC32 |
| |
| import scala.annotation.tailrec |
| |
| import org.apache.commons.codec.digest.DigestUtils |
| |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult |
| import org.apache.spark.sql.catalyst.expressions.codegen._ |
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ |
| import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} |
| import org.apache.spark.sql.catalyst.util.DateTimeConstants._ |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.types._ |
| import org.apache.spark.unsafe.Platform |
| import org.apache.spark.unsafe.hash.Murmur3_x86_32 |
| import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} |
| |
| //////////////////////////////////////////////////////////////////////////////////////////////////// |
| // This file defines all the expressions for hashing. |
| //////////////////////////////////////////////////////////////////////////////////////////////////// |
| |
| /** |
| * A function that calculates an MD5 128-bit checksum and returns it as a hex string |
| * For input of type [[BinaryType]] |
| */ |
| @ExpressionDescription( |
| usage = "_FUNC_(expr) - Returns an MD5 128-bit checksum as a hex string of `expr`.", |
| examples = """ |
| Examples: |
| > SELECT _FUNC_('Spark'); |
| 8cde774d6f7333752ed72cacddb05126 |
| """, |
| since = "1.5.0", |
| group = "hash_funcs") |
| case class Md5(child: Expression) |
| extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { |
| |
| override def dataType: DataType = StringType |
| |
| override def inputTypes: Seq[DataType] = Seq(BinaryType) |
| |
| protected override def nullSafeEval(input: Any): Any = |
| UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) |
| |
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { |
| defineCodeGen(ctx, ev, c => |
| s"UTF8String.fromString(${classOf[DigestUtils].getName}.md5Hex($c))") |
| } |
| |
| override protected def withNewChildInternal(newChild: Expression): Md5 = copy(child = newChild) |
| } |
| |
| /** |
| * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512) |
| * and returns it as a hex string. The first argument is the string or binary to be hashed. The |
| * second argument indicates the desired bit length of the result, which must have a value of 224, |
| * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If |
| * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or |
| * the hash length is not one of the permitted values, the return value is NULL. |
| */ |
| // scalastyle:off line.size.limit |
| @ExpressionDescription( |
| usage = """ |
| _FUNC_(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of `expr`. |
| SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256. |
| """, |
| examples = """ |
| Examples: |
| > SELECT _FUNC_('Spark', 256); |
| 529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b |
| """, |
| since = "1.5.0", |
| group = "hash_funcs") |
| // scalastyle:on line.size.limit |
| case class Sha2(left: Expression, right: Expression) |
| extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { |
| |
| override def dataType: DataType = StringType |
| override def nullable: Boolean = true |
| |
| override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) |
| |
| protected override def nullSafeEval(input1: Any, input2: Any): Any = { |
| val bitLength = input2.asInstanceOf[Int] |
| val input = input1.asInstanceOf[Array[Byte]] |
| bitLength match { |
| case 224 => |
| // DigestUtils doesn't support SHA-224 now |
| try { |
| val md = MessageDigest.getInstance("SHA-224") |
| md.update(input) |
| UTF8String.fromBytes(md.digest()) |
| } catch { |
| // SHA-224 is not supported on the system, return null |
| case noa: NoSuchAlgorithmException => null |
| } |
| case 256 | 0 => |
| UTF8String.fromString(DigestUtils.sha256Hex(input)) |
| case 384 => |
| UTF8String.fromString(DigestUtils.sha384Hex(input)) |
| case 512 => |
| UTF8String.fromString(DigestUtils.sha512Hex(input)) |
| case _ => null |
| } |
| } |
| |
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { |
| val digestUtils = classOf[DigestUtils].getName |
| nullSafeCodeGen(ctx, ev, (eval1, eval2) => { |
| s""" |
| if ($eval2 == 224) { |
| try { |
| java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); |
| md.update($eval1); |
| ${ev.value} = UTF8String.fromBytes(md.digest()); |
| } catch (java.security.NoSuchAlgorithmException e) { |
| ${ev.isNull} = true; |
| } |
| } else if ($eval2 == 256 || $eval2 == 0) { |
| ${ev.value} = |
| UTF8String.fromString($digestUtils.sha256Hex($eval1)); |
| } else if ($eval2 == 384) { |
| ${ev.value} = |
| UTF8String.fromString($digestUtils.sha384Hex($eval1)); |
| } else if ($eval2 == 512) { |
| ${ev.value} = |
| UTF8String.fromString($digestUtils.sha512Hex($eval1)); |
| } else { |
| ${ev.isNull} = true; |
| } |
| """ |
| }) |
| } |
| |
| override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Sha2 = |
| copy(left = newLeft, right = newRight) |
| } |
| |
| /** |
| * A function that calculates a sha1 hash value and returns it as a hex string |
| * For input of type [[BinaryType]] or [[StringType]] |
| */ |
| @ExpressionDescription( |
| usage = "_FUNC_(expr) - Returns a sha1 hash value as a hex string of the `expr`.", |
| examples = """ |
| Examples: |
| > SELECT _FUNC_('Spark'); |
| 85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c |
| """, |
| since = "1.5.0", |
| group = "hash_funcs") |
| case class Sha1(child: Expression) |
| extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { |
| |
| override def dataType: DataType = StringType |
| |
| override def inputTypes: Seq[DataType] = Seq(BinaryType) |
| |
| protected override def nullSafeEval(input: Any): Any = |
| UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]])) |
| |
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { |
| defineCodeGen(ctx, ev, c => |
| s"UTF8String.fromString(${classOf[DigestUtils].getName}.sha1Hex($c))" |
| ) |
| } |
| |
| override protected def withNewChildInternal(newChild: Expression): Sha1 = copy(child = newChild) |
| } |
| |
| /** |
| * A function that computes a cyclic redundancy check value and returns it as a bigint |
| * For input of type [[BinaryType]] |
| */ |
| @ExpressionDescription( |
| usage = "_FUNC_(expr) - Returns a cyclic redundancy check value of the `expr` as a bigint.", |
| examples = """ |
| Examples: |
| > SELECT _FUNC_('Spark'); |
| 1557323817 |
| """, |
| since = "1.5.0", |
| group = "hash_funcs") |
| case class Crc32(child: Expression) |
| extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { |
| |
| override def dataType: DataType = LongType |
| |
| override def inputTypes: Seq[DataType] = Seq(BinaryType) |
| |
| protected override def nullSafeEval(input: Any): Any = { |
| val checksum = new CRC32 |
| checksum.update(input.asInstanceOf[Array[Byte]], 0, input.asInstanceOf[Array[Byte]].length) |
| checksum.getValue |
| } |
| |
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { |
| val CRC32 = "java.util.zip.CRC32" |
| val checksum = ctx.freshName("checksum") |
| nullSafeCodeGen(ctx, ev, value => { |
| s""" |
| $CRC32 $checksum = new $CRC32(); |
| $checksum.update($value, 0, $value.length); |
| ${ev.value} = $checksum.getValue(); |
| """ |
| }) |
| } |
| |
| override protected def withNewChildInternal(newChild: Expression): Crc32 = copy(child = newChild) |
| } |
| |
| |
| /** |
| * A function that calculates hash value for a group of expressions. Note that the `seed` argument |
| * is not exposed to users and should only be set inside spark SQL. |
| * |
| * The hash value for an expression depends on its type and seed: |
| * - null: seed |
| * - boolean: turn boolean into int, 1 for true, 0 for false, |
| * and then use murmur3 to hash this int with seed. |
| * - byte, short, int: use murmur3 to hash the input as int with seed. |
| * - long: use murmur3 to hash the long input with seed. |
| * - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it. |
| * - double: turn it into long: java.lang.Double.doubleToLongBits(input), |
| * and hash it. |
| * - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long |
| * and hash it. Else, turn it into bytes and hash it. |
| * - calendar interval: hash `microseconds` first, and use the result as seed |
| * to hash `months`. |
| * - interval day to second: it store long value of `microseconds`, use murmur3 to hash the long |
| * input with seed. |
| * - interval year to month: it store int value of `months`, use murmur3 to hash the int |
| * input with seed. |
| * - binary: use murmur3 to hash the bytes with seed. |
| * - string: get the bytes of string and hash it. |
| * - array: The `result` starts with seed, then use `result` as seed, recursively |
| * calculate hash value for each element, and assign the element hash |
| * value to `result`. |
| * - struct: The `result` starts with seed, then use `result` as seed, recursively |
| * calculate hash value for each field, and assign the field hash value |
| * to `result`. |
| * |
| * Finally we aggregate the hash values for each expression by the same way of struct. |
| */ |
| abstract class HashExpression[E] extends Expression { |
| /** Seed of the HashExpression. */ |
| val seed: E |
| |
| override def foldable: Boolean = children.forall(_.foldable) |
| |
| override def nullable: Boolean = false |
| |
| private def hasMapType(dt: DataType): Boolean = { |
| dt.existsRecursively(_.isInstanceOf[MapType]) |
| } |
| |
| override def checkInputDataTypes(): TypeCheckResult = { |
| if (children.length < 1) { |
| TypeCheckResult.TypeCheckFailure( |
| s"input to function $prettyName requires at least one argument") |
| } else if (children.exists(child => hasMapType(child.dataType)) && |
| !SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE)) { |
| TypeCheckResult.TypeCheckFailure( |
| s"input to function $prettyName cannot contain elements of MapType. In Spark, same maps " + |
| "may have different hashcode, thus hash expressions are prohibited on MapType elements." + |
| s" To restore previous behavior set ${SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE.key} " + |
| "to true.") |
| } else { |
| TypeCheckResult.TypeCheckSuccess |
| } |
| } |
| |
| override def eval(input: InternalRow = null): Any = { |
| var hash = seed |
| var i = 0 |
| val len = children.length |
| while (i < len) { |
| hash = computeHash(children(i).eval(input), children(i).dataType, hash) |
| i += 1 |
| } |
| hash |
| } |
| |
| protected def computeHash(value: Any, dataType: DataType, seed: E): E |
| |
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { |
| ev.isNull = FalseLiteral |
| |
| val childrenHash = children.map { child => |
| val childGen = child.genCode(ctx) |
| childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { |
| computeHash(childGen.value, child.dataType, ev.value, ctx) |
| } |
| } |
| |
| val hashResultType = CodeGenerator.javaType(dataType) |
| val typedSeed = if (dataType.sameType(LongType)) s"${seed}L" else s"$seed" |
| val codes = ctx.splitExpressionsWithCurrentInputs( |
| expressions = childrenHash, |
| funcName = "computeHash", |
| extraArguments = Seq(hashResultType -> ev.value), |
| returnType = hashResultType, |
| makeSplitFunction = body => |
| s""" |
| |$body |
| |return ${ev.value}; |
| """.stripMargin, |
| foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) |
| |
| ev.copy(code = |
| code""" |
| |$hashResultType ${ev.value} = $typedSeed; |
| |$codes |
| """.stripMargin) |
| } |
| |
| protected def nullSafeElementHash( |
| input: String, |
| index: String, |
| nullable: Boolean, |
| elementType: DataType, |
| result: String, |
| ctx: CodegenContext): String = { |
| val element = ctx.freshName("element") |
| |
| val jt = CodeGenerator.javaType(elementType) |
| ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { |
| s""" |
| final $jt $element = ${CodeGenerator.getValue(input, elementType, index)}; |
| ${computeHash(element, elementType, result, ctx)} |
| """ |
| } |
| } |
| |
| protected def genHashInt(i: String, result: String): String = |
| s"$result = $hasherClassName.hashInt($i, $result);" |
| |
| protected def genHashLong(l: String, result: String): String = |
| s"$result = $hasherClassName.hashLong($l, $result);" |
| |
| protected def genHashBytes(b: String, result: String): String = { |
| val offset = "Platform.BYTE_ARRAY_OFFSET" |
| s"$result = $hasherClassName.hashUnsafeBytes($b, $offset, $b.length, $result);" |
| } |
| |
| protected def genHashBoolean(input: String, result: String): String = |
| genHashInt(s"$input ? 1 : 0", result) |
| |
| protected def genHashFloat(input: String, result: String): String = |
| genHashInt(s"Float.floatToIntBits($input)", result) |
| |
| protected def genHashDouble(input: String, result: String): String = |
| genHashLong(s"Double.doubleToLongBits($input)", result) |
| |
| protected def genHashDecimal( |
| ctx: CodegenContext, |
| d: DecimalType, |
| input: String, |
| result: String): String = { |
| if (d.precision <= Decimal.MAX_LONG_DIGITS) { |
| genHashLong(s"$input.toUnscaledLong()", result) |
| } else { |
| val bytes = ctx.freshName("bytes") |
| s""" |
| |final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); |
| |${genHashBytes(bytes, result)} |
| """.stripMargin |
| } |
| } |
| |
| protected def genHashTimestamp(t: String, result: String): String = genHashLong(t, result) |
| |
| protected def genHashCalendarInterval(input: String, result: String): String = { |
| val microsecondsHash = s"$hasherClassName.hashLong($input.microseconds, $result)" |
| s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);" |
| } |
| |
| protected def genHashString(input: String, result: String): String = { |
| val baseObject = s"$input.getBaseObject()" |
| val baseOffset = s"$input.getBaseOffset()" |
| val numBytes = s"$input.numBytes()" |
| s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" |
| } |
| |
| protected def genHashForMap( |
| ctx: CodegenContext, |
| input: String, |
| result: String, |
| keyType: DataType, |
| valueType: DataType, |
| valueContainsNull: Boolean): String = { |
| val index = ctx.freshName("index") |
| val keys = ctx.freshName("keys") |
| val values = ctx.freshName("values") |
| s""" |
| final ArrayData $keys = $input.keyArray(); |
| final ArrayData $values = $input.valueArray(); |
| for (int $index = 0; $index < $input.numElements(); $index++) { |
| ${nullSafeElementHash(keys, index, false, keyType, result, ctx)} |
| ${nullSafeElementHash(values, index, valueContainsNull, valueType, result, ctx)} |
| } |
| """ |
| } |
| |
| protected def genHashForArray( |
| ctx: CodegenContext, |
| input: String, |
| result: String, |
| elementType: DataType, |
| containsNull: Boolean): String = { |
| val index = ctx.freshName("index") |
| s""" |
| for (int $index = 0; $index < $input.numElements(); $index++) { |
| ${nullSafeElementHash(input, index, containsNull, elementType, result, ctx)} |
| } |
| """ |
| } |
| |
| protected def genHashForStruct( |
| ctx: CodegenContext, |
| input: String, |
| result: String, |
| fields: Array[StructField]): String = { |
| val tmpInput = ctx.freshName("input") |
| val fieldsHash = fields.zipWithIndex.map { case (field, index) => |
| nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx) |
| } |
| val hashResultType = CodeGenerator.javaType(dataType) |
| val code = ctx.splitExpressions( |
| expressions = fieldsHash, |
| funcName = "computeHashForStruct", |
| arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result), |
| returnType = hashResultType, |
| makeSplitFunction = body => |
| s""" |
| |$body |
| |return $result; |
| """.stripMargin, |
| foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) |
| s""" |
| |final InternalRow $tmpInput = $input; |
| |$code |
| """.stripMargin |
| } |
| |
| @tailrec |
| private def computeHashWithTailRec( |
| input: String, |
| dataType: DataType, |
| result: String, |
| ctx: CodegenContext): String = dataType match { |
| case NullType => "" |
| case BooleanType => genHashBoolean(input, result) |
| case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result) |
| case LongType => genHashLong(input, result) |
| case TimestampType => genHashTimestamp(input, result) |
| case FloatType => genHashFloat(input, result) |
| case DoubleType => genHashDouble(input, result) |
| case d: DecimalType => genHashDecimal(ctx, d, input, result) |
| case CalendarIntervalType => genHashCalendarInterval(input, result) |
| case DayTimeIntervalType => genHashLong(input, result) |
| case YearMonthIntervalType => genHashInt(input, result) |
| case BinaryType => genHashBytes(input, result) |
| case StringType => genHashString(input, result) |
| case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull) |
| case MapType(kt, vt, valueContainsNull) => |
| genHashForMap(ctx, input, result, kt, vt, valueContainsNull) |
| case StructType(fields) => genHashForStruct(ctx, input, result, fields) |
| case udt: UserDefinedType[_] => computeHashWithTailRec(input, udt.sqlType, result, ctx) |
| } |
| |
| protected def computeHash( |
| input: String, |
| dataType: DataType, |
| result: String, |
| ctx: CodegenContext): String = computeHashWithTailRec(input, dataType, result, ctx) |
| |
| protected def hasherClassName: String |
| } |
| |
| /** |
| * Base class for interpreted hash functions. |
| */ |
| abstract class InterpretedHashFunction { |
| protected def hashInt(i: Int, seed: Long): Long |
| |
| protected def hashLong(l: Long, seed: Long): Long |
| |
| protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long |
| |
| /** |
| * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity |
| * of input `value`. |
| */ |
| def hash(value: Any, dataType: DataType, seed: Long): Long = { |
| value match { |
| case null => seed |
| case b: Boolean => hashInt(if (b) 1 else 0, seed) |
| case b: Byte => hashInt(b, seed) |
| case s: Short => hashInt(s, seed) |
| case i: Int => hashInt(i, seed) |
| case l: Long => hashLong(l, seed) |
| case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) |
| case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) |
| case d: Decimal => |
| val precision = dataType.asInstanceOf[DecimalType].precision |
| if (precision <= Decimal.MAX_LONG_DIGITS) { |
| hashLong(d.toUnscaledLong, seed) |
| } else { |
| val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray |
| hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed) |
| } |
| case c: CalendarInterval => hashInt(c.months, hashInt(c.days, hashLong(c.microseconds, seed))) |
| case a: Array[Byte] => |
| hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) |
| case s: UTF8String => |
| hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) |
| |
| case array: ArrayData => |
| val elementType = dataType match { |
| case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType |
| case ArrayType(et, _) => et |
| } |
| var result = seed |
| var i = 0 |
| while (i < array.numElements()) { |
| result = hash(array.get(i, elementType), elementType, result) |
| i += 1 |
| } |
| result |
| |
| case map: MapData => |
| val (kt, vt) = dataType match { |
| case udt: UserDefinedType[_] => |
| val mapType = udt.sqlType.asInstanceOf[MapType] |
| mapType.keyType -> mapType.valueType |
| case MapType(kt, vt, _) => kt -> vt |
| } |
| val keys = map.keyArray() |
| val values = map.valueArray() |
| var result = seed |
| var i = 0 |
| while (i < map.numElements()) { |
| result = hash(keys.get(i, kt), kt, result) |
| result = hash(values.get(i, vt), vt, result) |
| i += 1 |
| } |
| result |
| |
| case struct: InternalRow => |
| val types: Array[DataType] = dataType match { |
| case udt: UserDefinedType[_] => |
| udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray |
| case StructType(fields) => fields.map(_.dataType) |
| } |
| var result = seed |
| var i = 0 |
| val len = struct.numFields |
| while (i < len) { |
| result = hash(struct.get(i, types(i)), types(i), result) |
| i += 1 |
| } |
| result |
| } |
| } |
| } |
| |
| /** |
| * A MurMur3 Hash expression. |
| * |
| * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle |
| * and bucketing have same data distribution. |
| */ |
| @ExpressionDescription( |
| usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.", |
| examples = """ |
| Examples: |
| > SELECT _FUNC_('Spark', array(123), 2); |
| -1321691492 |
| """, |
| since = "2.0.0", |
| group = "hash_funcs") |
| case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { |
| def this(arguments: Seq[Expression]) = this(arguments, 42) |
| |
| override def dataType: DataType = IntegerType |
| |
| override def prettyName: String = "hash" |
| |
| override protected def hasherClassName: String = classOf[Murmur3_x86_32].getName |
| |
| override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { |
| Murmur3HashFunction.hash(value, dataType, seed).toInt |
| } |
| |
| override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Murmur3Hash = |
| copy(children = newChildren) |
| } |
| |
| object Murmur3HashFunction extends InterpretedHashFunction { |
| override protected def hashInt(i: Int, seed: Long): Long = { |
| Murmur3_x86_32.hashInt(i, seed.toInt) |
| } |
| |
| override protected def hashLong(l: Long, seed: Long): Long = { |
| Murmur3_x86_32.hashLong(l, seed.toInt) |
| } |
| |
| override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { |
| Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) |
| } |
| } |
| |
| /** |
| * A xxHash64 64-bit hash expression. |
| */ |
| @ExpressionDescription( |
| usage = "_FUNC_(expr1, expr2, ...) - Returns a 64-bit hash value of the arguments.", |
| examples = """ |
| Examples: |
| > SELECT _FUNC_('Spark', array(123), 2); |
| 5602566077635097486 |
| """, |
| since = "3.0.0", |
| group = "hash_funcs") |
| case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpression[Long] { |
| def this(arguments: Seq[Expression]) = this(arguments, 42L) |
| |
| override def dataType: DataType = LongType |
| |
| override def prettyName: String = "xxhash64" |
| |
| override protected def hasherClassName: String = classOf[XXH64].getName |
| |
| override protected def computeHash(value: Any, dataType: DataType, seed: Long): Long = { |
| XxHash64Function.hash(value, dataType, seed) |
| } |
| |
| override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): XxHash64 = |
| copy(children = newChildren) |
| } |
| |
| object XxHash64Function extends InterpretedHashFunction { |
| override protected def hashInt(i: Int, seed: Long): Long = XXH64.hashInt(i, seed) |
| |
| override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) |
| |
| override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { |
| XXH64.hashUnsafeBytes(base, offset, len, seed) |
| } |
| } |
| |
| /** |
| * Simulates Hive's hashing function from Hive v1.2.1 at |
| * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() |
| * |
| * We should use this hash function for both shuffle and bucket of Hive tables, so that |
| * we can guarantee shuffle and bucketing have same data distribution |
| */ |
| @ExpressionDescription( |
| usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.", |
| since = "2.2.0", |
| group = "hash_funcs") |
| case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |
| override val seed = 0 |
| |
| override def dataType: DataType = IntegerType |
| |
| override def prettyName: String = "hive-hash" |
| |
| override protected def hasherClassName: String = classOf[HiveHasher].getName |
| |
| override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { |
| HiveHashFunction.hash(value, dataType, this.seed).toInt |
| } |
| |
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { |
| ev.isNull = FalseLiteral |
| |
| val childHash = ctx.freshName("childHash") |
| val childrenHash = children.map { child => |
| val childGen = child.genCode(ctx) |
| val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) { |
| computeHash(childGen.value, child.dataType, childHash, ctx) |
| } |
| s""" |
| |${childGen.code} |
| |$childHash = 0; |
| |$codeToComputeHash |
| |${ev.value} = (31 * ${ev.value}) + $childHash; |
| """.stripMargin |
| } |
| |
| val codes = ctx.splitExpressionsWithCurrentInputs( |
| expressions = childrenHash, |
| funcName = "computeHash", |
| extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value), |
| returnType = CodeGenerator.JAVA_INT, |
| makeSplitFunction = body => |
| s""" |
| |${CodeGenerator.JAVA_INT} $childHash = 0; |
| |$body |
| |return ${ev.value}; |
| """.stripMargin, |
| foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) |
| |
| |
| ev.copy(code = |
| code""" |
| |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; |
| |${CodeGenerator.JAVA_INT} $childHash = 0; |
| |$codes |
| """.stripMargin) |
| } |
| |
| override def eval(input: InternalRow = null): Int = { |
| var hash = seed |
| var i = 0 |
| val len = children.length |
| while (i < len) { |
| hash = (31 * hash) + computeHash(children(i).eval(input), children(i).dataType, hash) |
| i += 1 |
| } |
| hash |
| } |
| |
| override protected def genHashInt(i: String, result: String): String = |
| s"$result = $hasherClassName.hashInt($i);" |
| |
| override protected def genHashLong(l: String, result: String): String = |
| s"$result = $hasherClassName.hashLong($l);" |
| |
| override protected def genHashBytes(b: String, result: String): String = |
| s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" |
| |
| override protected def genHashDecimal( |
| ctx: CodegenContext, |
| d: DecimalType, |
| input: String, |
| result: String): String = { |
| s""" |
| $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( |
| $input.toJavaBigDecimal()).hashCode();""" |
| } |
| |
| override protected def genHashCalendarInterval(input: String, result: String): String = { |
| s""" |
| $result = (int) |
| ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashCalendarInterval($input); |
| """ |
| } |
| |
| override protected def genHashTimestamp(input: String, result: String): String = |
| s""" |
| $result = (int) ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashTimestamp($input); |
| """ |
| |
| override protected def genHashString(input: String, result: String): String = { |
| val baseObject = s"$input.getBaseObject()" |
| val baseOffset = s"$input.getBaseOffset()" |
| val numBytes = s"$input.numBytes()" |
| s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" |
| } |
| |
| override protected def genHashForArray( |
| ctx: CodegenContext, |
| input: String, |
| result: String, |
| elementType: DataType, |
| containsNull: Boolean): String = { |
| val index = ctx.freshName("index") |
| val childResult = ctx.freshName("childResult") |
| s""" |
| int $childResult = 0; |
| for (int $index = 0; $index < $input.numElements(); $index++) { |
| $childResult = 0; |
| ${nullSafeElementHash(input, index, containsNull, elementType, childResult, ctx)}; |
| $result = (31 * $result) + $childResult; |
| } |
| """ |
| } |
| |
| override protected def genHashForMap( |
| ctx: CodegenContext, |
| input: String, |
| result: String, |
| keyType: DataType, |
| valueType: DataType, |
| valueContainsNull: Boolean): String = { |
| val index = ctx.freshName("index") |
| val keys = ctx.freshName("keys") |
| val values = ctx.freshName("values") |
| val keyResult = ctx.freshName("keyResult") |
| val valueResult = ctx.freshName("valueResult") |
| s""" |
| final ArrayData $keys = $input.keyArray(); |
| final ArrayData $values = $input.valueArray(); |
| int $keyResult = 0; |
| int $valueResult = 0; |
| for (int $index = 0; $index < $input.numElements(); $index++) { |
| $keyResult = 0; |
| ${nullSafeElementHash(keys, index, false, keyType, keyResult, ctx)} |
| $valueResult = 0; |
| ${nullSafeElementHash(values, index, valueContainsNull, valueType, valueResult, ctx)} |
| $result += $keyResult ^ $valueResult; |
| } |
| """ |
| } |
| |
| override protected def genHashForStruct( |
| ctx: CodegenContext, |
| input: String, |
| result: String, |
| fields: Array[StructField]): String = { |
| val tmpInput = ctx.freshName("input") |
| val childResult = ctx.freshName("childResult") |
| val fieldsHash = fields.zipWithIndex.map { case (field, index) => |
| val computeFieldHash = nullSafeElementHash( |
| tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx) |
| s""" |
| |$childResult = 0; |
| |$computeFieldHash |
| |$result = (31 * $result) + $childResult; |
| """.stripMargin |
| } |
| |
| val code = ctx.splitExpressions( |
| expressions = fieldsHash, |
| funcName = "computeHashForStruct", |
| arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result), |
| returnType = CodeGenerator.JAVA_INT, |
| makeSplitFunction = body => |
| s""" |
| |${CodeGenerator.JAVA_INT} $childResult = 0; |
| |$body |
| |return $result; |
| """.stripMargin, |
| foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) |
| s""" |
| |final InternalRow $tmpInput = $input; |
| |${CodeGenerator.JAVA_INT} $childResult = 0; |
| |$code |
| """.stripMargin |
| } |
| |
| override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): HiveHash = |
| copy(children = newChildren) |
| } |
| |
| object HiveHashFunction extends InterpretedHashFunction { |
| override protected def hashInt(i: Int, seed: Long): Long = { |
| HiveHasher.hashInt(i) |
| } |
| |
| override protected def hashLong(l: Long, seed: Long): Long = { |
| HiveHasher.hashLong(l) |
| } |
| |
| override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { |
| HiveHasher.hashUnsafeBytes(base, offset, len) |
| } |
| |
| private val HIVE_DECIMAL_MAX_PRECISION = 38 |
| private val HIVE_DECIMAL_MAX_SCALE = 38 |
| |
| // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() |
| def normalizeDecimal(input: BigDecimal): BigDecimal = { |
| if (input == null) return null |
| |
| def trimDecimal(input: BigDecimal) = { |
| var result = input |
| if (result.compareTo(BigDecimal.ZERO) == 0) { |
| // Special case for 0, because java doesn't strip zeros correctly on that number. |
| result = BigDecimal.ZERO |
| } else { |
| result = result.stripTrailingZeros |
| if (result.scale < 0) { |
| // no negative scale decimals |
| result = result.setScale(0) |
| } |
| } |
| result |
| } |
| |
| var result = trimDecimal(input) |
| val intDigits = result.precision - result.scale |
| if (intDigits > HIVE_DECIMAL_MAX_PRECISION) { |
| return null |
| } |
| |
| val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE, |
| Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale)) |
| if (result.scale > maxScale) { |
| result = result.setScale(maxScale, RoundingMode.HALF_UP) |
| // Trimming is again necessary, because rounding may introduce new trailing 0's. |
| result = trimDecimal(result) |
| } |
| result |
| } |
| |
| /** |
| * Mimics TimestampWritable.hashCode() in Hive |
| */ |
| def hashTimestamp(timestamp: Long): Long = { |
| val timestampInSeconds = MICROSECONDS.toSeconds(timestamp) |
| val nanoSecondsPortion = (timestamp % MICROS_PER_SECOND) * NANOS_PER_MICROS |
| |
| var result = timestampInSeconds |
| result <<= 30 // the nanosecond part fits in 30 bits |
| result |= nanoSecondsPortion |
| ((result >>> 32) ^ result).toInt |
| } |
| |
| /** |
| * Hive allows input intervals to be defined using units below but the intervals |
| * have to be from the same category: |
| * - year, month (stored as HiveIntervalYearMonth) |
| * - day, hour, minute, second, nanosecond (stored as HiveIntervalDayTime) |
| * |
| * e.g. (INTERVAL '30' YEAR + INTERVAL '-23' DAY) fails in Hive |
| * |
| * This method mimics HiveIntervalDayTime.hashCode() in Hive. |
| * |
| * Two differences wrt Hive due to how intervals are stored in Spark vs Hive: |
| * |
| * - If the `INTERVAL` is backed as HiveIntervalYearMonth in Hive, then this method will not |
| * produce Hive compatible result. The reason being Spark's representation of calendar does not |
| * have such categories based on the interval and is unified. |
| * |
| * - Spark's [[CalendarInterval]] has precision upto microseconds but Hive's |
| * HiveIntervalDayTime can store data with precision upto nanoseconds. So, any input intervals |
| * with nanosecond values will lead to wrong output hashes (i.e. non adherent with Hive output) |
| */ |
| def hashCalendarInterval(calendarInterval: CalendarInterval): Long = { |
| val totalMicroSeconds = calendarInterval.days * MICROS_PER_DAY + calendarInterval.microseconds |
| val totalSeconds = totalMicroSeconds / MICROS_PER_SECOND.toInt |
| val result: Int = (17 * 37) + (totalSeconds ^ totalSeconds >> 32).toInt |
| |
| val nanoSeconds = (totalMicroSeconds - (totalSeconds * MICROS_PER_SECOND.toInt)).toInt * 1000 |
| (result * 37) + nanoSeconds |
| } |
| |
| override def hash(value: Any, dataType: DataType, seed: Long): Long = { |
| value match { |
| case null => 0 |
| case array: ArrayData => |
| val elementType = dataType match { |
| case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType |
| case ArrayType(et, _) => et |
| } |
| |
| var result = 0 |
| var i = 0 |
| val length = array.numElements() |
| while (i < length) { |
| result = (31 * result) + hash(array.get(i, elementType), elementType, 0).toInt |
| i += 1 |
| } |
| result |
| |
| case map: MapData => |
| val (kt, vt) = dataType match { |
| case udt: UserDefinedType[_] => |
| val mapType = udt.sqlType.asInstanceOf[MapType] |
| mapType.keyType -> mapType.valueType |
| case MapType(_kt, _vt, _) => _kt -> _vt |
| } |
| val keys = map.keyArray() |
| val values = map.valueArray() |
| |
| var result = 0 |
| var i = 0 |
| val length = map.numElements() |
| while (i < length) { |
| result += hash(keys.get(i, kt), kt, 0).toInt ^ hash(values.get(i, vt), vt, 0).toInt |
| i += 1 |
| } |
| result |
| |
| case struct: InternalRow => |
| val types: Array[DataType] = dataType match { |
| case udt: UserDefinedType[_] => |
| udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray |
| case StructType(fields) => fields.map(_.dataType) |
| } |
| |
| var result = 0 |
| var i = 0 |
| val length = struct.numFields |
| while (i < length) { |
| result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt |
| i += 1 |
| } |
| result |
| |
| case d: Decimal => normalizeDecimal(d.toJavaBigDecimal).hashCode() |
| case timestamp: Long if dataType.isInstanceOf[TimestampType] => hashTimestamp(timestamp) |
| case calendarInterval: CalendarInterval => hashCalendarInterval(calendarInterval) |
| case _ => super.hash(value, dataType, 0) |
| } |
| } |
| } |