blob: de15ec43c4f31db7b1184633e7f73f1081611440 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.util.Locale
import org.apache.spark.{QueryContext, SparkException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLType}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines the basic expression abstract classes in Catalyst.
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* An expression in Catalyst.
*
* If an expression wants to be exposed in the function registry (so users can call it with
* "name(arguments...)", the concrete implementation must be a case class whose constructor
* arguments are all Expressions types. See [[Substring]] for an example.
*
* There are a few important traits or abstract classes:
*
* - [[Nondeterministic]]: an expression that is not deterministic.
* - [[Unevaluable]]: an expression that is not supposed to be evaluated.
* - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to
* interpreted mode.
* - [[NullIntolerant]]: an expression that is null intolerant (i.e. any null input will result in
* null output).
* - [[NonSQLExpression]]: a common base trait for the expressions that do not have SQL
* expressions like representation. For example, `ScalaUDF`, `ScalaUDAF`,
* and object `MapObjects` and `Invoke`.
* - [[UserDefinedExpression]]: a common base trait for user-defined functions, including
* UDF/UDAF/UDTF.
* - [[HigherOrderFunction]]: a common base trait for higher order functions that take one or more
* (lambda) functions and applies these to some objects. The function
* produces a number of variables which can be consumed by some lambda
* functions.
* - [[NamedExpression]]: An [[Expression]] that is named.
* - [[TimeZoneAwareExpression]]: A common base trait for time zone aware expressions.
* - [[SubqueryExpression]]: A base interface for expressions that contain a
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*
* - [[LeafExpression]]: an expression that has no child.
* - [[UnaryExpression]]: an expression that has one child.
* - [[BinaryExpression]]: an expression that has two children.
* - [[TernaryExpression]]: an expression that has three children.
* - [[QuaternaryExpression]]: an expression that has four children.
* - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have
* the same output data type.
*
* A few important traits used for type coercion rules:
* - [[ExpectsInputTypes]]: an expression that has the expected input types. This trait is typically
* used by operator expressions (e.g. [[Add]], [[Subtract]]) to define
* expected input types without any implicit casting.
* - [[ImplicitCastInputTypes]]: an expression that has the expected input types, which can be
* implicitly castable using [[TypeCoercion.ImplicitTypeCasts]].
* - [[ComplexTypeMergingExpression]]: to resolve output types of the complex expressions
* (e.g., [[CaseWhen]]).
*/
abstract class Expression extends TreeNode[Expression] {
/**
* Returns true when an expression is a candidate for static evaluation before the query is
* executed. A typical use case: [[org.apache.spark.sql.catalyst.optimizer.ConstantFolding]]
*
* The following conditions are used to determine suitability for constant folding:
* - A [[Coalesce]] is foldable if all of its children are foldable
* - A [[BinaryExpression]] is foldable if its both left and right child are foldable
* - A [[Not]], [[IsNull]], or [[IsNotNull]] is foldable if its child is foldable
* - A [[Literal]] is foldable
* - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable
*/
def foldable: Boolean = false
/**
* Returns true when the current expression always return the same result for fixed inputs from
* children. The non-deterministic expressions should not change in number and order. They should
* not be evaluated during the query planning.
*
* Note that this means that an expression should be considered as non-deterministic if:
* - it relies on some mutable internal state, or
* - it relies on some implicit input that is not part of the children expression list.
* - it has non-deterministic child or children.
* - it assumes the input satisfies some certain condition via the child operator.
*
* An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext.
* By default leaf expressions are deterministic as Nil.forall(_.deterministic) returns true.
*/
lazy val deterministic: Boolean = children.forall(_.deterministic)
def nullable: Boolean
/**
* Workaround scala compiler so that we can call super on lazy vals
*/
@transient
private lazy val _references: AttributeSet =
AttributeSet.fromAttributeSets(children.map(_.references))
def references: AttributeSet = _references
/**
* Returns true if the expression contains mutable state.
*
* A stateful expression should never be evaluated multiple times for a single row. This should
* only be a problem for interpreted execution. This can be prevented by creating fresh copies
* of the stateful expression before execution. A common example to trigger this issue:
* {{{
* val rand = functions.rand()
* df.select(rand, rand) // These 2 rand should not share a state.
* }}}
*/
def stateful: Boolean = false
/**
* Returns true if the expression could potentially throw an exception when evaluated.
*/
lazy val throwable: Boolean = children.exists(_.throwable)
/**
* Returns a copy of this expression where all stateful expressions are replaced with fresh
* uninitialized copies. If the expression contains no stateful expressions then the original
* expression is returned.
*/
def freshCopyIfContainsStatefulExpression(): Expression = {
val childrenIndexedSeq: IndexedSeq[Expression] = children match {
case types: IndexedSeq[Expression] => types
case other => other.toIndexedSeq
}
val newChildren = childrenIndexedSeq.map(_.freshCopyIfContainsStatefulExpression())
// A more efficient version of `children.zip(newChildren).exists(_ ne _)`
val anyChildChanged = {
val size = newChildren.length
var i = 0
var res: Boolean = false
while (!res && i < size) {
res |= (childrenIndexedSeq(i) ne newChildren(i))
i += 1
}
res
}
// If the children contain stateful expressions and get copied, or this expression is stateful,
// copy this expression with the new children.
if (anyChildChanged || stateful) {
CurrentOrigin.withOrigin(origin) {
val res = withNewChildrenInternal(newChildren)
res.copyTagsFrom(this)
res
}
} else {
this
}
}
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: InternalRow = null): Any
/**
* Returns an [[ExprCode]], that contains the Java source code to generate the result of
* evaluating the expression on an input row.
*
* @param ctx a [[CodegenContext]]
* @return [[ExprCode]]
*/
def genCode(ctx: CodegenContext): ExprCode = {
ctx.subExprEliminationExprs.get(ExpressionEquals(this)).map { subExprState =>
// This expression is repeated which means that the code to evaluate it has already been added
// as a function before. In that case, we just re-use it.
ExprCode(
ctx.registerComment(this.toString),
subExprState.eval.isNull,
subExprState.eval.value)
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val eval = doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
reduceCodeSize(ctx, eval)
if (eval.code.toString.nonEmpty) {
// Add `this` in the comment.
eval.copy(code = ctx.registerComment(this.toString) + eval.code)
} else {
eval
}
}
}
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
val splitThreshold = SQLConf.get.methodSplitThreshold
if (eval.code.length > splitThreshold && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = JavaCode.isNullGlobal(globalIsNull)
s"$globalIsNull = $localIsNull;"
} else {
""
}
val javaType = CodeGenerator.javaType(dataType)
val newValue = ctx.freshName("value")
val funcName = ctx.freshName(nodeName)
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
| ${eval.code}
| $setIsNull
| return ${eval.value};
|}
""".stripMargin)
eval.value = JavaCode.variable(newValue, dataType)
eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}
/**
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression. Concrete expression
* implementations should override this to do actual code generation.
*
* @param ctx a [[CodegenContext]]
* @param ev an [[ExprCode]] with unique terms.
* @return an [[ExprCode]] containing the Java source code to generate the given expression
*/
protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode
/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and input data types checking passed, and `false` if it still contains any unresolved
* placeholders or has data types mismatch.
* Implementations of expressions should override this if the resolution of this type of
* expression involves more than just the resolution of its children and type checking.
*/
lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
/**
* Returns the [[DataType]] of the result of evaluating this expression. It is
* invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false).
*/
def dataType: DataType
/**
* Returns true if all the children of this expression have been resolved to a specific schema
* and false if any still contains any unresolved placeholders.
*/
def childrenResolved: Boolean = children.forall(_.resolved)
/**
* Returns an expression where a best effort attempt has been made to transform `this` in a way
* that preserves the result but removes cosmetic variations (case sensitivity, ordering for
* commutative operations, etc.).
*
* `deterministic` expressions where `this.canonicalized == other.canonicalized` will always
* evaluate to the same result.
*
* The process of canonicalization is a one pass, bottum-up expression tree computation based on
* canonicalizing children before canonicalizing the current node. There is one exception though,
* as adjacent, same class [[CommutativeExpression]]s canonicalazion happens in a way that calling
* `canonicalized` on the root:
* 1. Gathers and canonicalizes the non-commutative (or commutative but not same class) child
* expressions of the adjacent expressions.
* 2. Reorder the canonicalized child expressions by their hashcode.
* This means that the lazy `cannonicalized` is called and computed only on the root of the
* adjacent expressions.
*/
lazy val canonicalized: Expression = withCanonicalizedChildren
/**
* The default process of canonicalization. It is a one pass, bottum-up expression tree
* computation based oncanonicalizing children before canonicalizing the current node.
*/
final protected def withCanonicalizedChildren: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
withNewChildren(canonicalizedChildren)
}
/**
* Returns true when two expressions will always compute the same result, even if they differ
* cosmetically (i.e. capitalization of names in attributes may be different).
*
* See [[Expression#canonicalized]] for more details.
*/
final def semanticEquals(other: Expression): Boolean =
deterministic && other.deterministic && canonicalized == other.canonicalized
/**
* Returns a `hashCode` for the calculation performed by this expression. Unlike the standard
* `hashCode`, an attempt has been made to eliminate cosmetic differences.
*
* See [[Expression#canonicalized]] for more details.
*/
def semanticHash(): Int = canonicalized.hashCode()
/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `childrenResolved == true`.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
/**
* Returns a user-facing string representation of this expression's name.
* This should usually match the name of the function in SQL.
*/
def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(nodeName.toLowerCase(Locale.ROOT))
protected def flatArguments: Iterator[Any] = stringArgs.flatMap {
case t: Iterable[_] => t
case single => single :: Nil
}
// Marks this as final, Expression.verboseString should never be called, and thus shouldn't be
// overridden by concrete classes.
final override def verboseString(maxFields: Int): String = simpleString(maxFields)
override def simpleString(maxFields: Int): String = toString
override def toString: String = prettyName + truncatedString(
flatArguments.toSeq, "(", ", ", ")", SQLConf.get.maxToStringFields)
/**
* Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]],
* this method may return an arbitrary user facing string.
*/
def sql: String = {
val childrenSQL = children.map(_.sql).mkString(", ")
s"$prettyName($childrenSQL)"
}
override def simpleStringWithNodeId(): String = {
throw SparkException.internalError(s"$nodeName does not implement simpleStringWithNodeId")
}
protected def typeSuffix =
if (resolved) {
dataType match {
case LongType => "L"
case _ => ""
}
} else {
""
}
}
/**
* An expression that cannot be evaluated but is guaranteed to be replaced with a foldable value
* by query optimizer (e.g. CurrentDate).
*/
trait FoldableUnevaluable extends Expression {
override def foldable: Boolean = true
final override def eval(input: InternalRow = null): Any =
throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
}
/**
* An expression that cannot be evaluated. These expressions don't live past analysis or
* optimization time (e.g. Star) and should not be evaluated during query planning and
* execution.
*/
trait Unevaluable extends Expression with FoldableUnevaluable {
/** Unevaluable is not foldable by default because we don't have an eval for it.
* Exception are expressions that will be replaced by a literal by Optimizer (e.g. CurrentDate).
* Hence we allow overriding overriding of this field in special cases.
*/
final override def foldable: Boolean = false
}
/**
* An expression that gets replaced at runtime (currently by the optimizer) into a different
* expression for evaluation. This is mainly used to provide compatibility with other databases.
* For example, we use this to support "nvl" by replacing it with "coalesce".
*/
trait RuntimeReplaceable extends Expression {
def replacement: Expression
override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE)
override def nullable: Boolean = replacement.nullable
override def dataType: DataType = replacement.dataType
// As this expression gets replaced at optimization with its `child" expression,
// two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions
// are semantically equal.
override lazy val canonicalized: Expression = replacement.canonicalized
final override def eval(input: InternalRow = null): Any =
throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
}
/**
* An add-on of [[RuntimeReplaceable]]. It makes `replacement` the child of the expression, to
* inherit the analysis rules for it, such as type coercion. The implementation should put
* `replacement` in the case class constructor, and define a normal constructor that accepts only
* the original parameters. For an example, see [[TryAdd]]. To make sure the explain plan and
* expression SQL works correctly, the implementation should also implement the `parameters` method.
*/
trait InheritAnalysisRules extends UnaryLike[Expression] { self: RuntimeReplaceable =>
override def child: Expression = replacement
def parameters: Seq[Expression]
override def flatArguments: Iterator[Any] = parameters.iterator
// This method is used to generate a SQL string with transformed inputs. This is necessary as
// the actual inputs are not the children of this expression.
def makeSQLString(childrenSQL: Seq[String]): String = {
prettyName + childrenSQL.mkString("(", ", ", ")")
}
final override def sql: String = makeSQLString(parameters.map(_.sql))
}
/**
* An add-on of [[AggregateFunction]]. This gets rewritten (currently by the optimizer) into a
* different aggregate expression for evaluation. This is mainly used to provide compatibility
* with other databases. For example, we use this to support every, any/some aggregates by rewriting
* them with Min and Max respectively.
*/
trait RuntimeReplaceableAggregate extends RuntimeReplaceable { self: AggregateFunction =>
override def aggBufferSchema: StructType = {
throw SparkException.internalError(
"RuntimeReplaceableAggregate.aggBufferSchema should not be called")
}
override def aggBufferAttributes: Seq[AttributeReference] = {
throw SparkException.internalError(
"RuntimeReplaceableAggregate.aggBufferAttributes should not be called")
}
override def inputAggBufferAttributes: Seq[AttributeReference] = {
throw SparkException.internalError(
"RuntimeReplaceableAggregate.inputAggBufferAttributes should not be called")
}
}
/**
* Expressions that don't have SQL representation should extend this trait. Examples are
* `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`.
*/
trait NonSQLExpression extends Expression {
final override def sql: String = {
transform {
case a: Attribute => new PrettyAttribute(a)
case a: Alias => PrettyAttribute(a.sql, a.dataType)
case p: PythonFuncExpression => PrettyPythonUDF(p.name, p.dataType, p.children)
}.toString
}
}
/**
* An expression that is nondeterministic.
*/
trait Nondeterministic extends Expression {
final override lazy val deterministic: Boolean = false
final override def foldable: Boolean = false
@transient
private[this] var initialized = false
/**
* Initializes internal states given the current partition index and mark this as initialized.
* Subclasses should override [[initializeInternal()]].
*/
final def initialize(partitionIndex: Int): Unit = {
initializeInternal(partitionIndex)
initialized = true
}
protected def initializeInternal(partitionIndex: Int): Unit
/**
* @inheritdoc
* Throws an exception if [[initialize()]] is not called yet.
* Subclasses should override [[evalInternal()]].
*/
final override def eval(input: InternalRow = null): Any = {
require(initialized,
s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.")
evalInternal(input)
}
protected def evalInternal(input: InternalRow): Any
}
/**
* An expression that contains conditional expression branches, so not all branches will be hit.
* All optimization should be careful with the evaluation order.
*/
trait ConditionalExpression extends Expression {
final override def foldable: Boolean = children.forall(_.foldable)
/**
* Return the children expressions which can always be hit at runtime.
*/
def alwaysEvaluatedInputs: Seq[Expression]
/**
* Return a copy of itself with a new `alwaysEvaluatedInputs`.
*/
def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): ConditionalExpression
/**
* Return groups of branches. For each group, at least one branch will be hit at runtime,
* so that we can eagerly evaluate the common expressions of a group.
*/
def branchGroups: Seq[Seq[Expression]]
}
/**
* A leaf expression, i.e. one without any child expressions.
*/
abstract class LeafExpression extends Expression with LeafLike[Expression]
/**
* An expression with one input and one output. The output is by default evaluated to null
* if the input is evaluated to null.
*/
abstract class UnaryExpression extends Expression with UnaryLike[Expression] {
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
/**
* Default behavior of evaluation according to the default nullability of UnaryExpression.
* If subclass of UnaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
null
} else {
nullSafeEval(value)
}
}
/**
* Called by default [[eval]] implementation. If subclass of UnaryExpression keep the default
* nullability, they can override this method to save null-check code. If we need full control
* of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(input: Any): Any =
throw QueryExecutionErrors.notOverrideExpectedMethodsError(this.getClass.getName,
"eval", "nullSafeEval")
/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not null, use `f` to generate the expression.
*
* As an example, the following does a boolean inversion (i.e. NOT).
* {{{
* defineCodeGen(ctx, ev, c => s"!($c)")
* }}}
*
* @param f function that accepts a variable name and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: String => String): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
s"${ev.value} = ${f(eval)};"
})
}
/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not null, use `f` to generate the expression.
*
* @param f function that accepts the non-null evaluation result name of child and returns Java
* code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: String => String): ExprCode = {
val childGen = child.genCode(ctx)
val resultCode = f(childGen.value)
if (nullable) {
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
ev.copy(code = code"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = code"""
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
}
/**
* An expression with SQL query context. The context string can be serialized from the Driver
* to executors. It will also be kept after rule transforms.
*/
trait SupportQueryContext extends Expression with Serializable {
protected var queryContext: Option[QueryContext] = initQueryContext()
def initQueryContext(): Option[QueryContext]
def getContextOrNull(): QueryContext = queryContext.orNull
def getContextOrNullCode(ctx: CodegenContext, withErrorContext: Boolean = true): String = {
if (withErrorContext && queryContext.isDefined) {
ctx.addReferenceObj("errCtx", queryContext.get)
} else {
"null"
}
}
// Note: Even though query contexts are serialized to executors, it will be regenerated from an
// empty "Origin" during rule transforms since "Origin"s are not serialized to executors
// for better performance. Thus, we need to copy the original query context during
// transforms. The query context string is considered as a "tag" on the expression here.
override def copyTagsFrom(other: Expression): Unit = {
other match {
case s: SupportQueryContext =>
queryContext = s.queryContext
case _ =>
}
super.copyTagsFrom(other)
}
}
object UnaryExpression {
def unapply(e: UnaryExpression): Option[Expression] = Some(e.child)
}
/**
* An expression with two inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
abstract class BinaryExpression extends Expression with BinaryLike[Expression] {
override def foldable: Boolean = left.foldable && right.foldable
override def nullable: Boolean = left.nullable || right.nullable
/**
* Default behavior of evaluation according to the default nullability of BinaryExpression.
* If subclass of BinaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val value1 = left.eval(input)
if (value1 == null) {
null
} else {
val value2 = right.eval(input)
if (value2 == null) {
null
} else {
nullSafeEval(value1, value2)
}
}
}
/**
* Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default
* nullability, they can override this method to save null-check code. If we need full control
* of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(input1: Any, input2: Any): Any =
throw QueryExecutionErrors.notOverrideExpectedMethodsError(this.getClass.getName,
"eval", "nullSafeEval")
/**
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts two variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String) => String): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"${ev.value} = ${f(eval1, eval2)};"
})
}
/**
* Short hand for generating binary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f function that accepts the 2 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String) => String): ExprCode = {
val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
val resultCode = f(leftGen.value, rightGen.value)
if (nullable) {
val nullSafeEval =
leftGen.code.toString + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
rightGen.code.toString + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = code"""
${leftGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
}
object BinaryExpression {
def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right))
}
/**
* A [[BinaryExpression]] that is an operator, with two properties:
*
* 1. The string representation is "x symbol y", rather than "funcName(x, y)".
* 2. Two inputs are expected to be of the same type. If the two inputs have different types,
* the analyzer will find the tightest common type and do the proper type casting.
*/
abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes with QueryErrorsBase {
/**
* Expected input type from both left/right child expressions, similar to the
* [[ImplicitCastInputTypes]] trait.
*/
def inputType: AbstractDataType
def symbol: String
def sqlOperator: String = symbol
override def toString: String = s"($left $sqlOperator $right)"
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
override def checkInputDataTypes(): TypeCheckResult = {
// First check whether left and right have the same type, then check if the type is acceptable.
if (!DataTypeUtils.sameType(left.dataType, right.dataType)) {
DataTypeMismatch(
errorSubClass = "BINARY_OP_DIFF_TYPES",
messageParameters = Map(
"left" -> toSQLType(left.dataType),
"right" -> toSQLType(right.dataType)))
} else if (!inputType.acceptsType(left.dataType)) {
DataTypeMismatch(
errorSubClass = "BINARY_OP_WRONG_TYPE",
messageParameters = Map(
"inputType" -> toSQLType(inputType),
"actualDataType" -> toSQLType(left.dataType)))
} else {
TypeCheckResult.TypeCheckSuccess
}
}
override def sql: String = s"(${left.sql} $sqlOperator ${right.sql})"
}
object BinaryOperator {
def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right))
}
/**
* An expression with three inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
abstract class TernaryExpression extends Expression with TernaryLike[Expression] {
override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = children.exists(_.nullable)
/**
* Default behavior of evaluation according to the default nullability of TernaryExpression.
* If subclass of TernaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val value1 = first.eval(input)
if (value1 != null) {
val value2 = second.eval(input)
if (value2 != null) {
val value3 = third.eval(input)
if (value3 != null) {
return nullSafeEval(value1, value2, value3)
}
}
}
null
}
/**
* Called by default [[eval]] implementation. If subclass of TernaryExpression keep the default
* nullability, they can override this method to save null-check code. If we need full control
* of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any =
throw QueryExecutionErrors.notOverrideExpectedMethodsError(this.getClass.getName,
"eval", "nullSafeEval")
/**
* Short hand for generating ternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts three variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => {
s"${ev.value} = ${f(eval1, eval2, eval3)};"
})
}
/**
* Short hand for generating ternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f function that accepts the 3 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): ExprCode = {
val leftGen = children(0).genCode(ctx)
val midGen = children(1).genCode(ctx)
val rightGen = children(2).genCode(ctx)
val resultCode = f(leftGen.value, midGen.value, rightGen.value)
if (nullable) {
val nullSafeEval =
leftGen.code.toString + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) {
midGen.code.toString + ctx.nullSafeExec(children(1).nullable, midGen.isNull) {
rightGen.code.toString + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}
}
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = code"""
${leftGen.code}
${midGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
}
/**
* An expression with four inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
abstract class QuaternaryExpression extends Expression with QuaternaryLike[Expression] {
override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = children.exists(_.nullable)
/**
* Default behavior of evaluation according to the default nullability of QuaternaryExpression.
* If subclass of QuaternaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val value1 = first.eval(input)
if (value1 != null) {
val value2 = second.eval(input)
if (value2 != null) {
val value3 = third.eval(input)
if (value3 != null) {
val value4 = fourth.eval(input)
if (value4 != null) {
return nullSafeEval(value1, value2, value3, value4)
}
}
}
}
null
}
/**
* Called by default [[eval]] implementation. If subclass of QuaternaryExpression keep the
* default nullability, they can override this method to save null-check code. If we need
* full control of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(input1: Any, input2: Any, input3: Any, input4: Any): Any =
throw QueryExecutionErrors.notOverrideExpectedMethodsError(this.getClass.getName,
"eval", "nullSafeEval")
/**
* Short hand for generating quaternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts four variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String) => String): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3, eval4) => {
s"${ev.value} = ${f(eval1, eval2, eval3, eval4)};"
})
}
/**
* Short hand for generating quaternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f function that accepts the 4 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String) => String): ExprCode = {
val firstGen = children(0).genCode(ctx)
val secondGen = children(1).genCode(ctx)
val thridGen = children(2).genCode(ctx)
val fourthGen = children(3).genCode(ctx)
val resultCode = f(firstGen.value, secondGen.value, thridGen.value, fourthGen.value)
if (nullable) {
val nullSafeEval =
firstGen.code.toString + ctx.nullSafeExec(children(0).nullable, firstGen.isNull) {
secondGen.code.toString + ctx.nullSafeExec(children(1).nullable, secondGen.isNull) {
thridGen.code.toString + ctx.nullSafeExec(children(2).nullable, thridGen.isNull) {
fourthGen.code.toString + ctx.nullSafeExec(children(3).nullable, fourthGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}
}
}
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = code"""
${firstGen.code}
${secondGen.code}
${thridGen.code}
${fourthGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
}
/**
* An expression with five inputs and one output. The output is by default evaluated to null if
* any input is evaluated to null.
*/
abstract class QuinaryExpression extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = children.exists(_.nullable)
/**
* Default behavior of evaluation according to the default nullability of QuinaryExpression. If
* subclass of QuinaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
val v1 = exprs(0).eval(input)
if (v1 != null) {
val v2 = exprs(1).eval(input)
if (v2 != null) {
val v3 = exprs(2).eval(input)
if (v3 != null) {
val v4 = exprs(3).eval(input)
if (v4 != null) {
val v5 = exprs(4).eval(input)
if (v5 != null) {
return nullSafeEval(v1, v2, v3, v4, v5)
}
}
}
}
}
null
}
/**
* Called by default [[eval]] implementation. If subclass of QuinaryExpression keep the default
* nullability, they can override this method to save null-check code. If we need full control
* of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(
input1: Any,
input2: Any,
input3: Any,
input4: Any,
input5: Any): Any = {
throw QueryExecutionErrors.notOverrideExpectedMethodsError(
this.getClass.getName,
"eval",
"nullSafeEval")
}
/**
* Short hand for generating quinary evaluation code. If either of the sub-expressions is null,
* the result of this computation is assumed to be null.
*
* @param f
* accepts seven variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String, String) => String): ExprCode = {
nullSafeCodeGen(
ctx,
ev,
(eval1, eval2, eval3, eval4, eval5) => {
s"${ev.value} = ${f(eval1, eval2, eval3, eval4, eval5)};"
})
}
/**
* Short hand for generating quinary evaluation code. If either of the sub-expressions is null,
* the result of this computation is assumed to be null.
*
* @param f
* function that accepts the 5 non-null evaluation result names of children and returns Java
* code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String, String) => String): ExprCode = {
val firstGen = children(0).genCode(ctx)
val secondGen = children(1).genCode(ctx)
val thirdGen = children(2).genCode(ctx)
val fourthGen = children(3).genCode(ctx)
val fifthGen = children(4).genCode(ctx)
val resultCode =
f(firstGen.value, secondGen.value, thirdGen.value, fourthGen.value, fifthGen.value)
if (nullable) {
val nullSafeEval =
firstGen.code.toString + ctx.nullSafeExec(children(0).nullable, firstGen.isNull) {
secondGen.code.toString + ctx.nullSafeExec(children(1).nullable, secondGen.isNull) {
thirdGen.code.toString + ctx.nullSafeExec(children(2).nullable, thirdGen.isNull) {
fourthGen.code.toString + ctx.nullSafeExec(children(3).nullable, fourthGen.isNull) {
fifthGen.code.toString + ctx.nullSafeExec(children(4).nullable, fifthGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}
}
}
}
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(
code = code"""
${firstGen.code}
${secondGen.code}
${thirdGen.code}
${fourthGen.code}
${fifthGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""",
isNull = FalseLiteral)
}
}
}
/**
* An expression with six inputs + 7th optional input and one output.
* The output is by default evaluated to null if any input is evaluated to null.
*/
abstract class SeptenaryExpression extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = children.exists(_.nullable)
/**
* Default behavior of evaluation according to the default nullability of SeptenaryExpression.
* If subclass of SeptenaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
val v1 = exprs(0).eval(input)
if (v1 != null) {
val v2 = exprs(1).eval(input)
if (v2 != null) {
val v3 = exprs(2).eval(input)
if (v3 != null) {
val v4 = exprs(3).eval(input)
if (v4 != null) {
val v5 = exprs(4).eval(input)
if (v5 != null) {
val v6 = exprs(5).eval(input)
if (v6 != null) {
if (exprs.length > 6) {
val v7 = exprs(6).eval(input)
if (v7 != null) {
return nullSafeEval(v1, v2, v3, v4, v5, v6, Some(v7))
}
} else {
return nullSafeEval(v1, v2, v3, v4, v5, v6, None)
}
}
}
}
}
}
}
null
}
/**
* Called by default [[eval]] implementation. If subclass of SeptenaryExpression keep the
* default nullability, they can override this method to save null-check code. If we need
* full control of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(
input1: Any,
input2: Any,
input3: Any,
input4: Any,
input5: Any,
input6: Any,
input7: Option[Any]): Any = {
throw QueryExecutionErrors.notOverrideExpectedMethodsError(this.getClass.getName,
"eval", "nullSafeEval")
}
/**
* Short hand for generating septenary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts seven variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String, String, String, Option[String]) => String
): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3, eval4, eval5, eval6, eval7) => {
s"${ev.value} = ${f(eval1, eval2, eval3, eval4, eval5, eval6, eval7)};"
})
}
/**
* Short hand for generating septenary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f function that accepts the 7 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String, String, String, Option[String]) => String
): ExprCode = {
val firstGen = children(0).genCode(ctx)
val secondGen = children(1).genCode(ctx)
val thirdGen = children(2).genCode(ctx)
val fourthGen = children(3).genCode(ctx)
val fifthGen = children(4).genCode(ctx)
val sixthGen = children(5).genCode(ctx)
val seventhGen = if (children.length > 6) Some(children(6).genCode(ctx)) else None
val resultCode = f(
firstGen.value,
secondGen.value,
thirdGen.value,
fourthGen.value,
fifthGen.value,
sixthGen.value,
seventhGen.map(_.value))
if (nullable) {
val nullSafeEval =
firstGen.code.toString + ctx.nullSafeExec(children(0).nullable, firstGen.isNull) {
secondGen.code.toString + ctx.nullSafeExec(children(1).nullable, secondGen.isNull) {
thirdGen.code.toString + ctx.nullSafeExec(children(2).nullable, thirdGen.isNull) {
fourthGen.code.toString + ctx.nullSafeExec(children(3).nullable, fourthGen.isNull) {
fifthGen.code.toString + ctx.nullSafeExec(children(4).nullable, fifthGen.isNull) {
sixthGen.code.toString + ctx.nullSafeExec(children(5).nullable, sixthGen.isNull) {
val nullSafeResultCode =
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
seventhGen.map { gen =>
gen.code.toString + ctx.nullSafeExec(children(6).nullable, gen.isNull) {
nullSafeResultCode
}
}.getOrElse(nullSafeResultCode)
}
}
}
}
}
}
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = code"""
${firstGen.code}
${secondGen.code}
${thirdGen.code}
${fourthGen.code}
${fifthGen.code}
${sixthGen.code}
${seventhGen.map(_.code).getOrElse("")}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
}
/**
* A trait used for resolving nullable flags, including `nullable`, `containsNull` of [[ArrayType]]
* and `valueContainsNull` of [[MapType]], containsNull, valueContainsNull flags of the output date
* type. This is usually utilized by the expressions (e.g. [[CaseWhen]]) that combine data from
* multiple child expressions of non-primitive types.
*/
trait ComplexTypeMergingExpression extends Expression {
/**
* A collection of data types used for resolution the output type of the expression. By default,
* data types of all child expressions. The collection must not be empty.
*/
@transient
lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType)
def dataTypeCheck: Unit = {
SparkException.require(
requirement = inputTypesForMerging.nonEmpty,
errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.NO_INPUTS",
messageParameters = Map("expression" -> toSQLExpr(this)))
SparkException.require(
requirement = TypeCoercion.haveSameType(inputTypesForMerging),
errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.MISMATCHED_TYPES",
messageParameters = Map(
"expression" -> toSQLExpr(this),
"inputTypes" -> inputTypesForMerging.map(toSQLType).mkString("[", ", ", "]")))
}
private lazy val internalDataType: DataType = {
dataTypeCheck
inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_, _).get)
}
override def dataType: DataType = internalDataType
}
/**
* Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
* and Hive function wrappers.
*/
trait UserDefinedExpression {
def name: String
}
trait CommutativeExpression extends Expression {
/** Collects adjacent commutative operations. */
private def gatherCommutative(
e: Expression,
f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = e match {
case c: CommutativeExpression if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
case other => other.canonicalized :: Nil
}
/**
* Reorders adjacent commutative operators such as [[And]] in the expression tree, according to
* the `hashCode` of non-commutative nodes, to remove cosmetic variations.
*/
protected def orderCommutative(
f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] =
gatherCommutative(this, f).sortBy(_.hashCode())
/**
* Helper method to generated a canonicalized plan. If the number of operands are
* greater than the MULTI_COMMUTATIVE_OP_OPT_THRESHOLD, this method creates a
* [[MultiCommutativeOp]] as the canonicalized plan.
*/
protected def buildCanonicalizedPlan(
collectOperands: PartialFunction[Expression, Seq[Expression]],
buildBinaryOp: (Expression, Expression) => Expression,
evalMode: Option[EvalMode.Value] = None): Expression = {
val operands = orderCommutative(collectOperands)
val reorderResult =
if (operands.length < SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)) {
operands.reduce(buildBinaryOp)
} else {
MultiCommutativeOp(operands, this.getClass, evalMode)(this)
}
reorderResult
}
}
/**
* A helper class used by the Commutative expressions during canonicalization. During
* canonicalization, when we have a long tree of commutative operations, we use the MultiCommutative
* expression to represent that tree instead of creating new commutative objects.
* This class is added as a memory optimization for processing large commutative operation trees
* without creating a large number of new intermediate objects.
* The MultiCommutativeOp memory optimization is applied to the following commutative
* expressions:
* Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor.
* @param operands A sequence of operands that produces a commutative expression tree.
* @param opCls The class of the root operator of the expression tree.
* @param evalMode The optional expression evaluation mode.
* @param originalRoot Root operator of the commutative expression tree before canonicalization.
* This object reference is used to deduce the return dataType of Add and
* Multiply operations when the input datatype is decimal.
*/
case class MultiCommutativeOp(
operands: Seq[Expression],
opCls: Class[_],
evalMode: Option[EvalMode.Value])(originalRoot: Expression) extends Unevaluable {
// Helper method to deduce the data type of a single operation.
private def singleOpDataType(lType: DataType, rType: DataType): DataType = {
originalRoot match {
case add: Add =>
(lType, rType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
add.resultDecimalType(p1, s1, p2, s2)
case _ => lType
}
case multiply: Multiply =>
(lType, rType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
multiply.resultDecimalType(p1, s1, p2, s2)
case _ => lType
}
}
}
override def dataType: DataType = {
originalRoot match {
case _: Add | _: Multiply =>
operands.map(_.dataType).reduce((l, r) => singleOpDataType(l, r))
case other => other.dataType
}
}
override def nullable: Boolean = operands.exists(_.nullable)
override def children: Seq[Expression] = operands
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
this.copy(operands = newChildren)(originalRoot)
override protected final def otherCopyArgs: Seq[AnyRef] = originalRoot :: Nil
}