blob: e1c6a6f74740ee8e4224b6970dac57bb5b75784c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.handler.sql;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import static org.apache.solr.client.solrj.io.stream.metrics.CountDistinctMetric.APPROX_COUNT_DISTINCT;
import static org.apache.solr.client.solrj.io.stream.metrics.CountDistinctMetric.COUNT_DISTINCT;
/**
* Implementation of {@link org.apache.calcite.rel.core.Aggregate} relational expression in Solr.
*/
class SolrAggregate extends Aggregate implements SolrRel {
private static final List<SqlAggFunction> SUPPORTED_AGGREGATIONS = Arrays.asList(
SqlStdOperatorTable.COUNT,
SqlStdOperatorTable.SUM,
SqlStdOperatorTable.SUM0,
SqlStdOperatorTable.MIN,
SqlStdOperatorTable.MAX,
SqlStdOperatorTable.AVG
);
// Returns the Solr agg metric identifier (includes column) for the SQL metric
static String solrAggMetricId(String metric, String column) {
// CountDistinctMetric's getIdentifer returns "countDist" but all others return a lowercased value
String funcName = COUNT_DISTINCT.equals(metric) ? COUNT_DISTINCT : metric.toLowerCase(Locale.ROOT);
return String.format(Locale.ROOT, "%s(%s)", funcName, column);
}
SolrAggregate(
RelOptCluster cluster,
RelTraitSet traitSet,
List<RelHint> hints,
RelNode child,
ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
super(cluster, traitSet, hints, child, groupSet, groupSets, aggCalls);
assert getConvention() == SolrRel.CONVENTION;
assert getConvention() == child.getConvention();
}
@Override
public Aggregate copy(RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets,
List<AggregateCall> aggCalls) {
return new SolrAggregate(getCluster(), traitSet, hints,input, groupSet, groupSets, aggCalls);
}
@Override
public Aggregate copy(RelTraitSet traitSet, RelNode input,
boolean indicator, ImmutableBitSet groupSet,
List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) {
return new SolrAggregate(getCluster(), traitSet, hints, input, groupSet, groupSets, aggCalls);
}
public void implement(Implementor implementor) {
implementor.visitChild(0, getInput());
final List<String> inNames = SolrRules.solrFieldNames(getInput().getRowType());
for (Pair<AggregateCall, String> namedAggCall : getNamedAggCalls()) {
AggregateCall aggCall = namedAggCall.getKey();
Pair<String, String> metric = toSolrMetric(implementor, aggCall, inNames);
boolean isDistinct = SqlStdOperatorTable.COUNT.equals(aggCall.getAggregation()) && aggCall.isDistinct();
// map the SQL COUNT to either countDist or hll for distinct ops, otherwise, the metric names map over directly
String metricKey = isDistinct ? (aggCall.isApproximate() ? APPROX_COUNT_DISTINCT : COUNT_DISTINCT) : metric.getKey();
implementor.addReverseAggMapping(namedAggCall.getValue(), solrAggMetricId(metricKey, metric.getValue()));
implementor.addMetricPair(namedAggCall.getValue(), metricKey, metric.getValue());
}
for (int group : getGroupSet()) {
String inName = inNames.get(group);
implementor.addBucket(inName);
}
}
@SuppressWarnings({"fallthrough"})
private Pair<String, String> toSolrMetric(Implementor implementor, AggregateCall aggCall, List<String> inNames) {
SqlAggFunction aggregation = aggCall.getAggregation();
List<Integer> args = aggCall.getArgList();
switch (args.size()) {
case 0:
if (aggregation.equals(SqlStdOperatorTable.COUNT)) {
return new Pair<>(aggregation.getName(), "*");
}
case 1:
String inName = inNames.get(args.get(0));
String name = implementor.fieldMappings.getOrDefault(inName, inName);
if (SUPPORTED_AGGREGATIONS.contains(aggregation)) {
return new Pair<>(aggregation.getName(), name);
}
default:
throw new AssertionError("Invalid aggregation " + aggregation + " with args " + args + " with names" + inNames);
}
}
}
// End SolrAggregate.java