blob: 8a4012e1eea58d81ebe2f227b8d9101ea2a783b0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.rex;
import org.apache.calcite.avatica.util.DateTimeUtils;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.avatica.util.TimeUnitRange;
import org.apache.calcite.rel.metadata.NullSentinel;
import org.apache.calcite.runtime.SqlFunctions;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.NlsString;
import org.apache.calcite.util.RangeSets;
import org.apache.calcite.util.Sarg;
import org.apache.calcite.util.TimeString;
import org.apache.calcite.util.TimestampString;
import org.apache.calcite.util.Util;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.RangeSet;
import org.checkerframework.checker.nullness.qual.Nullable;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.function.IntPredicate;
/**
* Evaluates {@link RexNode} expressions.
*
* <p>Caveats:
* <ul>
* <li>It uses interpretation, so it is not very efficient.
* <li>It is intended for testing, so does not cover very many functions and
* operators. (Feel free to contribute more!)
* <li>It is not well tested.
* </ul>
*/
public class RexInterpreter implements RexVisitor<Comparable> {
private static final NullSentinel N = NullSentinel.INSTANCE;
public static final EnumSet<SqlKind> SUPPORTED_SQL_KIND =
EnumSet.of(SqlKind.IS_NOT_DISTINCT_FROM, SqlKind.EQUALS, SqlKind.IS_DISTINCT_FROM,
SqlKind.NOT_EQUALS, SqlKind.GREATER_THAN, SqlKind.GREATER_THAN_OR_EQUAL,
SqlKind.LESS_THAN, SqlKind.LESS_THAN_OR_EQUAL, SqlKind.AND, SqlKind.OR,
SqlKind.NOT, SqlKind.CASE, SqlKind.IS_TRUE, SqlKind.IS_NOT_TRUE,
SqlKind.IS_FALSE, SqlKind.IS_NOT_FALSE, SqlKind.PLUS_PREFIX,
SqlKind.MINUS_PREFIX, SqlKind.PLUS, SqlKind.MINUS, SqlKind.TIMES,
SqlKind.DIVIDE, SqlKind.COALESCE, SqlKind.CEIL,
SqlKind.FLOOR, SqlKind.EXTRACT);
private final Map<RexNode, Comparable> environment;
/** Creates an interpreter.
*
* @param environment Values of certain expressions (usually
* {@link RexInputRef}s)
*/
private RexInterpreter(Map<RexNode, Comparable> environment) {
this.environment = ImmutableMap.copyOf(environment);
}
/** Evaluates an expression in an environment. */
public static @Nullable Comparable evaluate(RexNode e, Map<RexNode, Comparable> map) {
final Comparable v = e.accept(new RexInterpreter(map));
if (false) {
System.out.println("evaluate " + e + " on " + map + " returns " + v);
}
return v;
}
private static IllegalArgumentException unbound(RexNode e) {
return new IllegalArgumentException("unbound: " + e);
}
private Comparable getOrUnbound(RexNode e) {
final Comparable comparable = environment.get(e);
if (comparable != null) {
return comparable;
}
throw unbound(e);
}
@Override public Comparable visitInputRef(RexInputRef inputRef) {
return getOrUnbound(inputRef);
}
@Override public Comparable visitLocalRef(RexLocalRef localRef) {
throw unbound(localRef);
}
@Override public Comparable visitLiteral(RexLiteral literal) {
return Util.first(literal.getValue4(), N);
}
@Override public Comparable visitOver(RexOver over) {
throw unbound(over);
}
@Override public Comparable visitCorrelVariable(RexCorrelVariable correlVariable) {
return getOrUnbound(correlVariable);
}
@Override public Comparable visitDynamicParam(RexDynamicParam dynamicParam) {
return getOrUnbound(dynamicParam);
}
@Override public Comparable visitRangeRef(RexRangeRef rangeRef) {
throw unbound(rangeRef);
}
@Override public Comparable visitFieldAccess(RexFieldAccess fieldAccess) {
return getOrUnbound(fieldAccess);
}
@Override public Comparable visitSubQuery(RexSubQuery subQuery) {
throw unbound(subQuery);
}
@Override public Comparable visitTableInputRef(RexTableInputRef fieldRef) {
throw unbound(fieldRef);
}
@Override public Comparable visitPatternFieldRef(RexPatternFieldRef fieldRef) {
throw unbound(fieldRef);
}
@Override public Comparable visitCall(RexCall call) {
final List<Comparable> values = visitList(call.operands);
switch (call.getKind()) {
case IS_NOT_DISTINCT_FROM:
if (containsNull(values)) {
return values.get(0).equals(values.get(1));
}
// falls through EQUALS
case EQUALS:
return compare(values, c -> c == 0);
case IS_DISTINCT_FROM:
if (containsNull(values)) {
return !values.get(0).equals(values.get(1));
}
// falls through NOT_EQUALS
case NOT_EQUALS:
return compare(values, c -> c != 0);
case GREATER_THAN:
return compare(values, c -> c > 0);
case GREATER_THAN_OR_EQUAL:
return compare(values, c -> c >= 0);
case LESS_THAN:
return compare(values, c -> c < 0);
case LESS_THAN_OR_EQUAL:
return compare(values, c -> c <= 0);
case AND:
return values.stream().map(Truthy::of).min(Comparator.naturalOrder())
.get().toComparable();
case OR:
return values.stream().map(Truthy::of).max(Comparator.naturalOrder())
.get().toComparable();
case NOT:
return not(values.get(0));
case CASE:
return case_(values);
case IS_TRUE:
return values.get(0).equals(true);
case IS_NOT_TRUE:
return !values.get(0).equals(true);
case IS_NULL:
return values.get(0).equals(N);
case IS_NOT_NULL:
return !values.get(0).equals(N);
case IS_FALSE:
return values.get(0).equals(false);
case IS_NOT_FALSE:
return !values.get(0).equals(false);
case PLUS_PREFIX:
return values.get(0);
case MINUS_PREFIX:
return containsNull(values) ? N
: number(values.get(0)).negate();
case PLUS:
return containsNull(values) ? N
: number(values.get(0)).add(number(values.get(1)));
case MINUS:
return containsNull(values) ? N
: number(values.get(0)).subtract(number(values.get(1)));
case TIMES:
return containsNull(values) ? N
: number(values.get(0)).multiply(number(values.get(1)));
case DIVIDE:
return containsNull(values) ? N
: number(values.get(0)).divide(number(values.get(1)));
case CAST:
return cast(values);
case COALESCE:
return coalesce(values);
case CEIL:
case FLOOR:
return ceil(call, values);
case EXTRACT:
return extract(values);
case LIKE:
return like(values);
case SEARCH:
return search(call.operands.get(1).getType().getSqlTypeName(), values);
default:
throw unbound(call);
}
}
private static Comparable extract(List<Comparable> values) {
final Comparable v = values.get(1);
if (v == N) {
return N;
}
final TimeUnitRange timeUnitRange = (TimeUnitRange) values.get(0);
final int v2;
if (v instanceof Long) {
// TIMESTAMP
v2 = (int) (((Long) v) / TimeUnit.DAY.multiplier.longValue());
} else {
// DATE
v2 = (Integer) v;
}
return DateTimeUtils.unixDateExtract(timeUnitRange, v2);
}
private static Comparable like(List<Comparable> values) {
if (containsNull(values)) {
return N;
}
final NlsString value = (NlsString) values.get(0);
final NlsString pattern = (NlsString) values.get(1);
switch (values.size()) {
case 2:
return SqlFunctions.like(value.getValue(), pattern.getValue());
case 3:
final NlsString escape = (NlsString) values.get(2);
return SqlFunctions.like(value.getValue(), pattern.getValue(),
escape.getValue());
default:
throw new AssertionError();
}
}
@SuppressWarnings({"BetaApi", "rawtypes", "unchecked", "UnstableApiUsage"})
private static Comparable search(SqlTypeName typeName, List<Comparable> values) {
final Comparable value = values.get(0);
final Sarg sarg = (Sarg) values.get(1);
if (value == N) {
switch (sarg.nullAs) {
case FALSE:
return false;
case TRUE:
return true;
default:
return N;
}
}
return translate(sarg.rangeSet, typeName).contains(value);
}
/** Translates the values in a RangeSet from literal format to runtime format.
* For example the DATE SQL type uses DateString for literals and Integer at
* runtime. */
@SuppressWarnings({"BetaApi", "rawtypes", "unchecked", "UnstableApiUsage"})
private static RangeSet translate(RangeSet rangeSet, SqlTypeName typeName) {
switch (typeName) {
case DATE:
return RangeSets.copy(rangeSet, DateString::getDaysSinceEpoch);
case TIME:
return RangeSets.copy(rangeSet, TimeString::getMillisOfDay);
case TIMESTAMP:
return RangeSets.copy(rangeSet, TimestampString::getMillisSinceEpoch);
default:
return rangeSet;
}
}
private static Comparable coalesce(List<Comparable> values) {
for (Comparable value : values) {
if (value != N) {
return value;
}
}
return N;
}
private static Comparable ceil(RexCall call, List<Comparable> values) {
if (values.get(0) == N) {
return N;
}
final Long v = (Long) values.get(0);
final TimeUnitRange unit = (TimeUnitRange) values.get(1);
switch (unit) {
case YEAR:
case MONTH:
switch (call.getKind()) {
case FLOOR:
return DateTimeUtils.unixTimestampFloor(unit, v);
default:
return DateTimeUtils.unixTimestampCeil(unit, v);
}
default:
break;
}
final TimeUnitRange subUnit = subUnit(unit);
for (long v2 = v;;) {
final int e = DateTimeUtils.unixTimestampExtract(subUnit, v2);
if (e == 0) {
return v2;
}
v2 -= unit.startUnit.multiplier.longValue();
}
}
private static TimeUnitRange subUnit(TimeUnitRange unit) {
switch (unit) {
case QUARTER:
return TimeUnitRange.MONTH;
default:
return TimeUnitRange.DAY;
}
}
private static Comparable cast(List<Comparable> values) {
if (values.get(0) == N) {
return N;
}
return values.get(0);
}
private static Comparable not(Comparable value) {
if (value.equals(true)) {
return false;
} else if (value.equals(false)) {
return true;
} else {
return N;
}
}
private static Comparable case_(List<Comparable> values) {
final int size;
final Comparable elseValue;
if (values.size() % 2 == 0) {
size = values.size();
elseValue = N;
} else {
size = values.size() - 1;
elseValue = Util.last(values);
}
for (int i = 0; i < size; i += 2) {
if (values.get(i).equals(true)) {
return values.get(i + 1);
}
}
return elseValue;
}
private static BigDecimal number(Comparable comparable) {
return comparable instanceof BigDecimal
? (BigDecimal) comparable
: comparable instanceof BigInteger
? new BigDecimal((BigInteger) comparable)
: comparable instanceof Long
|| comparable instanceof Integer
|| comparable instanceof Short
? new BigDecimal(((Number) comparable).longValue())
: new BigDecimal(((Number) comparable).doubleValue());
}
private static Comparable compare(List<Comparable> values, IntPredicate p) {
if (containsNull(values)) {
return N;
}
Comparable v0 = values.get(0);
Comparable v1 = values.get(1);
if (v0 instanceof Number && v1 instanceof NlsString) {
try {
v1 = new BigDecimal(((NlsString) v1).getValue());
} catch (NumberFormatException e) {
return false;
}
}
if (v1 instanceof Number && v0 instanceof NlsString) {
try {
v0 = new BigDecimal(((NlsString) v0).getValue());
} catch (NumberFormatException e) {
return false;
}
}
if (v0 instanceof Number) {
v0 = number(v0);
}
if (v1 instanceof Number) {
v1 = number(v1);
}
//noinspection unchecked
final int c = v0.compareTo(v1);
return p.test(c);
}
private static boolean containsNull(List<Comparable> values) {
for (Comparable value : values) {
if (value == N) {
return true;
}
}
return false;
}
/** An enum that wraps boolean and unknown values and makes them
* comparable. */
enum Truthy {
// Order is important; AND returns the min, OR returns the max
FALSE, UNKNOWN, TRUE;
static Truthy of(Comparable c) {
return c.equals(true) ? TRUE : c.equals(false) ? FALSE : UNKNOWN;
}
Comparable toComparable() {
switch (this) {
case TRUE: return true;
case FALSE: return false;
case UNKNOWN: return N;
default:
throw new AssertionError();
}
}
}
}