blob: 28f13e87f3ea72f9a59d3d686778e30cd043fc9f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.rel.rules;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalTableScan;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.immutables.value.Value;
import java.util.ArrayList;
import java.util.List;
/**
* Rule to extract a {@link org.apache.calcite.rel.core.Project}
* from an {@link org.apache.calcite.rel.core.Aggregate}
* and push it down towards the input.
*
* <p>What projections can be safely pushed down depends upon which fields the
* Aggregate uses.
*
* <p>To prevent cycles, this rule will not extract a {@code Project} if the
* {@code Aggregate}s input is already a {@code Project}.
*/
@Value.Enclosing
public class AggregateExtractProjectRule
extends RelRule<AggregateExtractProjectRule.Config>
implements TransformationRule {
public static final AggregateExtractProjectRule SCAN =
Config.DEFAULT.toRule();
/** Creates an AggregateExtractProjectRule. */
protected AggregateExtractProjectRule(Config config) {
super(config);
}
@Deprecated // to be removed before 2.0
public AggregateExtractProjectRule(
Class<? extends Aggregate> aggregateClass,
Class<? extends RelNode> inputClass,
RelBuilderFactory relBuilderFactory) {
this(Config.DEFAULT
.withRelBuilderFactory(relBuilderFactory)
.as(Config.class)
.withOperandFor(aggregateClass, inputClass));
}
@Deprecated // to be removed before 2.0
public AggregateExtractProjectRule(RelOptRuleOperand operand,
RelBuilderFactory builderFactory) {
this(Config.DEFAULT
.withRelBuilderFactory(builderFactory)
.withOperandSupplier(b -> b.exactly(operand))
.as(Config.class));
}
@Override public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final RelNode input = call.rel(1);
// Compute which input fields are used.
// 1. group fields are always used
final ImmutableBitSet.Builder inputFieldsUsed =
aggregate.getGroupSet().rebuild();
// 2. agg functions
for (AggregateCall aggCall : aggregate.getAggCallList()) {
for (int i : aggCall.getArgList()) {
inputFieldsUsed.set(i);
}
if (aggCall.filterArg >= 0) {
inputFieldsUsed.set(aggCall.filterArg);
}
}
final RelBuilder relBuilder = call.builder().push(input);
final List<RexNode> projects = new ArrayList<>();
final Mapping mapping =
Mappings.create(MappingType.INVERSE_SURJECTION,
aggregate.getInput().getRowType().getFieldCount(),
inputFieldsUsed.cardinality());
int j = 0;
for (int i : inputFieldsUsed.build()) {
projects.add(relBuilder.field(i));
mapping.set(i, j++);
}
relBuilder.project(projects);
final ImmutableBitSet newGroupSet =
Mappings.apply(mapping, aggregate.getGroupSet());
final List<ImmutableBitSet> newGroupSets =
aggregate.getGroupSets().stream()
.map(bitSet -> Mappings.apply(mapping, bitSet))
.collect(Util.toImmutableList());
final List<RelBuilder.AggCall> newAggCallList =
aggregate.getAggCallList().stream()
.map(aggCall -> relBuilder.aggregateCall(aggCall, mapping))
.collect(Util.toImmutableList());
final RelBuilder.GroupKey groupKey =
relBuilder.groupKey(newGroupSet, newGroupSets);
relBuilder.aggregate(groupKey, newAggCallList);
call.transformTo(relBuilder.build());
}
/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableAggregateExtractProjectRule.Config.of()
.withOperandFor(Aggregate.class, LogicalTableScan.class);
@Override default AggregateExtractProjectRule toRule() {
return new AggregateExtractProjectRule(this);
}
/** Defines an operand tree for the given classes. */
default Config withOperandFor(Class<? extends Aggregate> aggregateClass,
Class<? extends RelNode> inputClass) {
return withOperandSupplier(b0 ->
b0.operand(aggregateClass).oneInput(b1 ->
b1.operand(inputClass)
// Predicate prevents matching against an Aggregate whose
// input is already a Project. Prevents this rule firing
// repeatedly.
.predicate(r -> !(r instanceof Project)).anyInputs()))
.as(Config.class);
}
}
}