blob: a44b338cbbaa6c0fa07a56f8cc0a4d0ef941e3ac [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kylin.query.runtime.plan
import org.apache.calcite.rel.core.AggregateCall
import org.apache.calcite.rex.RexLiteral
import org.apache.calcite.sql.SqlKind
import org.apache.kylin.common.KylinConfig
import org.apache.kylin.engine.spark.utils.LogEx
import org.apache.kylin.measure.percentile.PercentileCounter
import org.apache.kylin.metadata.model.FunctionDesc
import org.apache.kylin.query.relnode.{KapAggregateRel, KapProjectRel, KylinAggregateCall, OLAPAggregateRel}
import org.apache.kylin.query.util.RuntimeHelper
import org.apache.spark.sql.KapFunctions._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
import org.apache.spark.sql.catalyst.expressions.{CreateArray, In}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.execution.utils.SchemaProcessor
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, StructType}
import org.apache.spark.sql.udaf.SingleValueAgg
import org.apache.spark.sql.util.SparderTypeUtil
import java.util.Locale
import scala.collection.JavaConverters._
// scalastyle:off
object AggregatePlan extends LogEx {
val binaryMeasureType =
List("PERCENTILE", "PERCENTILE_APPROX", "INTERSECT_COUNT", "COUNT_DISTINCT", "BITMAP_UUID",
FunctionDesc.FUNC_BITMAP_BUILD, FunctionDesc.FUNC_SUM_LC)
def agg(inputs: java.util.List[DataFrame],
rel: KapAggregateRel): DataFrame = logTime("aggregate", debug = true) {
var dataFrame = inputs.get(0)
val schemaNames = dataFrame.schema.fieldNames
val groupList = rel.getRewriteGroupKeys.asScala.map(groupId => col(schemaNames.apply(groupId))).toList
if (rel.getContext != null && rel.getContext.isExactlyAggregate && !rel.getContext.isNeedToManyDerived) {
// exactly match, skip agg, direct project.
val aggCols = rel.getRewriteAggCalls.asScala.zipWithIndex.map {
case (call: KylinAggregateCall, index: Int) =>
val funcName = OLAPAggregateRel.getAggrFuncName(call);
val dataType = call.getFunc.getReturnDataType
val argNames = call.getArgList.asScala.map(dataFrame.schema.names.apply(_))
val columnName = argNames.map(col)
val hash = System.identityHashCode(rel).toString
funcName match {
case FunctionDesc.FUNC_COUNT_DISTINCT =>
if (call.isHllCountDistinctFunc) {
val aggName = SchemaProcessor.replaceToAggravateSchemaName(index, "APPROX_COUNT_DISTINCT_DECODE", hash, argNames: _*)
KapFunctions.approx_count_distinct_decode(columnName.head, dataType.getPrecision).alias(aggName)
} else if (call.isBitmapCountDistinctFunc) {
if (rel.getContext.isExactlyFastBitmap) {
col(schemaNames.apply(call.getArgList.get(0)))
} else {
val aggName = SchemaProcessor.replaceToAggravateSchemaName(index, "PRECISE_COUNT_DISTINCT_DECODE", hash, argNames: _*)
KapFunctions.precise_count_distinct_decode(columnName.head).alias(aggName)
}
} else {
throw new IllegalArgumentException(
s"""Unsupported function name $funcName""")
}
case FunctionDesc.FUNC_PERCENTILE =>
val aggName = SchemaProcessor.replaceToAggravateSchemaName(index, "PERCENTILE_DECODE", hash, argNames: _*)
KapFunctions.k_percentile_decode(columnName.head, columnName(1), dataType.getPrecision).alias(aggName)
case FunctionDesc.FUNC_BITMAP_BUILD =>
val aggName = SchemaProcessor.replaceToAggravateSchemaName(index, "BITMAP_BUILD_DECODE", hash, argNames: _*)
KapFunctions.precise_bitmap_build_decode(columnName.head).alias(aggName)
case FunctionDesc.FUNC_SUM_LC =>
val aggName = SchemaProcessor.replaceToAggravateSchemaName(index, "SUM_LC_DECODE", hash, argNames: _*)
val sparkDataType = SparderTypeUtil.toSparkType(dataType)
KapFunctions.k_sum_lc_decode(columnName.head, sparkDataType.json).alias(aggName)
case _ =>
col(schemaNames.apply(call.getArgList.get(0)))
}
case (call: Any, _: Int) =>
col(schemaNames.apply(call.getArgList.get(0)))
}.toList
val prjList = groupList ++ aggCols
logInfo(s"Query exactly match index, skip agg, project $prjList.")
dataFrame.select(prjList: _*)
} else {
dataFrame = genFiltersWhenIntersectCount(rel, dataFrame)
val aggList = buildAgg(dataFrame.schema, rel)
val groupSets = rel.getRewriteGroupSets.asScala
.map(groupSet => groupSet.asScala.map(groupId => col(schemaNames.apply(groupId))).toList).toList
SparkOperation.agg(AggArgc(dataFrame, groupList, aggList, groupSets, rel.isSimpleGroupType))
}
}
private def genFiltersWhenIntersectCount(rel: KapAggregateRel, dataFrame: DataFrame): DataFrame = {
try {
val intersects = rel.getRewriteAggCalls.asScala.filter(_.isInstanceOf[KylinAggregateCall])
.filter(!_.asInstanceOf[KylinAggregateCall].getFunc.isCount)
.map(_.asInstanceOf[KylinAggregateCall])
.filter(call => !call.getFunc.isCount && OLAPAggregateRel.getAggrFuncName(call).equals(FunctionDesc.FUNC_INTERSECT_COUNT))
val names = dataFrame.schema.names
val children = dataFrame.queryExecution.logical
if (intersects.nonEmpty && intersects.size == rel.getRewriteAggCalls.size() && children.isInstanceOf[Project]) {
// only exists intersect count function in agg
val list = children.asInstanceOf[Project].projectList
val supportGenFilter = intersects.forall { call =>
val listIndex = call.getArgList.get(2)
call.getArgList.size() == 3 && list.apply(listIndex).eval().asInstanceOf[ArrayData].array.map(_.toString)
.forall(!_.contains(KylinConfig.getInstanceFromEnv.getIntersectFilterOrSeparator))
}
if (supportGenFilter) {
val filters = intersects.map { call =>
val filterColumnIndex = call.getArgList.get(1)
val litIndex = call.getArgList.get(2)
new Column(In(col(names(filterColumnIndex)).expr, list.apply(litIndex).children.head.asInstanceOf[CreateArray].children))
}
val column = filters.reduceLeft(_.or(_))
dataFrame.filter(column)
} else {
dataFrame
}
} else {
dataFrame
}
} catch {
case e: Throwable => logWarning("Error occurred when generate filters", e)
dataFrame
}
}
def buildAgg(schema: StructType,
rel: KapAggregateRel): List[Column] = {
val hash = System.identityHashCode(rel).toString
rel.getRewriteAggCalls.asScala.zipWithIndex.map {
case (call: KylinAggregateCall, index: Int)
if binaryMeasureType.contains(OLAPAggregateRel.getAggrFuncName(call)) =>
val dataType = call.getFunc.getReturnDataType
val isCount = call.getFunc.isCount
val funcName =
if (isCount) FunctionDesc.FUNC_COUNT else OLAPAggregateRel.getAggrFuncName(call)
val argNames = call.getArgList.asScala.map(schema.names.apply(_))
val columnName = argNames.map(col)
val registeredFuncName = RuntimeHelper.registerSingleByColName(funcName, dataType)
val aggName = SchemaProcessor.replaceToAggravateSchemaName(index, funcName, hash, argNames: _*)
if (funcName == FunctionDesc.FUNC_COUNT_DISTINCT) {
if (dataType.getName == "hllc") {
org.apache.spark.sql.KapFunctions
.approx_count_distinct(columnName.head, dataType.getPrecision)
.alias(aggName)
} else {
KapFunctions.precise_count_distinct(columnName.head).alias(aggName)
}
} else if (funcName.equalsIgnoreCase(FunctionDesc.FUNC_BITMAP_UUID)) {
KapFunctions.precise_bitmap_uuid(columnName.head).alias(aggName)
} else if (funcName.equalsIgnoreCase(FunctionDesc.FUNC_BITMAP_BUILD)) {
KapFunctions.precise_bitmap_build(columnName.head).alias(aggName)
} else if (funcName.equalsIgnoreCase(FunctionDesc.FUNC_INTERSECT_COUNT)) {
require(columnName.size >= 3, s"Input columns size ${columnName.size} don't greater than or equal to 3.")
val columns = columnName.slice(0, 3).zipWithIndex.map {
case (column: Column, 2) => column.cast(ArrayType.apply(schema.fields.apply(call.getArgList.get(1)).dataType))
case (column: Column, _) => column
}
val separator = s"\\${KylinConfig.getInstanceFromEnv.getIntersectFilterOrSeparator}"
val upperBound = KylinConfig.getInstanceFromEnv.getBitmapValuesUpperBound
call.name.toUpperCase(Locale.ROOT) match {
case FunctionDesc.FUNC_INTERSECT_COUNT => KapFunctions.intersect_count(separator, upperBound, columns.toList: _*).alias(aggName)
case FunctionDesc.FUNC_INTERSECT_VALUE => KapFunctions.intersect_value(separator, upperBound, columns.toList: _*).alias(aggName)
case FunctionDesc.FUNC_INTERSECT_BITMAP_UUID => KapFunctions.intersect_bitmap(separator, upperBound, columns.toList: _*).alias(aggName)
case FunctionDesc.FUNC_INTERSECT_COUNT_V2 => KapFunctions.intersect_count_v2(columnName.last, separator, upperBound, columns.toList: _*).alias(aggName)
case FunctionDesc.FUNC_INTERSECT_VALUE_V2 => KapFunctions.intersect_value_v2(columnName.last, separator, upperBound, columns.toList: _*).alias(aggName)
case FunctionDesc.FUNC_INTERSECT_BITMAP_UUID_V2 => KapFunctions.intersect_bitmap_v2(columnName.last, separator, upperBound, columns.toList: _*).alias(aggName)
case func => throw new UnsupportedOperationException(s"Unsupported intersect count function: $func, please check the sql.")
}
} else if (funcName.equalsIgnoreCase(FunctionDesc.FUNC_PERCENTILE)) {
require(columnName.size == 2, s"Input columns size ${columnName.size} don't equal to 2.")
KapFunctions.k_percentile(columnName.head, columnName(1), dataType.getPrecision).alias(aggName)
} else if (funcName.equalsIgnoreCase(FunctionDesc.FUNC_SUM_LC)) {
KapFunctions.k_sum_lc(columnName.head, SparderTypeUtil.toSparkType(dataType)).alias(aggName)
} else {
callUDF(registeredFuncName, columnName.toList: _*).alias(aggName)
}
case (call: Any, index: Int) =>
val funcName = OLAPAggregateRel.getAggrFuncName(call)
val schemaNames = schema.names
val argNames = call.getArgList.asScala.map(id => schemaNames.apply(id))
val columnName = argNames.map(col)
val inputType = call.getType
val aggName = SchemaProcessor.replaceToAggravateSchemaName(index,
funcName,
hash,
argNames: _*)
funcName match {
case FunctionDesc.FUNC_PERCENTILE =>
rel.getInput match {
case projectRel: KapProjectRel =>
val percentageArg = projectRel.getChildExps.get(call.getArgList.get(1))
val accuracyArg = if (call.getArgList.size() < 3) { None } else { Some(projectRel.getChildExps.get(call.getArgList.get(2))) }
(percentageArg, accuracyArg) match {
case (percentageLitRex: RexLiteral, accuracyArgLitRex: Option[RexLiteral]) =>
if (KylinConfig.getInstanceFromEnv.getPercentileApproxAlgorithm.equalsIgnoreCase("t-digest")) {
KapFunctions.k_percentile(columnName.head, columnName(1), PercentileCounter.DEFAULT_PERCENTILE_ACCURACY).alias(aggName)
} else {
val percentage = percentageLitRex.getValue
val accuracy = accuracyArgLitRex.map(arg => arg.getValue).getOrElse(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)
percentile_approx(col(argNames.head), lit(percentage), lit(accuracy)).alias(aggName)
}
}
case _ =>
throw new UnsupportedOperationException(s"Invalid percentile_approx parameters, " +
s"expecting approx_percentile(col, percentage [, accuracy]), percentage/accuracy must be of constant literal")
}
case FunctionDesc.FUNC_SUM =>
if (isSum0(call)) {
sum0(
col(argNames.head).cast(
SparderTypeUtil.convertSqlTypeToSparkType(inputType)))
.alias(aggName)
} else {
sum(
col(argNames.head).cast(
SparderTypeUtil.convertSqlTypeToSparkType(inputType)))
.alias(aggName)
}
case FunctionDesc.FUNC_COUNT =>
count(if (argNames.isEmpty) k_lit(1) else col(argNames.head))
.alias(aggName)
case FunctionDesc.FUNC_MAX =>
max(
col(argNames.head).cast(
SparderTypeUtil.convertSqlTypeToSparkType(inputType)))
.alias(aggName)
case FunctionDesc.FUNC_MIN =>
min(
col(argNames.head).cast(
SparderTypeUtil.convertSqlTypeToSparkType(inputType)))
.alias(aggName)
case FunctionDesc.FUNC_COUNT_DISTINCT if call.getAggregation.getName == "BITMAP_COUNT" =>
KapFunctions.precise_count_distinct(col(argNames.head)).alias(aggName)
case FunctionDesc.FUNC_COUNT_DISTINCT =>
countDistinct(argNames.head, argNames.drop(1): _*)
.alias(aggName)
case FunctionDesc.FUNC_BITMAP_BUILD =>
KapFunctions.precise_bitmap_build_pushdown(columnName.head).alias(aggName)
// Issue 4337: Supported select (select '2012-01-02') as data, xxx from table group by xxx
case SqlKind.SINGLE_VALUE.sql =>
SingleValueAgg(schema.head).apply(col(argNames.head)).alias(aggName)
case FunctionDesc.FUNC_GROUPING =>
if (!rel.isSimpleGroupType) {
grouping(argNames.head).alias(aggName)
} else {
if (rel.getRewriteGroupKeys.contains(call.getArgList.get(0))) {
k_lit(0).alias(aggName)
} else {
k_lit(1).alias(aggName)
}
}
case FunctionDesc.FUNC_COLLECT_SET =>
call match {
case kac: KylinAggregateCall =>
array_distinct(flatten(collect_set(col(argNames.head))))
.alias(aggName)
case _ =>
collect_set(col(argNames.head)).alias(aggName)
}
case _ =>
throw new IllegalArgumentException(
s"""Unsupported function name $funcName""")
}
}.toList
}
private def isSum0(call: AggregateCall) = {
call.isInstanceOf[KylinAggregateCall] && call
.asInstanceOf[KylinAggregateCall]
.isSum0
}
}