blob: 3347c46f039232500c49c99551e861b254b1d607 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.test.fuzzer;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgramBuilderBase;
import org.apache.calcite.rex.RexUnknownAs;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Sarg;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.TreeRangeSet;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
/**
* Generates random {@link RexNode} instances for tests.
*/
public class RexFuzzer extends RexProgramBuilderBase {
private static final int MAX_VARS = 2;
private static final SqlOperator[] BOOL_TO_BOOL = {
SqlStdOperatorTable.NOT,
SqlStdOperatorTable.IS_TRUE,
SqlStdOperatorTable.IS_FALSE,
SqlStdOperatorTable.IS_NOT_TRUE,
SqlStdOperatorTable.IS_NOT_FALSE,
};
private static final SqlOperator[] ANY_TO_BOOL = {
SqlStdOperatorTable.IS_NULL,
SqlStdOperatorTable.IS_NOT_NULL,
SqlStdOperatorTable.IS_UNKNOWN,
SqlStdOperatorTable.IS_NOT_UNKNOWN,
};
private static final SqlOperator[] COMPARABLE_TO_BOOL = {
SqlStdOperatorTable.EQUALS,
SqlStdOperatorTable.NOT_EQUALS,
SqlStdOperatorTable.GREATER_THAN,
SqlStdOperatorTable.GREATER_THAN_OR_EQUAL,
SqlStdOperatorTable.LESS_THAN,
SqlStdOperatorTable.LESS_THAN_OR_EQUAL,
SqlStdOperatorTable.IS_DISTINCT_FROM,
SqlStdOperatorTable.IS_NOT_DISTINCT_FROM,
};
private static final SqlOperator[] BOOL_TO_BOOL_MULTI_ARG = {
SqlStdOperatorTable.OR,
SqlStdOperatorTable.AND,
SqlStdOperatorTable.COALESCE,
};
private static final SqlOperator[] ANY_SAME_TYPE_MULTI_ARG = {
SqlStdOperatorTable.COALESCE,
};
private static final SqlOperator[] NUMERIC_TO_NUMERIC = {
SqlStdOperatorTable.PLUS,
SqlStdOperatorTable.MINUS,
SqlStdOperatorTable.MULTIPLY,
// Divide by zero is not allowed, so we do not generate divide
// SqlStdOperatorTable.DIVIDE,
// SqlStdOperatorTable.DIVIDE_INTEGER,
};
private static final SqlOperator[] UNARY_NUMERIC = {
SqlStdOperatorTable.UNARY_MINUS,
SqlStdOperatorTable.UNARY_PLUS,
};
private static final int[] INT_VALUES = {-1, 0, 1, 100500};
private final RelDataType intType;
private final RelDataType nullableIntType;
/**
* Generates randomized {@link RexNode}.
*
* @param rexBuilder builder to be used to create nodes
* @param typeFactory type factory
*/
public RexFuzzer(RexBuilder rexBuilder, JavaTypeFactory typeFactory) {
setUp();
this.rexBuilder = rexBuilder;
this.typeFactory = typeFactory;
intType = typeFactory.createSqlType(SqlTypeName.INTEGER);
nullableIntType = typeFactory.createTypeWithNullability(intType, true);
}
public RexNode getExpression(Random r, int depth) {
return getComparableExpression(r, depth);
}
private RexNode fuzzOperator(Random r, SqlOperator[] operators, RexNode... args) {
return rexBuilder.makeCall(operators[r.nextInt(operators.length)], args);
}
private RexNode fuzzOperator(Random r, SqlOperator[] operators, int length,
Function<Random, RexNode> factory) {
List<RexNode> args = new ArrayList<>(length);
for (int i = 0; i < length; i++) {
args.add(factory.apply(r));
}
return rexBuilder.makeCall(operators[r.nextInt(operators.length)], args);
}
public RexNode getComparableExpression(Random r, int depth) {
int v = r.nextInt(2);
switch (v) {
case 0:
return getBoolExpression(r, depth);
case 1:
return getIntExpression(r, depth);
}
throw new AssertionError("should not reach here");
}
public RexNode getSimpleBool(Random r) {
int v = r.nextInt(2);
switch (v) {
case 0:
boolean nullable = r.nextBoolean();
int field = r.nextInt(MAX_VARS);
return nullable ? vBool(field) : vBoolNotNull(field);
case 1:
return r.nextBoolean() ? trueLiteral : falseLiteral;
case 2:
return nullBool;
}
throw new AssertionError("should not reach here");
}
public RexNode getBoolExpression(Random r, int depth) {
int v = depth <= 0 ? 0 : r.nextInt(8);
switch (v) {
case 0:
return getSimpleBool(r);
case 1:
return fuzzOperator(r, ANY_TO_BOOL, getExpression(r, depth - 1));
case 2:
return fuzzOperator(r, BOOL_TO_BOOL, getBoolExpression(r, depth - 1));
case 3:
return fuzzOperator(r, COMPARABLE_TO_BOOL, getBoolExpression(r, depth - 1),
getBoolExpression(r, depth - 1));
case 4:
return fuzzOperator(r, COMPARABLE_TO_BOOL, getIntExpression(r, depth - 1),
getIntExpression(r, depth - 1));
case 5:
return fuzzOperator(r, BOOL_TO_BOOL_MULTI_ARG, r.nextInt(3) + 2,
x -> getBoolExpression(x, depth - 1));
case 6:
return fuzzCase(r, depth - 1,
x -> getBoolExpression(x, depth - 1));
case 7:
return fuzzSearch(r, getIntExpression(r, depth - 1));
}
throw new AssertionError("should not reach here");
}
public RexNode getSimpleInt(Random r) {
int v = r.nextInt(3);
switch (v) {
case 0:
boolean nullable = r.nextBoolean();
int field = r.nextInt(MAX_VARS);
return nullable ? vInt(field) : vIntNotNull(field);
case 1: {
int i = r.nextInt(INT_VALUES.length + 1);
int val = i < INT_VALUES.length ? INT_VALUES[i] : r.nextInt();
return rexBuilder.makeLiteral(val,
r.nextBoolean() ? intType : nullableIntType);
}
case 2:
return nullInt;
}
throw new AssertionError("should not reach here");
}
public RexNode getIntExpression(Random r, int depth) {
int v = depth <= 0 ? 0 : r.nextInt(5);
switch (v) {
case 0:
return getSimpleInt(r);
case 1:
return fuzzOperator(r, UNARY_NUMERIC, getIntExpression(r, depth - 1));
case 2:
return fuzzOperator(r, NUMERIC_TO_NUMERIC, getIntExpression(r, depth - 1),
getIntExpression(r, depth - 1));
case 3:
return fuzzOperator(r, ANY_SAME_TYPE_MULTI_ARG, r.nextInt(3) + 2,
x -> getIntExpression(x, depth - 1));
case 4:
return fuzzCase(r, depth - 1,
x -> getIntExpression(x, depth - 1));
}
throw new AssertionError("should not reach here");
}
public RexNode fuzzCase(Random r, int depth, Function<Random, RexNode> resultFactory) {
boolean caseArgWhen = r.nextBoolean();
int caseBranches = 1 + (depth <= 0 ? 0 : r.nextInt(3));
List<RexNode> args = new ArrayList<>(caseBranches + 1);
Function<Random, RexNode> exprFactory;
if (!caseArgWhen) {
exprFactory = x -> getBoolExpression(x, depth - 1);
} else {
int type = r.nextInt(2);
RexNode arg;
Function<Random, RexNode> baseExprFactory;
switch (type) {
case 0:
baseExprFactory = x -> getBoolExpression(x, depth - 1);
break;
case 1:
baseExprFactory = x -> getIntExpression(x, depth - 1);
break;
default:
throw new AssertionError("should not reach here: " + type);
}
arg = baseExprFactory.apply(r);
// emulate case when arg=2 then .. when arg=4 then ...
exprFactory = x -> eq(arg, baseExprFactory.apply(x));
}
for (int i = 0; i < caseBranches; i++) {
args.add(exprFactory.apply(r)); // when
args.add(resultFactory.apply(r)); // then
}
args.add(resultFactory.apply(r)); // else
return case_(args);
}
@SuppressWarnings("UnstableApiUsage")
public RexNode fuzzSearch(Random r, RexNode intExpression) {
final RangeSet<BigDecimal> rangeSet = TreeRangeSet.create();
final Generator<BigDecimal> integerGenerator = RexFuzzer::fuzzInt;
final Generator<RexUnknownAs> unknownGenerator =
enumGenerator(RexUnknownAs.class);
int i = 0;
for (;;) {
rangeSet.add(fuzzRange(r, integerGenerator));
if (r.nextBoolean() || i++ == 8) {
break;
}
}
final Sarg<BigDecimal> sarg =
Sarg.of(unknownGenerator.generate(r), rangeSet);
return rexBuilder.makeCall(SqlStdOperatorTable.SEARCH, intExpression,
rexBuilder.makeSearchArgumentLiteral(sarg, intExpression.getType()));
}
private static <T extends Enum<T>> Generator<T> enumGenerator(
Class<T> enumClass) {
final T[] enumConstants = enumClass.getEnumConstants();
return r -> enumConstants[r.nextInt(enumConstants.length)];
}
<T extends Comparable<T>> Range<T> fuzzRange(Random r,
Generator<T> generator) {
final Map.Entry<T, T> pair;
switch (r.nextInt(10)) {
case 0:
return Range.all();
case 1:
return Range.atLeast(generator.generate(r));
case 2:
return Range.atMost(generator.generate(r));
case 3:
return Range.greaterThan(generator.generate(r));
case 4:
return Range.lessThan(generator.generate(r));
case 5:
return Range.singleton(generator.generate(r));
case 6:
pair = orderedPair(r, false, generator);
return Range.closed(pair.getKey(), pair.getValue());
case 7:
pair = orderedPair(r, false, generator);
return Range.closedOpen(pair.getKey(), pair.getValue());
case 8:
pair = orderedPair(r, false, generator);
return Range.openClosed(pair.getKey(), pair.getValue());
case 9:
pair = orderedPair(r, true, generator);
return Range.open(pair.getKey(), pair.getValue());
default:
throw new AssertionError();
}
}
/** Generates a pair of values, the first being less than or equal to the
* second. */
static <T extends Comparable<T>> Pair<T, T> orderedPair(Random r,
boolean strict, Generator<T> generator) {
for (;;) {
final T v0 = generator.generate(r);
final T v1 = generator.generate(r);
int c = v0.compareTo(v1);
if (strict && c == 0) {
continue;
}
return c <= 0 ? Pair.of(v0, v1) : Pair.of(v1, v0);
}
}
/** Generates an integer between -5 and 10 (inclusive). All values are equally
* likely. */
static BigDecimal fuzzInt(Random r) {
return BigDecimal.valueOf(r.nextInt(16) - 5);
}
/** Generates values of a particular type, given a random-number generator.
*
* @param <T> Value type */
interface Generator<T> {
T generate(Random r);
}
}