blob: 749bf4b0cefa44bd445ebbf2befb67819483f3e8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.ignite
import org.apache.ignite.IgniteException
import org.apache.ignite.spark.impl.optimization._
import org.apache.ignite.spark.impl.optimization.accumulator._
import org.apache.ignite.spark.impl.{IgniteSQLAccumulatorRelation, IgniteSQLRelation, sqlCacheName}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.LogicalRelation
/**
* Query plan optimization for a Ignite based queries.
*/
object IgniteOptimization extends Rule[LogicalPlan] with Logging {
/** @inheritdoc */
override def apply(plan: LogicalPlan): LogicalPlan = {
logDebug("")
logDebug("== Plan Before Ignite Operator Push Down ==")
logDebug(plan.toString())
val transformed = fixAmbiguousOutput(pushDownOperators(plan))
logDebug("")
logDebug("== Plan After Ignite Operator Push Down ==")
logDebug(transformed.toString())
makeIgniteAccRelation(transformed)
}
/**
* Change query plan by accumulating query parts supported by Ignite into `QueryAccumulator`.
*
* @param plan Query plan.
* @return Transformed plan.
*/
private def pushDownOperators(plan: LogicalPlan): LogicalPlan = {
val aliasIndexIterator = Stream.from(1).iterator
//Flag to indicate that some step was skipped due to unsupported expression.
//When it true we has to skip entire transformation of higher level Nodes.
var stepSkipped = true
//Applying optimization rules from bottom to up tree nodes.
plan.transformUp {
//We found basic node to transform.
//We create new accumulator and going to the upper layers.
case LogicalRelation(igniteSqlRelation: IgniteSQLRelation[_, _], output, _catalogTable, _)
//Clear flag to optimize each statement separately
stepSkipped = false
val igniteQueryContext = IgniteQueryContext(
igniteContext = igniteSqlRelation.ic,
sqlContext = igniteSqlRelation.sqlContext,
catalogTable = _catalogTable,
aliasIndex = aliasIndexIterator,
cacheName =
sqlCacheName(igniteSqlRelation.ic.ignite(), igniteSqlRelation.tableName,
igniteSqlRelation.schemaName)
.getOrElse(throw new IgniteException("Unknown table")))
//Logical Relation is bottomest TreeNode in LogicalPlan.
//We replace it with accumulator.
//We push all supported SQL operator into it on the higher tree levels.
SingleTableSQLAccumulator(
igniteQueryContext = igniteQueryContext,
table = Some(igniteSqlRelation.tableName),
tableExpression = None,
outputExpressions = output.map(attr ⇒ attr.withQualifier(Seq(igniteSqlRelation.tableName))))
case project: Project if !stepSkipped && exprsAllowed(project.projectList)
//Project layer just changes output of current query.
project.child match {
case acc: SelectAccumulator
acc.withOutputExpressions(
substituteExpressions(project.projectList, acc.outputExpressions))
case _ ⇒
throw new IgniteException("stepSkipped == true but child is not SelectAccumulator")
}
case sort: Sort if !stepSkipped && isSortPushDownAllowed(sort.order, sort.global)
sort.child match {
case acc: QueryAccumulator
acc.withOrderBy(sort.order)
case _ ⇒
throw new IgniteException("stepSkipped == true but child is not SelectAccumulator")
}
case filter: Filter if !stepSkipped && exprsAllowed(filter.condition)
filter.child match {
case acc: SelectAccumulator
if (hasAggregateInside(filter.condition) || acc.groupBy.isDefined)
acc.withHaving(acc.having.getOrElse(Nil) :+ filter.condition)
else
acc.withWhere(acc.where.getOrElse(Nil) :+ filter.condition)
case _ ⇒
throw new IgniteException("stepSkipped == true but child is not SelectAccumulator")
}
case agg: Aggregate
if !stepSkipped && exprsAllowed(agg.groupingExpressions) && exprsAllowed(agg.aggregateExpressions)
agg.child match {
case acc: SelectAccumulator
if (acc.groupBy.isDefined) {
val tableAlias = acc.igniteQueryContext.uniqueTableAlias
SingleTableSQLAccumulator(
igniteQueryContext = acc.igniteQueryContext,
table = None,
tableExpression = Some((acc, tableAlias)),
outputExpressions = agg.aggregateExpressions)
}
else
acc
.withGroupBy(agg.groupingExpressions)
.withOutputExpressions(
substituteExpressions(agg.aggregateExpressions, acc.outputExpressions))
case acc: QueryAccumulator
val tableAlias = acc.igniteQueryContext.uniqueTableAlias
SingleTableSQLAccumulator(
igniteQueryContext = acc.igniteQueryContext,
table = None,
tableExpression = Some((acc, tableAlias)),
outputExpressions = agg.aggregateExpressions)
case _ ⇒
throw new IgniteException("stepSkipped == true but child is not SelectAccumulator")
}
case limit: LocalLimit if !stepSkipped && exprsAllowed(limit.limitExpr)
limit.child match {
case acc: SelectAccumulator
acc.withLocalLimit(limit.limitExpr)
case acc: QueryAccumulator
acc.withLocalLimit(limit.limitExpr)
case _ ⇒
throw new IgniteException("stepSkipped == true but child is not SelectAccumulator")
}
case limit: GlobalLimit if !stepSkipped && exprsAllowed(limit.limitExpr)
limit.child.transformUp {
case acc: SelectAccumulator
acc.withLimit(limit.limitExpr)
case acc: QueryAccumulator
acc.withLimit(limit.limitExpr)
case _ ⇒
throw new IgniteException("stepSkipped == true but child is not SelectAccumulator")
}
case union: Union if !stepSkipped && isAllChildrenOptimized(union.children)
val first = union.children.head.asInstanceOf[QueryAccumulator]
val subQueries = union.children.map(_.asInstanceOf[QueryAccumulator])
UnionSQLAccumulator(
first.igniteQueryContext,
subQueries,
subQueries.head.output)
case join: Join
if !stepSkipped && isAllChildrenOptimized(Seq(join.left, join.right)) &&
join.condition.forall(exprsAllowed)
val left = join.left.asInstanceOf[QueryAccumulator]
val (leftOutput, leftAlias) =
if (!isSimpleTableAcc(left)) {
val tableAlias = left.igniteQueryContext.uniqueTableAlias
(left.output, Some(tableAlias))
}
else
(left.output, None)
val right = join.right.asInstanceOf[QueryAccumulator]
val (rightOutput, rightAlias) =
if (!isSimpleTableAcc(right) ||
leftAlias.getOrElse(left.qualifier) == right.qualifier) {
val tableAlias = right.igniteQueryContext.uniqueTableAlias
(right.output, Some(tableAlias))
}
else
(right.output, None)
JoinSQLAccumulator(
left.igniteQueryContext,
left,
right,
join.joinType,
leftOutput ++ rightOutput,
join.condition,
leftAlias,
rightAlias)
case unknown ⇒
stepSkipped = true
unknown
}
}
/**
* Changes qualifiers for an ambiguous columns names.
*
* @param plan Query plan.
* @return Transformed plan.
*/
private def fixAmbiguousOutput(plan: LogicalPlan): LogicalPlan = plan.transformDown {
case acc: SingleTableSQLAccumulator if acc.children.exists(_.isInstanceOf[JoinSQLAccumulator])
val fixedChildOutput =
fixAmbiguousOutput(acc.tableExpression.get._1.outputExpressions, acc.igniteQueryContext)
val newOutput = substituteExpressions(acc.outputExpressions, fixedChildOutput, changeOnlyName = true)
acc.copy(
outputExpressions = newOutput,
where = acc.where.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)),
groupBy = acc.groupBy.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)),
having = acc.having.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)),
limit = acc.limit.map(
substituteExpression(_, fixedChildOutput, changeOnlyName = true)),
localLimit = acc.localLimit.map(
substituteExpression(_, fixedChildOutput, changeOnlyName = true)),
orderBy = acc.orderBy.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)))
acc
case acc: JoinSQLAccumulator
if acc.left.isInstanceOf[JoinSQLAccumulator] || acc.right.isInstanceOf[JoinSQLAccumulator]
val leftFixed = acc.left match {
case leftJoin: JoinSQLAccumulator
val fixedChildOutput = fixAmbiguousOutput(acc.left.outputExpressions, acc.igniteQueryContext)
val newOutput =
substituteExpressions(acc.outputExpressions, fixedChildOutput, changeOnlyName = true)
acc.copy(
outputExpressions = newOutput,
left = leftJoin.copy(outputExpressions = fixedChildOutput),
condition = acc.condition.map(
substituteExpression(_, fixedChildOutput, changeOnlyName = true)),
where = acc.where.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)),
groupBy = acc.groupBy.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)),
having = acc.having.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)),
limit = acc.limit.map(
substituteExpression(_, fixedChildOutput, changeOnlyName = true)),
localLimit = acc.localLimit.map(
substituteExpression(_, fixedChildOutput, changeOnlyName = true)),
orderBy = acc.orderBy.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)))
case _ ⇒ acc
}
val fixed = leftFixed.right match {
case rightJoin: JoinSQLAccumulator
val fixedChildOutput =
fixAmbiguousOutput(leftFixed.outputExpressions, leftFixed.igniteQueryContext)
val newOutput = substituteExpressions(leftFixed.outputExpressions, fixedChildOutput)
leftFixed.copy(
outputExpressions = newOutput,
right = rightJoin.copy(outputExpressions = fixedChildOutput),
condition = acc.condition.map(
substituteExpression(_, fixedChildOutput, changeOnlyName = true)),
where = acc.where.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)),
groupBy = acc.groupBy.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)),
having = acc.having.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)),
limit = acc.limit.map(
substituteExpression(_, fixedChildOutput, changeOnlyName = true)),
localLimit = acc.localLimit.map(
substituteExpression(_, fixedChildOutput, changeOnlyName = true)),
orderBy = acc.orderBy.map(
substituteExpressions(_, fixedChildOutput, changeOnlyName = true)))
case _ ⇒ leftFixed
}
fixed.copy(
condition = acc.condition.map(
substituteExpression(_, acc.outputExpressions, changeOnlyName = true)),
where = acc.where.map(
substituteExpressions(_, acc.outputExpressions, changeOnlyName = true)),
groupBy = acc.groupBy.map(
substituteExpressions(_, acc.outputExpressions, changeOnlyName = true)),
having = acc.having.map(
substituteExpressions(_, acc.outputExpressions, changeOnlyName = true)),
limit = acc.limit.map(
substituteExpression(_, acc.outputExpressions, changeOnlyName = true)),
localLimit = acc.localLimit.map(
substituteExpression(_, acc.outputExpressions, changeOnlyName = true)),
orderBy = acc.orderBy.map(
substituteExpressions(_, acc.outputExpressions, changeOnlyName = true)))
case unknown ⇒
unknown
}
private def fixAmbiguousOutput(exprs: Seq[NamedExpression], ctx: IgniteQueryContext): Seq[NamedExpression] =
exprs.foldLeft((Set[String](), Set[NamedExpression]())) {
case ((uniqueNames, fixed), next)
if (uniqueNames(next.name))
(uniqueNames, fixed + Alias(next, ctx.uniqueColumnAlias(next))(exprId = next.exprId))
else
(uniqueNames + next.name, fixed + next)
}._2.toSeq
/**
* Substitutes each `QueryAccumulator` with a `LogicalRelation` contains `IgniteSQLAccumulatorRelation`.
*
* @param plan Query plan.
* @return Transformed plan.
*/
private def makeIgniteAccRelation(plan: LogicalPlan): LogicalPlan =
plan.transformDown {
case acc: QueryAccumulator
new LogicalRelation (
relation = IgniteSQLAccumulatorRelation(acc),
output = acc.outputExpressions.map(toAttributeReference(_, Seq.empty)),
catalogTable = acc.igniteQueryContext.catalogTable,
false)
}
/**
* @param order Order.
* @param global True if order applied to entire result set false if ordering per-partition.
* @return True if sort can be pushed down to Ignite, false otherwise.
*/
private def isSortPushDownAllowed(order: Seq[SortOrder], global: Boolean): Boolean =
global && order.map(_.child).forall(exprsAllowed)
/**
* @param children Plans to check.
* @return True is all plan are `QueryAccumulator`, false otherwise.
*/
private def isAllChildrenOptimized(children: Seq[LogicalPlan]): Boolean =
children.forall {
case _: QueryAccumulator
true
case _ ⇒
false
}
/**
* Changes expression from `exprs` collection to expression with same `exprId` from `substitution`.
*
* @param exprs Expressions to substitute.
* @param substitution Expressions for substitution
* @param changeOnlyName If true substitute only expression name.
* @tparam T Concrete expression type.
* @return Substituted expressions.
*/
private def substituteExpressions[T <: Expression](exprs: Seq[T], substitution: Seq[NamedExpression],
changeOnlyName: Boolean = false): Seq[T] = {
exprs.map(substituteExpression(_, substitution, changeOnlyName))
}
private def substituteExpression[T <: Expression](expr: T, substitution: Seq[NamedExpression],
changeOnlyName: Boolean): T = expr match {
case ne: NamedExpression
substitution.find(_.exprId == ne.exprId) match {
case Some(found)
if (!changeOnlyName)
found.asInstanceOf[T]
else ne match {
case alias: Alias
Alias(
AttributeReference(
found.name,
found.dataType,
nullable = found.nullable,
metadata = found.metadata)(
exprId = found.exprId,
qualifier = found.qualifier),
alias.name) (
exprId = alias.exprId,
qualifier = alias.qualifier,
explicitMetadata = alias.explicitMetadata).asInstanceOf[T]
case attr: AttributeReference
attr.copy(name = found.name)(
exprId = found.exprId,
qualifier = found.qualifier).asInstanceOf[T]
case _ ⇒ ne.asInstanceOf[T]
}
case None
expr.withNewChildren(
substituteExpressions(expr.children, substitution, changeOnlyName)).asInstanceOf[T]
}
case _ ⇒
expr.withNewChildren(
substituteExpressions(expr.children, substitution, changeOnlyName)).asInstanceOf[T]
}
}