blob: a9be67432122f6dfc43cfe793d2aaefd10d09a86 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.spark.impl.optimization.accumulator
import org.apache.ignite.IgniteException
import org.apache.ignite.spark.impl.optimization._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.BinaryNode
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
/**
* Accumulator to store information about join query.
*/
private[apache] case class JoinSQLAccumulator(
igniteQueryContext: IgniteQueryContext,
left: QueryAccumulator,
right: QueryAccumulator,
joinType: JoinType,
outputExpressions: Seq[NamedExpression],
condition: Option[Expression],
leftAlias: Option[String],
rightAlias: Option[String],
distinct: Boolean = false,
where: Option[Seq[Expression]] = None,
groupBy: Option[Seq[Expression]] = None,
having: Option[Seq[Expression]] = None,
limit: Option[Expression] = None,
localLimit: Option[Expression] = None,
orderBy: Option[Seq[SortOrder]] = None
) extends BinaryNode with SelectAccumulator {
/** @inheritdoc */
override def compileQuery(prettyPrint: Boolean = false, nestedQuery: Boolean = false): String = {
val delim = if (prettyPrint) "\n" else " "
val tab = if (prettyPrint) " " else ""
var sql = s"SELECT$delim$tab" +
s"${fixQualifier(outputExpressions).map(exprToString(_, useQualifier = true)).mkString(", ")}$delim" +
s"FROM$delim$tab$compileJoinExpr"
if (allFilters.nonEmpty)
sql += s"${delim}WHERE$delim$tab" +
s"${fixQualifier(allFilters).map(exprToString(_, useQualifier = true)).mkString(s" AND$delim$tab")}"
if (groupBy.exists(_.nonEmpty))
sql += s"${delim}GROUP BY " +
s"${fixQualifier(groupBy.get).map(exprToString(_, useQualifier = true)).mkString(s",$delim$tab")}"
if (having.exists(_.nonEmpty))
sql += s"${delim}HAVING " +
s"${fixQualifier(having.get).map(exprToString(_, useQualifier = true)).mkString(s" AND$delim$tab")}"
if (orderBy.exists(_.nonEmpty))
sql += s"${delim}ORDER BY " +
s"${fixQualifier(orderBy.get).map(exprToString(_, useQualifier = true)).mkString(s",$delim$tab")}"
if (limit.isDefined) {
sql += s" LIMIT ${exprToString(fixQualifier0(limit.get), useQualifier = true)}"
if (nestedQuery)
sql = s"SELECT * FROM ($sql)"
}
sql
}
/**
* @return Filters for this query.
*/
private def allFilters: Seq[Expression] = {
val leftFilters =
if (isSimpleTableAcc(left))
left.asInstanceOf[SingleTableSQLAccumulator].where.getOrElse(Seq.empty)
else
Seq.empty
val rightFilters =
if (isSimpleTableAcc(right))
right.asInstanceOf[SingleTableSQLAccumulator].where.getOrElse(Seq.empty)
else Seq.empty
where.getOrElse(Seq.empty) ++ leftFilters ++ rightFilters
}
/**
* @return `table1 LEFT JOIN (SELECT....FROM...) table2` part of join query.
*/
private def compileJoinExpr: String = {
val leftJoinSql =
if (isSimpleTableAcc(left))
left.asInstanceOf[SingleTableSQLAccumulator].table.get
else
s"(${left.compileQuery()}) ${leftAlias.get}"
val rightJoinSql = {
val leftTableName =
if (isSimpleTableAcc(left))
left.qualifier
else
leftAlias.get
if (isSimpleTableAcc(right)) {
val rightTableName = right.asInstanceOf[SingleTableSQLAccumulator].table.get
if (leftTableName == rightTableName)
s"$rightTableName as ${rightAlias.get}"
else
rightTableName
} else
s"(${right.compileQuery()}) ${rightAlias.get}"
}
s"$leftJoinSql $joinTypeSQL $rightJoinSql" +
s"${condition.map(expr ⇒ s" ON ${exprToString(fixQualifier0(expr), useQualifier = true)}").getOrElse("")}"
}
/**
* @return SQL string representing specific join type.
*/
private def joinTypeSQL = joinType match {
case Inner
"JOIN"
case LeftOuter
"LEFT JOIN"
case RightOuter
"RIGHT JOIN"
case _ ⇒
throw new IgniteException(s"Unsupported join type $joinType")
}
/**
* Changes table qualifier in case of embedded query.
*
* @param exprs Expressions to fix.
* @tparam T type of input expression.
* @return copy of `exprs` with fixed qualifier.
*/
private def fixQualifier[T <: Expression](exprs: Seq[T]): Seq[T] =
exprs.map(fixQualifier0)
/**
* Changes table qualifier for single expression.
*
* @param expr Expression to fix.
* @tparam T type of input expression.
* @return copy of `expr` with fixed qualifier.
*/
private def fixQualifier0[T <: Expression](expr: T): T = expr match {
case attr: AttributeReference
attr.withQualifier(Seq(findQualifier(attr))).asInstanceOf[T]
case _ ⇒
expr.withNewChildren(fixQualifier(expr.children)).asInstanceOf[T]
}
/**
* Find right qualifier for a `attr`.
*
* @param attr Attribute to fix qualifier in
* @return Right qualifier for a `attr`
*/
private def findQualifier(attr: AttributeReference): String = {
val leftTableName =
if (isSimpleTableAcc(left))
left.qualifier
else
leftAlias.get
if (left.outputExpressions.exists(_.exprId == attr.exprId))
leftTableName
else if (isSimpleTableAcc(right) && right.qualifier != leftTableName)
right.qualifier
else
rightAlias.get
}
/** @inheritdoc */
override def simpleString: String =
s"JoinSQLAccumulator(joinType: $joinType, columns: $outputExpressions, condition: $condition)"
/** @inheritdoc */
override def withOutputExpressions(outputExpressions: Seq[NamedExpression]): SelectAccumulator = copy(outputExpressions= outputExpressions)
/** @inheritdoc */
override def withDistinct(distinct: Boolean): JoinSQLAccumulator = copy(distinct = distinct)
/** @inheritdoc */
override def withWhere(where: Seq[Expression]): JoinSQLAccumulator = copy(where = Some(where))
/** @inheritdoc */
override def withGroupBy(groupBy: Seq[Expression]): JoinSQLAccumulator = copy(groupBy = Some(groupBy))
/** @inheritdoc */
override def withHaving(having: Seq[Expression]): JoinSQLAccumulator = copy(having = Some(having))
/** @inheritdoc */
override def withLimit(limit: Expression): JoinSQLAccumulator = copy(limit = Some(limit))
/** @inheritdoc */
override def withLocalLimit(localLimit: Expression): JoinSQLAccumulator = copy(localLimit = Some(localLimit))
/** @inheritdoc */
override def withOrderBy(orderBy: Seq[SortOrder]): JoinSQLAccumulator = copy(orderBy = Some(orderBy))
/** @inheritdoc */
override def output: Seq[Attribute] = outputExpressions.map(toAttributeReference(_, Seq.empty))
/** @inheritdoc */
override lazy val qualifier: String = igniteQueryContext.uniqueTableAlias
}