blob: dbfd8c92005a0509255a099c8b716599a61e4824 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.rel.core;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPatternFieldRef;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlBitOpAggFunction;
import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.util.ImmutableBitSet;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.ImmutableSortedSet;
import org.checkerframework.checker.nullness.qual.Nullable;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import static com.google.common.base.Preconditions.checkArgument;
/**
* Relational expression that represent a MATCH_RECOGNIZE node.
*
* <p>Each output row has the columns defined in the measure statements.
*/
public abstract class Match extends SingleRel {
//~ Instance fields ---------------------------------------------
private static final String STAR = "*";
protected final ImmutableMap<String, RexNode> measures;
protected final RexNode pattern;
protected final boolean strictStart;
protected final boolean strictEnd;
protected final boolean allRows;
protected final RexNode after;
protected final ImmutableMap<String, RexNode> patternDefinitions;
protected final Set<RexMRAggCall> aggregateCalls;
protected final Map<String, SortedSet<RexMRAggCall>> aggregateCallsPreVar;
protected final ImmutableMap<String, SortedSet<String>> subsets;
protected final ImmutableBitSet partitionKeys;
protected final RelCollation orderKeys;
protected final @Nullable RexNode interval;
//~ Constructors -----------------------------------------------
/**
* Creates a Match.
*
* @param cluster Cluster
* @param traitSet Trait set
* @param input Input relational expression
* @param rowType Row type
* @param pattern Regular expression that defines pattern variables
* @param strictStart Whether it is a strict start pattern
* @param strictEnd Whether it is a strict end pattern
* @param patternDefinitions Pattern definitions
* @param measures Measure definitions
* @param after After match definitions
* @param subsets Subsets of pattern variables
* @param allRows Whether all rows per match (false means one row per match)
* @param partitionKeys Partition by columns
* @param orderKeys Order by columns
* @param interval Interval definition, null if WITHIN clause is not defined
*/
protected Match(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
RelDataType rowType, RexNode pattern,
boolean strictStart, boolean strictEnd,
Map<String, RexNode> patternDefinitions, Map<String, RexNode> measures,
RexNode after, Map<String, ? extends SortedSet<String>> subsets,
boolean allRows, ImmutableBitSet partitionKeys, RelCollation orderKeys,
@Nullable RexNode interval) {
super(cluster, traitSet, input);
this.rowType = Objects.requireNonNull(rowType, "rowType");
this.pattern = Objects.requireNonNull(pattern, "pattern");
checkArgument(!patternDefinitions.isEmpty());
this.strictStart = strictStart;
this.strictEnd = strictEnd;
this.patternDefinitions = ImmutableMap.copyOf(patternDefinitions);
this.measures = ImmutableMap.copyOf(measures);
this.after = Objects.requireNonNull(after, "after");
this.subsets = copyMap(subsets);
this.allRows = allRows;
this.partitionKeys = Objects.requireNonNull(partitionKeys, "partitionKeys");
this.orderKeys = Objects.requireNonNull(orderKeys, "orderKeys");
this.interval = interval;
final AggregateFinder aggregateFinder = new AggregateFinder();
for (RexNode rex : this.patternDefinitions.values()) {
if (rex instanceof RexCall) {
aggregateFinder.go((RexCall) rex);
}
}
for (RexNode rex : this.measures.values()) {
if (rex instanceof RexCall) {
aggregateFinder.go((RexCall) rex);
}
}
aggregateCalls = ImmutableSortedSet.copyOf(aggregateFinder.aggregateCalls);
aggregateCallsPreVar =
copyMap(aggregateFinder.aggregateCallsPerVar);
}
/** Creates an immutable map of a map of sorted sets. */
private static <K extends Comparable<K>, V>
ImmutableSortedMap<K, SortedSet<V>>
copyMap(Map<K, ? extends SortedSet<V>> map) {
final ImmutableSortedMap.Builder<K, SortedSet<V>> b =
ImmutableSortedMap.naturalOrder();
for (Map.Entry<K, ? extends SortedSet<V>> e : map.entrySet()) {
b.put(e.getKey(), ImmutableSortedSet.copyOf(e.getValue()));
}
return b.build();
}
//~ Methods --------------------------------------------------
public ImmutableMap<String, RexNode> getMeasures() {
return measures;
}
public RexNode getAfter() {
return after;
}
public RexNode getPattern() {
return pattern;
}
public boolean isStrictStart() {
return strictStart;
}
public boolean isStrictEnd() {
return strictEnd;
}
public boolean isAllRows() {
return allRows;
}
public ImmutableMap<String, RexNode> getPatternDefinitions() {
return patternDefinitions;
}
public ImmutableMap<String, SortedSet<String>> getSubsets() {
return subsets;
}
public ImmutableBitSet getPartitionKeys() {
return partitionKeys;
}
public RelCollation getOrderKeys() {
return orderKeys;
}
public @Nullable RexNode getInterval() {
return interval;
}
@Override public RelWriter explainTerms(RelWriter pw) {
return super.explainTerms(pw)
.item("partition", getPartitionKeys().asList())
.item("order", getOrderKeys())
.item("outputFields", getRowType().getFieldNames())
.item("allRows", isAllRows())
.item("after", getAfter())
.item("pattern", getPattern())
.item("isStrictStarts", isStrictStart())
.item("isStrictEnds", isStrictEnd())
.itemIf("interval", getInterval(), getInterval() != null)
.item("subsets", getSubsets().values().asList())
.item("patternDefinitions", getPatternDefinitions().values().asList())
.item("inputFields", getInput().getRowType().getFieldNames());
}
/**
* Find aggregate functions in operands.
*/
private static class AggregateFinder extends RexVisitorImpl<Void> {
final NavigableSet<RexMRAggCall> aggregateCalls = new TreeSet<>();
final Map<String, NavigableSet<RexMRAggCall>> aggregateCallsPerVar =
new TreeMap<>();
AggregateFinder() {
super(true);
}
@Override public Void visitCall(RexCall call) {
SqlAggFunction aggFunction = null;
switch (call.getKind()) {
case SUM:
aggFunction = new SqlSumAggFunction(call.getType());
break;
case SUM0:
aggFunction = new SqlSumEmptyIsZeroAggFunction();
break;
case MAX:
case MIN:
aggFunction = new SqlMinMaxAggFunction(call.getKind());
break;
case COUNT:
aggFunction = SqlStdOperatorTable.COUNT;
break;
case ANY_VALUE:
aggFunction = SqlStdOperatorTable.ANY_VALUE;
break;
case BIT_AND:
case BIT_OR:
case BIT_XOR:
aggFunction = new SqlBitOpAggFunction(call.getKind());
break;
default:
visitEach(call.operands);
}
if (aggFunction != null) {
RexMRAggCall aggCall =
new RexMRAggCall(aggFunction, call.getType(), call.getOperands(),
aggregateCalls.size());
aggregateCalls.add(aggCall);
Set<String> pv = new PatternVarFinder().go(call.getOperands());
if (pv.size() == 0) {
pv.add(STAR);
}
for (String alpha : pv) {
final NavigableSet<RexMRAggCall> set;
if (aggregateCallsPerVar.containsKey(alpha)) {
set = aggregateCallsPerVar.get(alpha);
} else {
set = new TreeSet<>();
aggregateCallsPerVar.put(alpha, set);
}
boolean update = true;
for (RexMRAggCall rex : set) {
if (rex.equals(aggCall)) {
update = false;
break;
}
}
if (update) {
set.add(aggCall);
}
}
}
return null;
}
public void go(RexCall call) {
call.accept(this);
}
}
/**
* Visits the operands of an aggregate call to retrieve relevant pattern
* variables.
*/
private static class PatternVarFinder extends RexVisitorImpl<Void> {
final Set<String> patternVars = new HashSet<>();
PatternVarFinder() {
super(true);
}
@Override public Void visitPatternFieldRef(RexPatternFieldRef fieldRef) {
patternVars.add(fieldRef.getAlpha());
return null;
}
@Override public Void visitCall(RexCall call) {
visitEach(call.operands);
return null;
}
public Set<String> go(RexNode rex) {
rex.accept(this);
return patternVars;
}
public Set<String> go(List<RexNode> rexNodeList) {
visitEach(rexNodeList);
return patternVars;
}
}
/**
* Aggregate calls in match recognize.
*/
public static final class RexMRAggCall extends RexCall
implements Comparable<RexMRAggCall> {
public final int ordinal;
RexMRAggCall(SqlAggFunction aggFun,
RelDataType type,
List<RexNode> operands,
int ordinal) {
super(type, aggFun, operands);
this.ordinal = ordinal;
digest = toString(); // can compute here because class is final
}
@Override public int compareTo(RexMRAggCall o) {
return toString().compareTo(o.toString());
}
@Override public boolean equals(@Nullable Object obj) {
return obj == this
|| obj instanceof RexMRAggCall
&& toString().equals(obj.toString());
}
@Override public int hashCode() {
return toString().hashCode();
}
}
}