blob: be9a69f2683952d28c0de5d1c3d9a38a40c3f1f6 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.aggregation.builtin;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.any.DoubleAnyAggregatorFactory;
import org.apache.druid.query.aggregation.any.FloatAnyAggregatorFactory;
import org.apache.druid.query.aggregation.any.LongAnyAggregatorFactory;
import org.apache.druid.query.aggregation.any.StringAnyAggregatorFactory;
import org.apache.druid.query.aggregation.first.DoubleFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.FloatFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.LongFirstAggregatorFactory;
import org.apache.druid.query.aggregation.first.StringFirstAggregatorFactory;
import org.apache.druid.query.aggregation.last.DoubleLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.FloatLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public class EarliestLatestAnySqlAggregator implements SqlAggregator
{
public static final SqlAggregator EARLIEST = new EarliestLatestAnySqlAggregator(AggregatorType.EARLIEST);
public static final SqlAggregator LATEST = new EarliestLatestAnySqlAggregator(AggregatorType.LATEST);
public static final SqlAggregator ANY_VALUE = new EarliestLatestAnySqlAggregator(AggregatorType.ANY_VALUE);
enum AggregatorType
{
EARLIEST {
@Override
AggregatorFactory createAggregatorFactory(String name, String fieldName, ValueType type, int maxStringBytes)
{
switch (type) {
case LONG:
return new LongFirstAggregatorFactory(name, fieldName);
case FLOAT:
return new FloatFirstAggregatorFactory(name, fieldName);
case DOUBLE:
return new DoubleFirstAggregatorFactory(name, fieldName);
case STRING:
case COMPLEX:
return new StringFirstAggregatorFactory(name, fieldName, maxStringBytes);
default:
throw new ISE("Cannot build EARLIEST aggregatorFactory for type[%s]", type);
}
}
},
LATEST {
@Override
AggregatorFactory createAggregatorFactory(String name, String fieldName, ValueType type, int maxStringBytes)
{
switch (type) {
case LONG:
return new LongLastAggregatorFactory(name, fieldName);
case FLOAT:
return new FloatLastAggregatorFactory(name, fieldName);
case DOUBLE:
return new DoubleLastAggregatorFactory(name, fieldName);
case STRING:
case COMPLEX:
return new StringLastAggregatorFactory(name, fieldName, maxStringBytes);
default:
throw new ISE("Cannot build LATEST aggregatorFactory for type[%s]", type);
}
}
},
ANY_VALUE {
@Override
AggregatorFactory createAggregatorFactory(String name, String fieldName, ValueType type, int maxStringBytes)
{
switch (type) {
case LONG:
return new LongAnyAggregatorFactory(name, fieldName);
case FLOAT:
return new FloatAnyAggregatorFactory(name, fieldName);
case DOUBLE:
return new DoubleAnyAggregatorFactory(name, fieldName);
case STRING:
return new StringAnyAggregatorFactory(name, fieldName, maxStringBytes);
default:
throw new ISE("Cannot build ANY aggregatorFactory for type[%s]", type);
}
}
};
abstract AggregatorFactory createAggregatorFactory(
String name,
String fieldName,
ValueType outputType,
int maxStringBytes
);
}
private final AggregatorType aggregatorType;
private final SqlAggFunction function;
private EarliestLatestAnySqlAggregator(final AggregatorType aggregatorType)
{
this.aggregatorType = aggregatorType;
this.function = new EarliestLatestSqlAggFunction(aggregatorType);
}
@Override
public SqlAggFunction calciteFunction()
{
return function;
}
@Nullable
@Override
public Aggregation toDruidAggregation(
final PlannerContext plannerContext,
final RowSignature rowSignature,
final VirtualColumnRegistry virtualColumnRegistry,
final RexBuilder rexBuilder,
final String name,
final AggregateCall aggregateCall,
final Project project,
final List<Aggregation> existingAggregations,
final boolean finalizeAggregations
)
{
final List<RexNode> rexNodes = aggregateCall
.getArgList()
.stream()
.map(i -> Expressions.fromFieldAccess(rowSignature, project, i))
.collect(Collectors.toList());
final List<DruidExpression> args = Expressions.toDruidExpressions(plannerContext, rowSignature, rexNodes);
if (args == null) {
return null;
}
final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;
final String fieldName;
if (args.get(0).isDirectColumnAccess()) {
fieldName = args.get(0).getDirectColumn();
} else {
final RelDataType dataType = rexNodes.get(0).getType();
final VirtualColumn virtualColumn =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, args.get(0), dataType);
fieldName = virtualColumn.getOutputName();
}
// Second arg must be a literal, if it exists (the type signature below requires it).
final int maxBytes = rexNodes.size() > 1 ? RexLiteral.intValue(rexNodes.get(1)) : -1;
final ValueType outputType = Calcites.getValueTypeForRelDataType(aggregateCall.getType());
if (outputType == null) {
throw new ISE(
"Cannot translate output sqlTypeName[%s] to Druid type for aggregator[%s]",
aggregateCall.getType().getSqlTypeName(),
aggregateCall.getName()
);
}
return Aggregation.create(
Collections.singletonList(
aggregatorType.createAggregatorFactory(
aggregatorName,
fieldName,
outputType,
maxBytes
)
),
finalizeAggregations ? new FinalizingFieldAccessPostAggregator(name, aggregatorName) : null
);
}
static class EarliestLatestReturnTypeInference implements SqlReturnTypeInference
{
private final int ordinal;
public EarliestLatestReturnTypeInference(int ordinal)
{
this.ordinal = ordinal;
}
@Override
public RelDataType inferReturnType(SqlOperatorBinding sqlOperatorBinding)
{
RelDataType type = sqlOperatorBinding.getOperandType(this.ordinal);
// For non-number and non-string type, which is COMPLEX type, we set the return type to VARCHAR.
if (!SqlTypeUtil.isNumeric(type) &&
!SqlTypeUtil.isString(type)) {
return sqlOperatorBinding.getTypeFactory().createSqlType(SqlTypeName.VARCHAR);
} else {
return type;
}
}
}
private static class EarliestLatestSqlAggFunction extends SqlAggFunction
{
private static final EarliestLatestReturnTypeInference EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE =
new EarliestLatestReturnTypeInference(0);
EarliestLatestSqlAggFunction(AggregatorType aggregatorType)
{
super(
aggregatorType.name(),
null,
SqlKind.OTHER_FUNCTION,
EARLIEST_LATEST_ARG0_RETURN_TYPE_INFERENCE,
InferTypes.RETURN_TYPE,
OperandTypes.or(
OperandTypes.NUMERIC,
OperandTypes.BOOLEAN,
OperandTypes.sequence(
"'" + aggregatorType.name() + "(expr, maxBytesPerString)'\n",
OperandTypes.ANY,
OperandTypes.and(OperandTypes.NUMERIC, OperandTypes.LITERAL)
)
),
SqlFunctionCategory.STRING,
false,
false,
Optionality.FORBIDDEN
);
}
}
}