blob: 7a721a18133c6fbce15b3f20f8cd9ad8f5352aa7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.rel;
import com.google.common.base.Preconditions;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.query.DataSource;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
/**
* Builder for a Druid query, not counting the "dataSource" (which will be slotted in later).
*/
public class PartialDruidQuery
{
private final Supplier<RelBuilder> builderSupplier;
private final RelNode scan;
private final Filter whereFilter;
private final Project selectProject;
private final Aggregate aggregate;
private final Filter havingFilter;
private final Project aggregateProject;
private final Sort sort;
private final Project sortProject;
public enum Stage
{
// SCAN must be present on all queries.
SCAN,
// WHERE_FILTER, SELECT_PROJECT may be present on any query.
WHERE_FILTER,
SELECT_PROJECT,
// AGGREGATE, HAING_FILTER, AGGREGATE_PROJECT can only be present on aggregating queries.
AGGREGATE,
HAVING_FILTER,
AGGREGATE_PROJECT,
// SORT, SORT_PROJECT may be present on any query.
SORT,
SORT_PROJECT
}
private PartialDruidQuery(
final Supplier<RelBuilder> builderSupplier,
final RelNode scan,
final Filter whereFilter,
final Project selectProject,
final Aggregate aggregate,
final Project aggregateProject,
final Filter havingFilter,
final Sort sort,
final Project sortProject
)
{
this.builderSupplier = Preconditions.checkNotNull(builderSupplier, "builderSupplier");
this.scan = Preconditions.checkNotNull(scan, "scan");
this.whereFilter = whereFilter;
this.selectProject = selectProject;
this.aggregate = aggregate;
this.aggregateProject = aggregateProject;
this.havingFilter = havingFilter;
this.sort = sort;
this.sortProject = sortProject;
}
public static PartialDruidQuery create(final RelNode scanRel)
{
final Supplier<RelBuilder> builderSupplier = () -> RelFactories.LOGICAL_BUILDER.create(
scanRel.getCluster(),
scanRel.getTable() != null ? scanRel.getTable().getRelOptSchema() : null
);
return new PartialDruidQuery(builderSupplier, scanRel, null, null, null, null, null, null, null);
}
public RelNode getScan()
{
return scan;
}
public Filter getWhereFilter()
{
return whereFilter;
}
public Project getSelectProject()
{
return selectProject;
}
public Aggregate getAggregate()
{
return aggregate;
}
public Filter getHavingFilter()
{
return havingFilter;
}
public Project getAggregateProject()
{
return aggregateProject;
}
public Sort getSort()
{
return sort;
}
public Project getSortProject()
{
return sortProject;
}
public PartialDruidQuery withWhereFilter(final Filter newWhereFilter)
{
validateStage(Stage.WHERE_FILTER);
return new PartialDruidQuery(
builderSupplier,
scan,
newWhereFilter,
selectProject,
aggregate,
aggregateProject,
havingFilter,
sort,
sortProject
);
}
public PartialDruidQuery withSelectProject(final Project newSelectProject)
{
validateStage(Stage.SELECT_PROJECT);
// Possibly merge together two projections.
final Project theProject;
if (selectProject == null) {
theProject = newSelectProject;
} else {
final List<RexNode> newProjectRexNodes = RelOptUtil.pushPastProject(
newSelectProject.getProjects(),
selectProject
);
if (RexUtil.isIdentity(newProjectRexNodes, selectProject.getInput().getRowType())) {
// The projection is gone.
theProject = null;
} else {
final RelBuilder relBuilder = builderSupplier.get();
relBuilder.push(selectProject.getInput());
relBuilder.project(
newProjectRexNodes,
newSelectProject.getRowType().getFieldNames()
);
theProject = (Project) relBuilder.build();
}
}
return new PartialDruidQuery(
builderSupplier,
scan,
whereFilter,
theProject,
aggregate,
aggregateProject,
havingFilter,
sort,
sortProject
);
}
public PartialDruidQuery withAggregate(final Aggregate newAggregate)
{
validateStage(Stage.AGGREGATE);
return new PartialDruidQuery(
builderSupplier,
scan,
whereFilter,
selectProject,
newAggregate,
aggregateProject,
havingFilter,
sort,
sortProject
);
}
public PartialDruidQuery withHavingFilter(final Filter newHavingFilter)
{
validateStage(Stage.HAVING_FILTER);
return new PartialDruidQuery(
builderSupplier,
scan,
whereFilter,
selectProject,
aggregate,
aggregateProject,
newHavingFilter,
sort,
sortProject
);
}
public PartialDruidQuery withAggregateProject(final Project newAggregateProject)
{
validateStage(Stage.AGGREGATE_PROJECT);
return new PartialDruidQuery(
builderSupplier,
scan,
whereFilter,
selectProject,
aggregate,
newAggregateProject,
havingFilter,
sort,
sortProject
);
}
public PartialDruidQuery withSort(final Sort newSort)
{
validateStage(Stage.SORT);
return new PartialDruidQuery(
builderSupplier,
scan,
whereFilter,
selectProject,
aggregate,
aggregateProject,
havingFilter,
newSort,
sortProject
);
}
public PartialDruidQuery withSortProject(final Project newSortProject)
{
validateStage(Stage.SORT_PROJECT);
return new PartialDruidQuery(
builderSupplier,
scan,
whereFilter,
selectProject,
aggregate,
aggregateProject,
havingFilter,
sort,
newSortProject
);
}
public RelDataType getRowType()
{
return leafRel().getRowType();
}
public RelTrait[] getRelTraits()
{
return leafRel().getTraitSet().toArray(new RelTrait[0]);
}
public DruidQuery build(
final DataSource dataSource,
final RowSignature sourceRowSignature,
final PlannerContext plannerContext,
final RexBuilder rexBuilder,
final boolean finalizeAggregations
)
{
return DruidQuery.fromPartialQuery(
this,
dataSource,
sourceRowSignature,
plannerContext,
rexBuilder,
finalizeAggregations
);
}
public boolean canAccept(final Stage stage)
{
final Stage currentStage = stage();
if (currentStage == Stage.SELECT_PROJECT && stage == Stage.SELECT_PROJECT) {
// Special case: allow layering SELECT_PROJECT on top of SELECT_PROJECT. Calcite's builtin rules cannot
// always collapse these, so we have to (one example: testSemiJoinWithOuterTimeExtract). See
// withSelectProject for the code here that handles this.
return true;
} else if (stage.compareTo(currentStage) <= 0) {
// Cannot go backwards.
return false;
} else if (stage.compareTo(Stage.AGGREGATE) > 0 && stage.compareTo(Stage.SORT) < 0 && aggregate == null) {
// Cannot do post-aggregation stages without an aggregation.
return false;
} else if (stage.compareTo(Stage.SORT) > 0 && sort == null) {
// Cannot do post-sort stages without a sort.
return false;
} else {
// Looks good.
return true;
}
}
/**
* Returns the stage corresponding to the rel at the end of the query. It will match the rel returned from
* {@link #leafRel()}.
*
* @return stage
*/
@SuppressWarnings("VariableNotUsedInsideIf")
public Stage stage()
{
if (sortProject != null) {
return Stage.SORT_PROJECT;
} else if (sort != null) {
return Stage.SORT;
} else if (aggregateProject != null) {
return Stage.AGGREGATE_PROJECT;
} else if (havingFilter != null) {
return Stage.HAVING_FILTER;
} else if (aggregate != null) {
return Stage.AGGREGATE;
} else if (selectProject != null) {
return Stage.SELECT_PROJECT;
} else if (whereFilter != null) {
return Stage.WHERE_FILTER;
} else {
return Stage.SCAN;
}
}
/**
* Returns the rel at the end of the query. It will match the stage returned from {@link #stage()}.
*
* @return leaf rel
*/
public RelNode leafRel()
{
final Stage currentStage = stage();
switch (currentStage) {
case SORT_PROJECT:
return sortProject;
case SORT:
return sort;
case AGGREGATE_PROJECT:
return aggregateProject;
case HAVING_FILTER:
return havingFilter;
case AGGREGATE:
return aggregate;
case SELECT_PROJECT:
return selectProject;
case WHERE_FILTER:
return whereFilter;
case SCAN:
return scan;
default:
throw new ISE("Unknown stage: %s", currentStage);
}
}
/**
* Estimates the per-row cost of running this query.
*/
public double estimateCost()
{
double cost = CostEstimates.COST_BASE;
if (getSelectProject() != null) {
for (final RexNode rexNode : getSelectProject().getChildExps()) {
if (rexNode.isA(SqlKind.INPUT_REF)) {
cost += CostEstimates.COST_COLUMN_READ;
} else {
cost += CostEstimates.COST_EXPRESSION;
}
}
}
if (getWhereFilter() != null) {
// We assume filters are free and have a selectivity of CostEstimates.MULTIPLIER_FILTER. They aren't actually
// free, but we want to encourage filters, so let's go with it.
cost *= CostEstimates.MULTIPLIER_FILTER;
}
if (getAggregate() != null) {
if (getSelectProject() == null) {
// No projection before aggregation, that means the aggregate operator is reading things directly.
// Account for the costs.
cost += CostEstimates.COST_COLUMN_READ * getAggregate().getGroupSet().size();
}
cost += CostEstimates.COST_DIMENSION * getAggregate().getGroupSet().size();
cost += CostEstimates.COST_AGGREGATION * getAggregate().getAggCallList().size();
}
if (getSort() != null) {
if (!getSort().collation.getFieldCollations().isEmpty()) {
cost *= CostEstimates.MULTIPLIER_ORDER_BY;
}
if (getSort().fetch != null) {
cost *= CostEstimates.MULTIPLIER_LIMIT;
}
}
if (getAggregateProject() != null) {
cost += CostEstimates.COST_EXPRESSION * getAggregateProject().getChildExps().size();
}
if (getSortProject() != null) {
cost += CostEstimates.COST_EXPRESSION * getSortProject().getChildExps().size();
}
return cost;
}
private void validateStage(final Stage stage)
{
if (!canAccept(stage)) {
throw new ISE("Cannot move from stage[%s] to stage[%s]", stage(), stage);
}
}
@Override
public boolean equals(final Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final PartialDruidQuery that = (PartialDruidQuery) o;
return Objects.equals(scan, that.scan) &&
Objects.equals(whereFilter, that.whereFilter) &&
Objects.equals(selectProject, that.selectProject) &&
Objects.equals(aggregate, that.aggregate) &&
Objects.equals(havingFilter, that.havingFilter) &&
Objects.equals(aggregateProject, that.aggregateProject) &&
Objects.equals(sort, that.sort) &&
Objects.equals(sortProject, that.sortProject);
}
@Override
public int hashCode()
{
return Objects.hash(
scan,
whereFilter,
selectProject,
aggregate,
havingFilter,
aggregateProject,
sort,
sortProject
);
}
@Override
public String toString()
{
return "PartialDruidQuery{" +
"scan=" + scan +
", whereFilter=" + whereFilter +
", selectProject=" + selectProject +
", aggregate=" + aggregate +
", havingFilter=" + havingFilter +
", aggregateProject=" + aggregateProject +
", sort=" + sort +
", sortProject=" + sortProject +
'}';
}
}