blob: b96875124193aea9796102ca26f1656731fc05f7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.sql.test;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.runtime.Utilities;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlUnresolvedFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.parser.SqlParserUtil;
import org.apache.calcite.sql.parser.StringAndPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.sql2rel.RelFieldTrimmer;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.test.DiffRepository;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.TestUtil;
import org.apache.calcite.util.Util;
import com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.hamcrest.Matcher;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.function.Consumer;
import static org.apache.calcite.test.Matchers.relIsValid;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasSize;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static java.util.Objects.requireNonNull;
/**
* Abstract implementation of {@link SqlTester}
* that talks to a mock catalog.
*
* <p>This is to implement the default behavior: testing is only against the
* {@link SqlValidator}.
*/
public abstract class AbstractSqlTester implements SqlTester, AutoCloseable {
private static final String NL = System.getProperty("line.separator");
public AbstractSqlTester() {
}
/**
* {@inheritDoc}
*
* <p>This default implementation does nothing.
*/
@Override public void close() {
// no resources to release
}
@Override public void assertExceptionIsThrown(SqlTestFactory factory,
StringAndPos sap, @Nullable String expectedMsgPattern) {
final SqlNode sqlNode;
try {
sqlNode = parseQuery(factory, sap.sql);
} catch (Throwable e) {
SqlTests.checkEx(e, expectedMsgPattern, sap, SqlTests.Stage.PARSE);
return;
}
final SqlValidator validator = factory.createValidator();
Throwable thrown = null;
try {
validator.validate(sqlNode);
} catch (Throwable ex) {
thrown = ex;
}
SqlTests.checkEx(thrown, expectedMsgPattern, sap, SqlTests.Stage.VALIDATE);
}
protected void checkParseEx(Throwable e, @Nullable String expectedMsgPattern,
StringAndPos sap) {
try {
throw e;
} catch (SqlParseException spe) {
String errMessage = spe.getMessage();
if (expectedMsgPattern == null) {
throw new RuntimeException("Error while parsing query:" + sap, spe);
} else if (errMessage == null
|| !Util.toLinux(errMessage).matches(expectedMsgPattern)) {
throw new RuntimeException("Error did not match expected ["
+ expectedMsgPattern + "] while parsing query ["
+ sap + "]", spe);
}
} catch (Throwable t) {
throw new RuntimeException("Error while parsing query: " + sap, t);
}
}
@Override public RelDataType getColumnType(SqlTestFactory factory,
String sql) {
return validateAndApply(factory, StringAndPos.of(sql),
(sql1, validator, n) -> {
final RelDataType rowType =
validator.getValidatedNodeType(n);
final List<RelDataTypeField> fields = rowType.getFieldList();
assertThat("expected query to return 1 field", fields, hasSize(1));
return fields.get(0).getType();
});
}
@Override public RelDataType getResultType(SqlTestFactory factory,
String sql) {
return validateAndApply(factory, StringAndPos.of(sql),
(sql1, validator, n) ->
validator.getValidatedNodeType(n));
}
Pair<SqlValidator, SqlNode> parseAndValidate(SqlTestFactory factory,
String sql) {
SqlNode sqlNode;
try {
sqlNode = parseQuery(factory, sql);
} catch (Throwable e) {
throw new RuntimeException("Error while parsing query: " + sql, e);
}
SqlValidator validator = factory.createValidator();
return Pair.of(validator, validator.validate(sqlNode));
}
@Override public SqlNode parseQuery(SqlTestFactory factory, String sql)
throws SqlParseException {
SqlParser parser = factory.createParser(sql);
return parser.parseQuery();
}
@Override public SqlNode parseExpression(SqlTestFactory factory,
String expr) throws SqlParseException {
SqlParser parser = factory.createParser(expr);
return parser.parseExpression();
}
@Override public void checkColumnType(SqlTestFactory factory, String sql,
String expected) {
validateAndThen(factory, StringAndPos.of(sql),
checkColumnTypeAction(is(expected)));
}
private static ValidatedNodeConsumer checkColumnTypeAction(
Matcher<String> matcher) {
return (sql1, validator, validatedNode) -> {
final RelDataType rowType =
validator.getValidatedNodeType(validatedNode);
final List<RelDataTypeField> fields = rowType.getFieldList();
assertEquals(1, fields.size(), "expected query to return 1 field");
final RelDataType actualType = fields.get(0).getType();
String actual = SqlTests.getTypeString(actualType);
assertThat(actual, matcher);
};
}
// SqlTester methods
@Override public void setFor(
SqlOperator operator,
VmName... unimplementedVmNames) {
// do nothing
}
@Override public void checkAgg(SqlTestFactory factory,
String expr,
String[] inputValues,
ResultChecker resultChecker) {
String query =
SqlTests.generateAggQuery(expr, inputValues);
check(factory, query, SqlTests.ANY_TYPE_CHECKER, resultChecker);
}
@Override public void checkWinAgg(SqlTestFactory factory,
String expr,
String[] inputValues,
String windowSpec,
String type,
ResultChecker resultChecker) {
String query =
SqlTests.generateWinAggQuery(
expr, windowSpec, inputValues);
check(factory, query, SqlTests.ANY_TYPE_CHECKER, resultChecker);
}
@Override public void check(SqlTestFactory factory,
String query, TypeChecker typeChecker,
ParameterChecker parameterChecker, ResultChecker resultChecker) {
// This implementation does NOT check the result!
// All it does is check the return type.
requireNonNull(typeChecker, "typeChecker");
requireNonNull(parameterChecker, "parameterChecker");
requireNonNull(resultChecker, "resultChecker");
// Parse and validate. There should be no errors.
// There must be 1 column. Get its type.
RelDataType actualType = getColumnType(factory, query);
// Check result type.
typeChecker.checkType(() -> "Query: " + query, actualType);
Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, query);
SqlValidator validator = requireNonNull(p.left);
SqlNode n = requireNonNull(p.right);
final RelDataType parameterRowType = validator.getParameterRowType(n);
parameterChecker.checkParameters(parameterRowType);
}
@Override public void validateAndThen(SqlTestFactory factory,
StringAndPos sap, ValidatedNodeConsumer consumer) {
Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, sap.sql);
SqlValidator validator = requireNonNull(p.left);
SqlNode rewrittenNode = requireNonNull(p.right);
consumer.accept(sap, validator, rewrittenNode);
}
@Override public <R> R validateAndApply(SqlTestFactory factory,
StringAndPos sap, ValidatedNodeFunction<R> function) {
Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, sap.sql);
SqlValidator validator = requireNonNull(p.left);
SqlNode rewrittenNode = requireNonNull(p.right);
return function.apply(sap, validator, rewrittenNode);
}
@Override public void checkFails(SqlTestFactory factory, StringAndPos sap,
String expectedError, boolean runtime) {
if (runtime) {
// We need to test that the expression fails at runtime.
// Ironically, that means that it must succeed at prepare time.
final String sql = buildQuery(sap.addCarets());
Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, sql);
SqlNode n = p.right;
assertNotNull(n);
} else {
StringAndPos sap1 = StringAndPos.of(buildQuery(sap.addCarets()));
checkQueryFails(factory, sap1, expectedError);
}
}
@Override public void checkQueryFails(SqlTestFactory factory,
StringAndPos sap, String expectedError) {
assertExceptionIsThrown(factory, sap, expectedError);
}
@Override public void checkAggFails(SqlTestFactory factory,
String expr,
String[] inputValues,
String expectedError,
boolean runtime) {
final String sql =
SqlTests.generateAggQuery(expr, inputValues);
if (runtime) {
Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, sql);
SqlNode n = p.right;
assertNotNull(n);
} else {
checkQueryFails(factory, StringAndPos.of(sql), expectedError);
}
}
public static String buildQuery(String expression) {
return "values (" + expression + ")";
}
public static String buildQueryAgg(String expression) {
return "select " + expression + " from (values (1)) as t(x) group by x";
}
/**
* Builds a query that extracts all literals as columns in an underlying
* select.
*
* <p>For example,
*
* <blockquote>{@code 1 < 5}</blockquote>
*
* <p>becomes
*
* <blockquote>{@code SELECT p0 < p1
* FROM (VALUES (1, 5)) AS t(p0, p1)}</blockquote>
*
* <p>Null literals don't have enough type information to be extracted.
* We push down {@code CAST(NULL AS type)} but raw nulls such as
* {@code CASE 1 WHEN 2 THEN 'a' ELSE NULL END} are left as is.
*
* @param factory Test factory
* @param expression Scalar expression
* @return Query that evaluates a scalar expression
*/
protected String buildQuery2(SqlTestFactory factory, String expression) {
if (expression.matches("(?i).*(percentile_(cont|disc)|convert|sort_array|cast)\\(.*")) {
// PERCENTILE_CONT requires its argument to be a literal,
// so converting its argument to a column will cause false errors.
// Similarly, MSSQL-style CONVERT.
return buildQuery(expression);
}
// "values (1 < 5)"
// becomes
// "select p0 < p1 from (values (1, 5)) as t(p0, p1)"
SqlNode x;
final String sql = "values (" + expression + ")";
try {
x = parseQuery(factory, sql);
} catch (SqlParseException e) {
throw TestUtil.rethrow(e);
}
final Collection<SqlNode> literalSet = new LinkedHashSet<>();
x.accept(
new SqlShuttle() {
private final List<SqlOperator> ops =
ImmutableList.of(
SqlStdOperatorTable.LITERAL_CHAIN,
SqlStdOperatorTable.LOCALTIME,
SqlStdOperatorTable.LOCALTIMESTAMP,
SqlStdOperatorTable.CURRENT_TIME,
SqlStdOperatorTable.CURRENT_TIMESTAMP);
@Override public SqlNode visit(SqlLiteral literal) {
if (!isNull(literal)
&& literal.getTypeName() != SqlTypeName.SYMBOL) {
literalSet.add(literal);
}
return literal;
}
@Override public SqlNode visit(SqlCall call) {
SqlOperator operator = call.getOperator();
if (operator.getKind() == SqlKind.LAMBDA) {
return call;
}
if (operator instanceof SqlUnresolvedFunction) {
final SqlUnresolvedFunction unresolvedFunction =
(SqlUnresolvedFunction) operator;
final SqlOperator lookup =
SqlValidatorUtil.lookupSqlFunctionByID(
SqlStdOperatorTable.instance(),
unresolvedFunction.getSqlIdentifier(),
unresolvedFunction.getFunctionType());
if (lookup != null) {
operator = lookup;
call =
operator.createCall(call.getFunctionQuantifier(),
call.getParserPosition(), call.getOperandList());
}
}
if (operator == SqlStdOperatorTable.CAST
&& isNull(call.operand(0))) {
literalSet.add(call);
return call;
} else if (ops.contains(operator)) {
// "Argument to function 'LOCALTIME' must be a
// literal"
return call;
} else {
return super.visit(call);
}
}
private boolean isNull(SqlNode sqlNode) {
return sqlNode instanceof SqlLiteral
&& ((SqlLiteral) sqlNode).getTypeName()
== SqlTypeName.NULL;
}
});
final List<SqlNode> nodes = new ArrayList<>(literalSet);
nodes.sort((o1, o2) -> {
final SqlParserPos pos0 = o1.getParserPosition();
final SqlParserPos pos1 = o2.getParserPosition();
int c = -Utilities.compare(pos0.getLineNum(), pos1.getLineNum());
if (c != 0) {
return c;
}
return -Utilities.compare(pos0.getColumnNum(), pos1.getColumnNum());
});
String sql2 = sql;
final PairList<String, String> values = PairList.of();
int p = 0;
for (SqlNode literal : nodes) {
final SqlParserPos pos = literal.getParserPosition();
final int start =
SqlParserUtil.lineColToIndex(
sql, pos.getLineNum(), pos.getColumnNum());
final int end =
SqlParserUtil.lineColToIndex(
sql,
pos.getEndLineNum(),
pos.getEndColumnNum()) + 1;
String param = "p" + p++;
values.add(sql2.substring(start, end), param);
sql2 = sql2.substring(0, start)
+ param
+ sql2.substring(end);
}
if (values.isEmpty()) {
values.add("1", "p0");
}
return "select "
+ sql2.substring("values (".length(), sql2.length() - 1)
+ " from (values ("
+ Util.commaList(values.leftList())
+ ")) as t("
+ Util.commaList(values.rightList())
+ ")";
}
@Override public void forEachQuery(SqlTestFactory factory,
String expression, Consumer<String> consumer) {
// Why not return a list? If there is a syntax error in the expression, the
// consumer will discover it before we try to parse it to do substitutions
// on the parse tree.
consumer.accept("values (" + expression + ")");
consumer.accept(buildQuery2(factory, expression));
}
@Override public void assertConvertsTo(SqlTestFactory factory,
DiffRepository diffRepos,
String sql,
String plan,
boolean trim,
boolean expression,
boolean decorrelate) {
if (expression) {
assertExprConvertsTo(factory, diffRepos, sql, plan);
} else {
assertSqlConvertsTo(factory, diffRepos, sql, plan, trim, decorrelate);
}
}
private void assertExprConvertsTo(SqlTestFactory factory,
DiffRepository diffRepos, String expr, String plan) {
String expr2 = diffRepos.expand("sql", expr);
RexNode rex = convertExprToRex(factory, expr2);
assertNotNull(rex);
// NOTE jvs 28-Mar-2006: insert leading newline so
// that plans come out nicely stacked instead of first
// line immediately after CDATA start
String actual = NL + rex + NL;
diffRepos.assertEquals("plan", plan, actual);
}
private void assertSqlConvertsTo(SqlTestFactory factory,
DiffRepository diffRepos, String sql, String plan,
boolean trim,
boolean decorrelate) {
String sql2 = diffRepos.expand("sql", sql);
final Pair<SqlValidator, RelRoot> pair =
convertSqlToRel2(factory, sql2, decorrelate, trim);
final RelRoot root = requireNonNull(pair.right);
final SqlValidator validator = requireNonNull(pair.left);
RelNode rel = root.project();
assertNotNull(rel);
assertThat(rel, relIsValid());
if (trim) {
final RelBuilder relBuilder =
RelFactories.LOGICAL_BUILDER.create(rel.getCluster(), null);
final RelFieldTrimmer trimmer =
createFieldTrimmer(validator, relBuilder);
rel = trimmer.trim(rel);
assertNotNull(rel);
assertThat(rel, relIsValid());
}
// NOTE jvs 28-Mar-2006: insert leading newline so
// that plans come out nicely stacked instead of first
// line immediately after CDATA start
String actual = NL + RelOptUtil.toString(rel);
diffRepos.assertEquals("plan", plan, actual);
}
private RexNode convertExprToRex(SqlTestFactory factory, String expr) {
requireNonNull(expr, "expr");
final SqlNode sqlQuery;
try {
sqlQuery = parseExpression(factory, expr);
} catch (RuntimeException | Error e) {
throw e;
} catch (Exception e) {
throw TestUtil.rethrow(e);
}
final SqlToRelConverter converter = factory.createSqlToRelConverter();
final SqlValidator validator = requireNonNull(converter.validator);
final SqlNode validatedQuery = validator.validate(sqlQuery);
return converter.convertExpression(validatedQuery);
}
@Override public Pair<SqlValidator, RelRoot> convertSqlToRel2(
SqlTestFactory factory, String sql, boolean decorrelate,
boolean trim) {
requireNonNull(sql, "sql");
final SqlNode sqlQuery;
try {
sqlQuery = parseQuery(factory, sql);
} catch (RuntimeException | Error e) {
throw e;
} catch (Exception e) {
throw TestUtil.rethrow(e);
}
final SqlToRelConverter converter = factory.createSqlToRelConverter();
final SqlValidator validator = requireNonNull(converter.validator);
final SqlNode validatedQuery = validator.validate(sqlQuery);
RelRoot root =
converter.convertQuery(validatedQuery, false, true);
requireNonNull(root, "root");
if (decorrelate || trim) {
root = root.withRel(converter.flattenTypes(root.rel, true));
}
if (decorrelate) {
root = root.withRel(converter.decorrelate(sqlQuery, root.rel));
}
if (trim) {
root = root.withRel(converter.trimUnusedFields(true, root.rel));
}
return Pair.of(validator, root);
}
@Override public RelNode trimRelNode(SqlTestFactory factory,
RelNode relNode) {
final SqlToRelConverter converter = factory.createSqlToRelConverter();
RelNode r2 = converter.flattenTypes(relNode, true);
return converter.trimUnusedFields(true, r2);
}
/**
* Creates a RelFieldTrimmer.
*
* @param validator Validator
* @param relBuilder Builder
* @return Field trimmer
*/
public RelFieldTrimmer createFieldTrimmer(SqlValidator validator,
RelBuilder relBuilder) {
return new RelFieldTrimmer(validator, relBuilder);
}
}