[CALCITE-6389] RexBuilder.removeCastFromLiteral does not preserve semantics for some types of literal

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
index e49788c..7886c53 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
@@ -268,22 +268,29 @@
   /**
    * Used for safe operators that return null if an exception is thrown.
    */
-  private static Expression expressionHandlingSafe(Expression body, boolean safe) {
-    return safe ? safeExpression(body) : body;
+  private Expression expressionHandlingSafe(
+      Expression body, boolean safe, RelDataType targetType) {
+    return safe ? safeExpression(body, targetType) : body;
   }
 
-  private static Expression safeExpression(Expression body) {
+  private Expression safeExpression(Expression body, RelDataType targetType) {
     final ParameterExpression e_ =
         Expressions.parameter(Exception.class, new BlockBuilder().newName("e"));
 
-    return Expressions.call(
-        Expressions.lambda(
-            Expressions.block(
-                Expressions.tryCatch(
-                    Expressions.return_(null, body),
-                Expressions.catch_(e_,
-                    Expressions.return_(null, constant(null)))))),
-        BuiltInMethod.FUNCTION0_APPLY.method);
+    // The type received for the targetType is never nullable.
+    // But safe casts may return null
+    RelDataType nullableTargetType = typeFactory.createTypeWithNullability(targetType, true);
+    Expression result =
+        Expressions.call(
+            Expressions.lambda(
+                Expressions.block(
+                    Expressions.tryCatch(
+                        Expressions.return_(null, body),
+                        Expressions.catch_(e_,
+                            Expressions.return_(null, constant(null)))))),
+            BuiltInMethod.FUNCTION0_APPLY.method);
+    // FUNCTION0 always returns Object, so we need a cast to the target type
+    return EnumUtils.convert(result, typeFactory.getJavaClass(nullableTargetType));
   }
 
   Expression translateCast(
@@ -294,7 +301,7 @@
       ConstantExpression format) {
     Expression convert = getConvertExpression(sourceType, targetType, operand, format);
     Expression convert2 = checkExpressionPadTruncate(convert, sourceType, targetType);
-    Expression convert3 = expressionHandlingSafe(convert2, safe);
+    Expression convert3 = expressionHandlingSafe(convert2, safe, targetType);
     return scaleValue(sourceType, targetType, convert3);
   }
 
diff --git a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java
index 341759f..7a7f247 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java
@@ -694,7 +694,9 @@
       return false;
     }
     if (toType.getSqlTypeName() != fromTypeName
-        && SqlTypeFamily.DATETIME.getTypeNames().contains(fromTypeName)) {
+        && (SqlTypeFamily.DATETIME.getTypeNames().contains(fromTypeName)
+        || SqlTypeFamily.INTERVAL_DAY_TIME.getTypeNames().contains(fromTypeName)
+        || SqlTypeFamily.INTERVAL_YEAR_MONTH.getTypeNames().contains(fromTypeName))) {
       return false;
     }
     if (value instanceof NlsString) {
@@ -720,9 +722,10 @@
       }
     }
 
