blob: 6e05ce8b49cfd1e694b12a2ac05253eed88d70c0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.data.input.MapBasedRow;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.ValueMatcher;
import org.apache.druid.segment.RowAdapters;
import org.apache.druid.segment.RowBasedColumnSelectorFactory;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.virtual.VirtualizedColumnSelectorFactory;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignatures;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.junit.Assert;
import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
class ExpressionTestHelper
{
private static final PlannerContext PLANNER_CONTEXT = PlannerContext.create(
CalciteTests.createOperatorTable(),
CalciteTests.createExprMacroTable(),
new PlannerConfig(),
ImmutableMap.of(),
ImmutableList.of(),
CalciteTests.REGULAR_USER_AUTH_RESULT
);
private final RowSignature rowSignature;
private final Map<String, Object> bindings;
private final RelDataTypeFactory typeFactory;
private final RexBuilder rexBuilder;
private final RelDataType relDataType;
ExpressionTestHelper(RowSignature rowSignature, Map<String, Object> bindings)
{
this.rowSignature = rowSignature;
this.bindings = bindings;
this.typeFactory = new JavaTypeFactoryImpl();
this.rexBuilder = new RexBuilder(typeFactory);
this.relDataType = RowSignatures.toRelDataType(rowSignature, typeFactory);
}
RelDataType createSqlType(SqlTypeName sqlTypeName)
{
return typeFactory.createSqlType(sqlTypeName);
}
RexNode makeInputRef(String columnName)
{
int columnNumber = rowSignature.indexOf(columnName);
return rexBuilder.makeInputRef(relDataType.getFieldList().get(columnNumber).getType(), columnNumber);
}
RexNode getConstantNull()
{
return rexBuilder.constantNull();
}
RexLiteral makeFlag(Enum flag)
{
return rexBuilder.makeFlag(flag);
}
RexLiteral makeNullLiteral(SqlTypeName sqlTypeName)
{
return rexBuilder.makeNullLiteral(createSqlType(sqlTypeName));
}
RexLiteral makeLiteral(String s)
{
return rexBuilder.makeLiteral(s);
}
RexNode makeLiteral(DateTime timestamp)
{
return rexBuilder.makeTimestampLiteral(Calcites.jodaToCalciteTimestampString(timestamp, DateTimeZone.UTC), 0);
}
RexNode makeLiteral(Integer integer)
{
return rexBuilder.makeLiteral(new BigDecimal(integer), createSqlType(SqlTypeName.INTEGER), true);
}
RexNode makeLiteral(Long bigint)
{
return rexBuilder.makeLiteral(new BigDecimal(bigint), createSqlType(SqlTypeName.BIGINT), true);
}
RexNode makeLiteral(BigDecimal bigDecimal)
{
return rexBuilder.makeExactLiteral(bigDecimal);
}
RexNode makeLiteral(BigDecimal v, SqlIntervalQualifier intervalQualifier)
{
return rexBuilder.makeIntervalLiteral(v, intervalQualifier);
}
RexNode makeLiteral(Double d)
{
return rexBuilder.makeLiteral(d, createSqlType(SqlTypeName.DOUBLE), true);
}
RexNode makeCall(SqlOperator op, RexNode... exprs)
{
return rexBuilder.makeCall(op, exprs);
}
RexNode makeAbstractCast(RelDataType type, RexNode exp)
{
return rexBuilder.makeAbstractCast(type, exp);
}
/**
* @return Representation of variable that is bound in an expression. Intended use is as one of
* the args to {@link #buildExpectedExpression(String, Object...)}.
*/
Variable makeVariable(String name)
{
return new Variable(name);
}
private static class Variable
{
private final String name;
Variable(String name)
{
this.name = name;
}
@Override
public String toString()
{
return "\"" + name + "\"";
}
}
DruidExpression buildExpectedExpression(String functionName, Object... args)
{
String noDelimiter = "";
String argsString = Arrays.stream(args)
.map(ExpressionTestHelper::quoteIfNeeded)
.collect(Collectors.joining(","));
List<String> elements = Arrays.asList(functionName, "(", argsString, ")");
return DruidExpression.fromExpression(String.join(noDelimiter, elements));
}
private static String quoteIfNeeded(@Nullable Object arg)
{
if (arg == null) {
return "null";
} else if (arg instanceof String) {
return "'" + arg + "'";
} else if (arg instanceof Boolean) {
return (Boolean) arg ? "1" : "0";
} else {
return arg.toString();
}
}
void testExpression(
final SqlTypeName sqlTypeName,
final SqlOperator op,
final List<RexNode> exprs,
final DruidExpression expectedExpression,
final Object expectedResult
)
{
RelDataType returnType = createSqlType(sqlTypeName);
testExpression(rexBuilder.makeCall(returnType, op, exprs), expectedExpression, expectedResult);
}
void testExpression(
final SqlOperator op,
final RexNode expr,
final DruidExpression expectedExpression,
final Object expectedResult
)
{
testExpression(op, Collections.singletonList(expr), expectedExpression, expectedResult);
}
void testExpression(
final SqlOperator op,
final List<? extends RexNode> exprs,
final DruidExpression expectedExpression,
final Object expectedResult
)
{
testExpression(rexBuilder.makeCall(op, exprs), expectedExpression, expectedResult);
}
void testExpression(
final RexNode rexNode,
final DruidExpression expectedExpression,
final Object expectedResult
)
{
DruidExpression expression = Expressions.toDruidExpression(PLANNER_CONTEXT, rowSignature, rexNode);
Assert.assertEquals("Expression for: " + rexNode, expectedExpression, expression);
ExprEval<?> result = Parser.parse(expression.getExpression(), PLANNER_CONTEXT.getExprMacroTable())
.eval(Parser.withMap(bindings));
Assert.assertEquals("Result for: " + rexNode, expectedResult, result.value());
}
void testFilter(
final SqlOperator op,
final List<? extends RexNode> exprs,
final List<VirtualColumn> expectedVirtualColumns,
final DimFilter expectedFilter,
final boolean expectedResult
)
{
final RexNode rexNode = rexBuilder.makeCall(op, exprs);
final VirtualColumnRegistry virtualColumnRegistry = VirtualColumnRegistry.create(rowSignature);
final DimFilter filter = Expressions.toFilter(PLANNER_CONTEXT, rowSignature, virtualColumnRegistry, rexNode);
Assert.assertEquals("Filter for: " + rexNode, expectedFilter, filter);
final List<VirtualColumn> virtualColumns =
filter.getRequiredColumns()
.stream()
.map(virtualColumnRegistry::getVirtualColumn)
.filter(Objects::nonNull)
.sorted(Comparator.comparing(VirtualColumn::getOutputName))
.collect(Collectors.toList());
Assert.assertEquals(
"Virtual columns for: " + rexNode,
expectedVirtualColumns.stream()
.sorted(Comparator.comparing(VirtualColumn::getOutputName))
.collect(Collectors.toList()),
virtualColumns
);
final ValueMatcher matcher = expectedFilter.toFilter().makeMatcher(
new VirtualizedColumnSelectorFactory(
RowBasedColumnSelectorFactory.create(
RowAdapters.standardRow(),
() -> new MapBasedRow(0L, bindings),
rowSignature,
false
),
VirtualColumns.create(virtualColumns)
)
);
Assert.assertEquals("Result for: " + rexNode, expectedResult, matcher.matches());
}
}