| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| package org.apache.doris.nereids.stats; |
| |
| import org.apache.doris.catalog.Env; |
| import org.apache.doris.catalog.MaterializedIndex; |
| import org.apache.doris.catalog.OlapTable; |
| import org.apache.doris.catalog.Partition; |
| import org.apache.doris.catalog.Table; |
| import org.apache.doris.common.Id; |
| import org.apache.doris.common.Pair; |
| import org.apache.doris.nereids.memo.GroupExpression; |
| import org.apache.doris.nereids.trees.expressions.NamedExpression; |
| import org.apache.doris.nereids.trees.expressions.Slot; |
| import org.apache.doris.nereids.trees.expressions.SlotReference; |
| import org.apache.doris.nereids.trees.plans.Plan; |
| import org.apache.doris.nereids.trees.plans.algebra.Aggregate; |
| import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation; |
| import org.apache.doris.nereids.trees.plans.algebra.Filter; |
| import org.apache.doris.nereids.trees.plans.algebra.Limit; |
| import org.apache.doris.nereids.trees.plans.algebra.OneRowRelation; |
| import org.apache.doris.nereids.trees.plans.algebra.Project; |
| import org.apache.doris.nereids.trees.plans.algebra.Repeat; |
| import org.apache.doris.nereids.trees.plans.algebra.Scan; |
| import org.apache.doris.nereids.trees.plans.algebra.TopN; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalProject; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalSort; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation; |
| import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalEmptyRelation; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalLocalQuickSort; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalOneRowRelation; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalRepeat; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalTVFRelation; |
| import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; |
| import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor; |
| import org.apache.doris.qe.ConnectContext; |
| import org.apache.doris.statistics.ColumnStatistic; |
| import org.apache.doris.statistics.ColumnStatisticBuilder; |
| import org.apache.doris.statistics.StatsDeriveResult; |
| |
| import com.google.common.collect.Maps; |
| |
| import java.util.AbstractMap.SimpleEntry; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.stream.Collectors; |
| |
| /** |
| * Used to calculate the stats for each plan |
| */ |
| public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void> { |
| |
| private static final int DEFAULT_AGGREGATE_RATIO = 1000; |
| |
| private final GroupExpression groupExpression; |
| |
| private StatsCalculator(GroupExpression groupExpression) { |
| this.groupExpression = groupExpression; |
| } |
| |
| /** |
| * estimate stats |
| */ |
| public static void estimate(GroupExpression groupExpression) { |
| if (ConnectContext.get() != null && ConnectContext.get().getSessionVariable().enableNereidsStatsDeriveV2) { |
| StatsCalculatorV2.estimate(groupExpression); |
| return; |
| } |
| StatsCalculator statsCalculator = new StatsCalculator(groupExpression); |
| statsCalculator.estimate(); |
| } |
| |
| private void estimate() { |
| StatsDeriveResult stats = groupExpression.getPlan().accept(this, null); |
| /* |
| in an ideal cost model, every group expression in a group are equivalent, but in fact the cost are different. |
| we record the lowest expression cost as group cost to avoid missing this group. |
| */ |
| if (groupExpression.getOwnerGroup().getStatistics() == null |
| || (stats.getRowCount() < groupExpression.getOwnerGroup().getStatistics().getRowCount())) { |
| groupExpression.getOwnerGroup().setStatistics(stats); |
| } |
| groupExpression.setStatDerived(true); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalEmptyRelation(LogicalEmptyRelation emptyRelation, Void context) { |
| return computeEmptyRelation(emptyRelation); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalLimit(LogicalLimit<? extends Plan> limit, Void context) { |
| return computeLimit(limit); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, Void context) { |
| return computeLimit(limit); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalOneRowRelation(LogicalOneRowRelation oneRowRelation, Void context) { |
| return computeOneRowRelation(oneRowRelation); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) { |
| return computeAggregate(aggregate); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, Void context) { |
| return computeRepeat(repeat); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) { |
| return computeFilter(filter); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalOlapScan(LogicalOlapScan olapScan, Void context) { |
| olapScan.getExpressions(); |
| return computeScan(olapScan); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalTVFRelation(LogicalTVFRelation tvfRelation, Void context) { |
| return tvfRelation.getFunction().computeStats(tvfRelation.getOutput()); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalProject(LogicalProject<? extends Plan> project, Void context) { |
| return computeProject(project); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalSort(LogicalSort<? extends Plan> sort, Void context) { |
| return groupExpression.childStatistics(0); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalTopN(LogicalTopN<? extends Plan> topN, Void context) { |
| return computeTopN(topN); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) { |
| return JoinEstimation.estimate(groupExpression.childStatistics(0), |
| groupExpression.childStatistics(1), join); |
| } |
| |
| @Override |
| public StatsDeriveResult visitLogicalAssertNumRows( |
| LogicalAssertNumRows<? extends Plan> assertNumRows, Void context) { |
| return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows()); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalEmptyRelation(PhysicalEmptyRelation emptyRelation, Void context) { |
| return computeEmptyRelation(emptyRelation); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalAggregate(PhysicalAggregate<? extends Plan> agg, Void context) { |
| return computeAggregate(agg); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalRepeat(PhysicalRepeat<? extends Plan> repeat, Void context) { |
| return computeRepeat(repeat); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalOneRowRelation(PhysicalOneRowRelation oneRowRelation, Void context) { |
| return computeOneRowRelation(oneRowRelation); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalOlapScan(PhysicalOlapScan olapScan, Void context) { |
| return computeScan(olapScan); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalTVFRelation(PhysicalTVFRelation tvfRelation, Void context) { |
| return tvfRelation.getFunction().computeStats(tvfRelation.getOutput()); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalQuickSort(PhysicalQuickSort<? extends Plan> sort, Void context) { |
| return groupExpression.childStatistics(0); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, Void context) { |
| return computeTopN(topN); |
| } |
| |
| public StatsDeriveResult visitPhysicalLocalQuickSort(PhysicalLocalQuickSort<? extends Plan> sort, Void context) { |
| return groupExpression.childStatistics(0); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalHashJoin( |
| PhysicalHashJoin<? extends Plan, ? extends Plan> hashJoin, Void context) { |
| return JoinEstimation.estimate(groupExpression.childStatistics(0), |
| groupExpression.childStatistics(1), hashJoin); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalNestedLoopJoin( |
| PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin, |
| Void context) { |
| return JoinEstimation.estimate(groupExpression.childStatistics(0), |
| groupExpression.childStatistics(1), nestedLoopJoin); |
| } |
| |
| // TODO: We should subtract those pruned column, and consider the expression transformations in the node. |
| @Override |
| public StatsDeriveResult visitPhysicalProject(PhysicalProject<? extends Plan> project, Void context) { |
| return computeProject(project); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, Void context) { |
| return computeFilter(filter); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalDistribute(PhysicalDistribute<? extends Plan> distribute, |
| Void context) { |
| return groupExpression.childStatistics(0); |
| } |
| |
| @Override |
| public StatsDeriveResult visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows, |
| Void context) { |
| return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows()); |
| } |
| |
| private StatsDeriveResult computeAssertNumRows(long desiredNumOfRows) { |
| StatsDeriveResult statsDeriveResult = groupExpression.childStatistics(0); |
| return statsDeriveResult.updateRowCountByLimit(1); |
| } |
| |
| private StatsDeriveResult computeFilter(Filter filter) { |
| StatsDeriveResult stats = groupExpression.childStatistics(0); |
| FilterSelectivityCalculator selectivityCalculator = |
| new FilterSelectivityCalculator(stats.getSlotIdToColumnStats()); |
| double selectivity = selectivityCalculator.estimate(filter.getPredicates()); |
| stats = stats.updateBySelectivity(selectivity, |
| filter.getPredicates().getInputSlots().stream().map(Slot::getExprId).collect(Collectors.toSet())); |
| stats.isReduced = selectivity < 1.0; |
| return stats; |
| } |
| |
| // TODO: 1. Subtract the pruned partition |
| // 2. Consider the influence of runtime filter |
| // 3. Get NDV and column data size from StatisticManger, StatisticManager doesn't support it now. |
| private StatsDeriveResult computeScan(Scan scan) { |
| Set<SlotReference> slotSet = scan.getOutput().stream().filter(SlotReference.class::isInstance) |
| .map(s -> (SlotReference) s).collect(Collectors.toSet()); |
| Map<Id, ColumnStatistic> columnStatisticMap = new HashMap<>(); |
| Table table = scan.getTable(); |
| double rowCount = Double.NaN; |
| long card = -1; |
| for (SlotReference slotReference : slotSet) { |
| String colName = slotReference.getName(); |
| if (colName == null) { |
| throw new RuntimeException("Column name of SlotReference shouldn't be null here"); |
| } |
| ColumnStatistic statistic = |
| Env.getCurrentEnv().getStatisticsCache().getColumnStatistics(table.getId(), colName); |
| if (statistic == ColumnStatistic.DEFAULT) { |
| if (card == -1) { |
| card = roughlyEstimatedCard(scan); |
| } |
| statistic = new ColumnStatisticBuilder(ColumnStatistic.DEFAULT).setCount(card).build(); |
| } |
| rowCount = statistic.count; |
| columnStatisticMap.put(slotReference.getExprId(), statistic); |
| } |
| return new StatsDeriveResult(rowCount, columnStatisticMap); |
| } |
| |
| private long roughlyEstimatedCard(Scan scan) { |
| long cardinality = 0; |
| if (scan instanceof PhysicalOlapScan || scan instanceof LogicalOlapScan) { |
| OlapTable table = (OlapTable) scan.getTable(); |
| for (long selectedPartitionId : table.getPartitionIds()) { |
| final Partition partition = table.getPartition(selectedPartitionId); |
| final MaterializedIndex baseIndex = partition.getBaseIndex(); |
| cardinality += baseIndex.getRowCount(); |
| } |
| } |
| return cardinality; |
| } |
| |
| private StatsDeriveResult computeTopN(TopN topN) { |
| StatsDeriveResult stats = groupExpression.childStatistics(0); |
| return stats.updateRowCountByLimit(topN.getLimit()); |
| } |
| |
| private StatsDeriveResult computeLimit(Limit limit) { |
| StatsDeriveResult stats = groupExpression.childStatistics(0); |
| return stats.updateRowCountByLimit(limit.getLimit()); |
| } |
| |
| private StatsDeriveResult computeAggregate(Aggregate aggregate) { |
| // TODO: since we have no column stats here. just use a fix ratio to compute the row count. |
| // List<Expression> groupByExpressions = aggregate.getGroupByExpressions(); |
| StatsDeriveResult childStats = groupExpression.childStatistics(0); |
| // Map<Slot, ColumnStats> childSlotToColumnStats = childStats.getSlotToColumnStats(); |
| // long resultSetCount = groupByExpressions.stream() |
| // .flatMap(expr -> expr.getInputSlots().stream()) |
| // .filter(childSlotToColumnStats::containsKey) |
| // .map(childSlotToColumnStats::get) |
| // .map(ColumnStats::getNdv) |
| // .reduce(1L, (a, b) -> a * b); |
| long resultSetCount = (long) childStats.getRowCount() / DEFAULT_AGGREGATE_RATIO; |
| if (resultSetCount <= 0) { |
| resultSetCount = 1L; |
| } |
| |
| Map<Id, ColumnStatistic> slotToColumnStats = Maps.newHashMap(); |
| List<NamedExpression> outputExpressions = aggregate.getOutputExpressions(); |
| // TODO: 1. Estimate the output unit size by the type of corresponding AggregateFunction |
| // 2. Handle alias, literal in the output expression list |
| for (NamedExpression outputExpression : outputExpressions) { |
| slotToColumnStats.put(outputExpression.toSlot().getExprId(), ColumnStatistic.DEFAULT); |
| } |
| StatsDeriveResult statsDeriveResult = new StatsDeriveResult(resultSetCount, childStats.getWidth(), |
| childStats.getPenalty(), slotToColumnStats); |
| statsDeriveResult.isReduced = true; |
| // TODO: Update ColumnStats properly, add new mapping from output slot to ColumnStats |
| return statsDeriveResult; |
| } |
| |
| private StatsDeriveResult computeRepeat(Repeat repeat) { |
| StatsDeriveResult childStats = groupExpression.childStatistics(0); |
| Map<Id, ColumnStatistic> slotIdToColumnStats = childStats.getSlotIdToColumnStats(); |
| int groupingSetNum = repeat.getGroupingSets().size(); |
| double rowCount = childStats.getRowCount(); |
| Map<Id, ColumnStatistic> columnStatisticMap = slotIdToColumnStats.entrySet() |
| .stream().map(kv -> { |
| ColumnStatistic stats = kv.getValue(); |
| return Pair.of(kv.getKey(), new ColumnStatistic( |
| stats.count < 0 ? stats.count : stats.count * groupingSetNum, |
| stats.ndv, |
| stats.avgSizeByte, |
| stats.numNulls < 0 ? stats.numNulls : stats.numNulls * groupingSetNum, |
| stats.dataSize < 0 ? stats.dataSize : stats.dataSize * groupingSetNum, |
| stats.minValue, |
| stats.maxValue, |
| stats.selectivity, |
| stats.minExpr, |
| stats.maxExpr |
| )); |
| }).collect(Collectors.toMap(Pair::key, Pair::value)); |
| return new StatsDeriveResult(rowCount < 0 ? rowCount : rowCount * groupingSetNum, columnStatisticMap); |
| } |
| |
| // TODO: do real project on column stats |
| private StatsDeriveResult computeProject(Project project) { |
| List<NamedExpression> projections = project.getProjects(); |
| StatsDeriveResult childStats = groupExpression.childStatistics(0); |
| Map<Id, ColumnStatistic> childColumnStats = childStats.getSlotIdToColumnStats(); |
| Map<Id, ColumnStatistic> columnsStats = projections.stream().map(projection -> { |
| ColumnStatistic value = null; |
| Set<Slot> slots = projection.getInputSlots(); |
| if (slots.isEmpty()) { |
| value = ColumnStatistic.DEFAULT; |
| } else { |
| // TODO: just a trick here, need to do real project on column stats |
| for (Slot slot : slots) { |
| if (childColumnStats.containsKey(slot.getExprId())) { |
| value = childColumnStats.get(slot.getExprId()); |
| break; |
| } |
| } |
| if (value == null) { |
| value = ColumnStatistic.DEFAULT; |
| } |
| } |
| return new SimpleEntry<>(projection.toSlot().getExprId(), value); |
| }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (item1, item2) -> item1)); |
| return new StatsDeriveResult(childStats.getRowCount(), childStats.getWidth(), |
| childStats.getPenalty(), columnsStats); |
| } |
| |
| private StatsDeriveResult computeOneRowRelation(OneRowRelation oneRowRelation) { |
| Map<Id, ColumnStatistic> columnStatsMap = oneRowRelation.getProjects() |
| .stream() |
| .map(project -> { |
| ColumnStatistic statistic = new ColumnStatisticBuilder().setNdv(1).build(); |
| // TODO: compute the literal size |
| return Pair.of(project.toSlot().getExprId(), statistic); |
| }) |
| .collect(Collectors.toMap(Pair::key, Pair::value)); |
| int rowCount = 1; |
| return new StatsDeriveResult(rowCount, columnStatsMap); |
| } |
| |
| private StatsDeriveResult computeEmptyRelation(EmptyRelation emptyRelation) { |
| Map<Id, ColumnStatistic> columnStatsMap = emptyRelation.getProjects() |
| .stream() |
| .map(project -> { |
| ColumnStatisticBuilder builder = new ColumnStatisticBuilder() |
| .setNdv(0) |
| .setNumNulls(0) |
| .setAvgSizeByte(0); |
| return Pair.of(project.toSlot().getExprId(), builder.build()); |
| }) |
| .collect(Collectors.toMap(Pair::key, Pair::value)); |
| int rowCount = 0; |
| return new StatsDeriveResult(rowCount, columnStatsMap); |
| } |
| } |