blob: 16a540522699744694edce742ad31d69c081e2e9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes.calcite
import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.util._
import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelCollation, RelNode, RelWriter, SingleRel}
import org.apache.calcite.sql.SqlRankFunction
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.util.{ImmutableBitSet, NumberUtil}
import scala.collection.JavaConversions._
import scala.collection.mutable
/**
* Relational expression that returns the rows in which the rank function value of each row
* is in the given range.
*
* <p>NOTES: Different from [[org.apache.calcite.sql.fun.SqlStdOperatorTable.RANK]],
* [[Rank]] is a Relational expression, not a window function.
*
* <p>[[Rank]] will output rank function value as its last column.
*
* <p>This RelNode only handles single rank function, is an optimization for some cases. e.g.
* <ol>
* <li>
* single rank function (on `OVER`) with filter in a SQL query statement
* </li>
* <li>
* `ORDER BY` with `LIMIT` in a SQL query statement
* (equivalent to `ROW_NUMBER` with filter and project)
* </li>
* </ol>
*
* @param cluster cluster that this relational expression belongs to
* @param traitSet the traits of this rel
* @param input input relational expression
* @param rankFunction rank function, including: CUME_DIST, DENSE_RANK, PERCENT_RANK, RANK,
* ROW_NUMBER
* @param partitionKey partition keys (may be empty)
* @param sortCollation order keys for rank function
* @param rankRange the expected range of rank function value
*/
abstract class Rank(
cluster: RelOptCluster,
traitSet: RelTraitSet,
input: RelNode,
val rankFunction: SqlRankFunction,
val partitionKey: ImmutableBitSet,
val sortCollation: RelCollation,
val rankRange: RankRange)
extends SingleRel(cluster, traitSet, input) {
rankRange match {
case r: ConstantRankRange =>
if (r.rankEnd <= 0) {
throw new TableException(s"Rank end can't smaller than zero. The rank end is ${r.rankEnd}")
}
if (r.rankStart > r.rankEnd) {
throw new TableException(
s"Rank start '${r.rankStart}' can't greater than rank end '${r.rankEnd}'.")
}
case v: VariableRankRange =>
if (v.rankEndIndex < 0) {
throw new TableException(s"Rank end index can't smaller than zero.")
}
if (v.rankEndIndex >= input.getRowType.getFieldCount) {
throw new TableException(s"Rank end index can't greater than input field count.")
}
}
override def deriveRowType(): RelDataType = {
val typeFactory = cluster.getRexBuilder.getTypeFactory
val typeBuilder = typeFactory.builder()
input.getRowType.getFieldList.foreach(typeBuilder.add)
// rank function column is always the last column, and its type is BIGINT NOT NULL
val allFieldNames = mutable.Set[String](input.getRowType.getFieldNames: _*)
val rankFieldName = FlinkRelOptUtil.buildUniqueFieldName(allFieldNames, "rk")
val bigIntType = typeFactory.createSqlType(SqlTypeName.BIGINT)
typeBuilder.add(rankFieldName, typeFactory.createTypeWithNullability(bigIntType, false))
typeBuilder.build()
}
override def explainTerms(pw: RelWriter): RelWriter = {
val select = getRowType.getFieldNames.zipWithIndex.map {
case (name, idx) => s"$name=$$$idx"
}.mkString(", ")
super.explainTerms(pw)
.item("rankFunction", rankFunction)
.item("partitionBy", partitionKey.map(i => s"$$$i").mkString(","))
.item("orderBy", Rank.sortFieldsToString(sortCollation))
.item("rankRange", rankRange.toString())
.item("select", select)
}
override def estimateRowCount(mq: RelMetadataQuery): Double = {
val countPerGroup = FlinkRelMdUtil.getRankRangeNdv(rankRange)
if (partitionKey.isEmpty) {
// only one group
countPerGroup
} else {
val inputRowCount = mq.getRowCount(input)
val numOfGroup = mq.getDistinctRowCount(input, partitionKey, null)
if (numOfGroup != null) {
NumberUtil.min(numOfGroup * countPerGroup, inputRowCount)
} else {
NumberUtil.min(mq.getRowCount(input) * 0.1, inputRowCount)
}
}
}
override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery): RelOptCost = {
val rowCount = mq.getRowCount(input)
val cpuCost = rowCount
planner.getCostFactory.makeCost(rowCount, cpuCost, 0)
}
}
object Rank {
def sortFieldsToString(collationSort: RelCollation): String = {
val fieldCollations = collationSort.getFieldCollations
.map(c => (c.getFieldIndex, SortUtil.directionToOrder(c.getDirection)))
fieldCollations.map {
case (index, order) => s"$$$index ${order.getShortName}"
}.mkString(", ")
}
def sortFieldsToString(collationSort: RelCollation, inputType: RelDataType): String = {
val fieldCollations = collationSort.getFieldCollations
.map(c => (c.getFieldIndex, SortUtil.directionToOrder(c.getDirection)))
val inputFieldNames = inputType.getFieldNames
fieldCollations.map {
case (index, order) => s"${inputFieldNames.get(index)} ${order.getShortName}"
}.mkString(", ")
}
}