-    if (toType.getSqlTypeName() == SqlTypeName.DECIMAL) {
+    if (toType.getSqlTypeName() == SqlTypeName.DECIMAL
+        && fromTypeName.getFamily() == SqlTypeFamily.NUMERIC) {
       final BigDecimal decimalValue = (BigDecimal) value;
-      return SqlTypeUtil.isValidDecimalValue(decimalValue, toType);
+      return SqlTypeUtil.canBeRepresentedExactly(decimalValue, toType);
     }
 
     if (SqlTypeName.INT_TYPES.contains(sqlType)) {
@@ -731,17 +734,23 @@
       if (s != 0) {
         return false;
       }
-      long l = decimalValue.longValue();
-      switch (sqlType) {
-      case TINYINT:
-        return l >= Byte.MIN_VALUE && l <= Byte.MAX_VALUE;
-      case SMALLINT:
-        return l >= Short.MIN_VALUE && l <= Short.MAX_VALUE;
-      case INTEGER:
-        return l >= Integer.MIN_VALUE && l <= Integer.MAX_VALUE;
-      case BIGINT:
-      default:
-        return true;
+      try {
+        // will trigger ArithmeticException when the value
+        // cannot be represented exactly as a long
+        long l = decimalValue.longValueExact();
+        switch (sqlType) {
+        case TINYINT:
+          return l >= Byte.MIN_VALUE && l <= Byte.MAX_VALUE;
+        case SMALLINT:
+          return l >= Short.MIN_VALUE && l <= Short.MAX_VALUE;
+        case INTEGER:
+          return l >= Integer.MIN_VALUE && l <= Integer.MAX_VALUE;
+        case BIGINT:
+        default:
+          return true;
+        }
+      } catch (ArithmeticException ex) {
+        return false;
       }
     }
 
diff --git a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java
index 3e91b1b..7f0bc64 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java
@@ -51,6 +51,7 @@
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 import java.math.BigDecimal;
+import java.math.RoundingMode;
 import java.nio.charset.Charset;
 import java.util.AbstractList;
 import java.util.ArrayList;
@@ -1823,6 +1824,35 @@
   }
 
   /**
+   * Returns whether the decimal value can be represented without information loss
+   * using the specified type.
+   * For example, 1111.11
+   * - cannot be represented exactly using DECIMAL(3, 1) since it overflows.
+   * - cannot be represented exactly using DECIMAL(6, 3) since it overflows.
+   * - cannot be represented exactly using DECIMAL(6, 1) since it requires rounding.
+   * - can be represented exactly using DECIMAL(6, 2)
+   *
+   * @param value  A decimal value
+   * @param toType A DECIMAL type.
+   * @return whether the value is valid for the type
+   */
+  public static boolean canBeRepresentedExactly(@Nullable BigDecimal value, RelDataType toType) {
+    assert toType.getSqlTypeName() == SqlTypeName.DECIMAL;
+    if (value == null) {
+      return true;
+    }
+    value = value.stripTrailingZeros();
+    if (value.scale() < 0) {
+      // Negative scale, convert to 0 scale.
+      // Rounding mode is irrelevant, since value is integer
+      value = value.setScale(0, RoundingMode.DOWN);
+    }
+    final int intDigits = value.precision() - value.scale();
+    final int maxIntDigits = toType.getPrecision() - toType.getScale();
+    return (intDigits <= maxIntDigits) && (value.scale() <= toType.getScale());
+  }
+
+  /**
    * Returns whether the decimal value is valid for the type. For example, 1111.11 is not
    * valid for DECIMAL(3, 1) since it overflows.
    *
diff --git a/core/src/test/java/org/apache/calcite/rex/RexBuilderTest.java b/core/src/test/java/org/apache/calcite/rex/RexBuilderTest.java
index 9db5f8e..535232a 100644
--- a/core/src/test/java/org/apache/calcite/rex/RexBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/rex/RexBuilderTest.java
@@ -69,6 +69,7 @@
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.hasToString;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -213,6 +214,58 @@
         hasToString("1969-07-21 02:56:15.102"));
   }
 
+  /** Test cases for
+   * <a href="https://issues.apache.org/jira/browse/CALCITE-6389">[CALCITE-6389]
+   * RexBuilder.removeCastFromLiteral does not preserve semantics for some types of literal</a>. */
+  @Test void testRemoveCast() {
+    final RelDataTypeFactory typeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
+    RexBuilder builder = new RexBuilder(typeFactory);
+
+    // Can remove cast of an integer to an integer
+    BigDecimal value = new BigDecimal(10);
+    RelDataType toType = builder.typeFactory.createSqlType(SqlTypeName.INTEGER);
+    assertTrue(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTEGER));
+
+    // Can remove cast from integer to decimal
+    toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL);
+    assertTrue(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTEGER));
+
+    // 250 is too large for a TINYINT
+    value = new BigDecimal(250);
+    toType = builder.typeFactory.createSqlType(SqlTypeName.TINYINT);
+    assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTEGER));
+
+    // 50 isn't too large for a TINYINT
+    value = new BigDecimal(50);
+    toType = builder.typeFactory.createSqlType(SqlTypeName.TINYINT);
+    assertTrue(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTEGER));
+
+    // 120.25 cannot be represented with precision 2 and scale 2 without loss
+    value = new BigDecimal("120.25");
+    toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL, 2, 2);
+    assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.DECIMAL));
+
+    // 120.25 cannot be represented with precision 5 and scale 1 without rounding
+    value = new BigDecimal("120.25");
+    toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL, 5, 1);
+    assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.DECIMAL));
+
+    // longmax + 1 cannot be represented as a long
+    value = new BigDecimal(Long.MAX_VALUE).add(BigDecimal.ONE);
+    toType = builder.typeFactory.createSqlType(SqlTypeName.BIGINT);
+    assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.DECIMAL));
+
+    // Cast to decimal of an INTERVAL '5' seconds cannot be removed
+    value = new BigDecimal("5");
+    toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL, 5, 1);
+    assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTERVAL_SECOND));
+
+    // Cast to decimal of an INTERVAL '5' minutes cannot be removed
+    value = new BigDecimal("5");
+    toType = builder.typeFactory.createSqlType(SqlTypeName.DECIMAL, 5, 1);
+    assertFalse(builder.canRemoveCastFromLiteral(toType, value, SqlTypeName.INTERVAL_MINUTE));
+  }
+
   @Test void testTimestampString() {
     final TimestampString ts = new TimestampString(1969, 7, 21, 2, 56, 15);
     assertThat(ts, hasToString("1969-07-21 02:56:15"));