blob: 5c8c479986a252655878fe617f6e3500d6ad62b1 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanCheckContext;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllCardinality;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* AbstractMaterializedViewAggregateRule
* This is responsible for common aggregate rewriting
*/
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {
protected static final Multimap<Function, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = ArrayListMultimap.create();
protected static final AggregateExpressionRewriter AGGREGATE_EXPRESSION_REWRITER =
new AggregateExpressionRewriter();
static {
// support roll up when count distinct is in query
// the column type is not bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(new ToBitmap(Any.INSTANCE)));
// with bitmap_union, to_bitmap and cast
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))));
// support roll up when bitmap_union_count is in query
// the column type is bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(Any.INSTANCE),
new BitmapUnion(Any.INSTANCE));
// the column type is not bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(new ToBitmap(Any.INSTANCE)),
new BitmapUnion(new ToBitmap(Any.INSTANCE)));
// with bitmap_union, to_bitmap and cast
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))),
new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))));
// support roll up when the column type is not hll
// query is approx_count_distinct
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Ndv(Any.INSTANCE),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Ndv(Any.INSTANCE),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// query is HLL_UNION_AGG
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnionAgg(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnionAgg(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// query is HLL_CARDINALITY
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(new HllHash(Any.INSTANCE))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(new HllHash(Any.INSTANCE))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllCardinality(new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT)))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllCardinality(new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT)))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// query is HLL_RAW_AGG or HLL_UNION
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
// support roll up when the column type is hll
// query is HLL_UNION_AGG
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(Any.INSTANCE),
new HllUnion(Any.INSTANCE));
// query is HLL_CARDINALITY
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(Any.INSTANCE)),
new HllUnion(Any.INSTANCE));
// query is HLL_RAW_AGG or HLL_UNION
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(Any.INSTANCE),
new HllUnion(Any.INSTANCE));
}
@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
StructInfo viewStructInfo,
SlotMapping viewToQuerySlotMapping,
Plan tempRewritedPlan,
MaterializationContext materializationContext) {
// get view and query aggregate and top plan correspondingly
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair = splitToTopPlanAndAggregate(viewStructInfo);
if (viewTopPlanAndAggPair == null) {
materializationContext.recordFailReason(queryStructInfo,
"Split view to top plan and agg fail, view doesn't not contain aggregate",
() -> String.format("view plan = %s\n", viewStructInfo.getOriginalPlan().treeString()));
return null;
}
Pair<Plan, LogicalAggregate<Plan>> queryTopPlanAndAggPair = splitToTopPlanAndAggregate(queryStructInfo);
if (queryTopPlanAndAggPair == null) {
materializationContext.recordFailReason(queryStructInfo,
"Split query to top plan and agg fail",
() -> String.format("query plan = %s\n", queryStructInfo.getOriginalPlan().treeString()));
return null;
}
// Firstly,if group by expression between query and view is equals, try to rewrite expression directly
Plan queryTopPlan = queryTopPlanAndAggPair.key();
if (isGroupByEquals(queryTopPlanAndAggPair, viewTopPlanAndAggPair, viewToQuerySlotMapping, queryStructInfo,
viewStructInfo)) {
List<Expression> rewrittenQueryExpressions = rewriteExpression(queryTopPlan.getOutput(),
queryTopPlan,
materializationContext.getMvExprToMvScanExprMapping(),
viewToQuerySlotMapping,
true,
queryStructInfo.getTableBitSet());
if (!rewrittenQueryExpressions.isEmpty()) {
List<NamedExpression> projects = new ArrayList<>();
for (Expression expression : rewrittenQueryExpressions) {
if (expression.containsType(AggregateFunction.class)) {
materializationContext.recordFailReason(queryStructInfo,
"rewritten expression contains aggregate functions when group equals aggregate rewrite",
() -> String.format("aggregate functions = %s\n", rewrittenQueryExpressions));
return null;
}
projects.add(expression instanceof NamedExpression
? (NamedExpression) expression : new Alias(expression));
}
return new LogicalProject<>(projects, tempRewritedPlan);
}
// if fails, record the reason and then try to roll up aggregate function
materializationContext.recordFailReason(queryStructInfo,
"Can not rewrite expression when no roll up",
() -> String.format("expressionToWrite = %s,\n mvExprToMvScanExprMapping = %s,\n"
+ "viewToQuerySlotMapping = %s",
queryTopPlan.getOutput(),
materializationContext.getMvExprToMvScanExprMapping(),
viewToQuerySlotMapping));
}
// if view is scalar aggregate but query is not. Or if query is scalar aggregate but view is not
// Should not rewrite
List<Expression> queryGroupByExpressions = queryTopPlanAndAggPair.value().getGroupByExpressions();
List<Expression> viewGroupByExpressions = viewTopPlanAndAggPair.value().getGroupByExpressions();
if ((queryGroupByExpressions.isEmpty() && !viewGroupByExpressions.isEmpty())
|| (!queryGroupByExpressions.isEmpty() && viewGroupByExpressions.isEmpty())) {
materializationContext.recordFailReason(queryStructInfo,
"only one the of query or view is scalar aggregate and "
+ "can not rewrite expression meanwhile",
() -> String.format("query aggregate = %s,\n view aggregate = %s,\n",
queryTopPlanAndAggPair.value().treeString(),
viewTopPlanAndAggPair.value().treeString()));
return null;
}
// try to roll up.
// split the query top plan expressions to group expressions and functions, if can not, bail out.
Pair<Set<? extends Expression>, Set<? extends Expression>> queryGroupAndFunctionPair
= topPlanSplitToGroupAndFunction(queryTopPlanAndAggPair, queryStructInfo);
Set<? extends Expression> queryTopPlanFunctionSet = queryGroupAndFunctionPair.value();
// try to rewrite, contains both roll up aggregate functions and aggregate group expression
List<NamedExpression> finalOutputExpressions = new ArrayList<>();
List<Expression> finalGroupExpressions = new ArrayList<>();
List<? extends Expression> queryExpressions = queryTopPlan.getOutput();
// permute the mv expr mapping to query based
Map<Expression, Expression> mvExprToMvScanExprQueryBased =
materializationContext.getMvExprToMvScanExprMapping().keyPermute(viewToQuerySlotMapping)
.flattenMap().get(0);
for (Expression topExpression : queryExpressions) {
// if agg function, try to roll up and rewrite
if (queryTopPlanFunctionSet.contains(topExpression)) {
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
topExpression,
queryTopPlan,
queryStructInfo.getTableBitSet());
AggregateExpressionRewriteContext context = new AggregateExpressionRewriteContext(
false, mvExprToMvScanExprQueryBased, queryTopPlan, queryStructInfo.getTableBitSet());
// queryFunctionShuttled maybe sum(column) + count(*), so need to use expression rewriter
Expression rollupedExpression = queryFunctionShuttled.accept(AGGREGATE_EXPRESSION_REWRITER,
context);
if (!context.isValid()) {
materializationContext.recordFailReason(queryStructInfo,
"Query function roll up fail",
() -> String.format("queryFunctionShuttled = %s,\n mvExprToMvScanExprQueryBased = %s",
queryFunctionShuttled, mvExprToMvScanExprQueryBased));
return null;
}
finalOutputExpressions.add(new Alias(rollupedExpression));
} else {
// if group by expression, try to rewrite group by expression
Expression queryGroupShuttledExpr = ExpressionUtils.shuttleExpressionWithLineage(
topExpression, queryTopPlan, queryStructInfo.getTableBitSet());
AggregateExpressionRewriteContext context = new AggregateExpressionRewriteContext(true,
mvExprToMvScanExprQueryBased, queryTopPlan, queryStructInfo.getTableBitSet());
// group by expression maybe group by a + b, so we need expression rewriter
Expression rewrittenGroupByExpression = queryGroupShuttledExpr.accept(AGGREGATE_EXPRESSION_REWRITER,
context);
if (!context.isValid()) {
// group expr can not rewrite by view
materializationContext.recordFailReason(queryStructInfo,
"View dimensions doesn't not cover the query dimensions",
() -> String.format("mvExprToMvScanExprQueryBased is %s,\n queryGroupShuttledExpr is %s",
mvExprToMvScanExprQueryBased, queryGroupShuttledExpr));
return null;
}
NamedExpression groupByExpression = rewrittenGroupByExpression instanceof NamedExpression
? (NamedExpression) rewrittenGroupByExpression : new Alias(rewrittenGroupByExpression);
finalOutputExpressions.add(groupByExpression);
finalGroupExpressions.add(groupByExpression);
}
}
// add project to guarantee group by column ref is slot reference,
// this is necessary because physical createHash will need slotReference later
List<Expression> copiedFinalGroupExpressions = new ArrayList<>(finalGroupExpressions);
List<NamedExpression> projectsUnderAggregate = copiedFinalGroupExpressions.stream()
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
projectsUnderAggregate.addAll(tempRewritedPlan.getOutput());
LogicalProject<Plan> mvProject = new LogicalProject<>(projectsUnderAggregate, tempRewritedPlan);
// add agg rewrite
Map<ExprId, Slot> projectOutPutExprIdMap = mvProject.getOutput().stream()
.distinct()
.collect(Collectors.toMap(NamedExpression::getExprId, slot -> slot));
// make the expressions to re reference project output
finalGroupExpressions = finalGroupExpressions.stream()
.map(expr -> {
ExprId exprId = ((NamedExpression) expr).getExprId();
if (projectOutPutExprIdMap.containsKey(exprId)) {
return projectOutPutExprIdMap.get(exprId);
}
return (NamedExpression) expr;
})
.collect(Collectors.toList());
finalOutputExpressions = finalOutputExpressions.stream()
.map(expr -> projectOutPutExprIdMap.containsKey(expr.getExprId())
? projectOutPutExprIdMap.get(expr.getExprId()) : expr)
.collect(Collectors.toList());
return new LogicalAggregate(finalGroupExpressions, finalOutputExpressions, mvProject);
}
private boolean isGroupByEquals(Pair<Plan, LogicalAggregate<Plan>> queryTopPlanAndAggPair,
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair,
SlotMapping viewToQuerySlotMapping,
StructInfo queryStructInfo,
StructInfo viewStructInfo) {
Plan queryTopPlan = queryTopPlanAndAggPair.key();
Plan viewTopPlan = viewTopPlanAndAggPair.key();
LogicalAggregate<Plan> queryAggregate = queryTopPlanAndAggPair.value();
LogicalAggregate<Plan> viewAggregate = viewTopPlanAndAggPair.value();
Set<? extends Expression> queryGroupShuttledExpression = new HashSet<>(
ExpressionUtils.shuttleExpressionWithLineage(
queryAggregate.getGroupByExpressions(), queryTopPlan, queryStructInfo.getTableBitSet()));
Set<? extends Expression> viewGroupShuttledExpressionQueryBased = ExpressionUtils.shuttleExpressionWithLineage(
viewAggregate.getGroupByExpressions(), viewTopPlan, viewStructInfo.getTableBitSet())
.stream()
.map(expr -> ExpressionUtils.replace(expr, viewToQuerySlotMapping.toSlotReferenceMap()))
.collect(Collectors.toSet());
return queryGroupShuttledExpression.equals(viewGroupShuttledExpressionQueryBased);
}
/**
* Roll up query aggregate function when query dimension num is less than mv dimension num,
*
* @param queryAggregateFunction query aggregate function to roll up.
* @param queryAggregateFunctionShuttled query aggregate function shuttled by lineage.
* @param mvExprToMvScanExprQueryBased mv def sql output expressions to mv result data output mapping.
* <p>
* Such as query is
* select max(a) + 1 from table group by b.
* mv is
* select max(a) from table group by a, b.
* the queryAggregateFunction is max(a), queryAggregateFunctionShuttled is max(a) + 1
* mvExprToMvScanExprQueryBased is { max(a) : MTMVScan(output#0) }
*/
private static Function rollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
if (!(queryAggregateFunction instanceof CouldRollUp)) {
return null;
}
Expression rollupParam = null;
Expression viewRollupFunction = null;
// handle simple aggregate function roll up which is not in the AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP
if (mvExprToMvScanExprQueryBased.containsKey(queryAggregateFunctionShuttled)
&& AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.keySet().stream()
.noneMatch(aggFunction -> aggFunction.equals(queryAggregateFunction))) {
rollupParam = mvExprToMvScanExprQueryBased.get(queryAggregateFunctionShuttled);
viewRollupFunction = queryAggregateFunctionShuttled;
} else {
// handle complex functions roll up
// eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param))
for (Expression mvExprShuttled : mvExprToMvScanExprQueryBased.keySet()) {
if (!(mvExprShuttled instanceof Function)) {
continue;
}
if (isAggregateFunctionEquivalent(queryAggregateFunction, (Function) mvExprShuttled)) {
rollupParam = mvExprToMvScanExprQueryBased.get(mvExprShuttled);
viewRollupFunction = mvExprShuttled;
}
}
}
if (rollupParam == null || !canRollup(viewRollupFunction)) {
return null;
}
// do roll up
return ((CouldRollUp) queryAggregateFunction).constructRollUp(rollupParam);
}
// Check the aggregate function can roll up or not, return true if could roll up
// if view aggregate function is distinct or is in the un supported rollup functions, it doesn't support
// roll up.
private static boolean canRollup(Expression rollupExpression) {
if (rollupExpression == null) {
return false;
}
if (rollupExpression instanceof Function && !(rollupExpression instanceof AggregateFunction)) {
return false;
}
if (rollupExpression instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) rollupExpression;
return !aggregateFunction.isDistinct() && aggregateFunction instanceof CouldRollUp;
}
return true;
}
private Pair<Set<? extends Expression>, Set<? extends Expression>> topPlanSplitToGroupAndFunction(
Pair<Plan, LogicalAggregate<Plan>> topPlanAndAggPair, StructInfo queryStructInfo) {
LogicalAggregate<Plan> bottomQueryAggregate = topPlanAndAggPair.value();
Set<Expression> groupByExpressionSet = new HashSet<>(bottomQueryAggregate.getGroupByExpressions());
// when query is bitmap_count(bitmap_union), the plan is as following:
// project(bitmap_count()#1)
// aggregate(bitmap_union()#2)
// we should use exprId which query top plan used to decide the query top plan is use the
// bottom agg function or not
Set<ExprId> bottomAggregateFunctionExprIdSet = bottomQueryAggregate.getOutput().stream()
.filter(expr -> !groupByExpressionSet.contains(expr))
.map(NamedExpression::getExprId)
.collect(Collectors.toSet());
Plan queryTopPlan = topPlanAndAggPair.key();
Set<Expression> topGroupByExpressions = new HashSet<>();
Set<Expression> topFunctionExpressions = new HashSet<>();
queryTopPlan.getOutput().forEach(expression -> {
ExpressionLineageReplacer.ExpressionReplaceContext replaceContext =
new ExpressionLineageReplacer.ExpressionReplaceContext(ImmutableList.of(expression),
ImmutableSet.of(), ImmutableSet.of(), queryStructInfo.getTableBitSet());
queryTopPlan.accept(ExpressionLineageReplacer.INSTANCE, replaceContext);
if (!Sets.intersection(bottomAggregateFunctionExprIdSet,
replaceContext.getExprIdExpressionMap().keySet()).isEmpty()) {
// if query top plan expression use any aggregate function, then consider it is aggregate function
topFunctionExpressions.add(expression);
} else {
topGroupByExpressions.add(expression);
}
});
return Pair.of(topGroupByExpressions, topFunctionExpressions);
}
private Pair<Plan, LogicalAggregate<Plan>> splitToTopPlanAndAggregate(StructInfo structInfo) {
Plan topPlan = structInfo.getTopPlan();
PlanSplitContext splitContext = new PlanSplitContext(Sets.newHashSet(LogicalAggregate.class));
topPlan.accept(StructInfo.PLAN_SPLITTER, splitContext);
if (!(splitContext.getBottomPlan() instanceof LogicalAggregate)) {
return null;
} else {
return Pair.of(topPlan, (LogicalAggregate<Plan>) splitContext.getBottomPlan());
}
}
/**
* Check Aggregate is simple or not and check join is whether valid or not.
* Support project, filter, join, logical relation node and join condition should only contain
* slot reference equals currently.
*/
@Override
protected boolean checkPattern(StructInfo structInfo) {
PlanCheckContext checkContext = PlanCheckContext.of(SUPPORTED_JOIN_TYPE_SET);
// if query or mv contains more then one top aggregate, should fail
return structInfo.getTopPlan().accept(StructInfo.PLAN_PATTERN_CHECKER, checkContext)
&& checkContext.isContainsTopAggregate() && checkContext.getTopAggregateNum() <= 1;
}
/**
* Check the queryFunction is equivalent to view function when function roll up.
* Not only check the function name but also check the argument between query and view aggregate function.
* Such as query is
* select count(distinct a) + 1 from table group by b.
* mv is
* select bitmap_union(to_bitmap(a)) from table group by a, b.
* the queryAggregateFunction is count(distinct a), queryAggregateFunctionShuttled is count(distinct a) + 1
* mvExprToMvScanExprQueryBased is { bitmap_union(to_bitmap(a)) : MTMVScan(output#0) }
* This will check the count(distinct a) in query is equivalent to bitmap_union(to_bitmap(a)) in mv,
* and then check their arguments is equivalent.
*/
private static boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
if (queryFunction.equals(viewFunction)) {
return true;
}
// check the argument of rollup function is equivalent to view function or not
for (Map.Entry<Function, Collection<Expression>> equivalentFunctionEntry :
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.asMap().entrySet()) {
if (equivalentFunctionEntry.getKey().equals(queryFunction)) {
// check is have equivalent function or not
for (Expression equivalentFunction : equivalentFunctionEntry.getValue()) {
if (!Any.equals(equivalentFunction, viewFunction)) {
continue;
}
// check param in query function is same as the view function
List<Expression> viewFunctionArguments = extractArguments(equivalentFunction, viewFunction);
List<Expression> queryFunctionArguments =
extractArguments(equivalentFunctionEntry.getKey(), queryFunction);
// check argument size,we only support roll up function which has only one argument currently
if (queryFunctionArguments.size() != 1 || viewFunctionArguments.size() != 1) {
continue;
}
if (Objects.equals(queryFunctionArguments.get(0), viewFunctionArguments.get(0))) {
return true;
}
}
}
}
return false;
}
/**
* Extract the function arguments by functionWithAny pattern
* Such as functionWithAny def is bitmap_union(to_bitmap(Any.INSTANCE)),
* actualFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end))
* after extracting, the return argument is: case when a = 5 then 1 else 2 end
*/
private static List<Expression> extractArguments(Expression functionWithAny, Function actualFunction) {
Set<Object> exprSetToRemove = functionWithAny.collectToSet(expr -> !(expr instanceof Any));
return actualFunction.collectFirst(expr ->
exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr)));
}
/**
* Aggregate expression rewriter which is responsible for rewriting group by and
* aggregate function expression
*/
protected static class AggregateExpressionRewriter
extends DefaultExpressionRewriter<AggregateExpressionRewriteContext> {
@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction,
AggregateExpressionRewriteContext rewriteContext) {
if (!rewriteContext.isValid()) {
return aggregateFunction;
}
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
aggregateFunction,
rewriteContext.getQueryTopPlan(),
rewriteContext.getQueryTableBitSet());
Function rollupAggregateFunction = rollup(aggregateFunction, queryFunctionShuttled,
rewriteContext.getMvExprToMvScanExprQueryBasedMapping());
if (rollupAggregateFunction == null) {
rewriteContext.setValid(false);
return aggregateFunction;
}
return rollupAggregateFunction;
}
@Override
public Expression visitSlot(Slot slot, AggregateExpressionRewriteContext rewriteContext) {
if (!rewriteContext.isValid()) {
return slot;
}
if (rewriteContext.getMvExprToMvScanExprQueryBasedMapping().containsKey(slot)) {
return rewriteContext.getMvExprToMvScanExprQueryBasedMapping().get(slot);
}
rewriteContext.setValid(false);
return slot;
}
@Override
public Expression visit(Expression expr, AggregateExpressionRewriteContext rewriteContext) {
if (!rewriteContext.isValid()) {
return expr;
}
// for group by expression try to get corresponding expression directly
if (rewriteContext.isOnlyContainGroupByExpression()
&& rewriteContext.getMvExprToMvScanExprQueryBasedMapping().containsKey(expr)) {
return rewriteContext.getMvExprToMvScanExprQueryBasedMapping().get(expr);
}
List<Expression> newChildren = new ArrayList<>(expr.arity());
boolean hasNewChildren = false;
for (Expression child : expr.children()) {
Expression newChild = child.accept(this, rewriteContext);
if (!rewriteContext.isValid()) {
return expr;
}
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? expr.withChildren(newChildren) : expr;
}
}
/**
* AggregateExpressionRewriteContext
*/
protected static class AggregateExpressionRewriteContext {
private boolean valid = true;
private final boolean onlyContainGroupByExpression;
private final Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping;
private final Plan queryTopPlan;
private final BitSet queryTableBitSet;
public AggregateExpressionRewriteContext(boolean onlyContainGroupByExpression,
Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping, Plan queryTopPlan,
BitSet queryTableBitSet) {
this.onlyContainGroupByExpression = onlyContainGroupByExpression;
this.mvExprToMvScanExprQueryBasedMapping = mvExprToMvScanExprQueryBasedMapping;
this.queryTopPlan = queryTopPlan;
this.queryTableBitSet = queryTableBitSet;
}
public boolean isValid() {
return valid;
}
public void setValid(boolean valid) {
this.valid = valid;
}
public boolean isOnlyContainGroupByExpression() {
return onlyContainGroupByExpression;
}
public Map<Expression, Expression> getMvExprToMvScanExprQueryBasedMapping() {
return mvExprToMvScanExprQueryBasedMapping;
}
public Plan getQueryTopPlan() {
return queryTopPlan;
}
public BitSet getQueryTableBitSet() {
return queryTableBitSet;
}
}
}