blob: 61c077fd12aa28068247c2bcdaa2f8fea24d5c46 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis
import java.util
import java.util.Locale
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils
/**
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]].
* Used for testing when all relations are already filled in and the analyzer needs only
* to resolve attribute references.
*/
object SimpleAnalyzer extends Analyzer(
new CatalogManager(
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true),
FakeV2SessionCatalog,
new SessionCatalog(
new InMemoryCatalog,
EmptyFunctionRegistry,
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) {
override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {}
}),
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true))
object FakeV2SessionCatalog extends TableCatalog {
private def fail() = throw new UnsupportedOperationException
override def listTables(namespace: Array[String]): Array[Identifier] = fail()
override def loadTable(ident: Identifier): Table = {
throw new NoSuchTableException(ident.toString)
}
override def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): Table = fail()
override def alterTable(ident: Identifier, changes: TableChange*): Table = fail()
override def dropTable(ident: Identifier): Boolean = fail()
override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = fail()
override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = fail()
override def name(): String = CatalogManager.SESSION_CATALOG_NAME
}
/**
* Provides a way to keep state during the analysis, this enables us to decouple the concerns
* of analysis environment from the catalog.
* The state that is kept here is per-query.
*
* Note this is thread local.
*
* @param catalogAndNamespace The catalog and namespace used in the view resolution. This overrides
* the current catalog and namespace when resolving relations inside
* views.
* @param nestedViewDepth The nested depth in the view resolution, this enables us to limit the
* depth of nested views.
* @param relationCache A mapping from qualified table names to resolved relations. This can ensure
* that the table is resolved only once if a table is used multiple times
* in a query.
*/
case class AnalysisContext(
catalogAndNamespace: Seq[String] = Nil,
nestedViewDepth: Int = 0,
relationCache: mutable.Map[Seq[String], LogicalPlan] = mutable.Map.empty)
object AnalysisContext {
private val value = new ThreadLocal[AnalysisContext]() {
override def initialValue: AnalysisContext = AnalysisContext()
}
def get: AnalysisContext = value.get()
def reset(): Unit = value.remove()
private def set(context: AnalysisContext): Unit = value.set(context)
def withAnalysisContext[A](catalogAndNamespace: Seq[String])(f: => A): A = {
val originContext = value.get()
val context = AnalysisContext(
catalogAndNamespace, originContext.nestedViewDepth + 1, originContext.relationCache)
set(context)
try f finally { set(originContext) }
}
}
/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
*/
class Analyzer(
override val catalogManager: CatalogManager,
conf: SQLConf)
extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog {
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)
}
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)
// Only for tests.
def this(catalog: SessionCatalog, conf: SQLConf) = {
this(
new CatalogManager(conf, FakeV2SessionCatalog, catalog),
conf)
}
def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = {
AnalysisHelper.markInAnalyzer {
val analyzed = executeAndTrack(plan, tracker)
try {
checkAnalysis(analyzed)
analyzed
} catch {
case e: AnalysisException =>
val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed))
ae.setStackTrace(e.getStackTrace)
throw ae
}
}
}
override def execute(plan: LogicalPlan): LogicalPlan = {
AnalysisContext.reset()
try {
executeSameContext(plan)
} finally {
AnalysisContext.reset()
}
}
private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan)
def resolver: Resolver = conf.resolver
/**
* If the plan cannot be resolved within maxIterations, analyzer will throw exception to inform
* user to increase the value of SQLConf.ANALYZER_MAX_ITERATIONS.
*/
protected def fixedPoint =
FixedPoint(
conf.analyzerMaxIterations,
errorOnExceed = true,
maxIterationsSetting = SQLConf.ANALYZER_MAX_ITERATIONS.key)
/**
* Override to provide additional rules for the "Resolution" batch.
*/
val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil
/**
* Override to provide rules to do post-hoc resolution. Note that these rules will be executed
* in an individual batch. This batch is to run right after the normal resolution batch and
* execute its rules in one pass.
*/
val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil
override def batches: Seq[Batch] = Seq(
Batch("Substitution", fixedPoint,
// This rule optimizes `UpdateFields` expression chains so looks more like optimization rule.
// However, when manipulating deeply nested schema, `UpdateFields` expression tree could be
// very complex and make analysis impossible. Thus we need to optimize `UpdateFields` early
// at the beginning of analysis.
OptimizeUpdateFields,
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
SubstituteUnresolvedOrdinals),
Batch("Disable Hints", Once,
new ResolveHints.DisableHints),
Batch("Hints", fixedPoint,
ResolveHints.ResolveJoinStrategyHints,
ResolveHints.ResolveCoalesceHints),
Batch("Simple Sanity Check", Once,
LookupFunctions),
Batch("Resolution", fixedPoint,
ResolveTableValuedFunctions ::
ResolveNamespace(catalogManager) ::
new ResolveCatalogs(catalogManager) ::
ResolveInsertInto ::
ResolveRelations ::
ResolveTables ::
ResolveReferences ::
ResolveCreateNamedStruct ::
ResolveDeserializer ::
ResolveNewInstance ::
ResolveUpCast ::
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveOrdinalInOrderByAndGroupBy ::
ResolveAggAliasInGroupBy ::
ResolveMissingReferences ::
ExtractGenerator ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
ResolveSubquery ::
ResolveSubqueryColumnAliases ::
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
ResolveOutputRelation ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables ::
ResolveHigherOrderFunctions(v1SessionCatalog) ::
ResolveLambdaVariables ::
ResolveTimeZone ::
ResolveRandomSeed ::
ResolveBinaryArithmetic ::
ResolveUnion ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once,
Seq(ResolveNoopDropTable) ++
postHocResolutionRules: _*),
Batch("Normalize Alter Table", Once, ResolveAlterTableChanges),
Batch("Remove Unresolved Hints", Once,
new ResolveHints.RemoveAllHints),
Batch("Nondeterministic", Once,
PullOutNondeterministic),
Batch("UDF", Once,
HandleNullInputsForUDF,
ResolveEncodersInUDF),
Batch("UpdateNullability", Once,
UpdateAttributeNullability),
Batch("Subquery", Once,
UpdateOuterReferences),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)
/**
* For [[Add]]:
* 1. if both side are interval, stays the same;
* 2. else if one side is date and the other is interval,
* turns it to [[DateAddInterval]];
* 3. else if one side is interval, turns it to [[TimeAdd]];
* 4. else if one side is date, turns it to [[DateAdd]] ;
* 5. else stays the same.
*
* For [[Subtract]]:
* 1. if both side are interval, stays the same;
* 2. else if the left side is date and the right side is interval,
* turns it to [[DateAddInterval(l, -r)]];
* 3. else if the right side is an interval, turns it to [[TimeAdd(l, -r)]];
* 4. else if one side is timestamp, turns it to [[SubtractTimestamps]];
* 5. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]];
* 6. else if the left side is date, turns it to [[DateSub]];
* 7. else turns it to stays the same.
*
* For [[Multiply]]:
* 1. If one side is interval, turns it to [[MultiplyInterval]];
* 2. otherwise, stays the same.
*
* For [[Divide]]:
* 1. If the left side is interval, turns it to [[DivideInterval]];
* 2. otherwise, stays the same.
*/
object ResolveBinaryArithmetic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p: LogicalPlan => p.transformExpressionsUp {
case a @ Add(l, r) if a.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => a
case (DateType, CalendarIntervalType) => DateAddInterval(l, r)
case (_, CalendarIntervalType) => Cast(TimeAdd(l, r), l.dataType)
case (CalendarIntervalType, DateType) => DateAddInterval(r, l)
case (CalendarIntervalType, _) => Cast(TimeAdd(r, l), r.dataType)
case (DateType, dt) if dt != StringType => DateAdd(l, r)
case (dt, DateType) if dt != StringType => DateAdd(r, l)
case _ => a
}
case s @ Subtract(l, r) if s.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, CalendarIntervalType) => s
case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r)))
case (_, CalendarIntervalType) =>
Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r))), l.dataType)
case (TimestampType, _) => SubtractTimestamps(l, r)
case (_, TimestampType) => SubtractTimestamps(l, r)
case (_, DateType) => SubtractDates(l, r)
case (DateType, dt) if dt != StringType => DateSub(l, r)
case _ => s
}
case m @ Multiply(l, r) if m.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => MultiplyInterval(l, r)
case (_, CalendarIntervalType) => MultiplyInterval(r, l)
case _ => m
}
case d @ Divide(l, r) if d.childrenResolved => (l.dataType, r.dataType) match {
case (CalendarIntervalType, _) => DivideInterval(l, r)
case _ => d
}
}
}
}
/**
* Substitute child plan with WindowSpecDefinitions.
*/
object WindowsSubstitution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
// Lookup WindowSpecDefinitions. This rule works with unresolved children.
case WithWindowDefinition(windowDefinitions, child) => child.resolveExpressions {
case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) =>
val errorMessage =
s"Window specification $windowName is not defined in the WINDOW clause."
val windowSpecDefinition =
windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage))
WindowExpression(c, windowSpecDefinition)
}
}
}
/**
* Replaces [[UnresolvedAlias]]s with concrete aliases.
*/
object ResolveAliases extends Rule[LogicalPlan] {
private def assignAliases(exprs: Seq[NamedExpression]) = {
exprs.map(_.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) =>
child match {
case ne: NamedExpression => ne
case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil)
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)()
case e: ExtractValue => Alias(e, toPrettySQL(e))()
case e if optGenAliasFunc.isDefined =>
Alias(child, optGenAliasFunc.get.apply(e))()
case e => Alias(e, toPrettySQL(e))()
}
}
).asInstanceOf[Seq[NamedExpression]]
}
private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)
case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
g.copy(aggregations = assignAliases(g.aggregations))
case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child)
if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) =>
Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child)
case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
Project(assignAliases(projectList), child)
}
}
object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
/*
* GROUP BY a, b, c WITH ROLLUP
* is equivalent to
* GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (a), ( ) ).
* Group Count: N + 1 (N is the number of group expressions)
*
* We need to get all of its subsets for the rule described above, the subset is
* represented as sequence of expressions.
*/
def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.inits.toIndexedSeq
/*
* GROUP BY a, b, c WITH CUBE
* is equivalent to
* GROUP BY a, b, c GROUPING SETS ( (a, b, c), (a, b), (b, c), (a, c), (a), (b), (c), ( ) ).
* Group Count: 2 ^ N (N is the number of group expressions)
*
* We need to get all of its subsets for a given GROUPBY expression, the subsets are
* represented as sequence of expressions.
*/
def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = {
// `cubeExprs0` is recursive and returns a lazy Stream. Here we call `toIndexedSeq` to
// materialize it and avoid serialization problems later on.
cubeExprs0(exprs).toIndexedSeq
}
def cubeExprs0(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match {
case x :: xs =>
val initial = cubeExprs0(xs)
initial.map(x +: _) ++ initial
case Nil =>
Seq(Seq.empty)
}
private[analysis] def hasGroupingFunction(e: Expression): Boolean = {
e.collectFirst {
case g: Grouping => g
case g: GroupingID => g
}.isDefined
}
private def replaceGroupingFunc(
expr: Expression,
groupByExprs: Seq[Expression],
gid: Expression): Expression = {
expr transform {
case e: GroupingID =>
if (e.groupByExprs.isEmpty ||
e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) {
Alias(gid, toPrettySQL(e))()
} else {
throw new AnalysisException(
s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
s"grouping columns (${groupByExprs.mkString(",")})")
}
case e @ Grouping(col: Expression) =>
val idx = groupByExprs.indexWhere(_.semanticEquals(col))
if (idx >= 0) {
Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
Literal(1L)), ByteType), toPrettySQL(e))()
} else {
throw new AnalysisException(s"Column of grouping ($col) can't be found " +
s"in grouping columns ${groupByExprs.mkString(",")}")
}
}
}
/*
* Create new alias for all group by expressions for `Expand` operator.
*/
private def constructGroupByAlias(groupByExprs: Seq[Expression]): Seq[Alias] = {
groupByExprs.map {
case e: NamedExpression => Alias(e, e.name)(qualifier = e.qualifier)
case other => Alias(other, other.toString)()
}
}
/*
* Construct [[Expand]] operator with grouping sets.
*/
private def constructExpand(
selectedGroupByExprs: Seq[Seq[Expression]],
child: LogicalPlan,
groupByAliases: Seq[Alias],
gid: Attribute): LogicalPlan = {
// Change the nullability of group by aliases if necessary. For example, if we have
// GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we
// should change the nullability of b to be TRUE.
// TODO: For Cube/Rollup just set nullability to be `true`.
val expandedAttributes = groupByAliases.map { alias =>
if (selectedGroupByExprs.exists(!_.contains(alias.child))) {
alias.toAttribute.withNullability(true)
} else {
alias.toAttribute
}
}
val groupingSetsAttributes = selectedGroupByExprs.map { groupingSetExprs =>
groupingSetExprs.map { expr =>
val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse(
failAnalysis(s"$expr doesn't show up in the GROUP BY list $groupByAliases"))
// Map alias to expanded attribute.
expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse(
alias.toAttribute)
}
}
Expand(groupingSetsAttributes, groupByAliases, expandedAttributes, gid, child)
}
/*
* Construct new aggregate expressions by replacing grouping functions.
*/
private def constructAggregateExprs(
groupByExprs: Seq[Expression],
aggregations: Seq[NamedExpression],
groupByAliases: Seq[Alias],
groupingAttrs: Seq[Expression],
gid: Attribute): Seq[NamedExpression] = aggregations.map {
// collect all the found AggregateExpression, so we can check an expression is part of
// any AggregateExpression or not.
val aggsBuffer = ArrayBuffer[Expression]()
// Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
}
replaceGroupingFunc(_, groupByExprs, gid).transformDown {
// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
// inside it.
case e: AggregateExpression =>
aggsBuffer += e
e
case e if isPartOfAggregation(e) => e
case e =>
// Replace expression by expand output attribute.
val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
if (index == -1) {
e
} else {
groupingAttrs(index)
}
}.asInstanceOf[NamedExpression]
}
private def getFinalGroupByExpressions(
selectedGroupByExprs: Seq[Seq[Expression]],
groupByExprs: Seq[Expression]): Seq[Expression] = {
// In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and
// can be null. In such case, we derive the groupByExprs from the user supplied values for
// grouping sets.
if (groupByExprs == Nil) {
selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) =>
// Only unique expressions are included in the group by expressions and is determined
// based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results
// in grouping expression (a * b)
if (result.find(_.semanticEquals(currentExpr)).isDefined) {
result
} else {
result :+ currentExpr
}
}
} else {
groupByExprs
}
}
/*
* Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
*/
private def constructAggregate(
selectedGroupByExprs: Seq[Seq[Expression]],
groupByExprs: Seq[Expression],
aggregationExprs: Seq[NamedExpression],
child: LogicalPlan): LogicalPlan = {
val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs)
if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) {
throw new AnalysisException(
s"Grouping sets size cannot be greater than ${GroupingID.dataType.defaultSize * 8}")
}
// Expand works by setting grouping expressions to null as determined by the
// `selectedGroupByExprs`. To prevent these null values from being used in an aggregate
// instead of the original value we need to create new aliases for all group by expressions
// that will only be used for the intended purpose.
val groupByAliases = constructGroupByAlias(finalGroupByExpressions)
val gid = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid)
val groupingAttrs = expand.output.drop(child.output.length)
val aggregations = constructAggregateExprs(
finalGroupByExpressions, aggregationExprs, groupByAliases, groupingAttrs, gid)
Aggregate(groupingAttrs, aggregations, expand)
}
private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
plan.collectFirst {
case a: Aggregate =>
// this Aggregate should have grouping id as the last grouping key.
val gid = a.groupingExpressions.last
if (!gid.isInstanceOf[AttributeReference]
|| gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) {
failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
}
a.groupingExpressions.take(a.groupingExpressions.length - 1)
}.getOrElse {
failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
}
}
private def tryResolveHavingCondition(h: UnresolvedHaving): LogicalPlan = {
val aggForResolving = h.child match {
// For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from
// groupingExpressions for condition resolving.
case a @ Aggregate(Seq(c @ Cube(groupByExprs)), _, _) =>
a.copy(groupingExpressions = groupByExprs)
case a @ Aggregate(Seq(r @ Rollup(groupByExprs)), _, _) =>
a.copy(groupingExpressions = groupByExprs)
case g: GroupingSets =>
Aggregate(
getFinalGroupByExpressions(g.selectedGroupByExprs, g.groupByExprs),
g.aggregations, g.child)
}
// Try resolving the condition of the filter as though it is in the aggregate clause
val resolvedInfo =
ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving)
// Push the aggregate expressions into the aggregate (if any).
if (resolvedInfo.nonEmpty) {
val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get
val newChild = h.child match {
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case x: GroupingSets =>
constructAggregate(
x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
}
// Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the
// aggregateExpressions keeps the input order. So here we build an exprMap to resolve the
// condition again.
val exprMap = extraAggExprs.zip(
newChild.asInstanceOf[Aggregate].aggregateExpressions.takeRight(
extraAggExprs.length)).toMap
val newCond = resolvedHavingCond.transform {
case ne: NamedExpression if exprMap.contains(ne) => exprMap(ne)
}
Project(newChild.output.dropRight(extraAggExprs.length),
Filter(newCond, newChild))
} else {
h
}
}
// This require transformDown to resolve having condition when generating aggregate node for
// CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved
// Filter/Sort.
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
case h @ UnresolvedHaving(
_, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _))
if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
tryResolveHavingCondition(h)
case h @ UnresolvedHaving(
_, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _))
if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
tryResolveHavingCondition(h)
case h @ UnresolvedHaving(_, g: GroupingSets)
if g.childrenResolved && g.expressions.forall(_.resolved) =>
tryResolveHavingCondition(h)
case a if !a.childrenResolved => a // be sure all of the children are resolved.
// Ensure group by expressions and aggregate expressions have been resolved.
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child)
if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
constructAggregate(cubeExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child)
if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
constructAggregate(rollupExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
// Ensure all the expressions have been resolved.
case x: GroupingSets if x.expressions.forall(_.resolved) =>
constructAggregate(x.selectedGroupByExprs, x.groupByExprs, x.aggregations, x.child)
// We should make sure all expressions in condition have been resolved.
case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved =>
val groupingExprs = findGroupingExprs(child)
// The unresolved grouping id will be resolved by ResolveMissingReferences
val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
f.copy(condition = newCond)
// We should make sure all [[SortOrder]]s have been resolved.
case s @ Sort(order, _, child)
if order.exists(hasGroupingFunction) && order.forall(_.resolved) =>
val groupingExprs = findGroupingExprs(child)
val gid = VirtualColumn.groupingIdAttribute
// The unresolved grouping id will be resolved by ResolveMissingReferences
val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
s.copy(order = newOrder)
}
}
object ResolvePivot extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved)
|| (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved))
|| !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p
case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
if (!RowOrdering.isOrderable(pivotColumn.dataType)) {
throw new AnalysisException(
s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.")
}
// Check all aggregate expressions.
aggregates.foreach(checkValidAggregateExpression)
// Check all pivot values are literal and match pivot column data type.
val evalPivotValues = pivotValues.map { value =>
val foldable = value match {
case Alias(v, _) => v.foldable
case _ => value.foldable
}
if (!foldable) {
throw new AnalysisException(
s"Literal expressions required for pivot values, found '$value'")
}
if (!Cast.canCast(value.dataType, pivotColumn.dataType)) {
throw new AnalysisException(s"Invalid pivot value '$value': " +
s"value data type ${value.dataType.simpleString} does not match " +
s"pivot column data type ${pivotColumn.dataType.catalogString}")
}
Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
}
// Group-by expressions coming from SQL are implicit and need to be deduced.
val groupByExprs = groupByExprsOpt.getOrElse {
val pivotColAndAggRefs = pivotColumn.references ++ AttributeSet(aggregates)
child.output.filterNot(pivotColAndAggRefs.contains)
}
val singleAgg = aggregates.size == 1
def outputName(value: Expression, aggregate: Expression): String = {
val stringValue = value match {
case n: NamedExpression => n.name
case _ =>
val utf8Value =
Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow)
Option(utf8Value).map(_.toString).getOrElse("null")
}
if (singleAgg) {
stringValue
} else {
val suffix = aggregate match {
case n: NamedExpression => n.name
case _ => toPrettySQL(aggregate)
}
stringValue + "_" + suffix
}
}
if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {
// Since evaluating |pivotValues| if statements for each input row can get slow this is an
// alternate plan that instead uses two steps of aggregation.
val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)())
val namedPivotCol = pivotColumn match {
case n: NamedExpression => n
case _ => Alias(pivotColumn, "__pivot_col")()
}
val bigGroup = groupByExprs :+ namedPivotCol
val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
val pivotAggs = namedAggExps.map { a =>
Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues)
.toAggregateExpression()
, "__pivot_" + a.sql)()
}
val groupByExprsAttr = groupByExprs.map(_.toAttribute)
val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg)
val pivotAggAttribute = pivotAggs.map(_.toAttribute)
val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) =>
aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) =>
Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))()
}
}
Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
} else {
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
def ifExpr(e: Expression) = {
If(
EqualNullSafe(
pivotColumn,
Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))),
e, Literal(null))
}
aggregates.map { aggregate =>
val filteredAggregate = aggregate.transformDown {
// Assumption is the aggregate function ignores nulls. This is true for all current
// AggregateFunction's with the exception of First and Last in their default mode
// (which we handle) and possibly some Hive UDAF's.
case First(expr, _) =>
First(ifExpr(expr), true)
case Last(expr, _) =>
Last(ifExpr(expr), true)
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
}.transform {
// We are duplicating aggregates that are now computing a different value for each
// pivot value.
// TODO: Don't construct the physical container until after analysis.
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
}
Alias(filteredAggregate, outputName(value, aggregate))()
}
}
Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
}
}
// Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF.
// TODO: Support Pandas UDF.
private def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis.
case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) =>
failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.")
case e: Attribute =>
failAnalysis(
s"Aggregate expression required for pivot, but '${e.sql}' " +
s"did not appear in any aggregate function.")
case e => e.children.foreach(checkValidAggregateExpression)
}
}
case class ResolveNamespace(catalogManager: CatalogManager)
extends Rule[LogicalPlan] with LookupCatalog {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case s @ ShowTables(UnresolvedNamespace(Seq()), _) =>
s.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace))
case s @ ShowViews(UnresolvedNamespace(Seq()), _) =>
s.copy(namespace = ResolvedNamespace(currentCatalog, catalogManager.currentNamespace))
case UnresolvedNamespace(Seq()) =>
ResolvedNamespace(currentCatalog, Seq.empty[String])
case UnresolvedNamespace(CatalogAndNamespace(catalog, ns)) =>
ResolvedNamespace(catalog, ns)
}
}
private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty
/**
* Resolve relations to temp views. This is not an actual rule, and is called by
* [[ResolveTables]] and [[ResolveRelations]].
*/
object ResolveTempViews extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case u @ UnresolvedRelation(ident, _, isStreaming) =>
lookupTempView(ident, isStreaming).getOrElse(u)
case i @ InsertIntoStatement(UnresolvedRelation(ident, _, false), _, _, _, _) =>
lookupTempView(ident)
.map(view => i.copy(table = view))
.getOrElse(i)
case u @ UnresolvedTable(ident) =>
lookupTempView(ident).foreach { _ =>
u.failAnalysis(s"${ident.quoted} is a temp view not table.")
}
u
case u @ UnresolvedTableOrView(ident) =>
lookupTempView(ident)
.map(_ => ResolvedView(ident.asIdentifier, isTemp = true))
.getOrElse(u)
}
def lookupTempView(
identifier: Seq[String], isStreaming: Boolean = false): Option[LogicalPlan] = {
// Permanent View can't refer to temp views, no need to lookup at all.
if (isResolvingView) return None
val tmpView = identifier match {
case Seq(part1) => v1SessionCatalog.lookupTempView(part1)
case Seq(part1, part2) => v1SessionCatalog.lookupGlobalTempView(part1, part2)
case _ => None
}
if (isStreaming && tmpView.nonEmpty && !tmpView.get.isStreaming) {
throw new AnalysisException(s"${identifier.quoted} is not a temp view of streaming " +
s"logical plan, please use batch API such as `DataFrameReader.table` to read it.")
}
tmpView
}
}
// If we are resolving relations insides views, we need to expand single-part relation names with
// the current catalog and namespace of when the view was created.
private def expandRelationName(nameParts: Seq[String]): Seq[String] = {
if (!isResolvingView) return nameParts
if (nameParts.length == 1) {
AnalysisContext.get.catalogAndNamespace :+ nameParts.head
} else if (catalogManager.isCatalogRegistered(nameParts.head)) {
nameParts
} else {
AnalysisContext.get.catalogAndNamespace.head +: nameParts
}
}
/**
* Resolve table relations with concrete relations from v2 catalog.
*
* [[ResolveRelations]] still resolves v1 tables.
*/
object ResolveTables extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp {
case u: UnresolvedRelation =>
lookupV2Relation(u.multipartIdentifier, u.options, u.isStreaming)
.map { relation =>
val (catalog, ident) = relation match {
case ds: DataSourceV2Relation => (ds.catalog, ds.identifier.get)
case s: StreamingRelationV2 => (s.catalog, s.identifier.get)
}
SubqueryAlias(catalog.get.name +: ident.namespace :+ ident.name, relation)
}.getOrElse(u)
case u @ UnresolvedTable(NonSessionCatalogAndIdentifier(catalog, ident)) =>
CatalogV2Util.loadTable(catalog, ident)
.map(ResolvedTable(catalog.asTableCatalog, ident, _))
.getOrElse(u)
case u @ UnresolvedTableOrView(NonSessionCatalogAndIdentifier(catalog, ident)) =>
CatalogV2Util.loadTable(catalog, ident)
.map(ResolvedTable(catalog.asTableCatalog, ident, _))
.getOrElse(u)
case i @ InsertIntoStatement(u @ UnresolvedRelation(_, _, false), _, _, _, _)
if i.query.resolved =>
lookupV2Relation(u.multipartIdentifier, u.options, false)
.map(v2Relation => i.copy(table = v2Relation))
.getOrElse(i)
case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) =>
CatalogV2Util.loadRelation(u.catalog, u.tableName)
.map(rel => alter.copy(table = rel))
.getOrElse(alter)
case u: UnresolvedV2Relation =>
CatalogV2Util.loadRelation(u.catalog, u.tableName).getOrElse(u)
}
/**
* Performs the lookup of DataSourceV2 Tables from v2 catalog.
*/
private def lookupV2Relation(
identifier: Seq[String],
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] =
expandRelationName(identifier) match {
case NonSessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) =>
if (isStreaming) {
Some(StreamingRelationV2(None, table.name, table, options,
table.schema.toAttributes, Some(catalog), Some(ident), None))
} else {
Some(DataSourceV2Relation.create(table, Some(catalog), Some(ident), options))
}
case None => None
}
case _ => None
}
}
/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
object ResolveRelations extends Rule[LogicalPlan] {
// The current catalog and namespace may be different from when the view was created, we must
// resolve the view logical plan here, with the catalog and namespace stored in view metadata.
// This is done by keeping the catalog and namespace in `AnalysisContext`, and analyzer will
// look at `AnalysisContext.catalogAndNamespace` when resolving relations with single-part name.
// If `AnalysisContext.catalogAndNamespace` is non-empty, analyzer will expand single-part names
// with it, instead of current catalog and namespace.
private def resolveViews(plan: LogicalPlan): LogicalPlan = plan match {
// The view's child should be a logical plan parsed from the `desc.viewText`, the variable
// `viewText` should be defined, or else we throw an error on the generation of the View
// operator.
case view @ View(desc, _, child) if !child.resolved =>
// Resolve all the UnresolvedRelations and Views in the child.
val newChild = AnalysisContext.withAnalysisContext(desc.viewCatalogAndNamespace) {
if (AnalysisContext.get.nestedViewDepth > conf.maxNestedViewDepth) {
view.failAnalysis(s"The depth of view ${desc.identifier} exceeds the maximum " +
s"view resolution depth (${conf.maxNestedViewDepth}). Analysis is aborted to " +
s"avoid errors. Increase the value of ${SQLConf.MAX_NESTED_VIEW_DEPTH.key} to work " +
"around this.")
}
executeSameContext(child)
}
view.copy(child = newChild)
case p @ SubqueryAlias(_, view: View) =>
p.copy(child = resolveViews(view))
case _ => plan
}
def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp {
case i @ InsertIntoStatement(table, _, _, _, _) if i.query.resolved =>
val relation = table match {
case u @ UnresolvedRelation(_, _, false) =>
lookupRelation(u.multipartIdentifier, u.options, false).getOrElse(u)
case other => other
}
EliminateSubqueryAliases(relation) match {
case v: View =>
table.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.")
case other => i.copy(table = other)
}
case u: UnresolvedRelation =>
lookupRelation(u.multipartIdentifier, u.options, u.isStreaming)
.map(resolveViews).getOrElse(u)
case u @ UnresolvedTable(identifier) =>
lookupTableOrView(identifier).map {
case v: ResolvedView =>
val viewStr = if (v.isTemp) "temp view" else "view"
u.failAnalysis(s"${v.identifier.quoted} is a $viewStr not table.")
case table => table
}.getOrElse(u)
case u @ UnresolvedTableOrView(identifier) =>
lookupTableOrView(identifier).getOrElse(u)
}
private def lookupTableOrView(identifier: Seq[String]): Option[LogicalPlan] = {
expandRelationName(identifier) match {
case SessionCatalogAndIdentifier(catalog, ident) =>
CatalogV2Util.loadTable(catalog, ident).map {
case v1Table: V1Table if v1Table.v1Table.tableType == CatalogTableType.VIEW =>
ResolvedView(ident, isTemp = false)
case table =>
ResolvedTable(catalog.asTableCatalog, ident, table)
}
case _ => None
}
}
// Look up a relation from the session catalog with the following logic:
// 1) If the resolved catalog is not session catalog, return None.
// 2) If a relation is not found in the catalog, return None.
// 3) If a v1 table is found, create a v1 relation. Otherwise, create a v2 relation.
private def lookupRelation(
identifier: Seq[String],
options: CaseInsensitiveStringMap,
isStreaming: Boolean): Option[LogicalPlan] = {
expandRelationName(identifier) match {
case SessionCatalogAndIdentifier(catalog, ident) =>
lazy val loaded = CatalogV2Util.loadTable(catalog, ident).map {
case v1Table: V1Table =>
if (isStreaming) {
if (v1Table.v1Table.tableType == CatalogTableType.VIEW) {
throw new AnalysisException(s"${identifier.quoted} is a permanent view, " +
"which is not supported by streaming reading API such as " +
"`DataStreamReader.table` yet.")
}
SubqueryAlias(
catalog.name +: ident.asMultipartIdentifier,
UnresolvedCatalogRelation(v1Table.v1Table, options, isStreaming = true))
} else {
v1SessionCatalog.getRelation(v1Table.v1Table, options)
}
case table =>
if (isStreaming) {
val v1Fallback = table match {
case withFallback: V2TableWithV1Fallback =>
Some(UnresolvedCatalogRelation(withFallback.v1Table, isStreaming = true))
case _ => None
}
SubqueryAlias(
catalog.name +: ident.asMultipartIdentifier,
StreamingRelationV2(None, table.name, table, options, table.schema.toAttributes,
Some(catalog), Some(ident), v1Fallback))
} else {
SubqueryAlias(
catalog.name +: ident.asMultipartIdentifier,
DataSourceV2Relation.create(table, Some(catalog), Some(ident), options))
}
}
val key = catalog.name +: ident.namespace :+ ident.name
AnalysisContext.get.relationCache.get(key).map(_.transform {
case multi: MultiInstanceRelation => multi.newInstance()
}).orElse {
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
loaded
}
case _ => None
}
}
}
object ResolveInsertInto extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) if i.query.resolved =>
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw new AnalysisException(
s"Cannot write, IF NOT EXISTS is not supported for table: ${r.table.name}")
}
val partCols = partitionColumnNames(r.table)
validatePartitionSpec(partCols, i.partitionSpec)
val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get).toMap
val query = addStaticPartitionColumns(r, i.query, staticPartitions)
if (!i.overwrite) {
AppendData.byPosition(r, query)
} else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) {
OverwritePartitionsDynamic.byPosition(r, query)
} else {
OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions))
}
}
private def partitionColumnNames(table: Table): Seq[String] = {
// get partition column names. in v2, partition columns are columns that are stored using an
// identity partition transform because the partition values and the column values are
// identical. otherwise, partition values are produced by transforming one or more source
// columns and cannot be set directly in a query's PARTITION clause.
table.partitioning.flatMap {
case IdentityTransform(FieldReference(Seq(name))) => Some(name)
case _ => None
}
}
private def validatePartitionSpec(
partitionColumnNames: Seq[String],
partitionSpec: Map[String, Option[String]]): Unit = {
// check that each partition name is a partition column. otherwise, it is not valid
partitionSpec.keySet.foreach { partitionName =>
partitionColumnNames.find(name => conf.resolver(name, partitionName)) match {
case Some(_) =>
case None =>
throw new AnalysisException(
s"PARTITION clause cannot contain a non-partition column name: $partitionName")
}
}
}
private def addStaticPartitionColumns(
relation: DataSourceV2Relation,
query: LogicalPlan,
staticPartitions: Map[String, String]): LogicalPlan = {
if (staticPartitions.isEmpty) {
query
} else {
// add any static value as a literal column
val withStaticPartitionValues = {
// for each static name, find the column name it will replace and check for unknowns.
val outputNameToStaticName = staticPartitions.keySet.map(staticName =>
relation.output.find(col => conf.resolver(col.name, staticName)) match {
case Some(attr) =>
attr.name -> staticName
case _ =>
throw new AnalysisException(
s"Cannot add static value for unknown column: $staticName")
}).toMap
val queryColumns = query.output.iterator
// for each output column, add the static value as a literal, or use the next input
// column. this does not fail if input columns are exhausted and adds remaining columns
// at the end. both cases will be caught by ResolveOutputRelation and will fail the
// query with a helpful error message.
relation.output.flatMap { col =>
outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match {
case Some(staticValue) =>
Some(Alias(Cast(Literal(staticValue), col.dataType), col.name)())
case _ if queryColumns.hasNext =>
Some(queryColumns.next)
case _ =>
None
}
} ++ queryColumns
}
Project(withStaticPartitionValues, query)
}
}
private def staticDeleteExpression(
relation: DataSourceV2Relation,
staticPartitions: Map[String, String]): Expression = {
if (staticPartitions.isEmpty) {
Literal(true)
} else {
staticPartitions.map { case (name, value) =>
relation.output.find(col => conf.resolver(col.name, name)) match {
case Some(attr) =>
// the delete expression must reference the table's column names, but these attributes
// are not available when CheckAnalysis runs because the relation is not a child of
// the logical operation. instead, expressions are resolved after
// ResolveOutputRelation runs, using the query's column names that will match the
// table names at that point. because resolution happens after a future rule, create
// an UnresolvedAttribute.
EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType))
case None =>
throw new AnalysisException(s"Unknown static partition column: $name")
}
}.reduce(And)
}
}
}
/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
*/
object ResolveReferences extends Rule[LogicalPlan] {
/**
* Generate a new logical plan for the right child with different expression IDs
* for all conflicting attributes.
*/
private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
s"between $left and $right")
/**
* For LogicalPlan likes MultiInstanceRelation, Project, Aggregate, etc, whose output doesn't
* inherit directly from its children, we could just stop collect on it. Because we could
* always replace all the lower conflict attributes with the new attributes from the new
* plan. Theoretically, we should do recursively collect for Generate and Window but we leave
* it to the next batch to reduce possible overhead because this should be a corner case.
*/
def collectConflictPlans(plan: LogicalPlan): Seq[(LogicalPlan, LogicalPlan)] = plan match {
// Handle base relations that might appear more than once.
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.newInstance()
Seq((oldVersion, newVersion))
case oldVersion: SerializeFromObject
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(
serializer = oldVersion.serializer.map(_.newInstance()))))
// Handle projects that create conflicting aliases.
case oldVersion @ Project(projectList, _)
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList))))
// We don't need to search child plan recursively if the projectList of a Project
// is only composed of Alias and doesn't contain any conflicting attributes.
// Because, even if the child plan has some conflicting attributes, the attributes
// will be aliased to non-conflicting attributes by the Project at the end.
case _ @ Project(projectList, _)
if findAliases(projectList).size == projectList.size =>
Nil
case oldVersion @ Aggregate(_, aggregateExpressions, _)
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(
aggregateExpressions = newAliases(aggregateExpressions))))
// We don't search the child plan recursively for the same reason as the above Project.
case _ @ Aggregate(_, aggregateExpressions, _)
if findAliases(aggregateExpressions).size == aggregateExpressions.size =>
Nil
case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
case oldVersion: Generate
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput)))
case oldVersion: Expand
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val producedAttributes = oldVersion.producedAttributes
val newOutput = oldVersion.output.map { attr =>
if (producedAttributes.contains(attr)) {
attr.newInstance()
} else {
attr
}
}
Seq((oldVersion, oldVersion.copy(output = newOutput)))
case oldVersion @ Window(windowExpressions, _, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
Seq((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))))
case _ => plan.children.flatMap(collectConflictPlans)
}
val conflictPlans = collectConflictPlans(right)
/*
* Note that it's possible `conflictPlans` can be empty which implies that there
* is a logical plan node that produces new references that this rule cannot handle.
* When that is the case, there must be another rule that resolves these conflicts.
* Otherwise, the analysis will fail.
*/
if (conflictPlans.isEmpty) {
right
} else {
val planMapping = conflictPlans.toMap
right.transformUpWithNewOutput {
case oldPlan =>
val newPlanOpt = planMapping.get(oldPlan)
newPlanOpt.map { newPlan =>
newPlan -> oldPlan.output.zip(newPlan.output)
}.getOrElse(oldPlan -> Nil)
}
}
}
/**
* Resolves the attribute and extract value expressions(s) by traversing the
* input expression in top down manner. The traversal is done in top-down manner as
* we need to skip over unbound lambda function expression. The lambda expressions are
* resolved in a different rule [[ResolveLambdaVariables]]
*
* Example :
* SELECT transform(array(1, 2, 3), (x, i) -> x + i)"
*
* In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]
*
* Note : In this routine, the unresolved attributes are resolved from the input plan's
* children attributes.
*
* @param e The expression need to be resolved.
* @param q The LogicalPlan whose children are used to resolve expression's attribute.
* @param trimAlias When true, trim unnecessary alias of `GetStructField`. Note that,
* we cannot trim the alias of top-level `GetStructField`, as we should
* resolve `UnresolvedAttribute` to a named expression. The caller side
* can trim the alias of top-level `GetStructField` if it's safe to do so.
* @return resolved Expression.
*/
private def resolveExpressionTopDown(
e: Expression,
q: LogicalPlan,
trimAlias: Boolean = false): Expression = {
def innerResolve(e: Expression, isTopLevel: Boolean): Expression = {
if (e.resolved) return e
e match {
case f: LambdaFunction if !f.bound => f
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val resolved =
withPosition(u) {
q.resolveChildren(nameParts, resolver)
.orElse(resolveLiteralFunction(nameParts, u, q))
.getOrElse(u)
}
val result = resolved match {
// As the comment of method `resolveExpressionTopDown`'s param `trimAlias` said,
// when trimAlias = true, we will trim unnecessary alias of `GetStructField` and
// we won't trim the alias of top-level `GetStructField`. Since we will call
// CleanupAliases later in Analyzer, trim non top-level unnecessary alias of
// `GetStructField` here is safe.
case Alias(s: GetStructField, _) if trimAlias && !isTopLevel => s
case others => others
}
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
ExtractValue(child, fieldExpr, resolver)
case _ => e.mapChildren(innerResolve(_, isTopLevel = false))
}
}
innerResolve(e, isTopLevel = true)
}
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p: LogicalPlan if !p.childrenResolved => p
// If the projection list contains Stars, expand it.
case p: Project if containsStar(p.projectList) =>
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) {
failAnalysis(
"Star (*) is not allowed in select list when GROUP BY ordinal position is used")
} else {
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
}
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)
case g: Generate if containsStar(g.generator.children) =>
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) =>
val leftRes = leftAttributes
.map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute])
val rightRes = rightAttributes
.map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute])
f.copy(leftAttributes = leftRes, rightAttributes = rightRes)
// intersect/except will be rewritten to join at the begininng of optimizer. Here we need to
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))
case e @ Except(left, right, _) if !e.duplicateResolved =>
e.copy(right = dedupRight(left, right))
// Only after we finish by-name resolution for Union
case u: Union if !u.byName && !u.duplicateResolved =>
// Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing
// feature in streaming.
val newChildren = u.children.foldRight(Seq.empty[LogicalPlan]) { (head, tail) =>
head +: tail.map {
case child if head.outputSet.intersect(child.outputSet).isEmpty =>
child
case child =>
val projectList = child.output.map { attr =>
Alias(attr, attr.name)()
}
Project(projectList, child)
}
}
u.copy(children = newChildren)
// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on its descendants
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
val newOrdering =
ordering.map(order => resolveExpressionBottomUp(order, child).asInstanceOf[SortOrder])
Sort(newOrdering, global, child)
// A special case for Generate, because the output of Generate should not be resolved by
// ResolveReferences. Attributes in the output will be resolved by ResolveGenerate.
case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g
case g @ Generate(generator, join, outer, qualifier, output, child) =>
val newG = resolveExpressionBottomUp(generator, child, throws = true)
if (newG.fastEquals(generator)) {
g
} else {
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}
// Skips plan which contains deserializer expressions, as they should be resolved by another
// rule: ResolveDeserializer.
case plan if containsDeserializer(plan.expressions) => plan
// SPARK-31607: Resolve Struct field in groupByExpressions and aggregateExpressions
// with CUBE/ROLLUP will be wrapped with alias like Alias(GetStructField, name) with
// different ExprId. This cause aggregateExpressions can't be replaced by expanded
// groupByExpressions in `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim
// unnecessary alias of GetStructField here.
case a: Aggregate =>
val planForResolve = a.child match {
// SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of
// `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute
// names leading to ambiguous references exception.
case appendColumns: AppendColumns => appendColumns
case _ => a
}
val resolvedGroupingExprs = a.groupingExpressions
.map(resolveExpressionTopDown(_, planForResolve, trimAlias = true))
.map(trimTopLevelGetStructFieldAlias)
val resolvedAggExprs = a.aggregateExpressions
.map(resolveExpressionTopDown(_, planForResolve, trimAlias = true))
.map(_.asInstanceOf[NamedExpression])
a.copy(resolvedGroupingExprs, resolvedAggExprs, a.child)
// SPARK-31607: Resolve Struct field in selectedGroupByExprs/groupByExprs and aggregations
// will be wrapped with alias like Alias(GetStructField, name) with different ExprId.
// This cause aggregateExpressions can't be replaced by expanded groupByExpressions in
// `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim unnecessary alias
// of GetStructField here.
case g: GroupingSets =>
val resolvedSelectedExprs = g.selectedGroupByExprs
.map(_.map(resolveExpressionTopDown(_, g, trimAlias = true))
.map(trimTopLevelGetStructFieldAlias))
val resolvedGroupingExprs = g.groupByExprs
.map(resolveExpressionTopDown(_, g, trimAlias = true))
.map(trimTopLevelGetStructFieldAlias)
val resolvedAggExprs = g.aggregations
.map(resolveExpressionTopDown(_, g, trimAlias = true))
.map(_.asInstanceOf[NamedExpression])
g.copy(resolvedSelectedExprs, resolvedGroupingExprs, g.child, resolvedAggExprs)
case o: OverwriteByExpression if !o.outputResolved =>
// do not resolve expression attributes until the query attributes are resolved against the
// table by ResolveOutputRelation. that rule will alias the attributes to the table's names.
o
case m @ MergeIntoTable(targetTable, sourceTable, _, _, _)
if !m.resolved && targetTable.resolved && sourceTable.resolved =>
EliminateSubqueryAliases(targetTable) match {
case r: NamedRelation if r.skipSchemaResolution =>
// Do not resolve the expression if the target table accepts any schema.
// This allows data sources to customize their own resolution logic using
// custom resolution rules.
m
case _ =>
val newMatchedActions = m.matchedActions.map {
case DeleteAction(deleteCondition) =>
val resolvedDeleteCondition = deleteCondition.map(resolveExpressionTopDown(_, m))
DeleteAction(resolvedDeleteCondition)
case UpdateAction(updateCondition, assignments) =>
val resolvedUpdateCondition = updateCondition.map(resolveExpressionTopDown(_, m))
// The update value can access columns from both target and source tables.
UpdateAction(
resolvedUpdateCondition,
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = false))
case o => o
}
val newNotMatchedActions = m.notMatchedActions.map {
case InsertAction(insertCondition, assignments) =>
// The insert action is used when not matched, so its condition and value can only
// access columns from the source table.
val resolvedInsertCondition =
insertCondition.map(resolveExpressionTopDown(_, Project(Nil, m.sourceTable)))
InsertAction(
resolvedInsertCondition,
resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true))
case o => o
}
val resolvedMergeCondition = resolveExpressionTopDown(m.mergeCondition, m)
m.copy(mergeCondition = resolvedMergeCondition,
matchedActions = newMatchedActions,
notMatchedActions = newNotMatchedActions)
}
// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
case h: UnresolvedHaving => h
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
q.mapExpressions(resolveExpressionTopDown(_, q))
}
def resolveAssignments(
assignments: Seq[Assignment],
mergeInto: MergeIntoTable,
resolveValuesWithSourceOnly: Boolean): Seq[Assignment] = {
if (assignments.isEmpty) {
val expandedColumns = mergeInto.targetTable.output
val expandedValues = mergeInto.sourceTable.output
expandedColumns.zip(expandedValues).map(kv => Assignment(kv._1, kv._2))
} else {
assignments.map { assign =>
val resolvedKey = assign.key match {
case c if !c.resolved =>
resolveExpressionTopDown(c, Project(Nil, mergeInto.targetTable))
case o => o
}
val resolvedValue = assign.value match {
// The update values may contain target and/or source references.
case c if !c.resolved =>
if (resolveValuesWithSourceOnly) {
resolveExpressionTopDown(c, Project(Nil, mergeInto.sourceTable))
} else {
resolveExpressionTopDown(c, mergeInto)
}
case o => o
}
Assignment(resolvedKey, resolvedValue)
}
}
}
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)()
case other => other
}
}
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
}
// This method is used to trim groupByExpressions/selectedGroupByExpressions's top-level
// GetStructField Alias. Since these expression are not NamedExpression originally,
// we are safe to trim top-level GetStructField Alias.
def trimTopLevelGetStructFieldAlias(e: Expression): Expression = {
e match {
case Alias(s: GetStructField, _) => s
case other => other
}
}
/**
* Build a project list for Project/Aggregate and expand the star if possible
*/
private def buildExpandedProjectList(
exprs: Seq[NamedExpression],
child: LogicalPlan): Seq[NamedExpression] = {
exprs.flatMap {
// Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
case s: Star => s.expand(child, resolver)
// Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
case o => o :: Nil
}.map(_.asInstanceOf[NamedExpression])
}
/**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
/**
* Expands the matching attribute.*'s in `child`'s output.
*/
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
expr.transformUp {
case f1: UnresolvedFunction if containsStar(f1.arguments) =>
f1.copy(arguments = f1.arguments.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateNamedStruct if containsStar(c.valExprs) =>
val newChildren = c.children.grouped(2).flatMap {
case Seq(k, s : Star) => CreateStruct(s.expand(child, resolver)).children
case kv => kv
}
c.copy(children = newChildren.toList )
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case p: Murmur3Hash if containsStar(p.children) =>
p.copy(children = p.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case p: XxHash64 if containsStar(p.children) =>
p.copy(children = p.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
// count(*) has been replaced by count(1)
case o if containsStar(o.children) =>
failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
}
}
}
private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
}
/**
* Literal functions do not require the user to specify braces when calling them
* When an attributes is not resolvable, we try to resolve it as a literal function.
*/
private def resolveLiteralFunction(
nameParts: Seq[String],
attribute: UnresolvedAttribute,
plan: LogicalPlan): Option[Expression] = {
if (nameParts.length != 1) return None
val isNamedExpression = plan match {
case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.contains(attribute)
case Project(projectList, _) => projectList.contains(attribute)
case Window(windowExpressions, _, _, _) => windowExpressions.contains(attribute)
case _ => false
}
val wrapper: Expression => Expression =
if (isNamedExpression) f => Alias(f, toPrettySQL(f))() else identity
// support CURRENT_DATE and CURRENT_TIMESTAMP
val literalFunctions = Seq(CurrentDate(), CurrentTimestamp())
val name = nameParts.head
val func = literalFunctions.find(e => caseInsensitiveResolution(e.prettyName, name))
func.map(wrapper)
}
/**
* Resolves the attribute, column value and extract value expressions(s) by traversing the
* input expression in bottom-up manner. In order to resolve the nested complex type fields
* correctly, this function makes use of `throws` parameter to control when to raise an
* AnalysisException.
*
* Example :
* SELECT a.b FROM t ORDER BY b[0].d
*
* In the above example, in b needs to be resolved before d can be resolved. Given we are
* doing a bottom up traversal, it will first attempt to resolve d and fail as b has not
* been resolved yet. If `throws` is false, this function will handle the exception by
* returning the original attribute. In this case `d` will be resolved in subsequent passes
* after `b` is resolved.
*/
protected[sql] def resolveExpressionBottomUp(
expr: Expression,
plan: LogicalPlan,
throws: Boolean = false): Expression = {
if (expr.resolved) return expr
// Resolve expression in one round.
// If throws == false or the desired attribute doesn't exist
// (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
// Else, throw exception.
try {
expr transformUp {
case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal)
case u @ UnresolvedAttribute(nameParts) =>
val result =
withPosition(u) {
plan.resolve(nameParts, resolver)
.orElse(resolveLiteralFunction(nameParts, u, plan))
.getOrElse(u)
}
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
} catch {
case a: AnalysisException if !throws => expr
}
}
/**
* In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by
* clauses. This rule is to convert ordinal positions to the corresponding expressions in the
* select list. This support is introduced in Spark 2.0.
*
* - When the sort references or group by expressions are not integer but foldable expressions,
* just ignore them.
* - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position
* numbers too.
*
* Before the release of Spark 2.0, the literals in order/sort by and group by clauses
* have no effect on the results.
*/
object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
case Sort(orders, global, child)
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
if (index > 0 && index <= child.output.size) {
SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty)
} else {
s.failAnalysis(
s"ORDER BY position $index is not in select list " +
s"(valid range is [1, ${child.output.size}])")
}
case o => o
}
Sort(newOrders, global, child)
// Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression)
case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
val newGroups = groups.map {
case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
aggs(index - 1)
case ordinal @ UnresolvedOrdinal(index) =>
ordinal.failAnalysis(
s"GROUP BY position $index is not in select list " +
s"(valid range is [1, ${aggs.size}])")
case o => o
}
Aggregate(newGroups, aggs, child)
}
}
/**
* Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses.
* This rule is expected to run after [[ResolveReferences]] applied.
*/
object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] {
// This is a strict check though, we put this to apply the rule only if the expression is not
// resolvable by child.
private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = {
!child.output.exists(a => resolver(a.name, attrName))
}
private def mayResolveAttrByAggregateExprs(
exprs: Seq[Expression], aggs: Seq[NamedExpression], child: LogicalPlan): Seq[Expression] = {
exprs.map { _.transform {
case u: UnresolvedAttribute if notResolvableByChild(u.name, child) =>
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
}}
}
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case agg @ Aggregate(groups, aggs, child)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(!_.resolved) =>
agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child))
case gs @ GroupingSets(selectedGroups, groups, child, aggs)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
gs.copy(
selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)),
groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child))
}
}
/**
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
* projection, so that they will be available during sorting. Another projection is added to
* remove these attributes after sorting.
*
* The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
object ResolveMissingReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
case s @ Sort(order, _, child)
if (!s.resolved || s.missingInput.nonEmpty) && child.resolved =>
val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child)
val ordering = newOrder.map(_.asInstanceOf[SortOrder])
if (child.output == newChild.output) {
s.copy(order = ordering)
} else {
// Add missing attributes and then project them away.
val newSort = s.copy(order = ordering, child = newChild)
Project(child.output, newSort)
}
case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved =>
val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child)
if (child.output == newChild.output) {
f.copy(condition = newCond.head)
} else {
// Add missing attributes and then project them away.
val newFilter = Filter(newCond.head, newChild)
Project(child.output, newFilter)
}
}
/**
* This method tries to resolve expressions and find missing attributes recursively. Specially,
* when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved
* attributes which are missed from child output. This method tries to find the missing
* attributes out and add into the projection.
*/
private def resolveExprsAndAddMissingAttrs(
exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = {
// Missing attributes can be unresolved attributes or resolved attributes which are not in
// the output attributes of the plan.
if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) {
(exprs, plan)
} else {
plan match {
case p: Project =>
// Resolving expressions against current plan.
val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, p))
// Recursively resolving expressions on the child of current plan.
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
// If some attributes used by expressions are resolvable only on the rewritten child
// plan, we need to add them into original projection.
val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet)
(newExprs, Project(p.projectList ++ missingAttrs, newChild))
case a @ Aggregate(groupExprs, aggExprs, child) =>
val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, a))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet)
if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
// All the missing attributes are grouping expressions, valid case.
(newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
} else {
// Need to add non-grouping attributes, invalid case.
(exprs, a)
}
case g: Generate =>
val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, g))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child)
(newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild))
// For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes
// via its children.
case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] =>
val maybeResolvedExprs = exprs.map(resolveExpressionBottomUp(_, u))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child)
(newExprs, u.withNewChildren(Seq(newChild)))
// For other operators, we can't recursively resolve and add attributes via its children.
case other =>
(exprs.map(resolveExpressionBottomUp(_, other)), other)
}
}
}
}
/**
* Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the
* function registry. Note that this rule doesn't try to resolve the [[UnresolvedFunction]]. It
* only performs simple existence check according to the function identifier to quickly identify
* undefined functions without triggering relation resolution, which may incur potentially
* expensive partition/schema discovery process in some cases.
* In order to avoid duplicate external functions lookup, the external function identifier will
* store in the local hash set externalFunctionNameSet.
* @see [[ResolveFunctions]]
* @see https://issues.apache.org/jira/browse/SPARK-19737
*/
object LookupFunctions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
val externalFunctionNameSet = new mutable.HashSet[FunctionIdentifier]()
plan.resolveExpressions {
case f: UnresolvedFunction
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) =>
externalFunctionNameSet.add(normalizeFuncName(f.name))
f
case f: UnresolvedFunction =>
withPosition(f) {
throw new NoSuchFunctionException(
f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
f.name.funcName)
}
}
}
def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = {
val funcName = if (conf.caseSensitiveAnalysis) {
name.funcName
} else {
name.funcName.toLowerCase(Locale.ROOT)
}
val databaseName = name.database match {
case Some(a) => formatDatabaseName(a)
case None => v1SessionCatalog.getCurrentDatabase
}
FunctionIdentifier(funcName, Some(databaseName))
}
protected def formatDatabaseName(name: String): String = {
if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT)
}
}
/**
* Replaces [[UnresolvedFunc]]s with concrete [[LogicalPlan]]s.
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
*/
object ResolveFunctions extends Rule[LogicalPlan] {
val trimWarningEnabled = new AtomicBoolean(true)
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
// Resolve functions with concrete relations from v2 catalog.
case UnresolvedFunc(multipartIdent) =>
val funcIdent = parseSessionCatalogFunctionIdentifier(multipartIdent)
ResolvedFunc(Identifier.of(funcIdent.database.toArray, funcIdent.funcName))
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) =>
withPosition(u) {
Alias(GroupingID(Nil), VirtualColumn.hiveGroupingIdName)()
}
case u @ UnresolvedGenerator(name, children) =>
withPosition(u) {
v1SessionCatalog.lookupFunction(name, children) match {
case generator: Generator => generator
case other =>
failAnalysis(s"$name is expected to be a generator. However, " +
s"its class is ${other.getClass.getCanonicalName}, which is not a generator.")
}
}
case u @ UnresolvedFunction(funcId, arguments, isDistinct, filter) =>
withPosition(u) {
v1SessionCatalog.lookupFunction(funcId, arguments) match {
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
case wf: AggregateWindowFunction =>
if (isDistinct || filter.isDefined) {
failAnalysis("DISTINCT or FILTER specified, " +
s"but ${wf.prettyName} is not an aggregate function")
} else {
wf
}
// We get an aggregate function, we need to wrap it in an AggregateExpression.
case agg: AggregateFunction =>
if (filter.isDefined && !filter.get.deterministic) {
failAnalysis("FILTER expression is non-deterministic, " +
"it cannot be used in aggregate functions")
}
AggregateExpression(agg, Complete, isDistinct, filter)
// This function is not an aggregate function, just return the resolved one.
case other if (isDistinct || filter.isDefined) =>
failAnalysis("DISTINCT or FILTER specified, " +
s"but ${other.prettyName} is not an aggregate function")
case e: String2TrimExpression if arguments.size == 2 =>
if (trimWarningEnabled.get) {
log.warn("Two-parameter TRIM/LTRIM/RTRIM function signatures are deprecated." +
" Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM str)`" +
" instead.")
trimWarningEnabled.set(false)
}
e
case other =>
other
}
}
}
}
}
/**
* This rule resolves and rewrites subqueries inside expressions.
*
* Note: CTEs are handled in CTESubstitution.
*/
object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {
/**
* Resolve the correlated expressions in a subquery by using the an outer plans' references. All
* resolved outer references are wrapped in an [[OuterReference]]
*/
private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
plan resolveOperatorsDown {
case q: LogicalPlan if q.childrenResolved && !q.resolved =>
q transformExpressions {
case u @ UnresolvedAttribute(nameParts) =>
withPosition(u) {
try {
outer.resolve(nameParts, resolver) match {
case Some(outerAttr) => OuterReference(outerAttr)
case None => u
}
} catch {
case _: AnalysisException => u
}
}
}
}
}
/**
* Resolves the subquery plan that is referenced in a subquery expression. The normal
* attribute references are resolved using regular analyzer and the outer references are
* resolved from the outer plans using the resolveOuterReferences method.
*
* Outer references from the correlated predicates are updated as children of
* Subquery expression.
*/
private def resolveSubQuery(
e: SubqueryExpression,
plans: Seq[LogicalPlan])(
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
// Step 1: Resolve the outer expressions.
var previous: LogicalPlan = null
var current = e.plan
do {
// Try to resolve the subquery plan using the regular analyzer.
previous = current
current = executeSameContext(current)
// Use the outer references to resolve the subquery plan if it isn't resolved yet.
val i = plans.iterator
val afterResolve = current
while (!current.resolved && current.fastEquals(afterResolve) && i.hasNext) {
current = resolveOuterReferences(current, i.next())
}
} while (!current.resolved && !current.fastEquals(previous))
// Step 2: If the subquery plan is fully resolved, pull the outer references and record
// them as children of SubqueryExpression.
if (current.resolved) {
// Record the outer references as children of subquery expression.
f(current, SubExprUtils.getOuterReferences(current))
} else {
e.withNewPlan(current)
}
}
/**
* Resolves the subquery. Apart of resolving the subquery and outer references (if any)
* in the subquery plan, the children of subquery expression are updated to record the
* outer references. This is needed to make sure
* (1) The column(s) referred from the outer query are not pruned from the plan during
* optimization.
* (2) Any aggregate expression(s) that reference outer attributes are pushed down to
* outer plan to get evaluated.
*/
private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = {
plan transformExpressions {
case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(e, plans)(Exists(_, _, exprId))
case InSubquery(values, l @ ListQuery(_, _, exprId, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, plans)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output)
})
InSubquery(values, expr.asInstanceOf[ListQuery])
}
}
/**
* Resolve and rewrite all subqueries in an operator tree..
*/
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
// In case of HAVING (a filter after an aggregate) we use both the aggregate and
// its child for resolution.
case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
resolveSubQueries(f, Seq(a, a.child))
// Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries.
case q: UnaryNode if q.childrenResolved =>
resolveSubQueries(q, q.children)
case j: Join if j.childrenResolved =>
resolveSubQueries(j, Seq(j, j.left, j.right))
case s: SupportsSubquery if s.childrenResolved =>
resolveSubQueries(s, s.children)
}
}
/**
* Replaces unresolved column aliases for a subquery with projections.
*/
object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved =>
// Resolves output attributes if a query has alias names in its subquery:
// e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2)
val outputAttrs = child.output
// Checks if the number of the aliases equals to the number of output columns
// in the subquery.
if (columnNames.size != outputAttrs.size) {
u.failAnalysis("Number of column aliases does not match number of columns. " +
s"Number of column aliases: ${columnNames.size}; " +
s"number of columns: ${outputAttrs.size}.")
}
val aliases = outputAttrs.zip(columnNames).map { case (attr, aliasName) =>
Alias(attr, aliasName)()
}
Project(aliases, child)
}
}
/**
* Turns projections that contain aggregate expressions into aggregations.
*/
object GlobalAggregates extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) if containsAggregates(projectList) =>
Aggregate(Nil, projectList, child)
}
def containsAggregates(exprs: Seq[Expression]): Boolean = {
// Collect all Windowed Aggregate Expressions.
val windowedAggExprs: Set[Expression] = exprs.flatMap { expr =>
expr.collect {
case WindowExpression(ae: AggregateExpression, _) => ae
case WindowExpression(e: PythonUDF, _) if PythonUDF.isGroupedAggPandasUDF(e) => e
}
}.toSet
// Find the first Aggregate Expression that is not Windowed.
exprs.exists(_.collectFirst {
case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae
case e: PythonUDF if PythonUDF.isGroupedAggPandasUDF(e) &&
!windowedAggExprs.contains(e) => e
}.isDefined)
}
}
/**
* This rule finds aggregate expressions that are not in an aggregate operator. For example,
* those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the
* underlying aggregate operator and then projected away after the original operator.
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] with AliasHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
// Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
// resolve the having condition expression, here we skip resolving it in ResolveReferences
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved =>
resolveHaving(Filter(cond, agg), agg)
case f @ Filter(_, agg: Aggregate) if agg.resolved =>
resolveHaving(f, agg)
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
// Try resolving the ordering as though it is in the aggregate clause.
try {
// If a sort order is unresolved, containing references not in aggregate, or containing
// `AggregateExpression`, we need to push down it to the underlying aggregate operator.
val unresolvedSortOrders = sortOrder.filter { s =>
!s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
}
val aliasedOrdering =
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
val resolvedAggregate: Aggregate =
executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate]
val resolvedAliasedOrdering: Seq[Alias] =
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
// If we pass the analysis check, then the ordering expressions should only reference to
// aggregate expressions or grouping expressions, and it's safe to push them down to
// Aggregate.
checkAnalysis(resolvedAggregate)
val originalAggExprs = aggregate.aggregateExpressions.map(trimNonTopLevelAliases)
// If the ordering expression is same with original aggregate expression, we don't need
// to push down this ordering expression and can reference the original aggregate
// expression instead.
val needsPushDown = ArrayBuffer.empty[NamedExpression]
val evaluatedOrderings = resolvedAliasedOrdering.zip(unresolvedSortOrders).map {
case (evaluated, order) =>
val index = originalAggExprs.indexWhere {
case Alias(child, _) => child semanticEquals evaluated.child
case other => other semanticEquals evaluated.child
}
if (index == -1) {
needsPushDown += evaluated
order.copy(child = evaluated.toAttribute)
} else {
order.copy(child = originalAggExprs(index).toAttribute)
}
}
val sortOrdersMap = unresolvedSortOrders
.map(new TreeNodeRef(_))
.zip(evaluatedOrderings)
.toMap
val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new TreeNodeRef(s), s))
// Since we don't rely on sort.resolved as the stop condition for this rule,
// we need to check this and prevent applying this rule multiple times
if (sortOrder == finalSortOrders) {
sort
} else {
Project(aggregate.output,
Sort(finalSortOrders, global,
aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => sort
}
}
def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
def resolveFilterCondInAggregate(
filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression], Expression)] = {
try {
val aggregatedCondition =
Aggregate(
agg.groupingExpressions,
Alias(filterCond, "havingCondition")() :: Nil,
agg.child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head
// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved) {
// Try to replace all aggregate expressions in the filter by an alias.
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val transformedAggregateFilter = resolvedAggregateFilter.transform {
case ae: AggregateExpression =>
val alias = Alias(ae, ae.toString)()
aggregateExpressions += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
ne.toAttribute
case _ =>
val alias = Alias(e, e.toString)()
aggregateExpressions += alias
alias.toAttribute
}
}
if (aggregateExpressions.nonEmpty) {
Some(aggregateExpressions.toSeq, transformedAggregateFilter)
} else {
None
}
} else {
None
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return None and the caller side will return the original plan.
case ae: AnalysisException => None
}
}
def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
// Try resolving the condition of the filter as though it is in the aggregate clause
val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg)
// Push the aggregate expressions into the aggregate (if any).
if (resolvedInfo.nonEmpty) {
val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get
Project(agg.output,
Filter(resolvedHavingCond,
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
} else {
filter
}
}
}
/**
* Extracts [[Generator]] from the projectList of a [[Project]] operator and creates [[Generate]]
* operator under [[Project]].
*
* This rule will throw [[AnalysisException]] for following cases:
* 1. [[Generator]] is nested in expressions, e.g. `SELECT explode(list) + 1 FROM tbl`
* 2. more than one [[Generator]] is found in projectList,
* e.g. `SELECT explode(list), explode(list) FROM tbl`
* 3. [[Generator]] is found in other operators that are not [[Project]] or [[Generate]],
* e.g. `SELECT * FROM tbl SORT BY explode(list)`
*/
object ExtractGenerator extends Rule[LogicalPlan] {
private def hasGenerator(expr: Expression): Boolean = {
expr.find(_.isInstanceOf[Generator]).isDefined
}
private def hasNestedGenerator(expr: NamedExpression): Boolean = {
def hasInnerGenerator(g: Generator): Boolean = g match {
// Since `GeneratorOuter` is just a wrapper of generators, we skip it here
case go: GeneratorOuter =>
hasInnerGenerator(go.child)
case _ =>
g.children.exists { _.find {
case _: Generator => true
case _ => false
}.isDefined }
}
trimNonTopLevelAliases(expr) match {
case UnresolvedAlias(g: Generator, _) => hasInnerGenerator(g)
case Alias(g: Generator, _) => hasInnerGenerator(g)
case MultiAlias(g: Generator, _) => hasInnerGenerator(g)
case other => hasGenerator(other)
}
}
private def hasAggFunctionInGenerator(ne: Seq[NamedExpression]): Boolean = {
ne.exists(_.find {
case g: Generator =>
g.children.exists(_.find(_.isInstanceOf[AggregateFunction]).isDefined)
case _ =>
false
}.nonEmpty)
}
private def trimAlias(expr: NamedExpression): Expression = expr match {
case UnresolvedAlias(child, _) => child
case Alias(child, _) => child
case MultiAlias(child, _) => child
case _ => expr
}
private object AliasedGenerator {
/**
* Extracts a [[Generator]] expression, any names assigned by aliases to the outputs
* and the outer flag. The outer flag is used when joining the generator output.
* @param e the [[Expression]]
* @return (the [[Generator]], seq of output names, outer flag)
*/
def unapply(e: Expression): Option[(Generator, Seq[String], Boolean)] = e match {
case Alias(GeneratorOuter(g: Generator), name) if g.resolved => Some((g, name :: Nil, true))
case MultiAlias(GeneratorOuter(g: Generator), names) if g.resolved => Some((g, names, true))
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil, false))
case MultiAlias(g: Generator, names) if g.resolved => Some((g, names, false))
case _ => None
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " +
"expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator)))
case Project(projectList, _) if projectList.count(hasGenerator) > 1 =>
val generators = projectList.filter(hasGenerator).map(trimAlias)
throw new AnalysisException("Only one generator allowed per select clause but found " +
generators.size + ": " + generators.map(toPrettySQL).mkString(", "))
case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) =>
val nestedGenerator = aggList.find(hasNestedGenerator).get
throw new AnalysisException("Generators are not supported when it's nested in " +
"expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator)))
case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 =>
val generators = aggList.filter(hasGenerator).map(trimAlias)
throw new AnalysisException("Only one generator allowed per aggregate clause but found " +
generators.size + ": " + generators.map(toPrettySQL).mkString(", "))
case agg @ Aggregate(groupList, aggList, child) if aggList.forall {
case AliasedGenerator(_, _, _) => true
case other => other.resolved
} && aggList.exists(hasGenerator) =>
// If generator in the aggregate list was visited, set the boolean flag true.
var generatorVisited = false
val projectExprs = Array.ofDim[NamedExpression](aggList.length)
val newAggList = aggList
.map(trimNonTopLevelAliases)
.zipWithIndex
.flatMap {
case (AliasedGenerator(generator, names, outer), idx) =>
// It's a sanity check, this should not happen as the previous case will throw
// exception earlier.
assert(!generatorVisited, "More than one generator found in aggregate.")
generatorVisited = true
val newGenChildren: Seq[Expression] = generator.children.zipWithIndex.map {
case (e, idx) => if (e.foldable) e else Alias(e, s"_gen_input_${idx}")()
}
val newGenerator = {
val g = generator.withNewChildren(newGenChildren.map { e =>
if (e.foldable) e else e.asInstanceOf[Alias].toAttribute
}).asInstanceOf[Generator]
if (outer) GeneratorOuter(g) else g
}
val newAliasedGenerator = if (names.length == 1) {
Alias(newGenerator, names(0))()
} else {
MultiAlias(newGenerator, names)
}
projectExprs(idx) = newAliasedGenerator
newGenChildren.filter(!_.foldable).asInstanceOf[Seq[NamedExpression]]
case (other, idx) =>
projectExprs(idx) = other.toAttribute
other :: Nil
}
val newAgg = Aggregate(groupList, newAggList, child)
Project(projectExprs.toList, newAgg)
case p @ Project(projectList, _) if hasAggFunctionInGenerator(projectList) =>
// If a generator has any aggregate function, we need to apply the `GlobalAggregates` rule
// first for replacing `Project` with `Aggregate`.
p
case p @ Project(projectList, child) =>
// Holds the resolved generator, if one exists in the project list.
var resolvedGenerator: Generate = null
val newProjectList = projectList
.map(trimNonTopLevelAliases)
.flatMap {
case AliasedGenerator(generator, names, outer) if generator.childrenResolved =>
// It's a sanity check, this should not happen as the previous case will throw
// exception earlier.
assert(resolvedGenerator == null, "More than one generator found in SELECT.")
resolvedGenerator =
Generate(
generator,
unrequiredChildIndex = Nil,
outer = outer,
qualifier = None,
generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names),
child)
resolvedGenerator.generatorOutput
case other => other :: Nil
}
if (resolvedGenerator != null) {
Project(newProjectList, resolvedGenerator)
} else {
p
}
case g: Generate => g
case p if p.expressions.exists(hasGenerator) =>
throw new AnalysisException("Generators are not supported outside the SELECT clause, but " +
"got: " + p.simpleString(SQLConf.get.maxToStringFields))
}
}
/**
* Rewrites table generating expressions that either need one or more of the following in order
* to be resolved:
* - concrete attribute references for their output.
* - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]).
*
* Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions
* that wrap the [[Generator]].
*/
object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case g: Generate if !g.child.resolved || !g.generator.resolved => g
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
}
/**
* Construct the output attributes for a [[Generator]], given a list of names. If the list of
* names is empty names are assigned from field names in generator.
*/
private[analysis] def makeGeneratorOutput(
generator: Generator,
names: Seq[String]): Seq[Attribute] = {
val elementAttrs = generator.elementSchema.toAttributes
if (names.length == elementAttrs.length) {
names.zip(elementAttrs).map {
case (name, attr) => attr.withName(name)
}
} else if (names.isEmpty) {
elementAttrs
} else {
failAnalysis(
"The number of aliases supplied in the AS clause does not match the number of columns " +
s"output by the UDTF expected ${elementAttrs.size} aliases but got " +
s"${names.mkString(",")} ")
}
}
}
/**
* Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and
* aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]]
* operators for every distinct [[WindowSpecDefinition]].
*
* This rule handles three cases:
* - A [[Project]] having [[WindowExpression]]s in its projectList;
* - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions.
* - A [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING
* clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions.
* Note: If there is a GROUP BY clause in the query, aggregations and corresponding
* filters (expressions in the HAVING clause) should be evaluated before any
* [[WindowExpression]]. If a query has SELECT DISTINCT, the DISTINCT part should be
* evaluated after all [[WindowExpression]]s.
*
* For every case, the transformation works as follows:
* 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions
* it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for
* all regular expressions.
* 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s
* and [[WindowFunctionType]]s.
* 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a
* [[Window]] operator and inserts it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
type Spec = (Seq[Expression], Seq[SortOrder], WindowFunctionType)
private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
exprs.exists(hasWindowFunction)
private def hasWindowFunction(expr: Expression): Boolean = {
expr.find {
case window: WindowExpression => true
case _ => false
}.isDefined
}
/**
* From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and
* other regular expressions that do not contain any window expression. For example, for
* `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract
* `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in
* the window expression as attribute references. So, the first returned value will be
* `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be
* [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2].
*
* @return (seq of expressions containing at least one window expression,
* seq of non-window expressions)
*/
private def extract(
expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
// First, we partition the input expressions to two part. For the first part,
// every expression in it contain at least one WindowExpression.
// Expressions in the second part do not have any WindowExpression.
val (expressionsWithWindowFunctions, regularExpressions) =
expressions.partition(hasWindowFunction)
// Then, we need to extract those regular expressions used in the WindowExpression.
// For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5),
// we need to make sure that col1 to col5 are all projected from the child of the Window
// operator.
val extractedExprBuffer = new ArrayBuffer[NamedExpression]()
def extractExpr(expr: Expression): Expression = expr match {
case ne: NamedExpression =>
// If a named expression is not in regularExpressions, add it to
// extractedExprBuffer and replace it with an AttributeReference.
val missingExpr =
AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
if (missingExpr.nonEmpty) {
extractedExprBuffer += ne
}
// alias will be cleaned in the rule CleanupAliases
ne
case e: Expression if e.foldable =>
e // No need to create an attribute reference if it will be evaluated as a Literal.
case e: Expression =>
// For other expressions, we extract it and replace it with an AttributeReference (with
// an internal column name, e.g. "_w0").
val withName = Alias(e, s"_w${extractedExprBuffer.length}")()
extractedExprBuffer += withName
withName.toAttribute
}
// Now, we extract regular expressions from expressionsWithWindowFunctions
// by using extractExpr.
val seenWindowAggregates = new ArrayBuffer[AggregateExpression]
val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
_.transform {
// Extracts children expressions of a WindowFunction (input parameters of
// a WindowFunction).
case wf: WindowFunction =>
val newChildren = wf.children.map(extractExpr)
wf.withNewChildren(newChildren)
// Extracts expressions from the partition spec and order spec.
case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) =>
val newPartitionSpec = partitionSpec.map(extractExpr)
val newOrderSpec = orderSpec.map { so =>
val newChild = extractExpr(so.child)
so.copy(child = newChild)
}
wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)
case WindowExpression(ae: AggregateExpression, _) if ae.filter.isDefined =>
failAnalysis(
"window aggregate function with filter predicate is not supported yet.")
// Extract Windowed AggregateExpression
case we @ WindowExpression(
ae @ AggregateExpression(function, _, _, _, _),
spec: WindowSpecDefinition) =>
val newChildren = function.children.map(extractExpr)
val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
val newAgg = ae.copy(aggregateFunction = newFunction)
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)
case AggregateExpression(aggFunc, _, _, _, _) if hasWindowFunction(aggFunc.children) =>
failAnalysis("It is not allowed to use a window function inside an aggregate " +
"function. Please use the inner window function in a sub-query.")
// Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...),
// we need to extract SUM(x).
case agg: AggregateExpression if !seenWindowAggregates.contains(agg) =>
val withName = Alias(agg, s"_w${extractedExprBuffer.length}")()
extractedExprBuffer += withName
withName.toAttribute
// Extracts other attributes
case attr: Attribute => extractExpr(attr)
}.asInstanceOf[NamedExpression]
}
(newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer)
} // end of extract
/**
* Adds operators for Window Expressions. Every Window operator handles a single Window Spec.
*/
private def addWindow(
expressionsWithWindowFunctions: Seq[NamedExpression],
child: LogicalPlan): LogicalPlan = {
// First, we need to extract all WindowExpressions from expressionsWithWindowFunctions
// and put those extracted WindowExpressions to extractedWindowExprBuffer.
// This step is needed because it is possible that an expression contains multiple
// WindowExpressions with different Window Specs.
// After extracting WindowExpressions, we need to construct a project list to generate
// expressionsWithWindowFunctions based on extractedWindowExprBuffer.
// For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract
// "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to
// "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)".
// Then, the projectList will be [_we0/_we1].
val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]()
val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
// We need to use transformDown because we want to trigger
// "case alias @ Alias(window: WindowExpression, _)" first.
_.transformDown {
case alias @ Alias(window: WindowExpression, _) =>
// If a WindowExpression has an assigned alias, just use it.
extractedWindowExprBuffer += alias
alias.toAttribute
case window: WindowExpression =>
// If there is no alias assigned to the WindowExpressions. We create an
// internal column.
val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")()
extractedWindowExprBuffer += withName
withName.toAttribute
}.asInstanceOf[NamedExpression]
}
// SPARK-32616: Use a linked hash map to maintains the insertion order of the Window
// operators, so the query with multiple Window operators can have the determined plan.
val groupedWindowExpressions = mutable.LinkedHashMap.empty[Spec, ArrayBuffer[NamedExpression]]
// Second, we group extractedWindowExprBuffer based on their Partition and Order Specs.
extractedWindowExprBuffer.foreach { expr =>
val distinctWindowSpec = expr.collect {
case window: WindowExpression => window.windowSpec
}.distinct
// We do a final check and see if we only have a single Window Spec defined in an
// expressions.
if (distinctWindowSpec.isEmpty) {
failAnalysis(s"$expr does not have any WindowExpression.")
} else if (distinctWindowSpec.length > 1) {
// newExpressionsWithWindowFunctions only have expressions with a single
// WindowExpression. If we reach here, we have a bug.
failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." +
s"Please file a bug report with this error message, stack trace, and the query.")
} else {
val spec = distinctWindowSpec.head
val specKey = (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr))
val windowExprs = groupedWindowExpressions
.getOrElseUpdate(specKey, new ArrayBuffer[NamedExpression])
windowExprs += expr
}
}
// Third, we aggregate them by adding each Window operator for each Window Spec and then
// setting this to the child of the next Window operator.
val windowOps =
groupedWindowExpressions.foldLeft(child) {
case (last, ((partitionSpec, orderSpec, _), windowExpressions)) =>
Window(windowExpressions.toSeq, partitionSpec, orderSpec, last)
}
// Finally, we create a Project to output windowOps's output
// newExpressionsWithWindowFunctions.
Project(windowOps.output ++ newExpressionsWithWindowFunctions, windowOps)
} // end of addWindow
// We have to use transformDown at here to make sure the rule of
// "Aggregate with Having clause" will be triggered.
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
case Filter(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside WHERE clause")
case UnresolvedHaving(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside HAVING clause")
// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
// Add a Filter operator for conditions in the Having clause.
val withFilter = Filter(condition, withAggregate)
val withWindow = addWindow(windowExpressions, withFilter)
// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map(_.toAttribute)
Project(finalProjectList, withWindow)
case p: LogicalPlan if !p.childrenResolved => p
// Aggregate without Having clause.
case a @ Aggregate(groupingExprs, aggregateExprs, child)
if hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
// Add Window operators.
val withWindow = addWindow(windowExpressions, withAggregate)
// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map(_.toAttribute)
Project(finalProjectList, withWindow)
// We only extract Window Expressions after all expressions of the Project
// have been resolved.
case p @ Project(projectList, child)
if hasWindowFunction(projectList) && !p.expressions.exists(!_.resolved) =>
val (windowExpressions, regularExpressions) = extract(projectList)
// We add a project to get all needed expressions for window expressions from the child
// of the original Project operator.
val withProject = Project(regularExpressions, child)
// Add Window operators.
val withWindow = addWindow(windowExpressions, withProject)
// Finally, generate output columns according to the original projectList.
val finalProjectList = projectList.map(_.toAttribute)
Project(finalProjectList, withWindow)
}
}
/**
* Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter,
* put them into an inner Project and finally project them away at the outer Project.
*/
object PullOutNondeterministic extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.resolved => p // Skip unresolved nodes.
case p: Project => p
case f: Filter => f
case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) =>
val nondeterToAttr = getNondeterToAttr(a.groupingExpressions)
val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child)
a.transformExpressions { case e =>
nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
}.copy(child = newChild)
// Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail)
// and we want to retain them inside the aggregate functions.
case m: CollectMetrics => m
// todo: It's hard to write a general rule to pull out nondeterministic expressions
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
val nondeterToAttr = getNondeterToAttr(p.expressions)
val newPlan = p.transformExpressions { case e =>
nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
}
val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child)
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
}
private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = {
exprs.filterNot(_.deterministic).flatMap { expr =>
val leafNondeterministic = expr.collect { case n: Nondeterministic => n }
leafNondeterministic.distinct.map { e =>
val ne = e match {
case n: NamedExpression => n
case _ => Alias(e, "_nondeterministic")()
}
e -> ne
}
}.toMap
}
}
/**
* Set the seed for random number generation.
*/
object ResolveRandomSeed extends Rule[LogicalPlan] {
private lazy val random = new Random()
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if p.resolved => p
case p => p transformExpressionsUp {
case Uuid(None) => Uuid(Some(random.nextLong()))
case Shuffle(child, None) => Shuffle(child, Some(random.nextLong()))
}
}
}
/**
* Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
* null check. When user defines a UDF with primitive parameters, there is no way to tell if the
* primitive parameter is null or not, so here we assume the primitive input is null-propagatable
* and we should return null if the input is null.
*/
object HandleNullInputsForUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.resolved => p // Skip unresolved nodes.
case p => p transformExpressionsUp {
case udf: ScalaUDF if udf.inputPrimitives.contains(true) =>
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
assert(udf.inputPrimitives.length == udf.children.length)
val inputPrimitivesPair = udf.inputPrimitives.zip(udf.children)
val inputNullCheck = inputPrimitivesPair.collect {
case (isPrimitive, input) if isPrimitive && input.nullable =>
IsNull(input)
}.reduceLeftOption[Expression](Or)
if (inputNullCheck.isDefined) {
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
// as null-safe (i.e., wrap with `KnownNotNull`), because the null-returning
// branch of `If` will be called if any of these checked inputs is null. Thus we can
// prevent this rule from being applied repeatedly.
val newInputs = inputPrimitivesPair.map {
case (isPrimitive, input) =>
if (isPrimitive && input.nullable) {
KnownNotNull(input)
} else {
input
}
}
val newUDF = udf.copy(children = newInputs)
If(inputNullCheck.get, Literal.create(null, udf.dataType), newUDF)
} else {
udf
}
}
}
}
/**
* Resolve the encoders for the UDF by explicitly given the attributes. We give the
* attributes explicitly in order to handle the case where the data type of the input
* value is not the same with the internal schema of the encoder, which could cause
* data loss. For example, the encoder should not cast the input value to Decimal(38, 18)
* if the actual data type is Decimal(30, 0).
*
* The resolved encoders then will be used to deserialize the internal row to Scala value.
*/
object ResolveEncodersInUDF extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.resolved => p // Skip unresolved nodes.
case p => p transformExpressionsUp {
case udf: ScalaUDF if udf.inputEncoders.nonEmpty =>
val boundEncoders = udf.inputEncoders.zipWithIndex.map { case (encOpt, i) =>
val dataType = udf.children(i).dataType
encOpt.map { enc =>
val attrs = if (enc.isSerializedAsStructForTopLevel) {
dataType.asInstanceOf[StructType].toAttributes
} else {
// the field name doesn't matter here, so we use
// a simple literal to avoid any overhead
new StructType().add("input", dataType).toAttributes
}
enc.resolveAndBind(attrs)
}
}
udf.copy(inputEncoders = boundEncoders)
}
}
}
/**
* Check and add proper window frames for all window functions.
*/
object ResolveWindowFrame extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case WindowExpression(wf: OffsetWindowFunction,
WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) if wf.frame != f =>
failAnalysis(s"Cannot specify window frame for ${wf.prettyName} function")
case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame))
if wf.frame != UnspecifiedFrame && wf.frame != f =>
failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}")
case WindowExpression(wf: WindowFunction, s @ WindowSpecDefinition(_, _, UnspecifiedFrame))
if wf.frame != UnspecifiedFrame =>
WindowExpression(wf, s.copy(frameSpecification = wf.frame))
case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame))
if e.resolved =>
val frame = if (o.nonEmpty) {
SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow)
} else {
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
}
we.copy(windowSpec = s.copy(frameSpecification = frame))
}
}
/**
* Check and add order to [[AggregateWindowFunction]]s.
*/
object ResolveWindowOrder extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty =>
failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " +
s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " +
s"ORDER BY window_ordering) from table")
case WindowExpression(rank: RankLike, spec) if spec.resolved =>
val order = spec.orderSpec.map(_.child)
WindowExpression(rank.withOrder(order), spec)
}
}
/**
* Removes natural or using joins by calculating output columns based on output from two sides,
* Then apply a Project on a normal Join to eliminate natural or using join.
*/
object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint)
if left.resolved && right.resolved && j.duplicateResolved =>
commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint)
case j @ Join(left, right, NaturalJoin(joinType), condition, hint)
if j.resolvedExceptNatural =>
// find common column names from both sides
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint)
}
}
/**
* Resolves columns of an output table from the data in a logical plan. This rule will:
*
* - Reorder columns when the write is by name
* - Insert casts when data types do not match
* - Insert aliases when column names do not match
* - Detect plans that are not compatible with the output table and throw AnalysisException
*/
object ResolveOutputRelation extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case append @ AppendData(table, query, _, isByName)
if table.resolved && query.resolved && !append.outputResolved =>
validateStoreAssignmentPolicy()
val projection =
TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf)
if (projection != query) {
append.copy(query = projection)
} else {
append
}
case overwrite @ OverwriteByExpression(table, _, query, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
validateStoreAssignmentPolicy()
val projection =
TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf)
if (projection != query) {
overwrite.copy(query = projection)
} else {
overwrite
}
case overwrite @ OverwritePartitionsDynamic(table, query, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
validateStoreAssignmentPolicy()
val projection =
TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf)
if (projection != query) {
overwrite.copy(query = projection)
} else {
overwrite
}
}
}
private def validateStoreAssignmentPolicy(): Unit = {
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2.
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
val configKey = SQLConf.STORE_ASSIGNMENT_POLICY.key
throw new AnalysisException(s"""
|"LEGACY" store assignment policy is disallowed in Spark data source V2.
|Please set the configuration $configKey to other values.""".stripMargin)
}
}
private def commonNaturalJoinProcessing(
left: LogicalPlan,
right: LogicalPlan,
joinType: JoinType,
joinNames: Seq[String],
condition: Option[Expression],
hint: JoinHint) = {
val leftKeys = joinNames.map { keyName =>
left.output.find(attr => resolver(attr.name, keyName)).getOrElse {
throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the left " +
s"side of the join. The left-side columns: [${left.output.map(_.name).mkString(", ")}]")
}
}
val rightKeys = joinNames.map { keyName =>
right.output.find(attr => resolver(attr.name, keyName)).getOrElse {
throw new AnalysisException(s"USING column `$keyName` cannot be resolved on the right " +
s"side of the join. The right-side columns: [${right.output.map(_.name).mkString(", ")}]")
}
}
val joinPairs = leftKeys.zip(rightKeys)
val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And)
// columns not in joinPairs
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
// the output list looks like: join keys, columns from left, columns from right
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
case LeftExistence(_) =>
leftKeys ++ lUniqueOutput
case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
case FullOuter =>
// in full outer join, joinCols should be non-null if there is.
val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
joinedCols ++
lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true))
case _ : InnerLike =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput
case _ =>
sys.error("Unsupported natural join type " + joinType)
}
// use Project to trim unnecessary fields
Project(projectList, Join(left, right, joinType, newCondition, hint))
}
/**
* Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
* to the given input attributes.
*/
object ResolveDeserializer extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.childrenResolved => p
case p if p.resolved => p
case p => p transformExpressions {
case UnresolvedDeserializer(deserializer, inputAttributes) =>
val inputs = if (inputAttributes.isEmpty) {
p.children.flatMap(_.output)
} else {
inputAttributes
}
validateTopLevelTupleFields(deserializer, inputs)
val resolved = resolveExpressionBottomUp(
deserializer, LocalRelation(inputs), throws = true)
val result = resolved transformDown {
case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved =>
inputData.dataType match {
case ArrayType(et, cn) =>
MapObjects(func, inputData, et, cn, cls) transformUp {
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
case other =>
throw new AnalysisException("need an array field but got " + other.catalogString)
}
case u: UnresolvedCatalystToExternalMap if u.child.resolved =>
u.child.dataType match {
case _: MapType =>
CatalystToExternalMap(u) transformUp {
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
case other =>
throw new AnalysisException("need a map field but got " + other.catalogString)
}
}
validateNestedTupleFields(result)
result
}
}
private def fail(schema: StructType, maxOrdinal: Int): Unit = {
throw new AnalysisException(s"Try to map ${schema.catalogString} to Tuple${maxOrdinal + 1}" +
", but failed as the number of fields does not line up.")
}
/**
* For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column
* by position. However, the actual number of columns may be different from the number of Tuple
* fields. This method is used to check the number of columns and fields, and throw an
* exception if they do not match.
*/
private def validateTopLevelTupleFields(
deserializer: Expression, inputs: Seq[Attribute]): Unit = {
val ordinals = deserializer.collect {
case GetColumnByOrdinal(ordinal, _) => ordinal
}.distinct.sorted
if (ordinals.nonEmpty && ordinals != inputs.indices) {
fail(inputs.toStructType, ordinals.last)
}
}
/**
* For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field
* by position. However, the actual number of struct fields may be different from the number
* of nested Tuple fields. This method is used to check the number of struct fields and nested
* Tuple fields, and throw an exception if they do not match.
*/
private def validateNestedTupleFields(deserializer: Expression): Unit = {
val structChildToOrdinals = deserializer
// There are 2 kinds of `GetStructField`:
// 1. resolved from `UnresolvedExtractValue`, and it will have a `name` property.
// 2. created when we build deserializer expression for nested tuple, no `name` property.
// Here we want to validate the ordinals of nested tuple, so we should only catch
// `GetStructField` without the name property.
.collect { case g: GetStructField if g.name.isEmpty => g }
.groupBy(_.child)
.mapValues(_.map(_.ordinal).distinct.sorted)
structChildToOrdinals.foreach { case (expr, ordinals) =>
val schema = expr.dataType.asInstanceOf[StructType]
if (ordinals != schema.indices) {
fail(schema, ordinals.last)
}
}
}
}
/**
* Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being
* constructed is an inner class.
*/
object ResolveNewInstance extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.childrenResolved => p
case p if p.resolved => p
case p => p transformExpressions {
case n: NewInstance if n.childrenResolved && !n.resolved =>
val outer = OuterScopes.getOuterScope(n.cls)
if (outer == null) {
throw new AnalysisException(
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
"access to the scope that this class was defined in.\n" +
"Try moving this class out of its parent class.")
}
n.copy(outerPointer = Some(outer))
}
}
}
/**
* Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
val fromStr = from match {
case l: LambdaVariable => "array element"
case e => e.sql
}
throw new AnalysisException(s"Cannot up cast $fromStr from " +
s"${from.dataType.catalogString} to ${to.catalogString}.\n" +
"The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
"You can either add an explicit cast to the input data or choose a higher precision " +
"type of the field in the target object")
}
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p if !p.childrenResolved => p
case p if p.resolved => p
case p => p transformExpressions {
case u @ UpCast(child, _, _) if !child.resolved => u
case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] =>
throw new AnalysisException(
s"UpCast only support DecimalType as AbstractDataType yet, but got: $target")
case UpCast(child, target, walkedTypePath) if target == DecimalType
&& child.dataType.isInstanceOf[DecimalType] =>
assert(walkedTypePath.nonEmpty,
"object DecimalType should only be used inside ExpressionEncoder")
// SPARK-31750: if we want to upcast to the general decimal type, and the `child` is
// already decimal type, we can remove the `Upcast` and accept any precision/scale.
// This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`.
child
case UpCast(child, target: AtomicType, _)
if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) &&
child.dataType == StringType =>
Cast(child, target.asNullable)
case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) =>
fail(child, u.dataType, walkedTypePath)
case u @ UpCast(child, _, _) => Cast(child, u.dataType.asNullable)
}
}
}
/** Rule to mostly resolve, normalize and rewrite column names based on case sensitivity. */
object ResolveAlterTableChanges extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case a @ AlterTable(_, _, t: NamedRelation, changes) if t.resolved =>
// 'colsToAdd' keeps track of new columns being added. It stores a mapping from a
// normalized parent name of fields to field names that belong to the parent.
// For example, if we add columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become
// Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")).
val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]]
val schema = t.schema
val normalizedChanges = changes.flatMap {
case add: AddColumn =>
def addColumn(
parentSchema: StructType,
parentName: String,
normalizedParentName: Seq[String]): TableChange = {
val fieldsAdded = colsToAdd.getOrElse(normalizedParentName, Nil)
val pos = findColumnPosition(add.position(), parentName, parentSchema, fieldsAdded)
val field = add.fieldNames().last
colsToAdd(normalizedParentName) = fieldsAdded :+ field
TableChange.addColumn(
(normalizedParentName :+ field).toArray,
add.dataType(),
add.isNullable,
add.comment,
pos)
}
val parent = add.fieldNames().init
if (parent.nonEmpty) {
// Adding a nested field, need to normalize the parent column and position
val target = schema.findNestedField(parent, includeCollections = true, conf.resolver)
if (target.isEmpty) {
// Leave unresolved. Throws error in CheckAnalysis
Some(add)
} else {
val (normalizedName, sf) = target.get
sf.dataType match {
case struct: StructType =>
Some(addColumn(struct, parent.quoted, normalizedName :+ sf.name))
case other =>
Some(add)
}
}
} else {
// Adding to the root. Just need to normalize position
Some(addColumn(schema, "root", Nil))
}
case typeChange: UpdateColumnType =>
// Hive style syntax provides the column type, even if it may not have changed
val fieldOpt = schema.findNestedField(
typeChange.fieldNames(), includeCollections = true, conf.resolver)
if (fieldOpt.isEmpty) {
// We couldn't resolve the field. Leave it to CheckAnalysis
Some(typeChange)
} else {
val (fieldNames, field) = fieldOpt.get
if (field.dataType == typeChange.newDataType()) {
// The user didn't want the field to change, so remove this change
None
} else {
Some(TableChange.updateColumnType(
(fieldNames :+ field.name).toArray, typeChange.newDataType()))
}
}
case n: UpdateColumnNullability =>
// Need to resolve column
resolveFieldNames(
schema,
n.fieldNames(),
TableChange.updateColumnNullability(_, n.nullable())).orElse(Some(n))
case position: UpdateColumnPosition =>
position.position() match {
case after: After =>
// Need to resolve column as well as position reference
val fieldOpt = schema.findNestedField(
position.fieldNames(), includeCollections = true, conf.resolver)
if (fieldOpt.isEmpty) {
Some(position)
} else {
val (normalizedPath, field) = fieldOpt.get
val targetCol = schema.findNestedField(
normalizedPath :+ after.column(), includeCollections = true, conf.resolver)
if (targetCol.isEmpty) {
// Leave unchanged to CheckAnalysis
Some(position)
} else {
Some(TableChange.updateColumnPosition(
(normalizedPath :+ field.name).toArray,
ColumnPosition.after(targetCol.get._2.name)))
}
}
case _ =>
// Need to resolve column
resolveFieldNames(
schema,
position.fieldNames(),
TableChange.updateColumnPosition(_, position.position())).orElse(Some(position))
}
case comment: UpdateColumnComment =>
resolveFieldNames(
schema,
comment.fieldNames(),
TableChange.updateColumnComment(_, comment.newComment())).orElse(Some(comment))
case rename: RenameColumn =>
resolveFieldNames(
schema,
rename.fieldNames(),
TableChange.renameColumn(_, rename.newName())).orElse(Some(rename))
case delete: DeleteColumn =>
resolveFieldNames(schema, delete.fieldNames(), TableChange.deleteColumn)
.orElse(Some(delete))
case column: ColumnChange =>
// This is informational for future developers
throw new UnsupportedOperationException(
"Please add an implementation for a column change here")
case other => Some(other)
}
a.copy(changes = normalizedChanges)
}
/**
* Returns the table change if the field can be resolved, returns None if the column is not
* found. An error will be thrown in CheckAnalysis for columns that can't be resolved.
*/
private def resolveFieldNames(
schema: StructType,
fieldNames: Array[String],
copy: Array[String] => TableChange): Option[TableChange] = {
val fieldOpt = schema.findNestedField(
fieldNames, includeCollections = true, conf.resolver)
fieldOpt.map { case (path, field) => copy((path :+ field.name).toArray) }
}
private def findColumnPosition(
position: ColumnPosition,
parentName: String,
struct: StructType,
fieldsAdded: Seq[String]): ColumnPosition = {
position match {
case null => null
case after: After =>
(struct.fieldNames ++ fieldsAdded).find(n => conf.resolver(n, after.column())) match {
case Some(colName) =>
ColumnPosition.after(colName)
case None =>
throw new AnalysisException("Couldn't find the reference column for " +
s"$after at $parentName")
}
case other => other
}
}
}
}
/**
* Removes [[SubqueryAlias]] operators from the plan. Subqueries are only required to provide
* scoping information for attributes and can be removed once analysis is complete.
*/
object EliminateSubqueryAliases extends Rule[LogicalPlan] {
// This is also called in the beginning of the optimization phase, and as a result
// is using transformUp rather than resolveOperators.
def apply(plan: LogicalPlan): LogicalPlan = AnalysisHelper.allowInvokingTransformsInAnalyzer {
plan transformUp {
case SubqueryAlias(_, child) => child
}
}
}
/**
* Removes [[Union]] operators from the plan if it just has one child.
*/
object EliminateUnions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: Union if u.children.size == 1 => u.children.head
}
}
/**
* Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
* expression in Project(project list) or Aggregate(aggregate expressions) or
* Window(window expressions). Notice that if an expression has other expression parameters which
* are not in its `children`, e.g. `RuntimeReplaceable`, the transformation for Aliases in this
* rule can't work for those parameters.
*/
object CleanupAliases extends Rule[LogicalPlan] with AliasHelper {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case Project(projectList, child) =>
val cleanedProjectList = projectList.map(trimNonTopLevelAliases)
Project(cleanedProjectList, child)
case Aggregate(grouping, aggs, child) =>
val cleanedAggs = aggs.map(trimNonTopLevelAliases)
Aggregate(grouping.map(trimAliases), cleanedAggs, child)
case Window(windowExprs, partitionSpec, orderSpec, child) =>
val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases)
Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
case CollectMetrics(name, metrics, child) =>
val cleanedMetrics = metrics.map(trimNonTopLevelAliases)
CollectMetrics(name, cleanedMetrics, child)
// Operators that operate on objects should only have expressions from encoders, which should
// never have extra aliases.
case o: ObjectConsumer => o
case o: ObjectProducer => o
case a: AppendColumns => a
case other =>
other transformExpressionsDown {
case Alias(child, _) => child
}
}
}
/**
* Ignore event time watermark in batch query, which is only supported in Structured Streaming.
* TODO: add this rule into analyzer rule list.
*/
object EliminateEventTimeWatermark extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case EventTimeWatermark(_, _, child) if !child.isStreaming => child
}
}
/**
* Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
* figure out how many windows a time column can map to, we over-estimate the number of windows and
* filter out the rows where the time column is not inside the time window.
*/
object TimeWindowing extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.dsl.expressions._
private final val WINDOW_COL_NAME = "window"
private final val WINDOW_START = "start"
private final val WINDOW_END = "end"
/**
* Generates the logical plan for generating window ranges on a timestamp column. Without
* knowing what the timestamp value is, it's non-trivial to figure out deterministically how many
* window ranges a timestamp will map to given all possible combinations of a window duration,
* slide duration and start time (offset). Therefore, we express and over-estimate the number of
* windows there may be, and filter the valid windows. We use last Project operator to group
* the window columns into a struct so they can be accessed as `window.start` and `window.end`.
*
* The windows are calculated as below:
* maxNumOverlapping <- ceil(windowDuration / slideDuration)
* for (i <- 0 until maxNumOverlapping)
* windowId <- ceil((timestamp - startTime) / slideDuration)
* windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime
* windowEnd <- windowStart + windowDuration
* return windowStart, windowEnd
*
* This behaves as follows for the given parameters for the time: 12:05. The valid windows are
* marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the
* Filter operator.
* window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m
* 11:55 - 12:07 + 11:52 - 12:04 x
* 12:00 - 12:12 + 11:57 - 12:09 +
* 12:05 - 12:17 + 12:02 - 12:14 +
*
* @param plan The logical plan
* @return the logical plan that will generate the time windows using the Expand operator, with
* the Filter operator for correctness and Project for usability.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case p: LogicalPlan if p.children.size == 1 =>
val child = p.children.head
val windowExpressions =
p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet
val numWindowExpr = windowExpressions.size
// Only support a single window expression for now
if (numWindowExpr == 1 &&
windowExpressions.head.timeColumn.resolved &&
windowExpressions.head.checkInputDataTypes().isSuccess) {
val window = windowExpressions.head
val metadata = window.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}
def getWindow(i: Int, overlappingWindows: Int): Expression = {
val division = (PreciseTimestampConversion(
window.timeColumn, TimestampType, LongType) - window.startTime) / window.slideDuration
val ceil = Ceil(division)
// if the division is equal to the ceiling, our record is the start of a window
val windowId = CaseWhen(Seq((ceil === division, ceil + 1)), Some(ceil))
val windowStart = (windowId + i - overlappingWindows) *
window.slideDuration + window.startTime
val windowEnd = windowStart + window.windowDuration
CreateNamedStruct(
Literal(WINDOW_START) ::
PreciseTimestampConversion(windowStart, LongType, TimestampType) ::
Literal(WINDOW_END) ::
PreciseTimestampConversion(windowEnd, LongType, TimestampType) ::
Nil)
}
val windowAttr = AttributeReference(
WINDOW_COL_NAME, window.dataType, metadata = metadata)()
if (window.windowDuration == window.slideDuration) {
val windowStruct = Alias(getWindow(0, 1), WINDOW_COL_NAME)(
exprId = windowAttr.exprId, explicitMetadata = Some(metadata))
val replacedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}
// For backwards compatibility we add a filter to filter out nulls
val filterExpr = IsNotNull(window.timeColumn)
replacedPlan.withNewChildren(
Filter(filterExpr,
Project(windowStruct +: child.output, child)) :: Nil)
} else {
val overlappingWindows =
math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
val windows =
Seq.tabulate(overlappingWindows)(i => getWindow(i, overlappingWindows))
val projections = windows.map(_ +: child.output)
val filterExpr =
window.timeColumn >= windowAttr.getField(WINDOW_START) &&
window.timeColumn < windowAttr.getField(WINDOW_END)
val substitutedPlan = Filter(filterExpr,
Expand(projections, windowAttr +: child.output, child))
val renamedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}
renamedPlan.withNewChildren(substitutedPlan :: Nil)
}
} else if (numWindowExpr > 1) {
p.failAnalysis("Multiple time window expressions would result in a cartesian product " +
"of rows, therefore they are currently not supported.")
} else {
p // Return unchanged. Analyzer will throw exception later
}
}
}
/**
* Resolve a [[CreateNamedStruct]] if it contains [[NamePlaceholder]]s.
*/
object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
case e: CreateNamedStruct if !e.resolved =>
val children = e.children.grouped(2).flatMap {
case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
Seq(Literal(e.name), e)
case kv =>
kv
}
CreateNamedStruct(children.toList)
}
}
/**
* The aggregate expressions from subquery referencing outer query block are pushed
* down to the outer query block for evaluation. This rule below updates such outer references
* as AttributeReference referring attributes from the parent/outer query block.
*
* For example (SQL):
* {{{
* SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b))
* }}}
* Plan before the rule.
* Project [a#226]
* +- Filter exists#245 [min(b#227)#249]
* : +- Project [1 AS 1#247]
* : +- Filter (d#238 < min(outer(b#227))) <-----
* : +- SubqueryAlias r
* : +- Project [_1#234 AS c#237, _2#235 AS d#238]
* : +- LocalRelation [_1#234, _2#235]
* +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249]
* +- SubqueryAlias l
* +- Project [_1#223 AS a#226, _2#224 AS b#227]
* +- LocalRelation [_1#223, _2#224]
* Plan after the rule.
* Project [a#226]
* +- Filter exists#245 [min(b#227)#249]
* : +- Project [1 AS 1#247]
* : +- Filter (d#238 < outer(min(b#227)#249)) <-----
* : +- SubqueryAlias r
* : +- Project [_1#234 AS c#237, _2#235 AS d#238]
* : +- LocalRelation [_1#234, _2#235]
* +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249]
* +- SubqueryAlias l
* +- Project [_1#223 AS a#226, _2#224 AS b#227]
* +- LocalRelation [_1#223, _2#224]
*/
object UpdateOuterReferences extends Rule[LogicalPlan] {
private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child }
private def updateOuterReferenceInSubquery(
plan: LogicalPlan,
refExprs: Seq[Expression]): LogicalPlan = {
plan resolveExpressions { case e =>
val outerAlias =
refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e)))
outerAlias match {
case Some(a: Alias) => OuterReference(a.toAttribute)
case _ => e
}
}
}
def apply(plan: LogicalPlan): LogicalPlan = {
plan resolveOperators {
case f @ Filter(_, a: Aggregate) if f.resolved =>
f transformExpressions {
case s: SubqueryExpression if s.children.nonEmpty =>
// Collect the aliases from output of aggregate.
val outerAliases = a.aggregateExpressions collect { case a: Alias => a }
// Update the subquery plan to record the OuterReference to point to outer query plan.
s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases))
}
}
}
}