| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.sql.catalyst.analysis |
| |
| import java.util |
| import java.util.Locale |
| import java.util.concurrent.atomic.AtomicBoolean |
| |
| import scala.collection.mutable |
| import scala.collection.mutable.ArrayBuffer |
| import scala.util.{Failure, Random, Success, Try} |
| |
| import org.apache.spark.{SparkException, SparkUnsupportedOperationException} |
| import org.apache.spark.sql.AnalysisException |
| import org.apache.spark.sql.catalyst._ |
| import org.apache.spark.sql.catalyst.catalog._ |
| import org.apache.spark.sql.catalyst.encoders.OuterScopes |
| import org.apache.spark.sql.catalyst.expressions._ |
| import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ |
| import org.apache.spark.sql.catalyst.expressions.aggregate._ |
| import org.apache.spark.sql.catalyst.expressions.objects._ |
| import org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields |
| import org.apache.spark.sql.catalyst.plans._ |
| import org.apache.spark.sql.catalyst.plans.logical._ |
| import org.apache.spark.sql.catalyst.rules._ |
| import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 |
| import org.apache.spark.sql.catalyst.trees.AlwaysProcess |
| import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin |
| import org.apache.spark.sql.catalyst.trees.TreePattern._ |
| import org.apache.spark.sql.catalyst.types.DataTypeUtils |
| import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils} |
| import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ |
| import org.apache.spark.sql.connector.catalog.{View => _, _} |
| import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ |
| import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition} |
| import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction} |
| import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} |
| import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} |
| import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation |
| import org.apache.spark.sql.internal.SQLConf |
| import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy} |
| import org.apache.spark.sql.internal.connector.V1Function |
| import org.apache.spark.sql.types._ |
| import org.apache.spark.sql.types.DayTimeIntervalType.DAY |
| import org.apache.spark.sql.util.CaseInsensitiveStringMap |
| import org.apache.spark.util.ArrayImplicits._ |
| |
| /** |
| * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and |
| * [[EmptyTableFunctionRegistry]]. Used for testing when all relations are already filled |
| * in and the analyzer needs only to resolve attribute references. |
| * |
| * Built-in function registry is set for Spark Connect project to test unresolved |
| * functions. |
| */ |
| object SimpleAnalyzer extends Analyzer( |
| new CatalogManager( |
| FakeV2SessionCatalog, |
| new SessionCatalog( |
| new InMemoryCatalog, |
| FunctionRegistry.builtin, |
| TableFunctionRegistry.builtin) { |
| override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {} |
| })) { |
| override def resolver: Resolver = caseSensitiveResolution |
| } |
| |
| object FakeV2SessionCatalog extends TableCatalog with FunctionCatalog { |
| private def fail() = throw SparkUnsupportedOperationException() |
| override def listTables(namespace: Array[String]): Array[Identifier] = fail() |
| override def loadTable(ident: Identifier): Table = { |
| throw new NoSuchTableException(ident.asMultipartIdentifier) |
| } |
| override def createTable( |
| ident: Identifier, |
| schema: StructType, |
| partitions: Array[Transform], |
| properties: util.Map[String, String]): Table = fail() |
| override def alterTable(ident: Identifier, changes: TableChange*): Table = fail() |
| override def dropTable(ident: Identifier): Boolean = fail() |
| override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = fail() |
| override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = fail() |
| override def name(): String = CatalogManager.SESSION_CATALOG_NAME |
| override def listFunctions(namespace: Array[String]): Array[Identifier] = fail() |
| override def loadFunction(ident: Identifier): UnboundFunction = fail() |
| } |
| |
| /** |
| * Provides a way to keep state during the analysis, mostly for resolving views and subqueries. |
| * This enables us to decouple the concerns of analysis environment from the catalog and resolve |
| * star expressions in subqueries that reference the outer query plans. |
| * The state that is kept here is per-query. |
| * |
| * Note this is thread local. |
| * |
| * @param catalogAndNamespace The catalog and namespace used in the view resolution. This overrides |
| * the current catalog and namespace when resolving relations inside |
| * views. |
| * @param nestedViewDepth The nested depth in the view resolution, this enables us to limit the |
| * depth of nested views. |
| * @param maxNestedViewDepth The maximum allowed depth of nested view resolution. |
| * @param relationCache A mapping from qualified table names and time travel spec to resolved |
| * relations. This can ensure that the table is resolved only once if a table |
| * is used multiple times in a query. |
| * @param referredTempViewNames All the temp view names referred by the current view we are |
| * resolving. It's used to make sure the relation resolution is |
| * consistent between view creation and view resolution. For example, |
| * if `t` was a permanent table when the current view was created, it |
| * should still be a permanent table when resolving the current view, |
| * even if a temp view `t` has been created. |
| * @param outerPlan The query plan from the outer query that can be used to resolve star |
| * expressions in a subquery. |
| */ |
| case class AnalysisContext( |
| catalogAndNamespace: Seq[String] = Nil, |
| nestedViewDepth: Int = 0, |
| maxNestedViewDepth: Int = -1, |
| relationCache: mutable.Map[(Seq[String], Option[TimeTravelSpec]), LogicalPlan] = |
| mutable.Map.empty, |
| referredTempViewNames: Seq[Seq[String]] = Seq.empty, |
| // 1. If we are resolving a view, this field will be restored from the view metadata, |
| // by calling `AnalysisContext.withAnalysisContext(viewDesc)`. |
| // 2. If we are not resolving a view, this field will be updated everytime the analyzer |
| // lookup a temporary function. And export to the view metadata. |
| referredTempFunctionNames: mutable.Set[String] = mutable.Set.empty, |
| referredTempVariableNames: Seq[Seq[String]] = Seq.empty, |
| outerPlan: Option[LogicalPlan] = None) |
| |
| object AnalysisContext { |
| private val value = new ThreadLocal[AnalysisContext]() { |
| override def initialValue: AnalysisContext = AnalysisContext() |
| } |
| |
| def get: AnalysisContext = value.get() |
| def reset(): Unit = value.remove() |
| |
| private def set(context: AnalysisContext): Unit = value.set(context) |
| |
| def withAnalysisContext[A](viewDesc: CatalogTable)(f: => A): A = { |
| val originContext = value.get() |
| val maxNestedViewDepth = if (originContext.maxNestedViewDepth == -1) { |
| // Here we start to resolve views, get `maxNestedViewDepth` from configs. |
| SQLConf.get.maxNestedViewDepth |
| } else { |
| originContext.maxNestedViewDepth |
| } |
| val context = AnalysisContext( |
| viewDesc.viewCatalogAndNamespace, |
| originContext.nestedViewDepth + 1, |
| maxNestedViewDepth, |
| originContext.relationCache, |
| viewDesc.viewReferredTempViewNames, |
| mutable.Set(viewDesc.viewReferredTempFunctionNames: _*), |
| viewDesc.viewReferredTempVariableNames) |
| set(context) |
| try f finally { set(originContext) } |
| } |
| |
| def withNewAnalysisContext[A](f: => A): A = { |
| val originContext = value.get() |
| reset() |
| try f finally { set(originContext) } |
| } |
| |
| def withOuterPlan[A](outerPlan: LogicalPlan)(f: => A): A = { |
| val originContext = value.get() |
| val context = originContext.copy(outerPlan = Some(outerPlan)) |
| set(context) |
| try f finally { set(originContext) } |
| } |
| } |
| |
| /** |
| * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and |
| * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. |
| */ |
| class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor[LogicalPlan] |
| with CheckAnalysis with SQLConfHelper with ColumnResolutionHelper { |
| |
| private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog |
| |
| override protected def validatePlanChanges( |
| previousPlan: LogicalPlan, |
| currentPlan: LogicalPlan): Option[String] = { |
| LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan) |
| } |
| |
| override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts) |
| |
| // Only for tests. |
| def this(catalog: SessionCatalog) = { |
| this(new CatalogManager(FakeV2SessionCatalog, catalog)) |
| } |
| |
| def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { |
| if (plan.analyzed) return plan |
| AnalysisHelper.markInAnalyzer { |
| val analyzed = executeAndTrack(plan, tracker) |
| checkAnalysis(analyzed) |
| analyzed |
| } |
| } |
| |
| override def execute(plan: LogicalPlan): LogicalPlan = { |
| AnalysisContext.withNewAnalysisContext { |
| executeSameContext(plan) |
| } |
| } |
| |
| private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan) |
| |
| def resolver: Resolver = conf.resolver |
| |
| /** |
| * If the plan cannot be resolved within maxIterations, analyzer will throw exception to inform |
| * user to increase the value of SQLConf.ANALYZER_MAX_ITERATIONS. |
| */ |
| protected def fixedPoint = |
| FixedPoint( |
| conf.analyzerMaxIterations, |
| errorOnExceed = true, |
| maxIterationsSetting = SQLConf.ANALYZER_MAX_ITERATIONS.key) |
| |
| /** |
| * Override to provide additional rules for the "Resolution" batch. |
| */ |
| val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil |
| |
| /** |
| * Override to provide rules to do post-hoc resolution. Note that these rules will be executed |
| * in an individual batch. This batch is to run right after the normal resolution batch and |
| * execute its rules in one pass. |
| */ |
| val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil |
| |
| private def typeCoercionRules(): List[Rule[LogicalPlan]] = if (conf.ansiEnabled) { |
| AnsiTypeCoercion.typeCoercionRules |
| } else { |
| TypeCoercion.typeCoercionRules |
| } |
| |
| override def batches: Seq[Batch] = Seq( |
| Batch("Substitution", fixedPoint, |
| new SubstituteExecuteImmediate(catalogManager), |
| // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. |
| // However, when manipulating deeply nested schema, `UpdateFields` expression tree could be |
| // very complex and make analysis impossible. Thus we need to optimize `UpdateFields` early |
| // at the beginning of analysis. |
| OptimizeUpdateFields, |
| CTESubstitution, |
| WindowsSubstitution, |
| EliminateUnions, |
| SubstituteUnresolvedOrdinals), |
| Batch("Disable Hints", Once, |
| new ResolveHints.DisableHints), |
| Batch("Hints", fixedPoint, |
| ResolveHints.ResolveJoinStrategyHints, |
| ResolveHints.ResolveCoalesceHints), |
| Batch("Simple Sanity Check", Once, |
| LookupFunctions), |
| Batch("Keep Legacy Outputs", Once, |
| KeepLegacyOutputs), |
| Batch("Resolution", fixedPoint, |
| new ResolveCatalogs(catalogManager) :: |
| ResolveInsertInto :: |
| ResolveRelations :: |
| ResolvePartitionSpec :: |
| ResolveFieldNameAndPosition :: |
| AddMetadataColumns :: |
| DeduplicateRelations :: |
| new ResolveReferences(catalogManager) :: |
| // Please do not insert any other rules in between. See the TODO comments in rule |
| // ResolveLateralColumnAliasReference for more details. |
| ResolveLateralColumnAliasReference :: |
| ResolveExpressionsWithNamePlaceholders :: |
| ResolveDeserializer :: |
| ResolveNewInstance :: |
| ResolveUpCast :: |
| ResolveGroupingAnalytics :: |
| ResolvePivot :: |
| ResolveUnpivot :: |
| ResolveOrdinalInOrderByAndGroupBy :: |
| ExtractGenerator :: |
| ResolveGenerate :: |
| ResolveFunctions :: |
| ResolveTableSpec :: |
| ResolveAliases :: |
| ResolveSubquery :: |
| ResolveSubqueryColumnAliases :: |
| ResolveWindowOrder :: |
| ResolveWindowFrame :: |
| ResolveNaturalAndUsingJoin :: |
| ResolveOutputRelation :: |
| new ResolveDataFrameDropColumns(catalogManager) :: |
| new ResolveSetVariable(catalogManager) :: |
| ExtractWindowExpressions :: |
| GlobalAggregates :: |
| ResolveAggregateFunctions :: |
| TimeWindowing :: |
| SessionWindowing :: |
| ResolveWindowTime :: |
| ResolveInlineTables :: |
| ResolveLambdaVariables :: |
| ResolveTimeZone :: |
| ResolveRandomSeed :: |
| ResolveBinaryArithmetic :: |
| ResolveIdentifierClause :: |
| ResolveUnion :: |
| ResolveRowLevelCommandAssignments :: |
| RewriteDeleteFromTable :: |
| RewriteUpdateTable :: |
| RewriteMergeIntoTable :: |
| MoveParameterizedQueriesDown :: |
| BindParameters :: |
| typeCoercionRules() ++ |
| Seq( |
| ResolveWithCTE, |
| ExtractDistributedSequenceID) ++ |
| Seq(ResolveUpdateEventTimeWatermarkColumn) ++ |
| extendedResolutionRules : _*), |
| Batch("Remove TempResolvedColumn", Once, RemoveTempResolvedColumn), |
| Batch("Post-Hoc Resolution", Once, |
| Seq(ResolveCommandsWithIfExists) ++ |
| postHocResolutionRules: _*), |
| Batch("Remove Unresolved Hints", Once, |
| new ResolveHints.RemoveAllHints), |
| Batch("Nondeterministic", Once, |
| PullOutNondeterministic), |
| Batch("UpdateNullability", Once, |
| UpdateAttributeNullability), |
| Batch("UDF", Once, |
| HandleNullInputsForUDF, |
| ResolveEncodersInUDF), |
| Batch("Subquery", Once, |
| UpdateOuterReferences), |
| Batch("Cleanup", fixedPoint, |
| CleanupAliases), |
| Batch("HandleSpecialCommand", Once, |
| HandleSpecialCommand), |
| Batch("Remove watermark for batch query", Once, |
| EliminateEventTimeWatermark) |
| ) |
| |
| /** |
| * For [[Add]]: |
| * 1. if both side are interval, stays the same; |
| * 2. else if one side is date and the other is interval, |
| * turns it to [[DateAddInterval]]; |
| * 3. else if one side is interval, turns it to [[TimeAdd]]; |
| * 4. else if one side is date, turns it to [[DateAdd]] ; |
| * 5. else stays the same. |
| * |
| * For [[Subtract]]: |
| * 1. if both side are interval, stays the same; |
| * 2. else if the left side is date and the right side is interval, |
| * turns it to [[DateAddInterval(l, -r)]]; |
| * 3. else if the right side is an interval, turns it to [[TimeAdd(l, -r)]]; |
| * 4. else if one side is timestamp, turns it to [[SubtractTimestamps]]; |
| * 5. else if the right side is date, turns it to [[DateDiff]]/[[SubtractDates]]; |
| * 6. else if the left side is date, turns it to [[DateSub]]; |
| * 7. else turns it to stays the same. |
| * |
| * For [[Multiply]]: |
| * 1. If one side is interval, turns it to [[MultiplyInterval]]; |
| * 2. otherwise, stays the same. |
| * |
| * For [[Divide]]: |
| * 1. If the left side is interval, turns it to [[DivideInterval]]; |
| * 2. otherwise, stays the same. |
| */ |
| object ResolveBinaryArithmetic extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = |
| plan.resolveExpressionsUpWithPruning(_.containsPattern(BINARY_ARITHMETIC), ruleId) { |
| case a @ Add(l, r, mode) if a.childrenResolved => (l.dataType, r.dataType) match { |
| case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r)) |
| case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, TimestampType), r) |
| case (DayTimeIntervalType(DAY, DAY), DateType) => DateAdd(r, ExtractANSIIntervalDays(l)) |
| case (_: DayTimeIntervalType, DateType) => TimeAdd(Cast(r, TimestampType), l) |
| case (DateType, _: YearMonthIntervalType) => DateAddYMInterval(l, r) |
| case (_: YearMonthIntervalType, DateType) => DateAddYMInterval(r, l) |
| case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) => |
| TimestampAddYMInterval(l, r) |
| case (_: YearMonthIntervalType, TimestampType | TimestampNTZType) => |
| TimestampAddYMInterval(r, l) |
| case (CalendarIntervalType, CalendarIntervalType) | |
| (_: DayTimeIntervalType, _: DayTimeIntervalType) => a |
| case (_: NullType, _: AnsiIntervalType) => |
| a.copy(left = Cast(a.left, a.right.dataType)) |
| case (_: AnsiIntervalType, _: NullType) => |
| a.copy(right = Cast(a.right, a.left.dataType)) |
| case (DateType, CalendarIntervalType) => |
| DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI) |
| case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType) |
| case (CalendarIntervalType, DateType) => |
| DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI) |
| case (CalendarIntervalType | _: DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType) |
| case (DateType, dt) if dt != StringType => DateAdd(l, r) |
| case (dt, DateType) if dt != StringType => DateAdd(r, l) |
| case _ => a |
| } |
| case s @ Subtract(l, r, mode) if s.childrenResolved => (l.dataType, r.dataType) match { |
| case (DateType, DayTimeIntervalType(DAY, DAY)) => |
| DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI)) |
| case (DateType, _: DayTimeIntervalType) => |
| DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI))) |
| case (DateType, _: YearMonthIntervalType) => |
| DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))) |
| case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) => |
| DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))) |
| case (CalendarIntervalType, CalendarIntervalType) | |
| (_: DayTimeIntervalType, _: DayTimeIntervalType) => s |
| case (_: NullType, _: AnsiIntervalType) => |
| s.copy(left = Cast(s.left, s.right.dataType)) |
| case (_: AnsiIntervalType, _: NullType) => |
| s.copy(right = Cast(s.right, s.left.dataType)) |
| case (DateType, CalendarIntervalType) => |
| DatetimeSub(l, r, DateAddInterval(l, |
| UnaryMinus(r, mode == EvalMode.ANSI), ansiEnabled = mode == EvalMode.ANSI)) |
| case (_, CalendarIntervalType | _: DayTimeIntervalType) => |
| Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType) |
| case _ if AnyTimestampTypeExpression.unapply(l) || |
| AnyTimestampTypeExpression.unapply(r) => SubtractTimestamps(l, r) |
| case (_, DateType) => SubtractDates(l, r) |
| case (DateType, dt) if dt != StringType => DateSub(l, r) |
| case _ => s |
| } |
| case m @ Multiply(l, r, mode) if m.childrenResolved => (l.dataType, r.dataType) match { |
| case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == EvalMode.ANSI) |
| case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == EvalMode.ANSI) |
| case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r) |
| case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l) |
| case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r) |
| case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l) |
| case _ => m |
| } |
| case d @ Divide(l, r, mode) if d.childrenResolved => (l.dataType, r.dataType) match { |
| case (CalendarIntervalType, _) => DivideInterval(l, r, mode == EvalMode.ANSI) |
| case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r) |
| case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r) |
| case _ => d |
| } |
| } |
| } |
| |
| /** |
| * Substitute child plan with WindowSpecDefinitions. |
| */ |
| object WindowsSubstitution extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsAnyPattern(WITH_WINDOW_DEFINITION, UNRESOLVED_WINDOW_EXPRESSION), ruleId) { |
| // Lookup WindowSpecDefinitions. This rule works with unresolved children. |
| case WithWindowDefinition(windowDefinitions, child) => child.resolveExpressions { |
| case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => |
| val windowSpecDefinition = windowDefinitions.getOrElse(windowName, |
| throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowName)) |
| WindowExpression(c, windowSpecDefinition) |
| } |
| } |
| } |
| |
| /** |
| * Replaces [[UnresolvedAlias]]s with concrete aliases. |
| */ |
| object ResolveAliases extends Rule[LogicalPlan] { |
| private def assignAliases(exprs: Seq[NamedExpression]) = { |
| def extractOnly(e: Expression): Boolean = e match { |
| case _: ExtractValue => e.children.forall(extractOnly) |
| case _: Literal => true |
| case _: Attribute => true |
| case _ => false |
| } |
| exprs.map(_.transformUpWithPruning(_.containsPattern(UNRESOLVED_ALIAS)) { |
| case u @ UnresolvedAlias(child, optGenAliasFunc) => |
| child match { |
| case ne: NamedExpression => ne |
| case go @ GeneratorOuter(g: Generator) if g.resolved => MultiAlias(go, Nil) |
| case e if !e.resolved => u |
| case g: Generator => MultiAlias(g, Nil) |
| case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)() |
| case e: ExtractValue if extractOnly(e) => Alias(e, toPrettySQL(e))() |
| case e if optGenAliasFunc.isDefined => |
| Alias(child, optGenAliasFunc.get.apply(e))() |
| case l: Literal => Alias(l, toPrettySQL(l))() |
| case e => |
| val metaForAutoGeneratedAlias = new MetadataBuilder() |
| .putString(AUTO_GENERATED_ALIAS, "true") |
| .build() |
| Alias(e, toPrettySQL(e))(explicitMetadata = Some(metaForAutoGeneratedAlias)) |
| } |
| } |
| ).asInstanceOf[Seq[NamedExpression]] |
| } |
| |
| private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = |
| exprs.exists(_.exists(_.isInstanceOf[UnresolvedAlias])) |
| |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(UNRESOLVED_ALIAS), ruleId) { |
| case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => |
| Aggregate(groups, assignAliases(aggs), child) |
| |
| case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child) |
| if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => |
| Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) |
| |
| case up: Unpivot if up.child.resolved && |
| (up.ids.exists(hasUnresolvedAlias) || up.values.exists(_.exists(hasUnresolvedAlias))) => |
| up.copy(ids = up.ids.map(assignAliases), values = up.values.map(_.map(assignAliases))) |
| |
| case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => |
| Project(assignAliases(projectList), child) |
| |
| case c: CollectMetrics if c.child.resolved && hasUnresolvedAlias(c.metrics) => |
| c.copy(metrics = assignAliases(c.metrics)) |
| } |
| } |
| |
| object ResolveGroupingAnalytics extends Rule[LogicalPlan] { |
| private[analysis] def hasGroupingFunction(e: Expression): Boolean = { |
| e.exists (g => g.isInstanceOf[Grouping] || g.isInstanceOf[GroupingID]) |
| } |
| |
| private def replaceGroupingFunc( |
| expr: Expression, |
| groupByExprs: Seq[Expression], |
| gid: Expression): Expression = { |
| expr transform { |
| case e: GroupingID => |
| if (e.groupByExprs.isEmpty || |
| e.groupByExprs.map(_.canonicalized) == groupByExprs.map(_.canonicalized)) { |
| Alias(gid, toPrettySQL(e))() |
| } else { |
| throw QueryCompilationErrors.groupingIDMismatchError(e, groupByExprs) |
| } |
| case e @ Grouping(col: Expression) => |
| val idx = groupByExprs.indexWhere(_.semanticEquals(col)) |
| if (idx >= 0) { |
| Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), |
| Literal(1L)), ByteType), toPrettySQL(e))() |
| } else { |
| throw QueryCompilationErrors.groupingColInvalidError(col, groupByExprs) |
| } |
| } |
| } |
| |
| /* |
| * Create new alias for all group by expressions for `Expand` operator. |
| */ |
| private def constructGroupByAlias(groupByExprs: Seq[Expression]): Seq[Alias] = { |
| groupByExprs.map { |
| case e: NamedExpression => Alias(e, e.name)(qualifier = e.qualifier) |
| case other => Alias(other, other.toString)() |
| } |
| } |
| |
| /* |
| * Construct [[Expand]] operator with grouping sets. |
| */ |
| private def constructExpand( |
| selectedGroupByExprs: Seq[Seq[Expression]], |
| child: LogicalPlan, |
| groupByAliases: Seq[Alias], |
| gid: Attribute): LogicalPlan = { |
| // Change the nullability of group by aliases if necessary. For example, if we have |
| // GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we |
| // should change the nullability of b to be TRUE. |
| // TODO: For Cube/Rollup just set nullability to be `true`. |
| val expandedAttributes = groupByAliases.map { alias => |
| if (selectedGroupByExprs.exists(!_.contains(alias.child))) { |
| alias.toAttribute.withNullability(true) |
| } else { |
| alias.toAttribute |
| } |
| } |
| |
| val groupingSetsAttributes = selectedGroupByExprs.map { groupingSetExprs => |
| groupingSetExprs.map { expr => |
| val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse( |
| throw QueryCompilationErrors.selectExprNotInGroupByError(expr, groupByAliases)) |
| // Map alias to expanded attribute. |
| expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse( |
| alias.toAttribute) |
| } |
| } |
| |
| Expand(groupingSetsAttributes, groupByAliases, expandedAttributes, gid, child) |
| } |
| |
| /* |
| * Construct new aggregate expressions by replacing grouping functions. |
| */ |
| private def constructAggregateExprs( |
| groupByExprs: Seq[Expression], |
| aggregations: Seq[NamedExpression], |
| groupByAliases: Seq[Alias], |
| groupingAttrs: Seq[Expression], |
| gid: Attribute): Seq[NamedExpression] = { |
| def replaceExprs(e: Expression): Expression = e match { |
| case e: AggregateExpression => e |
| case e => |
| // Replace expression by expand output attribute. |
| val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) |
| if (index == -1) { |
| e.mapChildren(replaceExprs) |
| } else { |
| groupingAttrs(index) |
| } |
| } |
| aggregations |
| .map(replaceGroupingFunc(_, groupByExprs, gid)) |
| .map(replaceExprs) |
| .map(_.asInstanceOf[NamedExpression]) |
| } |
| |
| /* |
| * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets. |
| */ |
| private def constructAggregate( |
| selectedGroupByExprs: Seq[Seq[Expression]], |
| groupByExprs: Seq[Expression], |
| aggregationExprs: Seq[NamedExpression], |
| child: LogicalPlan): LogicalPlan = { |
| |
| if (groupByExprs.size > GroupingID.dataType.defaultSize * 8) { |
| throw QueryCompilationErrors.groupingSizeTooLargeError(GroupingID.dataType.defaultSize * 8) |
| } |
| |
| // Expand works by setting grouping expressions to null as determined by the |
| // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate |
| // instead of the original value we need to create new aliases for all group by expressions |
| // that will only be used for the intended purpose. |
| val groupByAliases = constructGroupByAlias(groupByExprs) |
| |
| val gid = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)() |
| val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid) |
| val groupingAttrs = expand.output.drop(child.output.length) |
| |
| val aggregations = constructAggregateExprs( |
| groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid) |
| |
| Aggregate(groupingAttrs, aggregations, expand) |
| } |
| |
| private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { |
| plan.collectFirst { |
| case a: Aggregate => |
| // this Aggregate should have grouping id as the last grouping key. |
| val gid = a.groupingExpressions.last |
| if (!gid.isInstanceOf[AttributeReference] |
| || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) { |
| throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError() |
| } |
| a.groupingExpressions.take(a.groupingExpressions.length - 1) |
| }.getOrElse { |
| throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError() |
| } |
| } |
| |
| private def tryResolveHavingCondition( |
| h: UnresolvedHaving, |
| aggregate: Aggregate, |
| selectedGroupByExprs: Seq[Seq[Expression]], |
| groupByExprs: Seq[Expression]): LogicalPlan = { |
| // For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from |
| // groupingExpressions for condition resolving. |
| val aggForResolving = aggregate.copy(groupingExpressions = groupByExprs) |
| // HACK ALTER! Ideally we should only resolve GROUPING SETS + HAVING when the having condition |
| // is fully resolved, similar to the rule `ResolveAggregateFunctions`. However, Aggregate |
| // with GROUPING SETS is marked as unresolved and many analyzer rules can't apply to |
| // UnresolvedHaving because its child is not resolved. Here we explicitly resolve columns |
| // and subqueries of UnresolvedHaving so that the rewrite works in most cases. |
| // TODO: mark Aggregate as resolved even if it has GROUPING SETS. We can expand it at the end |
| // of the analysis phase. |
| val colResolved = h.mapExpressions { e => |
| resolveExpressionByPlanOutput( |
| resolveColWithAgg(e, aggForResolving), aggForResolving, includeLastResort = true) |
| } |
| val cond = if (SubqueryExpression.hasSubquery(colResolved.havingCondition)) { |
| val fake = Project(Alias(colResolved.havingCondition, "fake")() :: Nil, aggregate.child) |
| ResolveSubquery(fake).asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child |
| } else { |
| colResolved.havingCondition |
| } |
| // Try resolving the condition of the filter as though it is in the aggregate clause |
| val (extraAggExprs, Seq(resolvedHavingCond)) = |
| ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(cond), aggForResolving) |
| |
| // Push the aggregate expressions into the aggregate (if any). |
| val newChild = constructAggregate(selectedGroupByExprs, groupByExprs, |
| aggregate.aggregateExpressions ++ extraAggExprs, aggregate.child) |
| |
| // Since the output exprId will be changed in the constructed aggregate, here we build an |
| // attrMap to resolve the condition again. |
| val attrMap = AttributeMap((aggForResolving.output ++ extraAggExprs.map(_.toAttribute)) |
| .zip(newChild.output)) |
| val newCond = resolvedHavingCond.transform { |
| case a: AttributeReference => attrMap.getOrElse(a, a) |
| } |
| |
| if (extraAggExprs.isEmpty) { |
| Filter(newCond, newChild) |
| } else { |
| Project(newChild.output.dropRight(extraAggExprs.length), |
| Filter(newCond, newChild)) |
| } |
| } |
| |
| // This require transformDown to resolve having condition when generating aggregate node for |
| // CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved |
| // Filter/Sort. |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDownWithPruning( |
| _.containsPattern(GROUPING_ANALYTICS), ruleId) { |
| case h @ UnresolvedHaving(_, agg @ Aggregate( |
| GroupingAnalytics(selectedGroupByExprs, groupByExprs), aggExprs, _)) |
| if agg.childrenResolved && aggExprs.forall(_.resolved) => |
| tryResolveHavingCondition(h, agg, selectedGroupByExprs, groupByExprs) |
| |
| // Make sure all of the children are resolved. |
| // We can't put this at the beginning, because `Aggregate` with GROUPING SETS is unresolved |
| // but we need to resolve `UnresolvedHaving` above it. |
| case a if !a.childrenResolved => a |
| |
| // Ensure group by expressions and aggregate expressions have been resolved. |
| case Aggregate(GroupingAnalytics(selectedGroupByExprs, groupByExprs), aggExprs, child) |
| if aggExprs.forall(_.resolved) => |
| constructAggregate(selectedGroupByExprs, groupByExprs, aggExprs, child) |
| |
| // We should make sure all expressions in condition have been resolved. |
| case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved => |
| val groupingExprs = findGroupingExprs(child) |
| // The unresolved grouping id will be resolved by ResolveReferences |
| val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) |
| f.copy(condition = newCond) |
| |
| // We should make sure all [[SortOrder]]s have been resolved. |
| case s @ Sort(order, _, child) |
| if order.exists(hasGroupingFunction) && order.forall(_.resolved) => |
| val groupingExprs = findGroupingExprs(child) |
| val gid = VirtualColumn.groupingIdAttribute |
| // The unresolved grouping id will be resolved by ResolveReferences |
| val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) |
| s.copy(order = newOrder) |
| } |
| } |
| |
| object ResolvePivot extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( |
| _.containsPattern(PIVOT), ruleId) { |
| case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) |
| || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) |
| || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p |
| case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => |
| if (!RowOrdering.isOrderable(pivotColumn.dataType)) { |
| throw QueryCompilationErrors.unorderablePivotColError(pivotColumn) |
| } |
| // Check all aggregate expressions. |
| aggregates.foreach(checkValidAggregateExpression) |
| // Check all pivot values are literal and match pivot column data type. |
| val evalPivotValues = pivotValues.map { value => |
| val foldable = trimAliases(value).foldable |
| if (!foldable) { |
| throw QueryCompilationErrors.nonLiteralPivotValError(value) |
| } |
| if (!Cast.canCast(value.dataType, pivotColumn.dataType)) { |
| throw QueryCompilationErrors.pivotValDataTypeMismatchError(value, pivotColumn) |
| } |
| Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) |
| } |
| // Group-by expressions coming from SQL are implicit and need to be deduced. |
| val groupByExprs = groupByExprsOpt.getOrElse { |
| val pivotColAndAggRefs = pivotColumn.references ++ AttributeSet(aggregates) |
| child.output.filterNot(pivotColAndAggRefs.contains) |
| } |
| val singleAgg = aggregates.size == 1 |
| def outputName(value: Expression, aggregate: Expression): String = { |
| val stringValue = value match { |
| case n: NamedExpression => n.name |
| case _ => |
| val utf8Value = |
| Cast(value, StringType, Some(conf.sessionLocalTimeZone)).eval(EmptyRow) |
| Option(utf8Value).map(_.toString).getOrElse("null") |
| } |
| if (singleAgg) { |
| stringValue |
| } else { |
| val suffix = aggregate match { |
| case n: NamedExpression => n.name |
| case _ => toPrettySQL(aggregate) |
| } |
| stringValue + "_" + suffix |
| } |
| } |
| if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { |
| // Since evaluating |pivotValues| if statements for each input row can get slow this is an |
| // alternate plan that instead uses two steps of aggregation. |
| val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) |
| val namedPivotCol = pivotColumn match { |
| case n: NamedExpression => n |
| case _ => Alias(pivotColumn, "__pivot_col")() |
| } |
| val bigGroup = groupByExprs :+ namedPivotCol |
| val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) |
| val pivotAggs = namedAggExps.map { a => |
| Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues) |
| .toAggregateExpression() |
| , "__pivot_" + a.sql)() |
| } |
| val groupByExprsAttr = groupByExprs.map(_.toAttribute) |
| val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg) |
| val pivotAggAttribute = pivotAggs.map(_.toAttribute) |
| val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => |
| aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => |
| Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() |
| } |
| } |
| Project(groupByExprsAttr ++ pivotOutputs, secondAgg) |
| } else { |
| val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => |
| def ifExpr(e: Expression) = { |
| If( |
| EqualNullSafe( |
| pivotColumn, |
| Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))), |
| e, Literal(null)) |
| } |
| aggregates.map { aggregate => |
| val filteredAggregate = aggregate.transformDown { |
| // Assumption is the aggregate function ignores nulls. This is true for all current |
| // AggregateFunction's with the exception of First and Last in their default mode |
| // (which we handle) and possibly some Hive UDAF's. |
| case First(expr, _) => |
| First(ifExpr(expr), true) |
| case Last(expr, _) => |
| Last(ifExpr(expr), true) |
| case a: ApproximatePercentile => |
| // ApproximatePercentile takes two literals for accuracy and percentage which |
| // should not be wrapped by if-else. |
| a.withNewChildren(ifExpr(a.first) :: a.second :: a.third :: Nil) |
| case a: AggregateFunction => |
| a.withNewChildren(a.children.map(ifExpr)) |
| }.transform { |
| // We are duplicating aggregates that are now computing a different value for each |
| // pivot value. |
| // TODO: Don't construct the physical container until after analysis. |
| case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) |
| } |
| Alias(filteredAggregate, outputName(value, aggregate))() |
| } |
| } |
| Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) |
| } |
| } |
| |
| // Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF. |
| // TODO: Support Pandas UDF. |
| private def checkValidAggregateExpression(expr: Expression): Unit = expr match { |
| case a: AggregateExpression => |
| if (a.aggregateFunction.isInstanceOf[PythonUDAF]) { |
| throw QueryCompilationErrors.pandasUDFAggregateNotSupportedInPivotError() |
| } else { |
| // OK and leave the argument check to CheckAnalysis. |
| } |
| case e: Attribute => |
| throw QueryCompilationErrors.aggregateExpressionRequiredForPivotError(e.sql) |
| case e => e.children.foreach(checkValidAggregateExpression) |
| } |
| } |
| |
| object ResolveUnpivot extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( |
| _.containsPattern(UNPIVOT), ruleId) { |
| |
| // once children are resolved, we can determine values from ids and vice versa |
| // if only either is given, and only AttributeReference are given |
| case up @ Unpivot(Some(ids), None, _, _, _, _) if up.childrenResolved && |
| ids.forall(_.resolved) && |
| ids.forall(_.isInstanceOf[AttributeReference]) => |
| val idAttrs = AttributeSet(up.ids.get) |
| val values = up.child.output.filterNot(idAttrs.contains) |
| up.copy(values = Some(values.map(Seq(_)))) |
| case up @ Unpivot(None, Some(values), _, _, _, _) if up.childrenResolved && |
| values.forall(_.forall(_.resolved)) && |
| values.forall(_.forall(_.isInstanceOf[AttributeReference])) => |
| val valueAttrs = AttributeSet(up.values.get.flatten) |
| val ids = up.child.output.filterNot(valueAttrs.contains) |
| up.copy(ids = Some(ids)) |
| |
| case up: Unpivot if !up.childrenResolved || !up.ids.exists(_.forall(_.resolved)) || |
| !up.values.exists(_.nonEmpty) || !up.values.exists(_.forall(_.forall(_.resolved))) || |
| !up.values.get.forall(_.length == up.valueColumnNames.length) || |
| !up.valuesTypeCoercioned => up |
| |
| // TypeCoercionBase.UnpivotCoercion determines valueType |
| // and casts values once values are set and resolved |
| case Unpivot(Some(ids), Some(values), aliases, variableColumnName, valueColumnNames, child) => |
| |
| def toString(values: Seq[NamedExpression]): String = |
| values.map(v => v.name).mkString("_") |
| |
| // construct unpivot expressions for Expand |
| val exprs: Seq[Seq[Expression]] = |
| values.zip(aliases.getOrElse(values.map(_ => None))).map { |
| case (vals, Some(alias)) => (ids :+ Literal(alias)) ++ vals |
| case (Seq(value), None) => (ids :+ Literal(value.name)) :+ value |
| // there are more than one value in vals |
| case (vals, None) => (ids :+ Literal(toString(vals))) ++ vals |
| } |
| |
| // construct output attributes |
| val variableAttr = AttributeReference(variableColumnName, StringType, nullable = false)() |
| val valueAttrs = valueColumnNames.zipWithIndex.map { |
| case (valueColumnName, idx) => |
| AttributeReference( |
| valueColumnName, |
| values.head(idx).dataType, |
| values.map(_(idx)).exists(_.nullable))() |
| } |
| val output = (ids.map(_.toAttribute) :+ variableAttr) ++ valueAttrs |
| |
| // expand the unpivot expressions |
| Expand(exprs, output, child) |
| } |
| } |
| |
| private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty |
| private def isReferredTempViewName(nameParts: Seq[String]): Boolean = { |
| AnalysisContext.get.referredTempViewNames.exists { n => |
| (n.length == nameParts.length) && n.zip(nameParts).forall { |
| case (a, b) => resolver(a, b) |
| } |
| } |
| } |
| |
| // If we are resolving database objects (relations, functions, etc.) insides views, we may need to |
| // expand single or multi-part identifiers with the current catalog and namespace of when the |
| // view was created. |
| private def expandIdentifier(nameParts: Seq[String]): Seq[String] = { |
| if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts |
| |
| if (nameParts.length == 1) { |
| AnalysisContext.get.catalogAndNamespace :+ nameParts.head |
| } else if (catalogManager.isCatalogRegistered(nameParts.head)) { |
| nameParts |
| } else { |
| AnalysisContext.get.catalogAndNamespace.head +: nameParts |
| } |
| } |
| |
| /** |
| * Adds metadata columns to output for child relations when nodes are missing resolved attributes. |
| * |
| * References to metadata columns are resolved using columns from [[LogicalPlan.metadataOutput]], |
| * but the relation's output does not include the metadata columns until the relation is replaced. |
| * Unless this rule adds metadata to the relation's output, the analyzer will detect that nothing |
| * produces the columns. |
| * |
| * This rule only adds metadata columns when a node is resolved but is missing input from its |
| * children. This ensures that metadata columns are not added to the plan unless they are used. By |
| * checking only resolved nodes, this ensures that * expansion is already done so that metadata |
| * columns are not accidentally selected by *. This rule resolves operators downwards to avoid |
| * projecting away metadata columns prematurely. |
| */ |
| object AddMetadataColumns extends Rule[LogicalPlan] { |
| import org.apache.spark.sql.catalyst.util._ |
| |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDownWithPruning( |
| AlwaysProcess.fn, ruleId) { |
| case hint: UnresolvedHint => hint |
| // Add metadata output to all node types |
| case node if node.children.nonEmpty && node.resolved && hasMetadataCol(node) => |
| val inputAttrs = AttributeSet(node.children.flatMap(_.output)) |
| val metaCols = getMetadataAttributes(node).filterNot(inputAttrs.contains) |
| if (metaCols.isEmpty) { |
| node |
| } else { |
| val newNode = node.mapChildren(addMetadataCol(_, metaCols.map(_.exprId).toSet)) |
| // We should not change the output schema of the plan. We should project away the extra |
| // metadata columns if necessary. |
| if (newNode.sameOutput(node)) { |
| newNode |
| } else { |
| Project(node.output, newNode) |
| } |
| } |
| } |
| |
| private def getMetadataAttributes(plan: LogicalPlan): Seq[Attribute] = { |
| plan.expressions.flatMap(_.collect { |
| case a: Attribute if a.isMetadataCol => a |
| case a: Attribute |
| if plan.children.exists(c => c.metadataOutput.exists(_.exprId == a.exprId)) => |
| plan.children.collectFirst { |
| case c if c.metadataOutput.exists(_.exprId == a.exprId) => |
| c.metadataOutput.find(_.exprId == a.exprId).get |
| }.get |
| }) |
| } |
| |
| private def hasMetadataCol(plan: LogicalPlan): Boolean = { |
| plan.expressions.exists(_.exists { |
| case a: Attribute => |
| // If an attribute is resolved before being labeled as metadata |
| // (i.e. from the originating Dataset), we check with expression ID |
| a.isMetadataCol || |
| plan.children.exists(c => c.metadataOutput.exists(_.exprId == a.exprId)) |
| case _ => false |
| }) |
| } |
| |
| private def addMetadataCol( |
| plan: LogicalPlan, |
| requiredAttrIds: Set[ExprId]): LogicalPlan = plan match { |
| case s: ExposesMetadataColumns if s.metadataOutput.exists( a => |
| requiredAttrIds.contains(a.exprId)) => |
| s.withMetadataColumns() |
| case p: Project if p.metadataOutput.exists(a => requiredAttrIds.contains(a.exprId)) => |
| val newProj = p.copy( |
| // Do not leak the qualified-access-only restriction to normal plan outputs. |
| projectList = p.projectList ++ p.metadataOutput.map(_.markAsAllowAnyAccess()), |
| child = addMetadataCol(p.child, requiredAttrIds)) |
| newProj.copyTagsFrom(p) |
| newProj |
| case _ => plan.withNewChildren(plan.children.map(addMetadataCol(_, requiredAttrIds))) |
| } |
| } |
| |
| /** |
| * Replaces unresolved relations (tables and views) with concrete relations from the catalog. |
| */ |
| object ResolveRelations extends Rule[LogicalPlan] { |
| // The current catalog and namespace may be different from when the view was created, we must |
| // resolve the view logical plan here, with the catalog and namespace stored in view metadata. |
| // This is done by keeping the catalog and namespace in `AnalysisContext`, and analyzer will |
| // look at `AnalysisContext.catalogAndNamespace` when resolving relations with single-part name. |
| // If `AnalysisContext.catalogAndNamespace` is non-empty, analyzer will expand single-part names |
| // with it, instead of current catalog and namespace. |
| private def resolveViews(plan: LogicalPlan): LogicalPlan = plan match { |
| // The view's child should be a logical plan parsed from the `desc.viewText`, the variable |
| // `viewText` should be defined, or else we throw an error on the generation of the View |
| // operator. |
| case view @ View(desc, isTempView, child) if !child.resolved => |
| // Resolve all the UnresolvedRelations and Views in the child. |
| val newChild = AnalysisContext.withAnalysisContext(desc) { |
| val nestedViewDepth = AnalysisContext.get.nestedViewDepth |
| val maxNestedViewDepth = AnalysisContext.get.maxNestedViewDepth |
| if (nestedViewDepth > maxNestedViewDepth) { |
| throw QueryCompilationErrors.viewDepthExceedsMaxResolutionDepthError( |
| desc.identifier, maxNestedViewDepth, view) |
| } |
| SQLConf.withExistingConf(View.effectiveSQLConf(desc.viewSQLConfigs, isTempView)) { |
| executeSameContext(child) |
| } |
| } |
| // Fail the analysis eagerly because outside AnalysisContext, the unresolved operators |
| // inside a view maybe resolved incorrectly. |
| checkAnalysis(newChild) |
| view.copy(child = newChild) |
| case p @ SubqueryAlias(_, view: View) => |
| p.copy(child = resolveViews(view)) |
| case _ => plan |
| } |
| |
| private def unwrapRelationPlan(plan: LogicalPlan): LogicalPlan = { |
| EliminateSubqueryAliases(plan) match { |
| case v: View if v.isTempViewStoringAnalyzedPlan => v.child |
| case other => other |
| } |
| } |
| |
| def apply(plan: LogicalPlan) |
| : LogicalPlan = plan.resolveOperatorsUpWithPruning(AlwaysProcess.fn, ruleId) { |
| case i @ InsertIntoStatement(table, _, _, _, _, _, _) => |
| val relation = table match { |
| case u: UnresolvedRelation if !u.isStreaming => |
| resolveRelation(u).getOrElse(u) |
| case other => other |
| } |
| |
| // Inserting into a file-based temporary view is allowed. |
| // (e.g., spark.read.parquet("path").createOrReplaceTempView("t"). |
| // Thus, we need to look at the raw plan if `relation` is a temporary view. |
| unwrapRelationPlan(relation) match { |
| case v: View => |
| throw QueryCompilationErrors.insertIntoViewNotAllowedError(v.desc.identifier, table) |
| case other => i.copy(table = other) |
| } |
| |
| // TODO (SPARK-27484): handle streaming write commands when we have them. |
| case write: V2WriteCommand => |
| write.table match { |
| case u: UnresolvedRelation if !u.isStreaming => |
| resolveRelation(u).map(unwrapRelationPlan).map { |
| case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError( |
| v.desc.identifier, write) |
| case r: DataSourceV2Relation => write.withNewTable(r) |
| case u: UnresolvedCatalogRelation => |
| throw QueryCompilationErrors.writeIntoV1TableNotAllowedError( |
| u.tableMeta.identifier, write) |
| case other => |
| throw QueryCompilationErrors.writeIntoTempViewNotAllowedError( |
| u.multipartIdentifier.quoted) |
| }.getOrElse(write) |
| case _ => write |
| } |
| |
| case u: UnresolvedRelation => |
| resolveRelation(u).map(resolveViews).getOrElse(u) |
| |
| case r @ RelationTimeTravel(u: UnresolvedRelation, timestamp, version) |
| if timestamp.forall(ts => ts.resolved && !SubqueryExpression.hasSubquery(ts)) => |
| val timeTravelSpec = TimeTravelSpec.create(timestamp, version, conf.sessionLocalTimeZone) |
| resolveRelation(u, timeTravelSpec).getOrElse(r) |
| |
| case u @ UnresolvedTable(identifier, cmd, suggestAlternative) => |
| lookupTableOrView(identifier).map { |
| case v: ResolvedPersistentView => |
| val nameParts = v.catalog.name() +: v.identifier.asMultipartIdentifier |
| throw QueryCompilationErrors.expectTableNotViewError( |
| nameParts, cmd, suggestAlternative, u) |
| case _: ResolvedTempView => |
| throw QueryCompilationErrors.expectTableNotViewError( |
| identifier, cmd, suggestAlternative, u) |
| case table => table |
| }.getOrElse(u) |
| |
| case u @ UnresolvedView(identifier, cmd, allowTemp, suggestAlternative) => |
| lookupTableOrView(identifier, viewOnly = true).map { |
| case _: ResolvedTempView if !allowTemp => |
| throw QueryCompilationErrors.expectPermanentViewNotTempViewError( |
| identifier, cmd, u) |
| case t: ResolvedTable => |
| val nameParts = t.catalog.name() +: t.identifier.asMultipartIdentifier |
| throw QueryCompilationErrors.expectViewNotTableError( |
| nameParts, cmd, suggestAlternative, u) |
| case other => other |
| }.getOrElse(u) |
| |
| case u @ UnresolvedTableOrView(identifier, cmd, allowTempView) => |
| lookupTableOrView(identifier).map { |
| case _: ResolvedTempView if !allowTempView => |
| throw QueryCompilationErrors.expectPermanentViewNotTempViewError( |
| identifier, cmd, u) |
| case other => other |
| }.getOrElse(u) |
| } |
| |
| private def lookupTempView(identifier: Seq[String]): Option[TemporaryViewRelation] = { |
| // We are resolving a view and this name is not a temp view when that view was created. We |
| // return None earlier here. |
| if (isResolvingView && !isReferredTempViewName(identifier)) return None |
| v1SessionCatalog.getRawLocalOrGlobalTempView(identifier) |
| } |
| |
| private def resolveTempView( |
| identifier: Seq[String], |
| isStreaming: Boolean = false, |
| isTimeTravel: Boolean = false): Option[LogicalPlan] = { |
| lookupTempView(identifier).map { v => |
| val tempViewPlan = v1SessionCatalog.getTempViewRelation(v) |
| if (isStreaming && !tempViewPlan.isStreaming) { |
| throw QueryCompilationErrors.readNonStreamingTempViewError(identifier.quoted) |
| } |
| if (isTimeTravel) { |
| throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(identifier)) |
| } |
| tempViewPlan |
| } |
| } |
| |
| /** |
| * Resolves relations to `ResolvedTable` or `Resolved[Temp/Persistent]View`. This is |
| * for resolving DDL and misc commands. |
| */ |
| private def lookupTableOrView( |
| identifier: Seq[String], |
| viewOnly: Boolean = false): Option[LogicalPlan] = { |
| lookupTempView(identifier).map { tempView => |
| ResolvedTempView(identifier.asIdentifier, tempView.tableMeta) |
| }.orElse { |
| expandIdentifier(identifier) match { |
| case CatalogAndIdentifier(catalog, ident) => |
| if (viewOnly && !CatalogV2Util.isSessionCatalog(catalog)) { |
| throw QueryCompilationErrors.catalogOperationNotSupported(catalog, "views") |
| } |
| CatalogV2Util.loadTable(catalog, ident).map { |
| case v1Table: V1Table if CatalogV2Util.isSessionCatalog(catalog) && |
| v1Table.v1Table.tableType == CatalogTableType.VIEW => |
| val v1Ident = v1Table.catalogTable.identifier |
| val v2Ident = Identifier.of(v1Ident.database.toArray, v1Ident.identifier) |
| ResolvedPersistentView( |
| catalog, v2Ident, v1Table.catalogTable) |
| case table => |
| ResolvedTable.create(catalog.asTableCatalog, ident, table) |
| } |
| case _ => None |
| } |
| } |
| } |
| |
| private def createRelation( |
| catalog: CatalogPlugin, |
| ident: Identifier, |
| table: Option[Table], |
| options: CaseInsensitiveStringMap, |
| isStreaming: Boolean): Option[LogicalPlan] = { |
| table.map { |
| case v1Table: V1Table if CatalogV2Util.isSessionCatalog(catalog) => |
| if (isStreaming) { |
| if (v1Table.v1Table.tableType == CatalogTableType.VIEW) { |
| throw QueryCompilationErrors.permanentViewNotSupportedByStreamingReadingAPIError( |
| ident.quoted) |
| } |
| SubqueryAlias( |
| catalog.name +: ident.asMultipartIdentifier, |
| UnresolvedCatalogRelation(v1Table.v1Table, options, isStreaming = true)) |
| } else { |
| v1SessionCatalog.getRelation(v1Table.v1Table, options) |
| } |
| |
| case table => |
| if (isStreaming) { |
| val v1Fallback = table match { |
| case withFallback: V2TableWithV1Fallback => |
| Some(UnresolvedCatalogRelation(withFallback.v1Table, isStreaming = true)) |
| case _ => None |
| } |
| SubqueryAlias( |
| catalog.name +: ident.asMultipartIdentifier, |
| StreamingRelationV2(None, table.name, table, options, table.columns.toAttributes, |
| Some(catalog), Some(ident), v1Fallback)) |
| } else { |
| SubqueryAlias( |
| catalog.name +: ident.asMultipartIdentifier, |
| DataSourceV2Relation.create(table, Some(catalog), Some(ident), options)) |
| } |
| } |
| } |
| |
| /** |
| * Resolves relations to v1 relation if it's a v1 table from the session catalog, or to v2 |
| * relation. This is for resolving DML commands and SELECT queries. |
| */ |
| private def resolveRelation( |
| u: UnresolvedRelation, |
| timeTravelSpec: Option[TimeTravelSpec] = None): Option[LogicalPlan] = { |
| val timeTravelSpecFromOptions = TimeTravelSpec.fromOptions( |
| u.options, |
| conf.getConf(SQLConf.TIME_TRAVEL_TIMESTAMP_KEY), |
| conf.getConf(SQLConf.TIME_TRAVEL_VERSION_KEY), |
| conf.sessionLocalTimeZone |
| ) |
| if (timeTravelSpec.nonEmpty && timeTravelSpecFromOptions.nonEmpty) { |
| throw new AnalysisException("MULTIPLE_TIME_TRAVEL_SPEC", Map.empty[String, String]) |
| } |
| val finalTimeTravelSpec = timeTravelSpec.orElse(timeTravelSpecFromOptions) |
| resolveTempView(u.multipartIdentifier, u.isStreaming, finalTimeTravelSpec.isDefined).orElse { |
| expandIdentifier(u.multipartIdentifier) match { |
| case CatalogAndIdentifier(catalog, ident) => |
| val key = |
| ((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq, |
| finalTimeTravelSpec) |
| AnalysisContext.get.relationCache.get(key).map { cache => |
| val cachedRelation = cache.transform { |
| case multi: MultiInstanceRelation => |
| val newRelation = multi.newInstance() |
| newRelation.copyTagsFrom(multi) |
| newRelation |
| } |
| u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => |
| val cachedConnectRelation = cachedRelation.clone() |
| cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) |
| cachedConnectRelation |
| }.getOrElse(cachedRelation) |
| }.orElse { |
| val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec) |
| val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) |
| loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) |
| u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => |
| loaded.map { loadedRelation => |
| val loadedConnectRelation = loadedRelation.clone() |
| loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) |
| loadedConnectRelation |
| } |
| }.getOrElse(loaded) |
| } |
| case _ => None |
| } |
| } |
| } |
| |
| /** Consumes an unresolved relation and resolves it to a v1 or v2 relation or temporary view. */ |
| def resolveRelationOrTempView(u: UnresolvedRelation): LogicalPlan = { |
| EliminateSubqueryAliases(resolveRelation(u).getOrElse(u)) |
| } |
| } |
| |
| /** Handle INSERT INTO for DSv2 */ |
| object ResolveInsertInto extends ResolveInsertionBase { |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( |
| AlwaysProcess.fn, ruleId) { |
| case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _, _) |
| if i.query.resolved => |
| // ifPartitionNotExists is append with validation, but validation is not supported |
| if (i.ifPartitionNotExists) { |
| throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name) |
| } |
| |
| // Create a project if this is an INSERT INTO BY NAME query. |
| val projectByName = if (i.userSpecifiedCols.nonEmpty) { |
| Some(createProjectForByNameQuery(r.table.name, i)) |
| } else { |
| None |
| } |
| val isByName = projectByName.nonEmpty || i.byName |
| |
| val partCols = partitionColumnNames(r.table) |
| validatePartitionSpec(partCols, i.partitionSpec) |
| |
| val staticPartitions = i.partitionSpec.filter(_._2.isDefined).transform((_, v) => v.get) |
| val query = addStaticPartitionColumns(r, projectByName.getOrElse(i.query), staticPartitions, |
| isByName) |
| |
| if (!i.overwrite) { |
| if (isByName) { |
| AppendData.byName(r, query) |
| } else { |
| AppendData.byPosition(r, query) |
| } |
| } else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) { |
| if (isByName) { |
| OverwritePartitionsDynamic.byName(r, query) |
| } else { |
| OverwritePartitionsDynamic.byPosition(r, query) |
| } |
| } else { |
| if (isByName) { |
| OverwriteByExpression.byName(r, query, staticDeleteExpression(r, staticPartitions)) |
| } else { |
| OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions)) |
| } |
| } |
| } |
| |
| private def partitionColumnNames(table: Table): Seq[String] = { |
| // get partition column names. in v2, partition columns are columns that are stored using an |
| // identity partition transform because the partition values and the column values are |
| // identical. otherwise, partition values are produced by transforming one or more source |
| // columns and cannot be set directly in a query's PARTITION clause. |
| table.partitioning.flatMap { |
| case IdentityTransform(FieldReference(Seq(name))) => Some(name) |
| case _ => None |
| }.toImmutableArraySeq |
| } |
| |
| private def validatePartitionSpec( |
| partitionColumnNames: Seq[String], |
| partitionSpec: Map[String, Option[String]]): Unit = { |
| // check that each partition name is a partition column. otherwise, it is not valid |
| partitionSpec.keySet.foreach { partitionName => |
| partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { |
| case Some(_) => |
| case None => |
| throw QueryCompilationErrors.nonPartitionColError(partitionName) |
| } |
| } |
| } |
| |
| private def addStaticPartitionColumns( |
| relation: DataSourceV2Relation, |
| query: LogicalPlan, |
| staticPartitions: Map[String, String], |
| isByName: Boolean): LogicalPlan = { |
| |
| if (staticPartitions.isEmpty) { |
| query |
| |
| } else { |
| // add any static value as a literal column |
| val withStaticPartitionValues = { |
| // for each static name, find the column name it will replace and check for unknowns. |
| val outputNameToStaticName = staticPartitions.keySet.map { staticName => |
| if (isByName) { |
| // If this is INSERT INTO BY NAME, the query output's names will be the user specified |
| // column names. We need to make sure the static partition column name doesn't appear |
| // there to catch the following ambiguous query: |
| // INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2') |
| if (query.output.exists(col => conf.resolver(col.name, staticName))) { |
| throw QueryCompilationErrors.staticPartitionInUserSpecifiedColumnsError(staticName) |
| } |
| } |
| relation.output.find(col => conf.resolver(col.name, staticName)) match { |
| case Some(attr) => |
| attr.name -> staticName |
| case _ => |
| throw QueryCompilationErrors.missingStaticPartitionColumn(staticName) |
| } |
| }.toMap |
| |
| val queryColumns = query.output.iterator |
| |
| // for each output column, add the static value as a literal, or use the next input |
| // column. this does not fail if input columns are exhausted and adds remaining columns |
| // at the end. both cases will be caught by ResolveOutputRelation and will fail the |
| // query with a helpful error message. |
| relation.output.flatMap { col => |
| outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match { |
| case Some(staticValue) => |
| // SPARK-30844: try our best to follow StoreAssignmentPolicy for static partition |
| // values but not completely follow because we can't do static type checking due to |
| // the reason that the parser has erased the type info of static partition values |
| // and converted them to string. |
| val cast = Cast(Literal(staticValue), col.dataType, ansiEnabled = true) |
| cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) |
| Some(Alias(cast, col.name)()) |
| case _ if queryColumns.hasNext => |
| Some(queryColumns.next()) |
| case _ => |
| None |
| } |
| } ++ queryColumns |
| } |
| |
| Project(withStaticPartitionValues, query) |
| } |
| } |
| |
| private def staticDeleteExpression( |
| relation: DataSourceV2Relation, |
| staticPartitions: Map[String, String]): Expression = { |
| if (staticPartitions.isEmpty) { |
| Literal(true) |
| } else { |
| staticPartitions.map { case (name, value) => |
| relation.output.find(col => conf.resolver(col.name, name)) match { |
| case Some(attr) => |
| // the delete expression must reference the table's column names, but these attributes |
| // are not available when CheckAnalysis runs because the relation is not a child of |
| // the logical operation. instead, expressions are resolved after |
| // ResolveOutputRelation runs, using the query's column names that will match the |
| // table names at that point. because resolution happens after a future rule, create |
| // an UnresolvedAttribute. |
| EqualNullSafe( |
| UnresolvedAttribute.quoted(attr.name), |
| Cast(Literal(value), attr.dataType)) |
| case None => |
| throw QueryCompilationErrors.missingStaticPartitionColumn(name) |
| } |
| }.reduce(And) |
| } |
| } |
| } |
| |
| /** |
| * Resolves column references in the query plan. Basically it transform the query plan tree bottom |
| * up, and only try to resolve references for a plan node if all its children nodes are resolved, |
| * and there is no conflicting attributes between the children nodes (see `hasConflictingAttrs` |
| * for details). |
| * |
| * The general workflow to resolve references: |
| * 1. Expands the star in Project/Aggregate/Generate. |
| * 2. Resolves the columns to [[AttributeReference]] with the output of the children plans. This |
| * includes metadata columns as well. |
| * 3. Resolves the columns to literal function which is allowed to be invoked without braces, |
| * e.g. `SELECT col, current_date FROM t`. |
| * 4. Resolves the columns to outer references with the outer plan if we are resolving subquery |
| * expressions. |
| * 5. Resolves the columns to SQL variables. |
| * |
| * Some plan nodes have special column reference resolution logic, please read these sub-rules for |
| * details: |
| * - [[ResolveReferencesInAggregate]] |
| * - [[ResolveReferencesInUpdate]] |
| * - [[ResolveReferencesInSort]] |
| * |
| * Note: even if we use a single rule to resolve columns, it's still non-trivial to have a |
| * reliable column resolution order, as the rule will be executed multiple times, with other |
| * rules in the same batch. We should resolve columns with the next option only if all the |
| * previous options are permanently not applicable. If the current option can be applicable |
| * in the next iteration (other rules update the plan), we should not try the next option. |
| */ |
| class ResolveReferences(val catalogManager: CatalogManager) |
| extends Rule[LogicalPlan] with ColumnResolutionHelper { |
| |
| private val resolveColumnDefaultInCommandInputQuery = |
| new ResolveColumnDefaultInCommandInputQuery(catalogManager) |
| private val resolveReferencesInAggregate = |
| new ResolveReferencesInAggregate(catalogManager) |
| private val resolveReferencesInUpdate = |
| new ResolveReferencesInUpdate(catalogManager) |
| private val resolveReferencesInSort = |
| new ResolveReferencesInSort(catalogManager) |
| |
| /** |
| * Return true if there're conflicting attributes among children's outputs of a plan |
| * |
| * The children logical plans may output columns with conflicting attribute IDs. This may happen |
| * in cases such as self-join. We should wait for the rule [[DeduplicateRelations]] to eliminate |
| * conflicting attribute IDs, otherwise we can't resolve columns correctly due to ambiguity. |
| */ |
| def hasConflictingAttrs(p: LogicalPlan): Boolean = { |
| p.children.length > 1 && { |
| // Note that duplicated attributes are allowed within a single node, |
| // e.g., df.select($"a", $"a"), so we should only check conflicting |
| // attributes between nodes. |
| val uniqueAttrs = mutable.HashSet[ExprId]() |
| p.children.head.outputSet.foreach(a => uniqueAttrs.add(a.exprId)) |
| p.children.tail.exists { child => |
| val uniqueSize = uniqueAttrs.size |
| val childSize = child.outputSet.size |
| child.outputSet.foreach(a => uniqueAttrs.add(a.exprId)) |
| uniqueSize + childSize > uniqueAttrs.size |
| } |
| } |
| } |
| |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { |
| // Don't wait other rules to resolve the child plans of `InsertIntoStatement` as we need |
| // to resolve column "DEFAULT" in the child plans so that they must be unresolved. |
| case i: InsertIntoStatement => resolveColumnDefaultInCommandInputQuery(i) |
| |
| // Don't wait other rules to resolve the child plans of `SetVariable` as we need |
| // to resolve column "DEFAULT" in the child plans so that they must be unresolved. |
| case s: SetVariable => resolveColumnDefaultInCommandInputQuery(s) |
| |
| // Wait for other rules to resolve child plans first |
| case p: LogicalPlan if !p.childrenResolved => p |
| |
| // Wait for the rule `DeduplicateRelations` to resolve conflicting attrs first. |
| case p: LogicalPlan if hasConflictingAttrs(p) => p |
| |
| // If the projection list contains Stars, expand it. |
| case p: Project if containsStar(p.projectList) => |
| p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) |
| // If the filter list contains Stars, expand it. |
| case p: Filter if containsStar(Seq(p.condition)) => |
| p.copy(expandStarExpression(p.condition, p.child)) |
| // If the aggregate function argument contains Stars, expand it. |
| case a: Aggregate if containsStar(a.aggregateExpressions) => |
| if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { |
| throw QueryCompilationErrors.starNotAllowedWhenGroupByOrdinalPositionUsedError() |
| } else { |
| a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) |
| } |
| case g: Generate if containsStar(g.generator.children) => |
| throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF", |
| extractStar(g.generator.children)) |
| // If the Unpivot ids or values contain Stars, expand them. |
| case up: Unpivot if up.ids.exists(containsStar) || |
| // Only expand Stars in one-dimensional values |
| up.values.exists(values => values.exists(_.length == 1) && values.exists(containsStar)) => |
| up.copy( |
| ids = up.ids.map(buildExpandedProjectList(_, up.child)), |
| // The inner exprs in Option[[exprs] is one-dimensional, e.g. Optional[[["*"]]]. |
| // The single NamedExpression turns into multiple, which we here have to turn into |
| // Optional[[["col1"], ["col2"]]] |
| values = up.values.map(_.flatMap(buildExpandedProjectList(_, up.child)).map(Seq(_))) |
| ) |
| |
| case u @ Union(children, _, _) |
| // if there are duplicate output columns, give them unique expr ids |
| if children.exists(c => c.output.map(_.exprId).distinct.length < c.output.length) => |
| val newChildren = children.map { c => |
| if (c.output.map(_.exprId).distinct.length < c.output.length) { |
| val existingExprIds = mutable.HashSet[ExprId]() |
| val projectList = c.output.map { attr => |
| if (existingExprIds.contains(attr.exprId)) { |
| // replace non-first duplicates with aliases and tag them |
| val newMetadata = new MetadataBuilder().withMetadata(attr.metadata) |
| .putNull("__is_duplicate").build() |
| Alias(attr, attr.name)(explicitMetadata = Some(newMetadata)) |
| } else { |
| // leave first duplicate alone |
| existingExprIds.add(attr.exprId) |
| attr |
| } |
| } |
| Project(projectList, c) |
| } else { |
| c |
| } |
| } |
| u.withNewChildren(newChildren) |
| |
| // A special case for Generate, because the output of Generate should not be resolved by |
| // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. |
| case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g |
| |
| case g @ Generate(generator, join, outer, qualifier, output, child) => |
| val newG = resolveExpressionByPlanOutput( |
| generator, child, throws = true, includeLastResort = true) |
| if (newG.fastEquals(generator)) { |
| g |
| } else { |
| Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) |
| } |
| |
| case mg: MapGroups if mg.dataOrder.exists(!_.resolved) => |
| // Resolve against `AppendColumns`'s children, instead of `AppendColumns`, |
| // because `AppendColumns`'s serializer might produce conflict attribute |
| // names leading to ambiguous references exception. |
| val planForResolve = mg.child match { |
| case appendColumns: AppendColumns => appendColumns.child |
| case plan => plan |
| } |
| val resolvedOrder = mg.dataOrder |
| .map(resolveExpressionByPlanOutput(_, planForResolve).asInstanceOf[SortOrder]) |
| mg.copy(dataOrder = resolvedOrder) |
| |
| // Left and right sort expression have to be resolved against the respective child plan only |
| case cg: CoGroup if cg.leftOrder.exists(!_.resolved) || cg.rightOrder.exists(!_.resolved) => |
| // Resolve against `AppendColumns`'s children, instead of `AppendColumns`, |
| // because `AppendColumns`'s serializer might produce conflict attribute |
| // names leading to ambiguous references exception. |
| val (leftPlanForResolve, rightPlanForResolve) = Seq(cg.left, cg.right).map { |
| case appendColumns: AppendColumns => appendColumns.child |
| case plan => plan |
| } match { |
| case Seq(left, right) => (left, right) |
| } |
| |
| val resolvedLeftOrder = cg.leftOrder |
| .map(resolveExpressionByPlanOutput(_, leftPlanForResolve).asInstanceOf[SortOrder]) |
| val resolvedRightOrder = cg.rightOrder |
| .map(resolveExpressionByPlanOutput(_, rightPlanForResolve).asInstanceOf[SortOrder]) |
| |
| cg.copy(leftOrder = resolvedLeftOrder, rightOrder = resolvedRightOrder) |
| |
| // Skips plan which contains deserializer expressions, as they should be resolved by another |
| // rule: ResolveDeserializer. |
| case plan if containsDeserializer(plan.expressions) => plan |
| |
| case a: Aggregate => resolveReferencesInAggregate(a) |
| |
| // Special case for Project as it supports lateral column alias. |
| case p: Project => |
| val resolvedBasic = p.projectList.map(resolveExpressionByPlanChildren(_, p)) |
| // Lateral column alias has higher priority than outer reference. |
| val resolvedWithLCA = resolveLateralColumnAlias(resolvedBasic) |
| val resolvedFinal = resolvedWithLCA.map(resolveColsLastResort) |
| p.copy(projectList = resolvedFinal.map(_.asInstanceOf[NamedExpression])) |
| |
| case o: OverwriteByExpression if o.table.resolved => |
| // The delete condition of `OverwriteByExpression` will be passed to the table |
| // implementation and should be resolved based on the table schema. |
| o.copy(deleteExpr = resolveExpressionByPlanOutput(o.deleteExpr, o.table)) |
| |
| case u: UpdateTable => resolveReferencesInUpdate(u) |
| |
| case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _) |
| if !m.resolved && targetTable.resolved && sourceTable.resolved => |
| |
| EliminateSubqueryAliases(targetTable) match { |
| case r: NamedRelation if r.skipSchemaResolution => |
| // Do not resolve the expression if the target table accepts any schema. |
| // This allows data sources to customize their own resolution logic using |
| // custom resolution rules. |
| m |
| |
| case _ => |
| val newMatchedActions = m.matchedActions.map { |
| case DeleteAction(deleteCondition) => |
| val resolvedDeleteCondition = deleteCondition.map( |
| resolveExpressionByPlanChildren(_, m)) |
| DeleteAction(resolvedDeleteCondition) |
| case UpdateAction(updateCondition, assignments) => |
| val resolvedUpdateCondition = updateCondition.map( |
| resolveExpressionByPlanChildren(_, m)) |
| UpdateAction( |
| resolvedUpdateCondition, |
| // The update value can access columns from both target and source tables. |
| resolveAssignments(assignments, m, MergeResolvePolicy.BOTH)) |
| case UpdateStarAction(updateCondition) => |
| val assignments = targetTable.output.map { attr => |
| Assignment(attr, UnresolvedAttribute(Seq(attr.name))) |
| } |
| UpdateAction( |
| updateCondition.map(resolveExpressionByPlanChildren(_, m)), |
| // For UPDATE *, the value must be from source table. |
| resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) |
| case o => o |
| } |
| val newNotMatchedActions = m.notMatchedActions.map { |
| case InsertAction(insertCondition, assignments) => |
| // The insert action is used when not matched, so its condition and value can only |
| // access columns from the source table. |
| val resolvedInsertCondition = insertCondition.map( |
| resolveExpressionByPlanOutput(_, m.sourceTable)) |
| InsertAction( |
| resolvedInsertCondition, |
| resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) |
| case InsertStarAction(insertCondition) => |
| // The insert action is used when not matched, so its condition and value can only |
| // access columns from the source table. |
| val resolvedInsertCondition = insertCondition.map( |
| resolveExpressionByPlanOutput(_, m.sourceTable)) |
| val assignments = targetTable.output.map { attr => |
| Assignment(attr, UnresolvedAttribute(Seq(attr.name))) |
| } |
| InsertAction( |
| resolvedInsertCondition, |
| resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE)) |
| case o => o |
| } |
| val newNotMatchedBySourceActions = m.notMatchedBySourceActions.map { |
| case DeleteAction(deleteCondition) => |
| val resolvedDeleteCondition = deleteCondition.map( |
| resolveExpressionByPlanOutput(_, targetTable)) |
| DeleteAction(resolvedDeleteCondition) |
| case UpdateAction(updateCondition, assignments) => |
| val resolvedUpdateCondition = updateCondition.map( |
| resolveExpressionByPlanOutput(_, targetTable)) |
| UpdateAction( |
| resolvedUpdateCondition, |
| // The update value can access columns from the target table only. |
| resolveAssignments(assignments, m, MergeResolvePolicy.TARGET)) |
| case o => o |
| } |
| |
| val resolvedMergeCondition = resolveExpressionByPlanChildren(m.mergeCondition, m) |
| m.copy(mergeCondition = resolvedMergeCondition, |
| matchedActions = newMatchedActions, |
| notMatchedActions = newNotMatchedActions, |
| notMatchedBySourceActions = newNotMatchedBySourceActions) |
| } |
| |
| // UnresolvedHaving can host grouping expressions and aggregate functions. We should resolve |
| // columns with `agg.output` and the rule `ResolveAggregateFunctions` will push them down to |
| // Aggregate later. |
| case u @ UnresolvedHaving(cond, agg: Aggregate) if !cond.resolved => |
| u.mapExpressions { e => |
| // Columns in HAVING should be resolved with `agg.child.output` first, to follow the SQL |
| // standard. See more details in SPARK-31519. |
| val resolvedWithAgg = resolveColWithAgg(e, agg) |
| resolveExpressionByPlanChildren(resolvedWithAgg, u, includeLastResort = true) |
| } |
| |
| // RepartitionByExpression can host missing attributes that are from a descendant node. |
| // For example, `spark.table("t").select($"a").repartition($"b")`. We can resolve `b` with |
| // table `t` even if there is a Project node between the table scan node and Sort node. |
| // We also need to propagate the missing attributes from the descendant node to the current |
| // node, and project them way at the end via an extra Project. |
| case r @ RepartitionByExpression(partitionExprs, child, _, _) |
| if !r.resolved || r.missingInput.nonEmpty => |
| val resolvedBasic = partitionExprs.map(resolveExpressionByPlanChildren(_, r)) |
| val (newPartitionExprs, newChild) = resolveExprsAndAddMissingAttrs(resolvedBasic, child) |
| // Missing columns should be resolved right after basic column resolution. |
| // See the doc of `ResolveReferences`. |
| val resolvedFinal = newPartitionExprs.map(resolveColsLastResort) |
| if (child.output == newChild.output) { |
| r.copy(resolvedFinal, newChild) |
| } else { |
| Project(child.output, r.copy(resolvedFinal, newChild)) |
| } |
| |
| // Filter can host both grouping expressions/aggregate functions and missing attributes. |
| // The grouping expressions/aggregate functions resolution takes precedence over missing |
| // attributes. See the classdoc of `ResolveReferences` for details. |
| case f @ Filter(cond, child) if !cond.resolved || f.missingInput.nonEmpty => |
| val resolvedBasic = resolveExpressionByPlanChildren(cond, f) |
| val resolvedWithAgg = resolveColWithAgg(resolvedBasic, child) |
| val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(resolvedWithAgg), child) |
| // Missing columns should be resolved right after basic column resolution. |
| // See the doc of `ResolveReferences`. |
| val resolvedFinal = resolveColsLastResort(newCond.head) |
| if (child.output == newChild.output) { |
| f.copy(condition = resolvedFinal) |
| } else { |
| // Add missing attributes and then project them away. |
| val newFilter = Filter(resolvedFinal, newChild) |
| Project(child.output, newFilter) |
| } |
| |
| case s: Sort if !s.resolved || s.missingInput.nonEmpty => |
| resolveReferencesInSort(s) |
| |
| case q: LogicalPlan => |
| logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}") |
| q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true)) |
| } |
| |
| private object MergeResolvePolicy extends Enumeration { |
| val BOTH, SOURCE, TARGET = Value |
| } |
| |
| def resolveAssignments( |
| assignments: Seq[Assignment], |
| mergeInto: MergeIntoTable, |
| resolvePolicy: MergeResolvePolicy.Value): Seq[Assignment] = { |
| assignments.map { assign => |
| val resolvedKey = assign.key match { |
| case c if !c.resolved => |
| resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable)) |
| case o => o |
| } |
| val resolvedValue = assign.value match { |
| case c if !c.resolved => |
| val resolvePlan = resolvePolicy match { |
| case MergeResolvePolicy.BOTH => mergeInto |
| case MergeResolvePolicy.SOURCE => Project(Nil, mergeInto.sourceTable) |
| case MergeResolvePolicy.TARGET => Project(Nil, mergeInto.targetTable) |
| } |
| val resolvedExpr = resolveExprInAssignment(c, resolvePlan) |
| val withDefaultResolved = if (conf.enableDefaultColumns) { |
| resolveColumnDefaultInAssignmentValue( |
| resolvedKey, |
| resolvedExpr, |
| QueryCompilationErrors |
| .defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates()) |
| } else { |
| resolvedExpr |
| } |
| checkResolvedMergeExpr(withDefaultResolved, resolvePlan) |
| withDefaultResolved |
| case o => o |
| } |
| Assignment(resolvedKey, resolvedValue) |
| } |
| } |
| |
| private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = { |
| val resolved = resolveExprInAssignment(e, p) |
| checkResolvedMergeExpr(resolved, p) |
| resolved |
| } |
| |
| private def checkResolvedMergeExpr(e: Expression, p: LogicalPlan): Unit = { |
| e.references.filter(!_.resolved).foreach { a => |
| // Note: This will throw error only on unresolved attribute issues, |
| // not other resolution errors like mismatched data types. |
| val cols = p.inputSet.toSeq.map(attr => toSQLId(attr.name)).mkString(", ") |
| a.failAnalysis( |
| errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", |
| messageParameters = Map( |
| "objectName" -> toSQLId(a.name), |
| "proposal" -> cols)) |
| } |
| } |
| |
| // Expand the star expression using the input plan first. If failed, try resolve |
| // the star expression using the outer query plan and wrap the resolved attributes |
| // in outer references. Otherwise throw the original exception. |
| private def expand(s: Star, plan: LogicalPlan): Seq[NamedExpression] = { |
| withPosition(s) { |
| try { |
| s.expand(plan, resolver) |
| } catch { |
| case e: AnalysisException => |
| AnalysisContext.get.outerPlan.map { |
| // Only Project and Aggregate can host star expressions. |
| case u @ (_: Project | _: Aggregate) => |
| Try(s.expand(u.children.head, resolver)) match { |
| case Success(expanded) => expanded.map(wrapOuterReference) |
| case Failure(_) => throw e |
| } |
| // Do not use the outer plan to resolve the star expression |
| // since the star usage is invalid. |
| case _ => throw e |
| }.getOrElse { throw e } |
| } |
| } |
| } |
| |
| /** |
| * Build a project list for Project/Aggregate and expand the star if possible |
| */ |
| private def buildExpandedProjectList( |
| exprs: Seq[NamedExpression], |
| child: LogicalPlan): Seq[NamedExpression] = { |
| exprs.flatMap { |
| // Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*") |
| case s: Star => expand(s, child) |
| // Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b |
| case UnresolvedAlias(s: Star, _) => expand(s, child) |
| case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil |
| case o => o :: Nil |
| }.map(_.asInstanceOf[NamedExpression]) |
| } |
| |
| /** |
| * Returns true if `exprs` contains a [[Star]]. |
| */ |
| def containsStar(exprs: Seq[Expression]): Boolean = |
| exprs.exists(_.collect { case _: Star => true }.nonEmpty) |
| |
| private def extractStar(exprs: Seq[Expression]): Seq[Star] = |
| exprs.flatMap(_.collect { case s: Star => s }) |
| |
| /** |
| * Expands the matching attribute.*'s in `child`'s output. |
| */ |
| def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { |
| expr.transformUp { |
| case f0: UnresolvedFunction if !f0.isDistinct && |
| f0.nameParts.map(_.toLowerCase(Locale.ROOT)) == Seq("count") && |
| f0.arguments == Seq(UnresolvedStar(None)) => |
| // Transform COUNT(*) into COUNT(1). |
| f0.copy(nameParts = Seq("count"), arguments = Seq(Literal(1))) |
| case f1: UnresolvedFunction if containsStar(f1.arguments) => |
| // SPECIAL CASE: We want to block count(tblName.*) because in spark, count(tblName.*) will |
| // be expanded while count(*) will be converted to count(1). They will produce different |
| // results and confuse users if there are any null values. For count(t1.*, t2.*), it is |
| // still allowed, since it's well-defined in spark. |
| if (!conf.allowStarWithSingleTableIdentifierInCount && |
| f1.nameParts == Seq("count") && |
| f1.arguments.length == 1) { |
| f1.arguments.foreach { |
| case u: UnresolvedStar if u.isQualifiedByTable(child, resolver) => |
| throw QueryCompilationErrors |
| .singleTableStarInCountNotAllowedError(u.target.get.mkString(".")) |
| case _ => // do nothing |
| } |
| } |
| f1.copy(arguments = f1.arguments.flatMap { |
| case s: Star => expand(s, child) |
| case o => o :: Nil |
| }) |
| case c: CreateNamedStruct if containsStar(c.valExprs) => |
| val newChildren = c.children.grouped(2).flatMap { |
| case Seq(k, s : Star) => CreateStruct(expand(s, child)).children |
| case kv => kv |
| } |
| c.copy(children = newChildren.toList ) |
| case c: CreateArray if containsStar(c.children) => |
| c.copy(children = c.children.flatMap { |
| case s: Star => expand(s, child) |
| case o => o :: Nil |
| }) |
| case p: Murmur3Hash if containsStar(p.children) => |
| p.copy(children = p.children.flatMap { |
| case s: Star => expand(s, child) |
| case o => o :: Nil |
| }) |
| case p: XxHash64 if containsStar(p.children) => |
| p.copy(children = p.children.flatMap { |
| case s: Star => expand(s, child) |
| case o => o :: Nil |
| }) |
| case p: In if containsStar(p.children) => |
| p.copy(list = p.list.flatMap { |
| case s: Star => expand(s, child) |
| case o => o :: Nil |
| }) |
| // count(*) has been replaced by count(1) |
| case o if containsStar(o.children) => |
| throw QueryCompilationErrors.invalidStarUsageError(s"expression `${o.prettyName}`", |
| extractStar(o.children)) |
| } |
| } |
| } |
| |
| private def containsDeserializer(exprs: Seq[Expression]): Boolean = { |
| exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) |
| } |
| |
| /** |
| * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by |
| * clauses. This rule is to convert ordinal positions to the corresponding expressions in the |
| * select list. This support is introduced in Spark 2.0. |
| * |
| * - When the sort references or group by expressions are not integer but foldable expressions, |
| * just ignore them. |
| * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position |
| * numbers too. |
| * |
| * Before the release of Spark 2.0, the literals in order/sort by and group by clauses |
| * have no effect on the results. |
| */ |
| object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(UNRESOLVED_ORDINAL), ruleId) { |
| case p if !p.childrenResolved => p |
| // Replace the index with the related attribute for ORDER BY, |
| // which is a 1-base position of the projection list. |
| case Sort(orders, global, child) |
| if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => |
| val newOrders = orders map { |
| case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => |
| if (index > 0 && index <= child.output.size) { |
| SortOrder(child.output(index - 1), direction, nullOrdering, Seq.empty) |
| } else { |
| throw QueryCompilationErrors.orderByPositionRangeError(index, child.output.size, s) |
| } |
| case o => o |
| } |
| Sort(newOrders, global, child) |
| |
| // Replace the index with the corresponding expression in aggregateExpressions. The index is |
| // a 1-base position of aggregateExpressions, which is output columns (select expression) |
| case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && |
| groups.exists(containUnresolvedOrdinal) => |
| val newGroups = groups.map(resolveGroupByExpressionOrdinal(_, aggs)) |
| Aggregate(newGroups, aggs, child) |
| } |
| |
| private def containUnresolvedOrdinal(e: Expression): Boolean = e match { |
| case _: UnresolvedOrdinal => true |
| case gs: BaseGroupingSets => gs.children.exists(containUnresolvedOrdinal) |
| case _ => false |
| } |
| |
| private def resolveGroupByExpressionOrdinal( |
| expr: Expression, |
| aggs: Seq[Expression]): Expression = expr match { |
| case ordinal @ UnresolvedOrdinal(index) => |
| withPosition(ordinal) { |
| if (index > 0 && index <= aggs.size) { |
| val ordinalExpr = aggs(index - 1) |
| if (ordinalExpr.exists(_.isInstanceOf[AggregateExpression])) { |
| throw QueryCompilationErrors.groupByPositionRefersToAggregateFunctionError( |
| index, ordinalExpr) |
| } else { |
| trimAliases(ordinalExpr) match { |
| // HACK ALERT: If the ordinal expression is also an integer literal, don't use it |
| // but still keep the ordinal literal. The reason is we may repeatedly |
| // analyze the plan. Using a different integer literal may lead to |
| // a repeat GROUP BY ordinal resolution which is wrong. GROUP BY |
| // constant is meaningless so whatever value does not matter here. |
| // TODO: (SPARK-45932) GROUP BY ordinal should pull out grouping expressions to |
| // a Project, then the resolved ordinal expression is always |
| // `AttributeReference`. |
| case Literal(_: Int, IntegerType) => |
| Literal(index) |
| case _ => ordinalExpr |
| } |
| } |
| } else { |
| throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size) |
| } |
| } |
| case gs: BaseGroupingSets => |
| gs.withNewChildren(gs.children.map(resolveGroupByExpressionOrdinal(_, aggs))) |
| case others => others |
| } |
| } |
| |
| |
| /** |
| * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the |
| * function registry. Note that this rule doesn't try to resolve the [[UnresolvedFunction]]. It |
| * only performs simple existence check according to the function identifier to quickly identify |
| * undefined functions without triggering relation resolution, which may incur potentially |
| * expensive partition/schema discovery process in some cases. |
| * In order to avoid duplicate external functions lookup, the external function identifier will |
| * store in the local hash set externalFunctionNameSet. |
| * @see [[ResolveFunctions]] |
| * @see https://issues.apache.org/jira/browse/SPARK-19737 |
| */ |
| object LookupFunctions extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = { |
| val externalFunctionNameSet = new mutable.HashSet[Seq[String]]() |
| |
| plan.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) { |
| case f @ UnresolvedFunction(nameParts, _, _, _, _, _) => |
| if (ResolveFunctions.lookupBuiltinOrTempFunction(nameParts).isDefined) { |
| f |
| } else { |
| val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts) |
| val fullName = |
| normalizeFuncName((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq) |
| if (externalFunctionNameSet.contains(fullName)) { |
| f |
| } else if (catalog.asFunctionCatalog.functionExists(ident)) { |
| externalFunctionNameSet.add(fullName) |
| f |
| } else { |
| val catalogPath = (catalog.name() +: catalogManager.currentNamespace).mkString(".") |
| throw QueryCompilationErrors.unresolvedRoutineError( |
| nameParts, |
| Seq("system.builtin", "system.session", catalogPath), |
| f.origin) |
| } |
| } |
| } |
| } |
| |
| def normalizeFuncName(name: Seq[String]): Seq[String] = { |
| if (conf.caseSensitiveAnalysis) { |
| name |
| } else { |
| name.map(_.toLowerCase(Locale.ROOT)) |
| } |
| } |
| } |
| |
| /** |
| * Replaces [[UnresolvedFunctionName]]s with concrete [[LogicalPlan]]s. |
| * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. |
| * Replaces [[UnresolvedGenerator]]s with concrete [[Expression]]s. |
| * Replaces [[UnresolvedTableValuedFunction]]s with concrete [[LogicalPlan]]s. |
| */ |
| object ResolveFunctions extends Rule[LogicalPlan] { |
| val trimWarningEnabled = new AtomicBoolean(true) |
| |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsAnyPattern(UNRESOLVED_FUNC, UNRESOLVED_FUNCTION, GENERATOR, |
| UNRESOLVED_TABLE_VALUED_FUNCTION, UNRESOLVED_TVF_ALIASES), ruleId) { |
| // Resolve functions with concrete relations from v2 catalog. |
| case u @ UnresolvedFunctionName(nameParts, cmd, requirePersistentFunc, mismatchHint, _) => |
| lookupBuiltinOrTempFunction(nameParts) |
| .orElse(lookupBuiltinOrTempTableFunction(nameParts)).map { info => |
| if (requirePersistentFunc) { |
| throw QueryCompilationErrors.expectPersistentFuncError( |
| nameParts.head, cmd, mismatchHint, u) |
| } else { |
| ResolvedNonPersistentFunc(nameParts.head, V1Function(info)) |
| } |
| }.getOrElse { |
| val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts) |
| val fullName = catalog.name +: ident.namespace :+ ident.name |
| CatalogV2Util.loadFunction(catalog, ident).map { func => |
| ResolvedPersistentFunc(catalog.asFunctionCatalog, ident, func) |
| }.getOrElse(u.copy(possibleQualifiedName = Some(fullName.toImmutableArraySeq))) |
| } |
| |
| // Resolve table-valued function references. |
| case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => |
| withPosition(u) { |
| try { |
| val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { |
| val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.name) |
| if (CatalogV2Util.isSessionCatalog(catalog)) { |
| v1SessionCatalog.resolvePersistentTableFunction( |
| ident.asFunctionIdentifier, u.functionArgs) |
| } else { |
| throw QueryCompilationErrors.missingCatalogAbilityError( |
| catalog, "table-valued functions") |
| } |
| } |
| resolvedFunc.transformAllExpressionsWithPruning( |
| _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) { |
| case t: FunctionTableSubqueryArgumentExpression => |
| resolvedFunc match { |
| case Generate(_: PythonUDTF, _, _, _, _, _) => |
| case Generate(_: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) => |
| case _ => |
| assert(!t.hasRepartitioning, |
| "Cannot evaluate the table-valued function call because it included the " + |
| "PARTITION BY clause, but only Python table functions support this " + |
| "clause") |
| } |
| t |
| } |
| } catch { |
| case _: NoSuchFunctionException => |
| u.failAnalysis( |
| errorClass = "UNRESOLVABLE_TABLE_VALUED_FUNCTION", |
| messageParameters = Map("name" -> toSQLId(u.name))) |
| } |
| } |
| |
| // Resolve table-valued functions' output column aliases. |
| case u: UnresolvedTVFAliases if u.child.resolved => |
| // Add `Project` with the aliases. |
| val outputAttrs = u.child.output |
| // Checks if the number of the aliases is equal to expected one |
| if (u.outputNames.size != outputAttrs.size) { |
| u.failAnalysis( |
| errorClass = "NUM_TABLE_VALUE_ALIASES_MISMATCH", |
| messageParameters = Map( |
| "funcName" -> toSQLId(u.name), |
| "aliasesNum" -> u.outputNames.size.toString, |
| "outColsNum" -> outputAttrs.size.toString)) |
| } |
| val aliases = outputAttrs.zip(u.outputNames).map { |
| case (attr, name) => Alias(attr, name)() |
| } |
| Project(aliases, u.child) |
| |
| case p: LogicalPlan |
| if p.resolved && p.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION) => |
| withPosition(p) { |
| val tableArgs = |
| mutable.ArrayBuffer.empty[(FunctionTableSubqueryArgumentExpression, LogicalPlan)] |
| |
| val tvf = p.transformExpressionsWithPruning( |
| _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) { |
| case t: FunctionTableSubqueryArgumentExpression => |
| val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") |
| tableArgs.append((t, SubqueryAlias(alias, t.evaluable))) |
| UnresolvedAttribute(Seq(alias, "c")) |
| } |
| |
| assert(tableArgs.nonEmpty) |
| if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) { |
| throw QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError( |
| tableArgs.size) |
| } |
| val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") |
| |
| // Propagate the column indexes for TABLE arguments to the PythonUDTF instance. |
| val tvfWithTableColumnIndexes = tvf match { |
| case g @ Generate(pyudtf: PythonUDTF, _, _, _, _, _) |
| if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty => |
| val partitionColumnIndexes = |
| PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes) |
| g.copy(generator = pyudtf.copy( |
| pythonUDTFPartitionColumnIndexes = Some(partitionColumnIndexes))) |
| case _ => tvf |
| } |
| |
| Project( |
| Seq(UnresolvedStar(Some(Seq(alias)))), |
| LateralJoin( |
| tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)), |
| LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None) |
| ) |
| } |
| |
| case q: LogicalPlan => |
| q.transformExpressionsUpWithPruning( |
| _.containsAnyPattern(UNRESOLVED_FUNCTION, GENERATOR), |
| ruleId) { |
| case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _) |
| if hasLambdaAndResolvedArguments(arguments) => withPosition(u) { |
| resolveBuiltinOrTempFunction(nameParts, arguments, Some(u)).map { |
| case func: HigherOrderFunction => func |
| case other => other.failAnalysis( |
| errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", |
| messageParameters = Map( |
| "class" -> other.getClass.getCanonicalName)) |
| }.getOrElse { |
| throw QueryCompilationErrors.unresolvedRoutineError( |
| nameParts, |
| // We don't support persistent high-order functions yet. |
| Seq("system.builtin", "system.session"), |
| u.origin) |
| } |
| } |
| |
| case u if !u.childrenResolved => u // Skip until children are resolved. |
| |
| case u @ UnresolvedGenerator(name, arguments) => withPosition(u) { |
| // For generator function, the parser only accepts v1 function name and creates |
| // `FunctionIdentifier`. |
| v1SessionCatalog.lookupFunction(name, arguments) match { |
| case generator: Generator => generator |
| case other => throw QueryCompilationErrors.generatorNotExpectedError( |
| name, other.getClass.getCanonicalName) |
| } |
| } |
| |
| case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _) => withPosition(u) { |
| resolveBuiltinOrTempFunction(nameParts, arguments, Some(u)).getOrElse { |
| val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts) |
| if (CatalogV2Util.isSessionCatalog(catalog)) { |
| resolveV1Function(ident.asFunctionIdentifier, arguments, u) |
| } else { |
| resolveV2Function(catalog.asFunctionCatalog, ident, arguments, u) |
| } |
| } |
| } |
| |
| case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) { |
| // Check if this is a call to a Python user-defined table function whose polymorphic |
| // 'analyze' method returned metadata indicated requested partitioning and/or |
| // ordering properties of the input relation. In that event, make sure that the UDTF |
| // call did not include any explicit PARTITION BY and/or ORDER BY clauses for the |
| // corresponding TABLE argument, and then update the TABLE argument representation |
| // to apply the requested partitioning and/or ordering. |
| val analyzeResult = u.resolveElementMetadata(u.func, u.children) |
| val newChildren = u.children.map { |
| case NamedArgumentExpression(key, t: FunctionTableSubqueryArgumentExpression) => |
| NamedArgumentExpression(key, analyzeResult.applyToTableArgument(u.name, t)) |
| case t: FunctionTableSubqueryArgumentExpression => |
| analyzeResult.applyToTableArgument(u.name, t) |
| case c => c |
| } |
| PythonUDTF( |
| u.name, u.func, analyzeResult.schema, Some(analyzeResult.pickledAnalyzeResult), |
| newChildren, u.evalType, u.udfDeterministic, u.resultId) |
| } |
| } |
| } |
| |
| /** |
| * Check if the arguments of a function are either resolved or a lambda function. |
| */ |
| private def hasLambdaAndResolvedArguments(expressions: Seq[Expression]): Boolean = { |
| val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction]) |
| lambdas.nonEmpty && others.forall(_.resolved) |
| } |
| |
| def lookupBuiltinOrTempFunction(name: Seq[String]): Option[ExpressionInfo] = { |
| if (name.length == 1) { |
| v1SessionCatalog.lookupBuiltinOrTempFunction(name.head) |
| } else { |
| None |
| } |
| } |
| |
| def lookupBuiltinOrTempTableFunction(name: Seq[String]): Option[ExpressionInfo] = { |
| if (name.length == 1) { |
| v1SessionCatalog.lookupBuiltinOrTempTableFunction(name.head) |
| } else { |
| None |
| } |
| } |
| |
| private def resolveBuiltinOrTempFunction( |
| name: Seq[String], |
| arguments: Seq[Expression], |
| u: Option[UnresolvedFunction]): Option[Expression] = { |
| if (name.length == 1) { |
| v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments).map { func => |
| if (u.isDefined) validateFunction(func, arguments.length, u.get) else func |
| } |
| } else { |
| None |
| } |
| } |
| |
| private def resolveBuiltinOrTempTableFunction( |
| name: Seq[String], |
| arguments: Seq[Expression]): Option[LogicalPlan] = { |
| if (name.length == 1) { |
| v1SessionCatalog.resolveBuiltinOrTempTableFunction(name.head, arguments) |
| } else { |
| None |
| } |
| } |
| |
| private def resolveV1Function( |
| ident: FunctionIdentifier, |
| arguments: Seq[Expression], |
| u: UnresolvedFunction): Expression = { |
| val func = v1SessionCatalog.resolvePersistentFunction(ident, arguments) |
| validateFunction(func, arguments.length, u) |
| } |
| |
| private def validateFunction( |
| func: Expression, |
| numArgs: Int, |
| u: UnresolvedFunction): Expression = { |
| func match { |
| case owg: SupportsOrderingWithinGroup if u.isDistinct => |
| throw QueryCompilationErrors.distinctInverseDistributionFunctionUnsupportedError( |
| owg.prettyName) |
| case owg: SupportsOrderingWithinGroup |
| if !owg.orderingFilled && u.orderingWithinGroup.isEmpty => |
| throw QueryCompilationErrors.inverseDistributionFunctionMissingWithinGroupError( |
| owg.prettyName) |
| case owg: SupportsOrderingWithinGroup |
| if owg.orderingFilled && u.orderingWithinGroup.nonEmpty => |
| throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError( |
| owg.prettyName, 0, u.orderingWithinGroup.length) |
| case f |
| if !f.isInstanceOf[SupportsOrderingWithinGroup] && u.orderingWithinGroup.nonEmpty => |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| func.prettyName, "WITHIN GROUP (ORDER BY ...)") |
| // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within |
| // the context of a Window clause. They do not need to be wrapped in an |
| // AggregateExpression. |
| case wf: AggregateWindowFunction => |
| if (u.isDistinct) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| wf.prettyName, "DISTINCT") |
| } else if (u.filter.isDefined) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| wf.prettyName, "FILTER clause") |
| } else if (u.ignoreNulls) { |
| wf match { |
| case nthValue: NthValue => |
| nthValue.copy(ignoreNulls = u.ignoreNulls) |
| case _ => |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| wf.prettyName, "IGNORE NULLS") |
| } |
| } else { |
| wf |
| } |
| case owf: FrameLessOffsetWindowFunction => |
| if (u.isDistinct) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| owf.prettyName, "DISTINCT") |
| } else if (u.filter.isDefined) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| owf.prettyName, "FILTER clause") |
| } else if (u.ignoreNulls) { |
| owf match { |
| case lead: Lead => |
| lead.copy(ignoreNulls = u.ignoreNulls) |
| case lag: Lag => |
| lag.copy(ignoreNulls = u.ignoreNulls) |
| } |
| } else { |
| owf |
| } |
| // We get an aggregate function, we need to wrap it in an AggregateExpression. |
| case agg: AggregateFunction => |
| // Note: PythonUDAF does not support these advanced clauses. |
| if (agg.isInstanceOf[PythonUDAF]) checkUnsupportedAggregateClause(agg, u) |
| // After parse, the inverse distribution functions not set the ordering within group yet. |
| val newAgg = agg match { |
| case owg: SupportsOrderingWithinGroup |
| if !owg.orderingFilled && u.orderingWithinGroup.nonEmpty => |
| owg.withOrderingWithinGroup(u.orderingWithinGroup) |
| case _ => |
| agg |
| } |
| |
| u.filter match { |
| case Some(filter) if !filter.deterministic => |
| throw QueryCompilationErrors.nonDeterministicFilterInAggregateError( |
| filterExpr = filter) |
| case Some(filter) if filter.dataType != BooleanType => |
| throw QueryCompilationErrors.nonBooleanFilterInAggregateError( |
| filterExpr = filter) |
| case Some(filter) if filter.exists(_.isInstanceOf[AggregateExpression]) => |
| throw QueryCompilationErrors.aggregateInAggregateFilterError( |
| filterExpr = filter, |
| aggExpr = filter.find(_.isInstanceOf[AggregateExpression]).get) |
| case Some(filter) if filter.exists(_.isInstanceOf[WindowExpression]) => |
| throw QueryCompilationErrors.windowFunctionInAggregateFilterError( |
| filterExpr = filter, |
| windowExpr = filter.find(_.isInstanceOf[WindowExpression]).get) |
| case _ => |
| } |
| if (u.ignoreNulls) { |
| val aggFunc = newAgg match { |
| case first: First => first.copy(ignoreNulls = u.ignoreNulls) |
| case last: Last => last.copy(ignoreNulls = u.ignoreNulls) |
| case any_value: AnyValue => any_value.copy(ignoreNulls = u.ignoreNulls) |
| case _ => |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| newAgg.prettyName, "IGNORE NULLS") |
| } |
| aggFunc.toAggregateExpression(u.isDistinct, u.filter) |
| } else { |
| newAgg.toAggregateExpression(u.isDistinct, u.filter) |
| } |
| // This function is not an aggregate function, just return the resolved one. |
| case other => |
| checkUnsupportedAggregateClause(other, u) |
| if (other.isInstanceOf[String2TrimExpression] && numArgs == 2) { |
| if (trimWarningEnabled.get) { |
| log.warn("Two-parameter TRIM/LTRIM/RTRIM function signatures are deprecated." + |
| " Use SQL syntax `TRIM((BOTH | LEADING | TRAILING)? trimStr FROM str)`" + |
| " instead.") |
| trimWarningEnabled.set(false) |
| } |
| } |
| other |
| } |
| } |
| |
| private def checkUnsupportedAggregateClause(func: Expression, u: UnresolvedFunction): Unit = { |
| if (u.isDistinct) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| func.prettyName, "DISTINCT") |
| } |
| if (u.filter.isDefined) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| func.prettyName, "FILTER clause") |
| } |
| if (u.ignoreNulls) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| func.prettyName, "IGNORE NULLS") |
| } |
| } |
| |
| private def resolveV2Function( |
| catalog: FunctionCatalog, |
| ident: Identifier, |
| arguments: Seq[Expression], |
| u: UnresolvedFunction): Expression = { |
| val unbound = catalog.loadFunction(ident) |
| val inputType = StructType(arguments.zipWithIndex.map { |
| case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) |
| }) |
| val bound = try { |
| unbound.bind(inputType) |
| } catch { |
| case unsupported: UnsupportedOperationException => |
| throw QueryCompilationErrors.functionCannotProcessInputError( |
| unbound, arguments, unsupported) |
| } |
| |
| if (bound.inputTypes().length != arguments.length) { |
| throw QueryCompilationErrors.v2FunctionInvalidInputTypeLengthError( |
| bound, arguments) |
| } |
| |
| bound match { |
| case scalarFunc: ScalarFunction[_] => |
| processV2ScalarFunction(scalarFunc, arguments, u) |
| case aggFunc: V2AggregateFunction[_, _] => |
| processV2AggregateFunction(aggFunc, arguments, u) |
| case _ => |
| failAnalysis( |
| errorClass = "INVALID_UDF_IMPLEMENTATION", |
| messageParameters = Map("funcName" -> toSQLId(bound.name()))) |
| } |
| } |
| |
| private def processV2ScalarFunction( |
| scalarFunc: ScalarFunction[_], |
| arguments: Seq[Expression], |
| u: UnresolvedFunction): Expression = { |
| if (u.isDistinct) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| scalarFunc.name(), "DISTINCT") |
| } else if (u.filter.isDefined) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| scalarFunc.name(), "FILTER clause") |
| } else if (u.ignoreNulls) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| scalarFunc.name(), "IGNORE NULLS") |
| } else { |
| V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) |
| } |
| } |
| |
| private def processV2AggregateFunction( |
| aggFunc: V2AggregateFunction[_, _], |
| arguments: Seq[Expression], |
| u: UnresolvedFunction): Expression = { |
| if (u.ignoreNulls) { |
| throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( |
| aggFunc.name(), "IGNORE NULLS") |
| } |
| val aggregator = V2Aggregator(aggFunc, arguments) |
| aggregator.toAggregateExpression(u.isDistinct, u.filter) |
| } |
| } |
| |
| /** |
| * This rule resolves and rewrites subqueries inside expressions. |
| * |
| * Note: CTEs are handled in CTESubstitution. |
| */ |
| object ResolveSubquery extends Rule[LogicalPlan] { |
| /** |
| * Resolves the subquery plan that is referenced in a subquery expression, by invoking the |
| * entire analyzer recursively. We set outer plan in `AnalysisContext`, so that the analyzer |
| * can resolve outer references. |
| * |
| * Outer references of the subquery are updated as children of Subquery expression. |
| */ |
| private def resolveSubQuery( |
| e: SubqueryExpression, |
| outer: LogicalPlan)( |
| f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { |
| val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) { |
| executeSameContext(e.plan) |
| } |
| |
| // If the subquery plan is fully resolved, pull the outer references and record |
| // them as children of SubqueryExpression. |
| if (newSubqueryPlan.resolved) { |
| // Record the outer references as children of subquery expression. |
| f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan)) |
| } else { |
| e.withNewPlan(newSubqueryPlan) |
| } |
| } |
| |
| /** |
| * Resolves the subquery. Apart of resolving the subquery and outer references (if any) |
| * in the subquery plan, the children of subquery expression are updated to record the |
| * outer references. This is needed to make sure |
| * (1) The column(s) referred from the outer query are not pruned from the plan during |
| * optimization. |
| * (2) Any aggregate expression(s) that reference outer attributes are pushed down to |
| * outer plan to get evaluated. |
| */ |
| private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { |
| plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { |
| case s @ ScalarSubquery(sub, _, exprId, _, _, _) if !sub.resolved => |
| resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId)) |
| case e @ Exists(sub, _, exprId, _, _) if !sub.resolved => |
| resolveSubQuery(e, outer)(Exists(_, _, exprId)) |
| case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _)) |
| if values.forall(_.resolved) && !l.resolved => |
| val expr = resolveSubQuery(l, outer)((plan, exprs) => { |
| ListQuery(plan, exprs, exprId, plan.output.length) |
| }) |
| InSubquery(values, expr.asInstanceOf[ListQuery]) |
| case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved => |
| resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId)) |
| case a: FunctionTableSubqueryArgumentExpression if !a.plan.resolved => |
| resolveSubQuery(a, outer)( |
| (plan, outerAttrs) => a.copy(plan = plan, outerAttrs = outerAttrs)) |
| } |
| } |
| |
| /** |
| * Resolve and rewrite all subqueries in an operator tree.. |
| */ |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(PLAN_EXPRESSION), ruleId) { |
| case j: LateralJoin if j.left.resolved => |
| // We can't pass `LateralJoin` as the outer plan, as its right child is not resolved yet |
| // and we can't call `LateralJoin.resolveChildren` to resolve outer references. Here we |
| // create a fake Project node as the outer plan. |
| resolveSubQueries(j, Project(Nil, j.left)) |
| // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. |
| case q: UnaryNode if q.childrenResolved => |
| resolveSubQueries(q, q) |
| case r: RelationTimeTravel => |
| resolveSubQueries(r, r) |
| case j: Join if j.childrenResolved && j.duplicateResolved => |
| resolveSubQueries(j, j) |
| case tvf: UnresolvedTableValuedFunction => |
| resolveSubQueries(tvf, tvf) |
| case s: SupportsSubquery if s.childrenResolved => |
| resolveSubQueries(s, s) |
| } |
| } |
| |
| /** |
| * Replaces unresolved column aliases for a subquery with projections. |
| */ |
| object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { |
| |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(UNRESOLVED_SUBQUERY_COLUMN_ALIAS), ruleId) { |
| case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => |
| // Resolves output attributes if a query has alias names in its subquery: |
| // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) |
| val outputAttrs = child.output |
| // Checks if the number of the aliases equals to the number of output columns |
| // in the subquery. |
| if (columnNames.size != outputAttrs.size) { |
| throw QueryCompilationErrors.aliasNumberNotMatchColumnNumberError( |
| columnNames.size, outputAttrs.size, u) |
| } |
| val aliases = outputAttrs.zip(columnNames).map { case (attr, aliasName) => |
| Alias(attr, aliasName)() |
| } |
| Project(aliases, child) |
| } |
| } |
| |
| /** |
| * Turns projections that contain aggregate expressions into aggregations. |
| */ |
| object GlobalAggregates extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| t => t.containsAnyPattern(AGGREGATE_EXPRESSION, PYTHON_UDF) && t.containsPattern(PROJECT), |
| ruleId) { |
| case Project(projectList, child) if containsAggregates(projectList) => |
| Aggregate(Nil, projectList, child) |
| } |
| |
| def containsAggregates(exprs: Seq[Expression]): Boolean = { |
| // Collect all Windowed Aggregate Expressions. |
| val windowedAggExprs: Set[Expression] = exprs.flatMap { expr => |
| expr.collect { |
| case WindowExpression(ae: AggregateExpression, _) => ae |
| case UnresolvedWindowExpression(ae: AggregateExpression, _) => ae |
| } |
| }.toSet |
| |
| // Find the first Aggregate Expression that is not Windowed. |
| exprs.exists(_.exists { |
| case ae: AggregateExpression => !windowedAggExprs.contains(ae) |
| case _ => false |
| }) |
| } |
| } |
| |
| /** |
| * This rule finds aggregate expressions that are not in an aggregate operator. For example, |
| * those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the |
| * underlying aggregate operator and then projected away after the original operator. |
| * |
| * We need to make sure the expressions all fully resolved before looking for aggregate functions |
| * and group by expressions from them. |
| */ |
| object ResolveAggregateFunctions extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(AGGREGATE), ruleId) { |
| case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved && cond.resolved => |
| resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => { |
| val newCond = newExprs.head |
| if (newCond.resolved) { |
| Filter(newCond, newChild) |
| } else { |
| // The condition can be unresolved after the resolution, as we may mark |
| // `TempResolvedColumn` as unresolved if it's not aggregate function inputs or grouping |
| // expressions. We should remain `UnresolvedHaving` as the rule `ResolveReferences` can |
| // re-resolve `TempResolvedColumn` and `UnresolvedHaving` has a special column |
| // resolution order. |
| UnresolvedHaving(newCond, newChild) |
| } |
| }) |
| |
| case Filter(cond, agg: Aggregate) if agg.resolved && cond.resolved => |
| resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => { |
| Filter(newExprs.head, newChild) |
| }) |
| |
| case s @ Sort(_, _, agg: Aggregate) if agg.resolved && s.order.forall(_.resolved) => |
| resolveOperatorWithAggregate(s.order.map(_.child), agg, (newExprs, newChild) => { |
| val newSortOrder = s.order.zip(newExprs).map { |
| case (sortOrder, expr) => sortOrder.copy(child = expr) |
| } |
| s.copy(order = newSortOrder, child = newChild) |
| }) |
| |
| case s @ Sort(_, _, f @ Filter(cond, agg: Aggregate)) |
| if agg.resolved && cond.resolved && s.order.forall(_.resolved) => |
| resolveOperatorWithAggregate(s.order.map(_.child), agg, (newExprs, newChild) => { |
| val newSortOrder = s.order.zip(newExprs).map { |
| case (sortOrder, expr) => sortOrder.copy(child = expr) |
| } |
| s.copy(order = newSortOrder, child = f.copy(child = newChild)) |
| }) |
| } |
| |
| /** |
| * Resolves the given expressions as if they are in the given Aggregate operator, which means |
| * the column can be resolved using `agg.child` and aggregate functions/grouping columns are |
| * allowed. It returns a list of named expressions that need to be appended to |
| * `agg.aggregateExpressions`, and the list of resolved expressions. |
| */ |
| def resolveExprsWithAggregate( |
| exprs: Seq[Expression], |
| agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = { |
| val extraAggExprs = ArrayBuffer.empty[NamedExpression] |
| val transformed = exprs.map { e => |
| if (!e.resolved) { |
| e |
| } else { |
| buildAggExprList(e, agg, extraAggExprs) |
| } |
| } |
| (extraAggExprs.toSeq, transformed) |
| } |
| |
| private def buildAggExprList( |
| expr: Expression, |
| agg: Aggregate, |
| aggExprList: ArrayBuffer[NamedExpression]): Expression = { |
| // Avoid adding an extra aggregate expression if it's already present in |
| // `agg.aggregateExpressions`. |
| val index = agg.aggregateExpressions.indexWhere { |
| case Alias(child, _) => child semanticEquals expr |
| case other => other semanticEquals expr |
| } |
| if (index >= 0) { |
| agg.aggregateExpressions(index).toAttribute |
| } else { |
| expr match { |
| case ae: AggregateExpression => |
| val cleaned = trimTempResolvedColumn(ae) |
| val alias = Alias(cleaned, cleaned.toString)() |
| aggExprList += alias |
| alias.toAttribute |
| case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) => |
| trimTempResolvedColumn(grouping) match { |
| case ne: NamedExpression => |
| aggExprList += ne |
| ne.toAttribute |
| case other => |
| val alias = Alias(other, other.toString)() |
| aggExprList += alias |
| alias.toAttribute |
| } |
| case t: TempResolvedColumn => |
| if (t.child.isInstanceOf[Attribute]) { |
| // This column is neither inside aggregate functions nor a grouping column. It |
| // shouldn't be resolved with `agg.child.output`. Mark it as "hasTried", so that it |
| // can be re-resolved later or go back to `UnresolvedAttribute` at the end. |
| withOrigin(t.origin)(t.copy(hasTried = true)) |
| } else { |
| // This is a nested column, we still have a chance to match grouping expressions with |
| // the top-level column. Here we wrap the underlying `Attribute` with |
| // `TempResolvedColumn` and try again. |
| val childWithTempCol = t.child.transformUp { |
| case a: Attribute => TempResolvedColumn(a, Seq(a.name)) |
| } |
| val newChild = buildAggExprList(childWithTempCol, agg, aggExprList) |
| if (newChild.containsPattern(TEMP_RESOLVED_COLUMN)) { |
| withOrigin(t.origin)(t.copy(hasTried = true)) |
| } else { |
| newChild |
| } |
| } |
| case other => |
| other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList))) |
| } |
| } |
| } |
| |
| private def trimTempResolvedColumn(input: Expression): Expression = input.transform { |
| case t: TempResolvedColumn => t.child |
| } |
| |
| def resolveOperatorWithAggregate( |
| exprs: Seq[Expression], |
| agg: Aggregate, |
| buildOperator: (Seq[Expression], Aggregate) => LogicalPlan): LogicalPlan = { |
| val (extraAggExprs, resolvedExprs) = resolveExprsWithAggregate(exprs, agg) |
| if (extraAggExprs.isEmpty) { |
| buildOperator(resolvedExprs, agg) |
| } else { |
| Project(agg.output, buildOperator(resolvedExprs, agg.copy( |
| aggregateExpressions = agg.aggregateExpressions ++ extraAggExprs))) |
| } |
| } |
| } |
| |
| /** |
| * Extracts [[Generator]] from the projectList of a [[Project]] operator and creates [[Generate]] |
| * operator under [[Project]]. |
| * |
| * This rule will throw [[AnalysisException]] for following cases: |
| * 1. [[Generator]] is nested in expressions, e.g. `SELECT explode(list) + 1 FROM tbl` |
| * 2. more than one [[Generator]] is found in projectList, |
| * e.g. `SELECT explode(list), explode(list) FROM tbl` |
| * 3. [[Generator]] is found in other operators that are not [[Project]] or [[Generate]], |
| * e.g. `SELECT * FROM tbl SORT BY explode(list)` |
| */ |
| object ExtractGenerator extends Rule[LogicalPlan] { |
| def hasGenerator(expr: Expression): Boolean = { |
| expr.exists(_.isInstanceOf[Generator]) |
| } |
| |
| private def hasNestedGenerator(expr: NamedExpression): Boolean = { |
| @scala.annotation.tailrec |
| def hasInnerGenerator(g: Generator): Boolean = g match { |
| // Since `GeneratorOuter` is just a wrapper of generators, we skip it here |
| case go: GeneratorOuter => |
| hasInnerGenerator(go.child) |
| case _ => |
| g.children.exists { _.exists { |
| case _: Generator => true |
| case _ => false |
| } } |
| } |
| trimNonTopLevelAliases(expr) match { |
| case UnresolvedAlias(g: Generator, _) => hasInnerGenerator(g) |
| case Alias(g: Generator, _) => hasInnerGenerator(g) |
| case MultiAlias(g: Generator, _) => hasInnerGenerator(g) |
| case other => hasGenerator(other) |
| } |
| } |
| |
| private def hasAggFunctionInGenerator(ne: Seq[NamedExpression]): Boolean = { |
| ne.exists(_.exists { |
| case g: Generator => |
| g.children.exists(_.exists(_.isInstanceOf[AggregateFunction])) |
| case _ => |
| false |
| }) |
| } |
| |
| private def trimAlias(expr: NamedExpression): Expression = expr match { |
| case UnresolvedAlias(child, _) => child |
| case Alias(child, _) => child |
| case MultiAlias(child, _) => child |
| case _ => expr |
| } |
| |
| private object AliasedGenerator { |
| /** |
| * Extracts a [[Generator]] expression, any names assigned by aliases to the outputs |
| * and the outer flag. The outer flag is used when joining the generator output. |
| * @param e the [[Expression]] |
| * @return (the [[Generator]], seq of output names, outer flag) |
| */ |
| def unapply(e: Expression): Option[(Generator, Seq[String], Boolean)] = e match { |
| case Alias(GeneratorOuter(g: Generator), name) if g.resolved => Some((g, name :: Nil, true)) |
| case MultiAlias(GeneratorOuter(g: Generator), names) if g.resolved => Some((g, names, true)) |
| case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil, false)) |
| case MultiAlias(g: Generator, names) if g.resolved => Some((g, names, false)) |
| case _ => None |
| } |
| } |
| |
| // We must wait until all expressions except for generator functions are resolved before |
| // rewriting generator functions in Project/Aggregate. This is necessary to make this rule |
| // stable for different execution orders of analyzer rules. See also SPARK-47241. |
| private def canRewriteGenerator(namedExprs: Seq[NamedExpression]): Boolean = { |
| namedExprs.forall { ne => |
| ne.resolved || { |
| trimNonTopLevelAliases(ne) match { |
| case AliasedGenerator(_, _, _) => true |
| case _ => false |
| } |
| } |
| } |
| } |
| |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(GENERATOR), ruleId) { |
| case Project(projectList, _) if projectList.exists(hasNestedGenerator) => |
| val nestedGenerator = projectList.find(hasNestedGenerator).get |
| throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) |
| |
| case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) => |
| val nestedGenerator = aggList.find(hasNestedGenerator).get |
| throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) |
| |
| case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 => |
| val generators = aggList.filter(hasGenerator).map(trimAlias) |
| throw QueryCompilationErrors.moreThanOneGeneratorError(generators) |
| |
| case Aggregate(groupList, aggList, child) if canRewriteGenerator(aggList) && |
| aggList.exists(hasGenerator) => |
| // If generator in the aggregate list was visited, set the boolean flag true. |
| var generatorVisited = false |
| |
| val projectExprs = Array.ofDim[NamedExpression](aggList.length) |
| val newAggList = aggList |
| .toIndexedSeq |
| .map(trimNonTopLevelAliases) |
| .zipWithIndex |
| .flatMap { |
| case (AliasedGenerator(generator, names, outer), idx) => |
| // It's a sanity check, this should not happen as the previous case will throw |
| // exception earlier. |
| assert(!generatorVisited, "More than one generator found in aggregate.") |
| generatorVisited = true |
| |
| val newGenChildren: Seq[Expression] = generator.children.zipWithIndex.map { |
| case (e, idx) => if (e.foldable) e else Alias(e, s"_gen_input_${idx}")() |
| } |
| val newGenerator = { |
| val g = generator.withNewChildren(newGenChildren.map { e => |
| if (e.foldable) e else e.asInstanceOf[Alias].toAttribute |
| }).asInstanceOf[Generator] |
| if (outer) GeneratorOuter(g) else g |
| } |
| val newAliasedGenerator = if (names.length == 1) { |
| Alias(newGenerator, names(0))() |
| } else { |
| MultiAlias(newGenerator, names) |
| } |
| projectExprs(idx) = newAliasedGenerator |
| newGenChildren.filter(!_.foldable).asInstanceOf[Seq[NamedExpression]] |
| case (other, idx) => |
| projectExprs(idx) = other.toAttribute |
| other :: Nil |
| } |
| |
| val newAgg = Aggregate(groupList, newAggList, child) |
| Project(projectExprs.toList, newAgg) |
| |
| case p @ Project(projectList, _) if hasAggFunctionInGenerator(projectList) => |
| // If a generator has any aggregate function, we need to apply the `GlobalAggregates` rule |
| // first for replacing `Project` with `Aggregate`. |
| p |
| |
| case p @ Project(projectList, child) if canRewriteGenerator(projectList) && |
| projectList.exists(hasGenerator) => |
| val (resolvedGenerator, newProjectList) = projectList |
| .map(trimNonTopLevelAliases) |
| .foldLeft((None: Option[Generate], Nil: Seq[NamedExpression])) { (res, e) => |
| e match { |
| // If there are more than one generator, we only rewrite the first one and wait for |
| // the next analyzer iteration to rewrite the next one. |
| case AliasedGenerator(generator, names, outer) if res._1.isEmpty && |
| generator.childrenResolved => |
| val g = Generate( |
| generator, |
| unrequiredChildIndex = Nil, |
| outer = outer, |
| qualifier = None, |
| generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), |
| child) |
| (Some(g), res._2 ++ g.nullableOutput) |
| case other => |
| (res._1, res._2 :+ other) |
| } |
| } |
| |
| if (resolvedGenerator.isDefined) { |
| Project(newProjectList, resolvedGenerator.get) |
| } else { |
| p |
| } |
| |
| case g @ Generate(GeneratorOuter(generator), _, _, _, _, _) => |
| g.copy(generator = generator, outer = true) |
| |
| case g: Generate => g |
| |
| case u: UnresolvedTableValuedFunction => u |
| |
| case p: Project => p |
| |
| case a: Aggregate => a |
| |
| case p if p.expressions.exists(hasGenerator) => |
| throw QueryCompilationErrors.generatorOutsideSelectError(p) |
| } |
| } |
| |
| /** |
| * Rewrites table generating expressions that either need one or more of the following in order |
| * to be resolved: |
| * - concrete attribute references for their output. |
| * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). |
| * |
| * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions |
| * that wrap the [[Generator]]. |
| */ |
| object ResolveGenerate extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(GENERATE), ruleId) { |
| case g: Generate if !g.child.resolved || !g.generator.resolved => g |
| case g: Generate if !g.resolved => withPosition(g) { |
| // Check nested generators. |
| if (g.generator.children.exists(ExtractGenerator.hasGenerator)) { |
| throw QueryCompilationErrors.nestedGeneratorError(g.generator) |
| } |
| g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) |
| } |
| } |
| |
| /** |
| * Construct the output attributes for a [[Generator]], given a list of names. If the list of |
| * names is empty names are assigned from field names in generator. |
| */ |
| private[analysis] def makeGeneratorOutput( |
| generator: Generator, |
| names: Seq[String]): Seq[Attribute] = { |
| val elementAttrs = DataTypeUtils.toAttributes(generator.elementSchema) |
| |
| if (names.length == elementAttrs.length) { |
| names.zip(elementAttrs).map { |
| case (name, attr) => attr.withName(name) |
| } |
| } else if (names.isEmpty) { |
| elementAttrs |
| } else { |
| throw QueryCompilationErrors.aliasesNumberNotMatchUDTFOutputError( |
| elementAttrs.size, names.mkString(",")) |
| } |
| } |
| } |
| |
| /** |
| * Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and |
| * aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]] |
| * operators for every distinct [[WindowSpecDefinition]]. |
| * |
| * This rule handles three cases: |
| * - A [[Project]] having [[WindowExpression]]s in its projectList; |
| * - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions. |
| * - A [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING |
| * clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions. |
| * Note: If there is a GROUP BY clause in the query, aggregations and corresponding |
| * filters (expressions in the HAVING clause) should be evaluated before any |
| * [[WindowExpression]]. If a query has SELECT DISTINCT, the DISTINCT part should be |
| * evaluated after all [[WindowExpression]]s. |
| * |
| * Note: [[ResolveLateralColumnAliasReference]] rule is applied before this rule. To guarantee |
| * this order, we make sure this rule applies only when the [[Project]] or [[Aggregate]] doesn't |
| * contain any [[LATERAL_COLUMN_ALIAS_REFERENCE]]. |
| * |
| * For every case, the transformation works as follows: |
| * 1. For a list of [[Expression]]s (a projectList or an aggregateExpressions), partitions |
| * it two lists of [[Expression]]s, one for all [[WindowExpression]]s and another for |
| * all regular expressions. |
| * 2. For all [[WindowExpression]]s, groups them based on their [[WindowSpecDefinition]]s |
| * and [[WindowFunctionType]]s. |
| * 3. For every distinct [[WindowSpecDefinition]] and [[WindowFunctionType]], creates a |
| * [[Window]] operator and inserts it into the plan tree. |
| */ |
| object ExtractWindowExpressions extends Rule[LogicalPlan] { |
| type Spec = (Seq[Expression], Seq[SortOrder], WindowFunctionType) |
| |
| private def hasWindowFunction(exprs: Seq[Expression]): Boolean = |
| exprs.exists(hasWindowFunction) |
| |
| private def hasWindowFunction(expr: Expression): Boolean = { |
| expr.exists { |
| case window: WindowExpression => true |
| case _ => false |
| } |
| } |
| |
| /** |
| * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and |
| * other regular expressions that do not contain any window expression. For example, for |
| * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract |
| * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in |
| * the window expression as attribute references. So, the first returned value will be |
| * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be |
| * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. |
| * |
| * @return (seq of expressions containing at least one window expression, |
| * seq of non-window expressions) |
| */ |
| private def extract( |
| expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = { |
| // First, we partition the input expressions to two part. For the first part, |
| // every expression in it contain at least one WindowExpression. |
| // Expressions in the second part do not have any WindowExpression. |
| val (expressionsWithWindowFunctions, regularExpressions) = |
| expressions.partition(hasWindowFunction) |
| |
| // Then, we need to extract those regular expressions used in the WindowExpression. |
| // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5), |
| // we need to make sure that col1 to col5 are all projected from the child of the Window |
| // operator. |
| val extractedExprMap = mutable.LinkedHashMap.empty[Expression, NamedExpression] |
| def getOrExtract(key: Expression, value: Expression): Expression = { |
| extractedExprMap.getOrElseUpdate(key.canonicalized, |
| Alias(value, s"_w${extractedExprMap.size}")()).toAttribute |
| } |
| def extractExpr(expr: Expression): Expression = expr match { |
| case ne: NamedExpression => |
| // If a named expression is not in regularExpressions, add it to |
| // extractedExprMap and replace it with an AttributeReference. |
| val missingExpr = |
| AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprMap.values) |
| if (missingExpr.nonEmpty) { |
| extractedExprMap += ne.canonicalized -> ne |
| } |
| // alias will be cleaned in the rule CleanupAliases |
| ne |
| case e: Expression if e.foldable => |
| e // No need to create an attribute reference if it will be evaluated as a Literal. |
| case e: NamedArgumentExpression => |
| // For NamedArgumentExpression, we extract the value and replace it with |
| // an AttributeReference (with an internal column name, e.g. "_w0"). |
| NamedArgumentExpression(e.key, getOrExtract(e, e.value)) |
| case e: Expression => |
| // For other expressions, we extract it and replace it with an AttributeReference (with |
| // an internal column name, e.g. "_w0"). |
| getOrExtract(e, e) |
| } |
| |
| // Now, we extract regular expressions from expressionsWithWindowFunctions |
| // by using extractExpr. |
| val seenWindowAggregates = new ArrayBuffer[AggregateExpression] |
| val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { |
| _.transform { |
| // Extracts children expressions of a WindowFunction (input parameters of |
| // a WindowFunction). |
| case wf: WindowFunction => |
| val newChildren = wf.children.map(extractExpr) |
| wf.withNewChildren(newChildren) |
| |
| // Extracts expressions from the partition spec and order spec. |
| case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => |
| val newPartitionSpec = partitionSpec.map(extractExpr) |
| val newOrderSpec = orderSpec.map { so => |
| val newChild = extractExpr(so.child) |
| so.copy(child = newChild) |
| } |
| wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) |
| |
| case WindowExpression(ae: AggregateExpression, _) if ae.filter.isDefined => |
| throw QueryCompilationErrors.windowAggregateFunctionWithFilterNotSupportedError() |
| |
| // Extract Windowed AggregateExpression |
| case we @ WindowExpression( |
| ae @ AggregateExpression(function, _, _, _, _), |
| spec: WindowSpecDefinition) => |
| val newChildren = function.children.map(extractExpr) |
| val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] |
| val newAgg = ae.copy(aggregateFunction = newFunction) |
| seenWindowAggregates += newAgg |
| WindowExpression(newAgg, spec) |
| |
| case AggregateExpression(aggFunc, _, _, _, _) if hasWindowFunction(aggFunc.children) => |
| throw QueryCompilationErrors.windowFunctionInsideAggregateFunctionNotAllowedError() |
| |
| // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), |
| // we need to extract SUM(x). |
| case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => |
| extractedExprMap.getOrElseUpdate(agg.canonicalized, |
| Alias(agg, s"_w${extractedExprMap.size}")()).toAttribute |
| |
| // Extracts other attributes |
| case attr: Attribute => extractExpr(attr) |
| |
| }.asInstanceOf[NamedExpression] |
| } |
| |
| (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprMap.values) |
| } // end of extract |
| |
| /** |
| * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. |
| */ |
| private def addWindow( |
| expressionsWithWindowFunctions: Seq[NamedExpression], |
| child: LogicalPlan): LogicalPlan = { |
| // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions |
| // and put those extracted WindowExpressions to extractedWindowExprBuffer. |
| // This step is needed because it is possible that an expression contains multiple |
| // WindowExpressions with different Window Specs. |
| // After extracting WindowExpressions, we need to construct a project list to generate |
| // expressionsWithWindowFunctions based on extractedWindowExprBuffer. |
| // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract |
| // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to |
| // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)". |
| // Then, the projectList will be [_we0/_we1]. |
| val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]() |
| val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { |
| // We need to use transformDown because we want to trigger |
| // "case alias @ Alias(window: WindowExpression, _)" first. |
| _.transformDown { |
| case alias @ Alias(window: WindowExpression, _) => |
| // If a WindowExpression has an assigned alias, just use it. |
| extractedWindowExprBuffer += alias |
| alias.toAttribute |
| case window: WindowExpression => |
| // If there is no alias assigned to the WindowExpressions. We create an |
| // internal column. |
| val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")() |
| extractedWindowExprBuffer += withName |
| withName.toAttribute |
| }.asInstanceOf[NamedExpression] |
| } |
| |
| // SPARK-32616: Use a linked hash map to maintains the insertion order of the Window |
| // operators, so the query with multiple Window operators can have the determined plan. |
| val groupedWindowExpressions = mutable.LinkedHashMap.empty[Spec, ArrayBuffer[NamedExpression]] |
| // Second, we group extractedWindowExprBuffer based on their Partition and Order Specs. |
| extractedWindowExprBuffer.foreach { expr => |
| val distinctWindowSpec = expr.collect { |
| case window: WindowExpression => window.windowSpec |
| }.distinct |
| |
| // We do a final check and see if we only have a single Window Spec defined in an |
| // expressions. |
| if (distinctWindowSpec.isEmpty) { |
| throw QueryCompilationErrors.expressionWithoutWindowExpressionError(expr) |
| } else if (distinctWindowSpec.length > 1) { |
| // newExpressionsWithWindowFunctions only have expressions with a single |
| // WindowExpression. If we reach here, we have a bug. |
| throw QueryCompilationErrors.expressionWithMultiWindowExpressionsError( |
| expr, distinctWindowSpec) |
| } else { |
| val spec = distinctWindowSpec.head |
| val specKey = (spec.partitionSpec, spec.orderSpec, WindowFunctionType.functionType(expr)) |
| val windowExprs = groupedWindowExpressions |
| .getOrElseUpdate(specKey, new ArrayBuffer[NamedExpression]) |
| windowExprs += expr |
| } |
| } |
| |
| // Third, we aggregate them by adding each Window operator for each Window Spec and then |
| // setting this to the child of the next Window operator. |
| val windowOps = |
| groupedWindowExpressions.foldLeft(child) { |
| case (last, ((partitionSpec, orderSpec, _), windowExpressions)) => |
| Window(windowExpressions.toSeq, partitionSpec, orderSpec, last) |
| } |
| |
| // Finally, we create a Project to output windowOps's output |
| // newExpressionsWithWindowFunctions. |
| Project(windowOps.output ++ newExpressionsWithWindowFunctions, windowOps) |
| } // end of addWindow |
| |
| // We have to use transformDown at here to make sure the rule of |
| // "Aggregate with Having clause" will be triggered. |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDownWithPruning( |
| _.containsPattern(WINDOW_EXPRESSION), ruleId) { |
| |
| case Filter(condition, _) if hasWindowFunction(condition) => |
| throw QueryCompilationErrors.windowFunctionNotAllowedError("WHERE") |
| |
| case UnresolvedHaving(condition, _) if hasWindowFunction(condition) => |
| throw QueryCompilationErrors.windowFunctionNotAllowedError("HAVING") |
| |
| // Aggregate with Having clause. This rule works with an unresolved Aggregate because |
| // a resolved Aggregate will not have Window Functions. |
| case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) |
| if child.resolved && |
| hasWindowFunction(aggregateExprs) && |
| a.expressions.forall(_.resolved) => |
| aggregateExprs.foreach(_.transformDownWithPruning( |
| _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { |
| case lcaRef: LateralColumnAliasReference => |
| throw QueryCompilationErrors.lateralColumnAliasInAggWithWindowAndHavingUnsupportedError( |
| lcaRef.nameParts) |
| }) |
| val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) |
| // Create an Aggregate operator to evaluate aggregation functions. |
| val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) |
| // Add a Filter operator for conditions in the Having clause. |
| val withFilter = Filter(condition, withAggregate) |
| val withWindow = addWindow(windowExpressions, withFilter) |
| |
| // Finally, generate output columns according to the original projectList. |
| val finalProjectList = aggregateExprs.map(_.toAttribute) |
| Project(finalProjectList, withWindow) |
| |
| case p: LogicalPlan if !p.childrenResolved => p |
| |
| // Aggregate without Having clause. |
| // Make sure the lateral column aliases are properly handled first. |
| case a @ Aggregate(groupingExprs, aggregateExprs, child) |
| if hasWindowFunction(aggregateExprs) && |
| a.expressions.forall(_.resolved) && |
| !aggregateExprs.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => |
| val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) |
| // Create an Aggregate operator to evaluate aggregation functions. |
| val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) |
| // Add Window operators. |
| val withWindow = addWindow(windowExpressions, withAggregate) |
| |
| // Finally, generate output columns according to the original projectList. |
| val finalProjectList = aggregateExprs.map(_.toAttribute) |
| Project(finalProjectList, withWindow) |
| |
| // We only extract Window Expressions after all expressions of the Project |
| // have been resolved, and lateral column aliases are properly handled first. |
| case p @ Project(projectList, child) |
| if hasWindowFunction(projectList) && |
| p.expressions.forall(_.resolved) && |
| !projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => |
| val (windowExpressions, regularExpressions) = extract(projectList.toIndexedSeq) |
| // We add a project to get all needed expressions for window expressions from the child |
| // of the original Project operator. |
| val withProject = Project(regularExpressions, child) |
| // Add Window operators. |
| val withWindow = addWindow(windowExpressions, withProject) |
| |
| // Finally, generate output columns according to the original projectList. |
| val finalProjectList = projectList.map(_.toAttribute) |
| Project(finalProjectList, withWindow) |
| } |
| } |
| |
| /** |
| * Set the seed for random number generation. |
| */ |
| object ResolveRandomSeed extends Rule[LogicalPlan] { |
| private lazy val random = new Random() |
| |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(EXPRESSION_WITH_RANDOM_SEED), ruleId) { |
| case p if p.resolved => p |
| case p => p.transformExpressionsUpWithPruning( |
| _.containsPattern(EXPRESSION_WITH_RANDOM_SEED), ruleId) { |
| case e: ExpressionWithRandomSeed if e.seedExpression == UnresolvedSeed => |
| e.withNewSeed(random.nextLong()) |
| } |
| } |
| } |
| |
| /** |
| * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the |
| * null check. When user defines a UDF with primitive parameters, there is no way to tell if the |
| * primitive parameter is null or not, so here we assume the primitive input is null-propagatable |
| * and we should return null if the input is null. |
| */ |
| object HandleNullInputsForUDF extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(SCALA_UDF)) { |
| case p if !p.resolved => p // Skip unresolved nodes. |
| |
| case p => p.transformExpressionsUpWithPruning(_.containsPattern(SCALA_UDF)) { |
| |
| case udf: ScalaUDF if udf.inputPrimitives.contains(true) => |
| // Otherwise, add special handling of null for fields that can't accept null. |
| // The result of operations like this, when passed null, is generally to return null. |
| assert(udf.inputPrimitives.length == udf.children.length) |
| |
| val inputPrimitivesPair = udf.inputPrimitives.zip(udf.children) |
| val inputNullCheck = inputPrimitivesPair.collect { |
| case (isPrimitive, input) if isPrimitive && input.nullable => |
| IsNull(input) |
| }.reduceLeftOption[Expression](Or) |
| |
| if (inputNullCheck.isDefined) { |
| // Once we add an `If` check above the udf, it is safe to mark those checked inputs |
| // as null-safe (i.e., wrap with `KnownNotNull`), because the null-returning |
| // branch of `If` will be called if any of these checked inputs is null. Thus we can |
| // prevent this rule from being applied repeatedly. |
| val newInputs = inputPrimitivesPair.map { |
| case (isPrimitive, input) => |
| if (isPrimitive && input.nullable) { |
| KnownNotNull(input) |
| } else { |
| input |
| } |
| } |
| val newUDF = udf.copy(children = newInputs) |
| If(inputNullCheck.get, Literal.create(null, udf.dataType), newUDF) |
| } else { |
| udf |
| } |
| } |
| } |
| } |
| |
| /** |
| * Resolve the encoders for the UDF by explicitly given the attributes. We give the |
| * attributes explicitly in order to handle the case where the data type of the input |
| * value is not the same with the internal schema of the encoder, which could cause |
| * data loss. For example, the encoder should not cast the input value to Decimal(38, 18) |
| * if the actual data type is Decimal(30, 0). |
| * |
| * The resolved encoders then will be used to deserialize the internal row to Scala value. |
| */ |
| object ResolveEncodersInUDF extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(SCALA_UDF), ruleId) { |
| case p if !p.resolved => p // Skip unresolved nodes. |
| |
| case p => p.transformExpressionsUpWithPruning(_.containsPattern(SCALA_UDF), ruleId) { |
| |
| case udf: ScalaUDF if udf.inputEncoders.nonEmpty => |
| val boundEncoders = udf.inputEncoders.zipWithIndex.map { case (encOpt, i) => |
| val dataType = udf.children(i).dataType |
| encOpt.map { enc => |
| val attrs = if (enc.isSerializedAsStructForTopLevel) { |
| // Value class that has been replaced with its underlying type |
| if (enc.schema.fields.length == 1 && enc.schema.fields.head.dataType == dataType) { |
| DataTypeUtils.toAttributes(enc.schema) |
| } else { |
| DataTypeUtils.toAttributes(dataType.asInstanceOf[StructType]) |
| } |
| } else { |
| // the field name doesn't matter here, so we use |
| // a simple literal to avoid any overhead |
| DataTypeUtils.toAttribute(StructField("input", dataType)) :: Nil |
| } |
| enc.resolveAndBind(attrs) |
| } |
| } |
| udf.copy(inputEncoders = boundEncoders) |
| } |
| } |
| } |
| |
| /** |
| * Check and add proper window frames for all window functions. |
| */ |
| object ResolveWindowFrame extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning( |
| _.containsPattern(WINDOW_EXPRESSION), ruleId) { |
| case WindowExpression(wf: FrameLessOffsetWindowFunction, |
| WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) if wf.frame != f => |
| throw QueryCompilationErrors.cannotSpecifyWindowFrameError(wf.prettyName) |
| case WindowExpression(wf: WindowFunction, WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) |
| if wf.frame != UnspecifiedFrame && wf.frame != f => |
| throw QueryCompilationErrors.windowFrameNotMatchRequiredFrameError(f, wf.frame) |
| case WindowExpression(wf: WindowFunction, s @ WindowSpecDefinition(_, _, UnspecifiedFrame)) |
| if wf.frame != UnspecifiedFrame => |
| WindowExpression(wf, s.copy(frameSpecification = wf.frame)) |
| case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) |
| if e.resolved => |
| val frame = if (o.nonEmpty) { |
| SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) |
| } else { |
| SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing) |
| } |
| we.copy(windowSpec = s.copy(frameSpecification = frame)) |
| } |
| } |
| |
| /** |
| * Check and add order to [[AggregateWindowFunction]]s. |
| */ |
| object ResolveWindowOrder extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning( |
| _.containsPattern(WINDOW_EXPRESSION), ruleId) { |
| case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => |
| throw QueryCompilationErrors.windowFunctionWithWindowFrameNotOrderedError(wf) |
| case WindowExpression(rank: RankLike, spec) if spec.resolved => |
| val order = spec.orderSpec.map(_.child) |
| WindowExpression(rank.withOrder(order), spec) |
| } |
| } |
| |
| /** |
| * Removes natural or using joins by calculating output columns based on output from two sides, |
| * Then apply a Project on a normal Join to eliminate natural or using join. |
| */ |
| object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(NATURAL_LIKE_JOIN), ruleId) { |
| case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint) |
| if left.resolved && right.resolved && j.duplicateResolved => |
| val project = commonNaturalJoinProcessing( |
| left, right, joinType, usingCols, None, hint) |
| j.getTagValue(LogicalPlan.PLAN_ID_TAG) |
| .foreach(project.setTagValue(LogicalPlan.PLAN_ID_TAG, _)) |
| project |
| case j @ Join(left, right, NaturalJoin(joinType), condition, hint) |
| if j.resolvedExceptNatural => |
| // find common column names from both sides |
| val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) |
| val project = commonNaturalJoinProcessing( |
| left, right, joinType, joinNames, condition, hint) |
| j.getTagValue(LogicalPlan.PLAN_ID_TAG) |
| .foreach(project.setTagValue(LogicalPlan.PLAN_ID_TAG, _)) |
| project |
| } |
| } |
| |
| /** |
| * Resolves columns of an output table from the data in a logical plan. This rule will: |
| * |
| * - Reorder columns when the write is by name |
| * - Insert casts when data types do not match |
| * - Insert aliases when column names do not match |
| * - Detect plans that are not compatible with the output table and throw AnalysisException |
| */ |
| object ResolveOutputRelation extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( |
| _.containsPattern(COMMAND), ruleId) { |
| case v2Write: V2WriteCommand |
| if v2Write.table.resolved && v2Write.query.resolved && !v2Write.outputResolved => |
| validateStoreAssignmentPolicy() |
| TableOutputResolver.suitableForByNameCheck(v2Write.isByName, |
| expected = v2Write.table.output, queryOutput = v2Write.query.output) |
| val projection = TableOutputResolver.resolveOutputColumns( |
| v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf) |
| if (projection != v2Write.query) { |
| val cleanedTable = v2Write.table match { |
| case r: DataSourceV2Relation => |
| r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)) |
| case other => other |
| } |
| v2Write.withNewQuery(projection).withNewTable(cleanedTable) |
| } else { |
| v2Write |
| } |
| } |
| } |
| |
| private def validateStoreAssignmentPolicy(): Unit = { |
| // SPARK-28730: LEGACY store assignment policy is disallowed in data source v2. |
| if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { |
| throw QueryCompilationErrors.legacyStoreAssignmentPolicyError() |
| } |
| } |
| |
| private def commonNaturalJoinProcessing( |
| left: LogicalPlan, |
| right: LogicalPlan, |
| joinType: JoinType, |
| joinNames: Seq[String], |
| condition: Option[Expression], |
| hint: JoinHint): LogicalPlan = { |
| import org.apache.spark.sql.catalyst.util._ |
| |
| val leftKeys = joinNames.map { keyName => |
| left.output.find(attr => resolver(attr.name, keyName)).getOrElse { |
| throw QueryCompilationErrors.unresolvedUsingColForJoinError( |
| keyName, left.schema.fieldNames.sorted.map(toSQLId).mkString(", "), "left") |
| } |
| } |
| val rightKeys = joinNames.map { keyName => |
| right.output.find(attr => resolver(attr.name, keyName)).getOrElse { |
| throw QueryCompilationErrors.unresolvedUsingColForJoinError( |
| keyName, right.schema.fieldNames.sorted.map(toSQLId).mkString(", "), "right") |
| } |
| } |
| val joinPairs = leftKeys.zip(rightKeys) |
| |
| val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And) |
| |
| // columns not in joinPairs |
| val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) |
| val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) |
| |
| // the output list looks like: join keys, columns from left, columns from right |
| val (projectList, hiddenList) = joinType match { |
| case LeftOuter => |
| (leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)), |
| rightKeys.map(_.withNullability(true))) |
| case LeftExistence(_) => |
| (leftKeys ++ lUniqueOutput, Seq.empty) |
| case RightOuter => |
| (rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput, |
| leftKeys.map(_.withNullability(true))) |
| case FullOuter => |
| // In full outer join, we should return non-null values for the join columns |
| // if either side has non-null values for those columns. Therefore, for each |
| // join column pair, add a coalesce to return the non-null value, if it exists. |
| val joinedCols = joinPairs.map { case (l, r) => |
| // Since this is a full outer join, either side could be null, so we explicitly |
| // set the nullability to true for both sides. |
| Alias(Coalesce(Seq(l.withNullability(true), r.withNullability(true))), l.name)() |
| } |
| (joinedCols ++ |
| lUniqueOutput.map(_.withNullability(true)) ++ |
| rUniqueOutput.map(_.withNullability(true)), |
| leftKeys.map(_.withNullability(true)) ++ |
| rightKeys.map(_.withNullability(true))) |
| case _ : InnerLike => |
| (leftKeys ++ lUniqueOutput ++ rUniqueOutput, rightKeys) |
| case _ => |
| throw QueryExecutionErrors.unsupportedNaturalJoinTypeError(joinType) |
| } |
| |
| // use Project to hide duplicated common keys |
| // propagate hidden columns from nested USING/NATURAL JOINs |
| val project = Project(projectList, Join(left, right, joinType, newCondition, hint)) |
| project.setTagValue( |
| Project.hiddenOutputTag, |
| hiddenList.map(_.markAsQualifiedAccessOnly()) ++ |
| project.child.metadataOutput.filter(_.qualifiedAccessOnly)) |
| project |
| } |
| |
| /** |
| * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved |
| * to the given input attributes. |
| */ |
| object ResolveDeserializer extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(UNRESOLVED_DESERIALIZER), ruleId) { |
| case p if !p.childrenResolved => p |
| case p if p.resolved => p |
| |
| case p => p.transformExpressionsWithPruning( |
| _.containsPattern(UNRESOLVED_DESERIALIZER), ruleId) { |
| case UnresolvedDeserializer(deserializer, inputAttributes) => |
| val inputs = if (inputAttributes.isEmpty) { |
| p.children.flatMap(_.output) |
| } else { |
| inputAttributes |
| } |
| |
| validateTopLevelTupleFields(deserializer, inputs) |
| val resolved = resolveExpressionByPlanOutput( |
| deserializer, LocalRelation(inputs), throws = true) |
| val result = resolved transformDown { |
| case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => |
| inputData.dataType match { |
| case ArrayType(et, cn) => |
| MapObjects(func, inputData, et, cn, cls) transformUp { |
| case UnresolvedExtractValue(child, fieldName) if child.resolved => |
| ExtractValue(child, fieldName, resolver) |
| } |
| case other => |
| throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other, |
| "array") |
| } |
| case u: UnresolvedCatalystToExternalMap if u.child.resolved => |
| u.child.dataType match { |
| case _: MapType => |
| CatalystToExternalMap(u) transformUp { |
| case UnresolvedExtractValue(child, fieldName) if child.resolved => |
| ExtractValue(child, fieldName, resolver) |
| } |
| case other => |
| throw QueryCompilationErrors.dataTypeMismatchForDeserializerError(other, "map") |
| } |
| } |
| validateNestedTupleFields(result) |
| result |
| } |
| } |
| |
| private def fail(schema: StructType, maxOrdinal: Int): Unit = { |
| throw QueryCompilationErrors.fieldNumberMismatchForDeserializerError(schema, maxOrdinal) |
| } |
| |
| /** |
| * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column |
| * by position. However, the actual number of columns may be different from the number of Tuple |
| * fields. This method is used to check the number of columns and fields, and throw an |
| * exception if they do not match. |
| */ |
| private def validateTopLevelTupleFields( |
| deserializer: Expression, inputs: Seq[Attribute]): Unit = { |
| val ordinals = deserializer.collect { |
| case GetColumnByOrdinal(ordinal, _) => ordinal |
| }.distinct.sorted |
| |
| if (ordinals.nonEmpty && ordinals != inputs.indices) { |
| fail(inputs.toStructType, ordinals.last) |
| } |
| } |
| |
| /** |
| * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field |
| * by position. However, the actual number of struct fields may be different from the number |
| * of nested Tuple fields. This method is used to check the number of struct fields and nested |
| * Tuple fields, and throw an exception if they do not match. |
| */ |
| private def validateNestedTupleFields(deserializer: Expression): Unit = { |
| val structChildToOrdinals = deserializer |
| // There are 2 kinds of `GetStructField`: |
| // 1. resolved from `UnresolvedExtractValue`, and it will have a `name` property. |
| // 2. created when we build deserializer expression for nested tuple, no `name` property. |
| // Here we want to validate the ordinals of nested tuple, so we should only catch |
| // `GetStructField` without the name property. |
| .collect { case g: GetStructField if g.name.isEmpty => g } |
| .groupBy(_.child) |
| .transform((_, v) => v.map(_.ordinal).distinct.sorted) |
| |
| structChildToOrdinals.foreach { case (expr, ordinals) => |
| val schema = expr.dataType.asInstanceOf[StructType] |
| if (ordinals != schema.indices) { |
| fail(schema, ordinals.last) |
| } |
| } |
| } |
| } |
| |
| /** |
| * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being |
| * constructed is an inner class. |
| */ |
| object ResolveNewInstance extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(NEW_INSTANCE), ruleId) { |
| case p if !p.childrenResolved => p |
| case p if p.resolved => p |
| |
| case p => p.transformExpressionsUpWithPruning(_.containsPattern(NEW_INSTANCE), ruleId) { |
| case n: NewInstance if n.childrenResolved && !n.resolved => |
| val outer = OuterScopes.getOuterScope(n.cls) |
| if (outer == null) { |
| throw QueryCompilationErrors.outerScopeFailureForNewInstanceError(n.cls.getName) |
| } |
| n.copy(outerPointer = Some(outer)) |
| } |
| } |
| } |
| |
| /** |
| * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate. |
| */ |
| object ResolveUpCast extends Rule[LogicalPlan] { |
| private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { |
| val fromStr = from match { |
| case l: LambdaVariable => "array element" |
| case e => e.sql |
| } |
| throw QueryCompilationErrors.upCastFailureError(fromStr, from, to, walkedTypePath) |
| } |
| |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| _.containsPattern(UP_CAST), ruleId) { |
| case p if !p.childrenResolved => p |
| case p if p.resolved => p |
| |
| case p => p.transformExpressionsWithPruning(_.containsPattern(UP_CAST), ruleId) { |
| case u @ UpCast(child, _, _) if !child.resolved => u |
| |
| case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] => |
| throw SparkException.internalError( |
| s"UpCast only supports DecimalType as AbstractDataType yet, but got: $target") |
| |
| case UpCast(child, target, walkedTypePath) if target == DecimalType |
| && child.dataType.isInstanceOf[DecimalType] => |
| assert(walkedTypePath.nonEmpty, |
| "object DecimalType should only be used inside ExpressionEncoder") |
| |
| // SPARK-31750: if we want to upcast to the general decimal type, and the `child` is |
| // already decimal type, we can remove the `Upcast` and accept any precision/scale. |
| // This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`. |
| child |
| |
| case UpCast(child, target: AtomicType, _) |
| if conf.getConf(SQLConf.LEGACY_LOOSE_UPCAST) && |
| child.dataType == StringType => |
| Cast(child, target.asNullable) |
| |
| case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) => |
| fail(child, u.dataType, walkedTypePath) |
| |
| case u @ UpCast(child, _, _) => Cast(child, u.dataType) |
| } |
| } |
| } |
| |
| /** |
| * Rule to resolve, normalize and rewrite field names based on case sensitivity for commands. |
| */ |
| object ResolveFieldNameAndPosition extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { |
| case cmd: CreateIndex if cmd.table.resolved && |
| cmd.columns.exists(_._1.isInstanceOf[UnresolvedFieldName]) => |
| val table = cmd.table.asInstanceOf[ResolvedTable] |
| cmd.copy(columns = cmd.columns.map { |
| case (u: UnresolvedFieldName, prop) => resolveFieldNames(table, u.name, u) -> prop |
| case other => other |
| }) |
| |
| case a: DropColumns if a.table.resolved && hasUnresolvedFieldName(a) && a.ifExists => |
| // for DropColumn with IF EXISTS clause, we should resolve and ignore missing column errors |
| val table = a.table.asInstanceOf[ResolvedTable] |
| val columnsToDrop = a.columnsToDrop |
| a.copy(columnsToDrop = columnsToDrop.flatMap(c => resolveFieldNamesOpt(table, c.name, c))) |
| |
| case a: AlterTableCommand if a.table.resolved && hasUnresolvedFieldName(a) => |
| val table = a.table.asInstanceOf[ResolvedTable] |
| a.transformExpressions { |
| case u: UnresolvedFieldName => resolveFieldNames(table, u.name, u) |
| } |
| |
| case a @ AddColumns(r: ResolvedTable, cols) if !a.resolved => |
| // 'colsToAdd' keeps track of new columns being added. It stores a mapping from a |
| // normalized parent name of fields to field names that belong to the parent. |
| // For example, if we add columns "a.b.c", "a.b.d", and "a.c", 'colsToAdd' will become |
| // Map(Seq("a", "b") -> Seq("c", "d"), Seq("a") -> Seq("c")). |
| val colsToAdd = mutable.Map.empty[Seq[String], Seq[String]] |
| def resolvePosition( |
| col: QualifiedColType, |
| parentSchema: StructType, |
| resolvedParentName: Seq[String]): Option[FieldPosition] = { |
| val fieldsAdded = colsToAdd.getOrElse(resolvedParentName, Nil) |
| val resolvedPosition = col.position.map { |
| case u: UnresolvedFieldPosition => u.position match { |
| case after: After => |
| val allFields = parentSchema.fieldNames ++ fieldsAdded |
| allFields.find(n => conf.resolver(n, after.column())) match { |
| case Some(colName) => |
| ResolvedFieldPosition(ColumnPosition.after(colName)) |
| case None => |
| throw QueryCompilationErrors.referenceColNotFoundForAlterTableChangesError( |
| col.colName, allFields) |
| } |
| case _ => ResolvedFieldPosition(u.position) |
| } |
| case resolved => resolved |
| } |
| colsToAdd(resolvedParentName) = fieldsAdded :+ col.colName |
| resolvedPosition |
| } |
| val schema = r.table.columns.asSchema |
| val resolvedCols = cols.map { col => |
| col.path match { |
| case Some(parent: UnresolvedFieldName) => |
| // Adding a nested field, need to resolve the parent column and position. |
| val resolvedParent = resolveFieldNames(r, parent.name, parent) |
| val parentSchema = resolvedParent.field.dataType match { |
| case s: StructType => s |
| case _ => throw QueryCompilationErrors.invalidFieldName( |
| col.name, parent.name, parent.origin) |
| } |
| val resolvedPosition = resolvePosition(col, parentSchema, resolvedParent.name) |
| col.copy(path = Some(resolvedParent), position = resolvedPosition) |
| case _ => |
| // Adding to the root. Just need to resolve position. |
| val resolvedPosition = resolvePosition(col, schema, Nil) |
| col.copy(position = resolvedPosition) |
| } |
| } |
| val resolved = a.copy(columnsToAdd = resolvedCols) |
| resolved.copyTagsFrom(a) |
| resolved |
| |
| case a @ AlterColumn( |
| table: ResolvedTable, ResolvedFieldName(path, field), dataType, _, _, position, _) => |
| val newDataType = dataType.flatMap { dt => |
| // Hive style syntax provides the column type, even if it may not have changed. |
| val existing = CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType) |
| if (existing == dt) None else Some(dt) |
| } |
| val newPosition = position map { |
| case u @ UnresolvedFieldPosition(after: After) => |
| // TODO: since the field name is already resolved, it's more efficient if |
| // `ResolvedFieldName` carries the parent struct and we resolve column position |
| // based on the parent struct, instead of re-resolving the entire column path. |
| val resolved = resolveFieldNames(table, path :+ after.column(), u) |
| ResolvedFieldPosition(ColumnPosition.after(resolved.field.name)) |
| case u: UnresolvedFieldPosition => ResolvedFieldPosition(u.position) |
| case other => other |
| } |
| val resolved = a.copy(dataType = newDataType, position = newPosition) |
| resolved.copyTagsFrom(a) |
| resolved |
| } |
| |
| /** |
| * Returns the resolved field name if the field can be resolved, returns None if the column is |
| * not found. An error will be thrown in CheckAnalysis for columns that can't be resolved. |
| */ |
| private def resolveFieldNames( |
| table: ResolvedTable, |
| fieldName: Seq[String], |
| context: Expression): ResolvedFieldName = { |
| resolveFieldNamesOpt(table, fieldName, context) |
| .getOrElse(throw QueryCompilationErrors.missingFieldError(fieldName, table, context.origin)) |
| } |
| |
| private def resolveFieldNamesOpt( |
| table: ResolvedTable, |
| fieldName: Seq[String], |
| context: Expression): Option[ResolvedFieldName] = { |
| table.schema.findNestedField( |
| fieldName, includeCollections = true, conf.resolver, context.origin |
| ).map { |
| case (path, field) => ResolvedFieldName(path, field) |
| } |
| } |
| |
| private def hasUnresolvedFieldName(a: AlterTableCommand): Boolean = { |
| a.expressions.exists(_.exists(_.isInstanceOf[UnresolvedFieldName])) |
| } |
| } |
| |
| /** |
| * A rule to handle special commands that need to be notified when analysis is done. This rule |
| * should run after all other analysis rules are run. |
| */ |
| object HandleSpecialCommand extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( |
| _.containsPattern(COMMAND)) { |
| case c: AnalysisOnlyCommand if c.resolved => |
| checkAnalysis(c) |
| c.markAsAnalyzed(AnalysisContext.get) |
| case c: KeepAnalyzedQuery if c.resolved => |
| c.storeAnalyzedQuery() |
| } |
| } |
| } |
| |
| /** |
| * Removes [[SubqueryAlias]] operators from the plan. Subqueries are only required to provide |
| * scoping information for attributes and can be removed once analysis is complete. |
| */ |
| object EliminateSubqueryAliases extends Rule[LogicalPlan] { |
| // This is also called in the beginning of the optimization phase, and as a result |
| // is using transformUp rather than resolveOperators. |
| def apply(plan: LogicalPlan): LogicalPlan = AnalysisHelper.allowInvokingTransformsInAnalyzer { |
| plan.transformUpWithPruning(AlwaysProcess.fn, ruleId) { |
| case SubqueryAlias(_, child) => child |
| } |
| } |
| } |
| |
| /** |
| * Removes [[Union]] operators from the plan if it just has one child. |
| */ |
| object EliminateUnions extends Rule[LogicalPlan] { |
| def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( |
| _.containsPattern(UNION), ruleId) { |
| case u: Union if u.children.size == 1 => u.children.head |
| } |
| } |
| |
| /** |
| * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level |
| * expression in Project(project list) or Aggregate(aggregate expressions) or |
| * Window(window expressions). Notice that if an expression has other expression parameters which |
| * are not in its `children`, e.g. `RuntimeReplaceable`, the transformation for Aliases in this |
| * rule can't work for those parameters. |
| */ |
| object CleanupAliases extends Rule[LogicalPlan] with AliasHelper { |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( |
| // trimNonTopLevelAliases can transform Alias and MultiAlias. |
| _.containsAnyPattern(ALIAS, MULTI_ALIAS)) { |
| case Project(projectList, child) => |
| val cleanedProjectList = projectList.map(trimNonTopLevelAliases) |
| Project(cleanedProjectList, child) |
| |
| case Aggregate(grouping, aggs, child) => |
| val cleanedAggs = aggs.map(trimNonTopLevelAliases) |
| Aggregate(grouping.map(trimAliases), cleanedAggs, child) |
| |
| case Window(windowExprs, partitionSpec, orderSpec, child) => |
| val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases) |
| Window(cleanedWindowExprs, partitionSpec.map(trimAliases), |
| orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) |
| |
| case CollectMetrics(name, metrics, child, dataframeId) => |
| val cleanedMetrics = metrics.map(trimNonTopLevelAliases) |
| CollectMetrics(name, cleanedMetrics, child, dataframeId) |
| |
| case Unpivot(ids, values, aliases, variableColumnName, valueColumnNames, child) => |
| val cleanedIds = ids.map(_.map(trimNonTopLevelAliases)) |
| val cleanedValues = values.map(_.map(_.map(trimNonTopLevelAliases))) |
| Unpivot( |
| cleanedIds, |
| cleanedValues, |
| aliases, |
| variableColumnName, |
| valueColumnNames, |
| child) |
| |
| // Operators that operate on objects should only have expressions from encoders, which should |
| // never have extra aliases. |
| case o: ObjectConsumer => o |
| case o: ObjectProducer => o |
| case a: AppendColumns => a |
| |
| case other => |
| other.transformExpressionsDownWithPruning(_.containsAnyPattern(ALIAS, MULTI_ALIAS)) { |
| case Alias(child, _) => child |
| } |
| } |
| } |
| |
| /** |
| * Ignore event time watermark in batch query, which is only supported in Structured Streaming. |
| * TODO: add this rule into analyzer rule list. |
| */ |
| object EliminateEventTimeWatermark extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( |
| _.containsPattern(EVENT_TIME_WATERMARK)) { |
| case EventTimeWatermark(_, _, child) if child.resolved && !child.isStreaming => child |
| case UpdateEventTimeWatermarkColumn(_, _, child) if child.resolved && !child.isStreaming => |
| child |
| } |
| } |
| |
| /** |
| * Resolve expressions if they contains [[NamePlaceholder]]s. |
| */ |
| object ResolveExpressionsWithNamePlaceholders extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning( |
| _.containsAnyPattern(ARRAYS_ZIP, CREATE_NAMED_STRUCT), ruleId) { |
| case e: ArraysZip if !e.resolved => |
| val names = e.children.zip(e.names).map { |
| case (e: NamedExpression, NamePlaceholder) if e.resolved => |
| Literal(e.name) |
| case (_, other) => other |
| } |
| ArraysZip(e.children, names) |
| |
| case e: CreateNamedStruct if !e.resolved => |
| val children = e.children.grouped(2).flatMap { |
| case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => |
| Seq(Literal(e.name), e) |
| case kv => |
| kv |
| } |
| CreateNamedStruct(children.toList) |
| } |
| } |
| |
| /** |
| * The aggregate expressions from subquery referencing outer query block are pushed |
| * down to the outer query block for evaluation. This rule below updates such outer references |
| * as AttributeReference referring attributes from the parent/outer query block. |
| * |
| * For example (SQL): |
| * {{{ |
| * SELECT l.a FROM l GROUP BY 1 HAVING EXISTS (SELECT 1 FROM r WHERE r.d < min(l.b)) |
| * }}} |
| * Plan before the rule. |
| * Project [a#226] |
| * +- Filter exists#245 [min(b#227)#249] |
| * : +- Project [1 AS 1#247] |
| * : +- Filter (d#238 < min(outer(b#227))) <----- |
| * : +- SubqueryAlias r |
| * : +- Project [_1#234 AS c#237, _2#235 AS d#238] |
| * : +- LocalRelation [_1#234, _2#235] |
| * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] |
| * +- SubqueryAlias l |
| * +- Project [_1#223 AS a#226, _2#224 AS b#227] |
| * +- LocalRelation [_1#223, _2#224] |
| * Plan after the rule. |
| * Project [a#226] |
| * +- Filter exists#245 [min(b#227)#249] |
| * : +- Project [1 AS 1#247] |
| * : +- Filter (d#238 < outer(min(b#227)#249)) <----- |
| * : +- SubqueryAlias r |
| * : +- Project [_1#234 AS c#237, _2#235 AS d#238] |
| * : +- LocalRelation [_1#234, _2#235] |
| * +- Aggregate [a#226], [a#226, min(b#227) AS min(b#227)#249] |
| * +- SubqueryAlias l |
| * +- Project [_1#223 AS a#226, _2#224 AS b#227] |
| * +- LocalRelation [_1#223, _2#224] |
| */ |
| object UpdateOuterReferences extends Rule[LogicalPlan] { |
| private def stripAlias(expr: Expression): Expression = expr match { case a: Alias => a.child } |
| |
| private def updateOuterReferenceInSubquery( |
| plan: LogicalPlan, |
| refExprs: Seq[Expression]): LogicalPlan = { |
| plan resolveExpressions { case e => |
| val outerAlias = |
| refExprs.find(stripAlias(_).semanticEquals(stripOuterReference(e))) |
| outerAlias match { |
| case Some(a: Alias) => OuterReference(a.toAttribute) |
| case _ => e |
| } |
| } |
| } |
| |
| def apply(plan: LogicalPlan): LogicalPlan = { |
| plan.resolveOperatorsWithPruning( |
| _.containsAllPatterns(PLAN_EXPRESSION, FILTER, AGGREGATE), ruleId) { |
| case f @ Filter(_, a: Aggregate) if f.resolved => |
| f.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { |
| case s: SubqueryExpression if s.children.nonEmpty => |
| // Collect the aliases from output of aggregate. |
| val outerAliases = a.aggregateExpressions collect { case a: Alias => a } |
| // Update the subquery plan to record the OuterReference to point to outer query plan. |
| s.withNewPlan(updateOuterReferenceInSubquery(s.plan, outerAliases)) |
| } |
| } |
| } |
| } |
| |
| /** |
| * The rule `ResolveReferences` in the main resolution batch creates [[TempResolvedColumn]] in |
| * UnresolvedHaving/Filter/Sort to hold the temporarily resolved column with `agg.child`. |
| * |
| * If the expression hosting [[TempResolvedColumn]] is fully resolved, the rule |
| * `ResolveAggregationFunctions` will |
| * - Replace [[TempResolvedColumn]] with [[AttributeReference]] if it's inside aggregate functions |
| * or grouping expressions. |
| * - Mark [[TempResolvedColumn]] as `hasTried` if not inside aggregate functions or grouping |
| * expressions, hoping other rules can re-resolve it. |
| * `ResolveReferences` will re-resolve [[TempResolvedColumn]] if `hasTried` is true, and keep it |
| * unchanged if the resolution fails. We should turn it back to [[UnresolvedAttribute]] so that the |
| * analyzer can report missing column error later. |
| * |
| * If the expression hosting [[TempResolvedColumn]] is not resolved, [[TempResolvedColumn]] will |
| * remain with `hasTried` as false. We should strip [[TempResolvedColumn]], so that users can see |
| * the reason why the expression is not resolved, e.g. type mismatch. |
| */ |
| object RemoveTempResolvedColumn extends Rule[LogicalPlan] { |
| override def apply(plan: LogicalPlan): LogicalPlan = { |
| plan.resolveExpressionsWithPruning(_.containsPattern(TEMP_RESOLVED_COLUMN)) { |
| case t: TempResolvedColumn => |
| if (t.hasTried) { |
| UnresolvedAttribute(t.nameParts) |
| } else { |
| t.child |
| } |
| } |
| } |
| } |