blob: 6b449ed8187bf36c244a0e2f69688400700774b5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.rule;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.sql.calcite.rel.DruidOuterQueryRel;
import org.apache.druid.sql.calcite.rel.DruidRel;
import org.apache.druid.sql.calcite.rel.PartialDruidQuery;
import java.util.List;
import java.util.function.BiFunction;
public class DruidRules
{
public static final Predicate<DruidRel> CAN_BUILD_ON = druidRel -> druidRel.getPartialDruidQuery() != null;
private DruidRules()
{
// No instantiation.
}
public static List<RelOptRule> rules()
{
return ImmutableList.of(
new DruidQueryRule<>(
Filter.class,
PartialDruidQuery.Stage.WHERE_FILTER,
PartialDruidQuery::withWhereFilter
),
new DruidQueryRule<>(
Project.class,
PartialDruidQuery.Stage.SELECT_PROJECT,
PartialDruidQuery::withSelectProject
),
new DruidQueryRule<>(
Aggregate.class,
PartialDruidQuery.Stage.AGGREGATE,
PartialDruidQuery::withAggregate
),
new DruidQueryRule<>(
Project.class,
PartialDruidQuery.Stage.AGGREGATE_PROJECT,
PartialDruidQuery::withAggregateProject
),
new DruidQueryRule<>(
Filter.class,
PartialDruidQuery.Stage.HAVING_FILTER,
PartialDruidQuery::withHavingFilter
),
new DruidQueryRule<>(
Sort.class,
PartialDruidQuery.Stage.SORT,
PartialDruidQuery::withSort
),
new DruidQueryRule<>(
Project.class,
PartialDruidQuery.Stage.SORT_PROJECT,
PartialDruidQuery::withSortProject
),
DruidOuterQueryRule.AGGREGATE,
DruidOuterQueryRule.FILTER_AGGREGATE,
DruidOuterQueryRule.FILTER_PROJECT_AGGREGATE,
DruidOuterQueryRule.PROJECT_AGGREGATE,
DruidOuterQueryRule.AGGREGATE_SORT_PROJECT,
DruidUnionRule.instance(),
DruidUnionDataSourceRule.instance(),
DruidSortUnionRule.instance(),
DruidJoinRule.instance()
);
}
public static class DruidQueryRule<RelType extends RelNode> extends RelOptRule
{
private final PartialDruidQuery.Stage stage;
private final BiFunction<PartialDruidQuery, RelType, PartialDruidQuery> f;
public DruidQueryRule(
final Class<RelType> relClass,
final PartialDruidQuery.Stage stage,
final BiFunction<PartialDruidQuery, RelType, PartialDruidQuery> f
)
{
super(
operand(relClass, operand(DruidRel.class, null, CAN_BUILD_ON, any())),
StringUtils.format("%s(%s)", DruidQueryRule.class.getSimpleName(), stage)
);
this.stage = stage;
this.f = f;
}
@Override
public boolean matches(final RelOptRuleCall call)
{
final DruidRel druidRel = call.rel(1);
return druidRel.getPartialDruidQuery().canAccept(stage);
}
@Override
public void onMatch(final RelOptRuleCall call)
{
final RelType otherRel = call.rel(0);
final DruidRel druidRel = call.rel(1);
final PartialDruidQuery newPartialDruidQuery = f.apply(druidRel.getPartialDruidQuery(), otherRel);
final DruidRel newDruidRel = druidRel.withPartialQuery(newPartialDruidQuery);
if (newDruidRel.isValidDruidQuery()) {
call.transformTo(newDruidRel);
}
}
}
public abstract static class DruidOuterQueryRule extends RelOptRule
{
public static final RelOptRule AGGREGATE = new DruidOuterQueryRule(
operand(Aggregate.class, operand(DruidRel.class, null, CAN_BUILD_ON, any())),
"AGGREGATE"
)
{
@Override
public void onMatch(final RelOptRuleCall call)
{
final Aggregate aggregate = call.rel(0);
final DruidRel druidRel = call.rel(1);
final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
druidRel,
PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
.withAggregate(aggregate)
);
if (outerQueryRel.isValidDruidQuery()) {
call.transformTo(outerQueryRel);
}
}
};
public static final RelOptRule FILTER_AGGREGATE = new DruidOuterQueryRule(
operand(Aggregate.class, operand(Filter.class, operand(DruidRel.class, null, CAN_BUILD_ON, any()))),
"FILTER_AGGREGATE"
)
{
@Override
public void onMatch(final RelOptRuleCall call)
{
final Aggregate aggregate = call.rel(0);
final Filter filter = call.rel(1);
final DruidRel druidRel = call.rel(2);
final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
druidRel,
PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
.withWhereFilter(filter)
.withAggregate(aggregate)
);
if (outerQueryRel.isValidDruidQuery()) {
call.transformTo(outerQueryRel);
}
}
};
public static final RelOptRule FILTER_PROJECT_AGGREGATE = new DruidOuterQueryRule(
operand(
Aggregate.class,
operand(Project.class, operand(Filter.class, operand(DruidRel.class, null, CAN_BUILD_ON, any())))
),
"FILTER_PROJECT_AGGREGATE"
)
{
@Override
public void onMatch(final RelOptRuleCall call)
{
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);
final Filter filter = call.rel(2);
final DruidRel druidRel = call.rel(3);
final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
druidRel,
PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
.withWhereFilter(filter)
.withSelectProject(project)
.withAggregate(aggregate)
);
if (outerQueryRel.isValidDruidQuery()) {
call.transformTo(outerQueryRel);
}
}
};
public static final RelOptRule PROJECT_AGGREGATE = new DruidOuterQueryRule(
operand(Aggregate.class, operand(Project.class, operand(DruidRel.class, null, CAN_BUILD_ON, any()))),
"PROJECT_AGGREGATE"
)
{
@Override
public void onMatch(final RelOptRuleCall call)
{
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);
final DruidRel druidRel = call.rel(2);
final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
druidRel,
PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
.withSelectProject(project)
.withAggregate(aggregate)
);
if (outerQueryRel.isValidDruidQuery()) {
call.transformTo(outerQueryRel);
}
}
};
public static final RelOptRule AGGREGATE_SORT_PROJECT = new DruidOuterQueryRule(
operand(
Project.class,
operand(Sort.class, operand(Aggregate.class, operand(DruidRel.class, null, CAN_BUILD_ON, any())))
),
"AGGREGATE_SORT_PROJECT"
)
{
@Override
public void onMatch(RelOptRuleCall call)
{
final Project sortProject = call.rel(0);
final Sort sort = call.rel(1);
final Aggregate aggregate = call.rel(2);
final DruidRel druidRel = call.rel(3);
final DruidOuterQueryRel outerQueryRel = DruidOuterQueryRel.create(
druidRel,
PartialDruidQuery.create(druidRel.getPartialDruidQuery().leafRel())
.withAggregate(aggregate)
.withSort(sort)
.withSortProject(sortProject)
);
if (outerQueryRel.isValidDruidQuery()) {
call.transformTo(outerQueryRel);
}
}
};
public DruidOuterQueryRule(final RelOptRuleOperand op, final String description)
{
super(op, StringUtils.format("%s(%s)", DruidOuterQueryRel.class.getSimpleName(), description));
}
@Override
public boolean matches(final RelOptRuleCall call)
{
// Subquery must be a groupBy, so stage must be >= AGGREGATE.
final DruidRel druidRel = call.rel(call.getRelList().size() - 1);
return druidRel.getPartialDruidQuery().stage().compareTo(PartialDruidQuery.Stage.AGGREGATE) >= 0;
}
}
}