Optimize insert constant folding
diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java index 151d981..d7cf419 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java
@@ -267,13 +267,16 @@ "701: Input time format aa error. Input like yyyy-MM-dd HH:mm:ss, yyyy-MM-ddTHH:mm:ss or refer to user document for more info.", e.getMessage()); } - try { - session.executeNonQueryStatement("insert into wrong_time values(1+1,'bb','cc','dd')"); - fail("No exception thrown"); - } catch (StatementExecutionException e) { - assertEquals( - "701: Insert expression must be constant after constant folding: (1 + 1) (folded to (1 + 1))", - e.getMessage()); + session.executeNonQueryStatement("insert into wrong_time values(1+1,'bb','cc','dd')"); + try (SessionDataSet dataSet = + session.executeQueryStatement("select * from wrong_time where time = 2")) { + assertTrue(dataSet.hasNext()); + RowRecord rowRecord = dataSet.next(); + assertEquals(2L, rowRecord.getFields().get(0).getLongV()); + assertEquals("bb", rowRecord.getFields().get(1).getBinaryV().toString()); + assertEquals("cc", rowRecord.getFields().get(2).getBinaryV().toString()); + assertEquals("dd", rowRecord.getFields().get(3).getBinaryV().toString()); + assertFalse(dataSet.hasNext()); } try { session.executeNonQueryStatement("insert into wrong_time values(1.0,'bb','cc','dd')");
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrExpressionInterpreter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrExpressionInterpreter.java index 9e2b2ed..3d8d1e9 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrExpressionInterpreter.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrExpressionInterpreter.java
@@ -73,7 +73,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.instanceOf; -import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; @@ -84,6 +83,11 @@ import static org.apache.iotdb.commons.queryengine.plan.relational.type.TypeSignatureTranslator.toTypeSignature; import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.DeterminismEvaluator.isDeterministic; import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.isEffectivelyLiteral; +import static org.apache.tsfile.read.common.type.DoubleType.DOUBLE; +import static org.apache.tsfile.read.common.type.FloatType.FLOAT; +import static org.apache.tsfile.read.common.type.IntType.INT32; +import static org.apache.tsfile.read.common.type.LongType.INT64; +import static org.apache.tsfile.read.common.type.TimestampType.TIMESTAMP; public class IrExpressionInterpreter { @@ -515,23 +519,13 @@ } if (node.getSign() == PLUS) { - return value; + return normalizeArithmeticValue(value, type(node)); } else { - try { - Expression valueExpression = toExpression(value, type(node.getValue())); - if (valueExpression instanceof ArithmeticUnaryExpression - && ((ArithmeticUnaryExpression) valueExpression).getSign().equals(MINUS)) { - return ((ArithmeticUnaryExpression) valueExpression).getValue(); - } - return new ArithmeticUnaryExpression(MINUS, valueExpression); - // TODO use the following after we implement InterpretedFunctionInvoker - // return invokeOperator(OperatorType.NEGATION, types(node.getValue()), - // Collections.singletonList(value)); - } catch (Throwable throwable) { - throwIfInstanceOf(throwable, RuntimeException.class); - throwIfInstanceOf(throwable, Error.class); - throw new RuntimeException(throwable.getMessage(), throwable); + Type returnType = type(node); + if (isSupportedArithmeticUnaryType(returnType)) { + return evaluateArithmeticNegation(value, returnType); } + return new ArithmeticUnaryExpression(MINUS, toExpression(value, type(node.getValue()))); } } @@ -551,16 +545,21 @@ node.getOperator(), toExpression(left, type(node.getLeft())), toExpression(right, type(node.getRight()))); - } else { - return new ArithmeticBinaryExpression( - node.getOperator(), - toExpression(left, type(node.getLeft())), - toExpression(right, type(node.getRight()))); - // TODO use the following after we implement InterpretedFunctionInvoker - // return invokeOperator(OperatorType.valueOf(node.getOperator().name()), - // types(node.getLeft(), node.getRight()), - // ImmutableList.of(left, right)); } + + Type leftType = type(node.getLeft()); + Type rightType = type(node.getRight()); + Type returnType = type(node); + if (isSupportedArithmeticBinaryType(leftType) + && isSupportedArithmeticBinaryType(rightType) + && isSupportedArithmeticBinaryType(returnType)) { + return evaluateArithmeticBinary(node.getOperator(), left, leftType, right, returnType); + } + + return new ArithmeticBinaryExpression( + node.getOperator(), + toExpression(left, type(node.getLeft())), + toExpression(right, type(node.getRight()))); } @Override @@ -644,67 +643,12 @@ toExpression(left, type(leftExpression)), toExpression(right, type(rightExpression))); } else { - if (!(left instanceof Number) || !(right instanceof Number)) { - throw new IllegalArgumentException( - DataNodeQueryMessages.BOTH_OBJECT_MUST_BE_TYPE_OF_NUMBER); + if (left instanceof Number && right instanceof Number) { + return evaluateNumericComparison(operator, (Number) left, (Number) right); } - if (left instanceof Integer && right instanceof Integer) { - Integer leftNum = (Integer) left; - Integer rightNum = (Integer) right; - if (operator == ComparisonExpression.Operator.LESS_THAN) { - return leftNum < rightNum; - } else if (operator == ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) { - return leftNum <= rightNum; - } else if (operator == ComparisonExpression.Operator.EQUAL) { - return leftNum.equals(rightNum); - } - } - - if (left instanceof Long && right instanceof Long) { - Long leftNum = (Long) left; - Long rightNum = (Long) right; - if (operator == ComparisonExpression.Operator.LESS_THAN) { - return leftNum < rightNum; - } else if (operator == ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) { - return leftNum <= rightNum; - } else if (operator == ComparisonExpression.Operator.EQUAL) { - return leftNum.equals(rightNum); - } - } - - if (left instanceof Float && right instanceof Float) { - Float leftNum = (Float) left; - Float rightNum = (Float) right; - if (operator == ComparisonExpression.Operator.LESS_THAN) { - return leftNum < rightNum; - } else if (operator == ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) { - return leftNum <= rightNum; - } else if (operator == ComparisonExpression.Operator.EQUAL) { - return leftNum.equals(rightNum); - } - } - - if (left instanceof Double && right instanceof Double) { - Double leftNum = (Double) left; - Double rightNum = (Double) right; - if (operator == ComparisonExpression.Operator.LESS_THAN) { - return leftNum < rightNum; - } else if (operator == ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) { - return leftNum <= rightNum; - } else if (operator == ComparisonExpression.Operator.EQUAL) { - return leftNum.equals(rightNum); - } - } - - return new ComparisonExpression( - operator, - toExpression(left, type(leftExpression)), - toExpression(right, type(rightExpression))); - // TODO use the following after we implement InterpretedFunctionInvoker - // return invokeOperator(OperatorType.valueOf(operator.name()), types(leftExpression, - // rightExpression), - // ImmutableList.of(left, right)); + throw new IllegalArgumentException( + DataNodeQueryMessages.BOTH_OBJECT_MUST_BE_TYPE_OF_NUMBER); } } @@ -966,6 +910,203 @@ return values.stream().anyMatch(instanceOf(Expression.class)); } + private boolean isSupportedArithmeticBinaryType(Type type) { + return type.equals(INT32) || type.equals(INT64) || type.equals(FLOAT) || type.equals(DOUBLE); + } + + private boolean isSupportedArithmeticUnaryType(Type type) { + return isSupportedArithmeticBinaryType(type) || type.equals(TIMESTAMP); + } + + private Object evaluateArithmeticNegation(Object value, Type returnType) { + if (returnType.equals(INT32)) { + int intValue = toInt(value); + if (intValue == Integer.MIN_VALUE) { + throw new SemanticException(String.format("The %s is out of range of int.", intValue)); + } + return -intValue; + } + if (returnType.equals(INT64) || returnType.equals(TIMESTAMP)) { + long longValue = toLong(value); + if (longValue == Long.MIN_VALUE) { + throw new SemanticException(String.format("The %s is out of range of long.", longValue)); + } + return -longValue; + } + if (returnType.equals(FLOAT)) { + return -toFloat(value); + } + if (returnType.equals(DOUBLE)) { + return -toDouble(value); + } + throw new SemanticException(DataNodeQueryMessages.UNKNOWN_TYPE_2 + returnType); + } + + private Object evaluateArithmeticBinary( + ArithmeticBinaryExpression.Operator operator, + Object left, + Type leftType, + Object right, + Type returnType) { + if (returnType.equals(INT32)) { + return evaluateIntArithmetic(operator, toInt(left), toInt(right)); + } + if (returnType.equals(INT64)) { + return evaluateLongArithmetic(operator, toLong(left), leftType, toLong(right)); + } + if (returnType.equals(FLOAT)) { + return evaluateFloatArithmetic(operator, toFloat(left), toFloat(right)); + } + if (returnType.equals(DOUBLE)) { + return evaluateDoubleArithmetic(operator, toDouble(left), toDouble(right)); + } + throw new SemanticException(DataNodeQueryMessages.UNKNOWN_TYPE_2 + returnType); + } + + private int evaluateIntArithmetic( + ArithmeticBinaryExpression.Operator operator, int left, int right) { + try { + switch (operator) { + case ADD: + return Math.addExact(left, right); + case SUBTRACT: + return Math.subtractExact(left, right); + case MULTIPLY: + return Math.multiplyExact(left, right); + case DIVIDE: + if (left == Integer.MIN_VALUE && right == -1) { + throw new SemanticException(String.format("int overflow: %s / %s", left, right)); + } + return left / right; + case MODULUS: + return left % right; + default: + throw new UnsupportedOperationException( + DataNodeQueryMessages.UNSUPPORTED_EXPRESSION + operator); + } + } catch (ArithmeticException e) { + throw new SemanticException(e.getMessage()); + } + } + + private long evaluateLongArithmetic( + ArithmeticBinaryExpression.Operator operator, long left, Type leftType, long right) { + try { + switch (operator) { + case ADD: + return Math.addExact(left, right); + case SUBTRACT: + return Math.subtractExact(left, right); + case MULTIPLY: + return Math.multiplyExact(left, right); + case DIVIDE: + if ((leftType.equals(INT32) && left == Integer.MIN_VALUE && right == -1) + || (leftType.equals(INT64) && left == Long.MIN_VALUE && right == -1)) { + throw new SemanticException(String.format("long overflow: %s / %s", left, right)); + } + return left / right; + case MODULUS: + return left % right; + default: + throw new UnsupportedOperationException( + DataNodeQueryMessages.UNSUPPORTED_EXPRESSION + operator); + } + } catch (ArithmeticException e) { + throw new SemanticException(e.getMessage()); + } + } + + private float evaluateFloatArithmetic( + ArithmeticBinaryExpression.Operator operator, float left, float right) { + switch (operator) { + case ADD: + return left + right; + case SUBTRACT: + return left - right; + case MULTIPLY: + return left * right; + case DIVIDE: + return left / right; + case MODULUS: + return left % right; + default: + throw new UnsupportedOperationException( + DataNodeQueryMessages.UNSUPPORTED_EXPRESSION + operator); + } + } + + private double evaluateDoubleArithmetic( + ArithmeticBinaryExpression.Operator operator, double left, double right) { + switch (operator) { + case ADD: + return left + right; + case SUBTRACT: + return left - right; + case MULTIPLY: + return left * right; + case DIVIDE: + return left / right; + case MODULUS: + return left % right; + default: + throw new UnsupportedOperationException( + DataNodeQueryMessages.UNSUPPORTED_EXPRESSION + operator); + } + } + + private Object normalizeArithmeticValue(Object value, Type type) { + if (type.equals(INT32)) { + return toInt(value); + } + if (type.equals(INT64) || type.equals(TIMESTAMP)) { + return toLong(value); + } + if (type.equals(FLOAT)) { + return toFloat(value); + } + if (type.equals(DOUBLE)) { + return toDouble(value); + } + return value; + } + + private Boolean evaluateNumericComparison( + ComparisonExpression.Operator operator, Number left, Number right) { + int comparison = + (left instanceof Float + || left instanceof Double + || right instanceof Float + || right instanceof Double) + ? Double.compare(left.doubleValue(), right.doubleValue()) + : Long.compare(left.longValue(), right.longValue()); + switch (operator) { + case LESS_THAN: + return comparison < 0; + case LESS_THAN_OR_EQUAL: + return comparison <= 0; + case EQUAL: + return comparison == 0; + default: + return null; + } + } + + private int toInt(Object value) { + return ((Number) value).intValue(); + } + + private long toLong(Object value) { + return ((Number) value).longValue(); + } + + private float toFloat(Object value) { + return ((Number) value).floatValue(); + } + + private double toDouble(Object value) { + return ((Number) value).doubleValue(); + } + private Object invokeOperator( OperatorType operatorType, List<? extends Type> argumentTypes,