[CALCITE-6283] Function ARRAY_APPEND with a NULL array argument crashes with NullPointerException

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 130e01f..003e66f 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -1180,6 +1180,10 @@
   private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding opBinding) {
     final RelDataType arrayType = opBinding.collectOperandTypes().get(0);
     final RelDataType componentType = arrayType.getComponentType();
+    if (componentType == null) {
+      // NULL used for array.
+      return arrayType;
+    }
     final RelDataType elementType = opBinding.collectOperandTypes().get(1);
     RelDataType type =
         opBinding.getTypeFactory().leastRestrictive(
@@ -1196,7 +1200,7 @@
   public static final SqlFunction ARRAY_APPEND =
       SqlBasicFunction.create(SqlKind.ARRAY_APPEND,
           SqlLibraryOperators::arrayAppendPrependReturnType,
-          OperandTypes.ARRAY_ELEMENT);
+          OperandTypes.ARRAY_ELEMENT_NONNULL);
 
   /** The "EXISTS(array, lambda)" function (Spark); returns whether a predicate holds
    * for one or more elements in the array. */
@@ -1311,35 +1315,35 @@
   public static final SqlFunction ARRAY_MAX =
       SqlBasicFunction.create(SqlKind.ARRAY_MAX,
           ReturnTypes.TO_COLLECTION_ELEMENT_FORCE_NULLABLE,
-          OperandTypes.ARRAY);
+          OperandTypes.ARRAY_NONNULL);
 
   /** The "ARRAY_MAX(array)" function. */
   @LibraryOperator(libraries = {SPARK})
   public static final SqlFunction ARRAY_MIN =
       SqlBasicFunction.create(SqlKind.ARRAY_MIN,
           ReturnTypes.TO_COLLECTION_ELEMENT_FORCE_NULLABLE,
-          OperandTypes.ARRAY);
+          OperandTypes.ARRAY_NONNULL);
 
   /** The "ARRAY_POSITION(array, element)" function. */
   @LibraryOperator(libraries = {SPARK})
   public static final SqlFunction ARRAY_POSITION =
       SqlBasicFunction.create(SqlKind.ARRAY_POSITION,
           ReturnTypes.BIGINT_NULLABLE,
-          OperandTypes.ARRAY_ELEMENT);
+          OperandTypes.ARRAY_ELEMENT_NONNULL);
 
   /** The "ARRAY_PREPEND(array, element)" function. */
   @LibraryOperator(libraries = {SPARK})
   public static final SqlFunction ARRAY_PREPEND =
       SqlBasicFunction.create(SqlKind.ARRAY_PREPEND,
           SqlLibraryOperators::arrayAppendPrependReturnType,
-          OperandTypes.ARRAY_ELEMENT);
+          OperandTypes.ARRAY_ELEMENT_NONNULL);
 
   /** The "ARRAY_REMOVE(array, element)" function. */
   @LibraryOperator(libraries = {SPARK})
   public static final SqlFunction ARRAY_REMOVE =
       SqlBasicFunction.create(SqlKind.ARRAY_REMOVE,
           ReturnTypes.ARG0_NULLABLE,
-          OperandTypes.ARRAY_ELEMENT);
+          OperandTypes.ARRAY_ELEMENT_NONNULL);
 
   /** The "ARRAY_REPEAT(element, count)" function. */
   @LibraryOperator(libraries = {SPARK})
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java b/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java
index bed38a7..13eb7f3 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java
@@ -21,7 +21,6 @@
 import org.apache.calcite.sql.SqlNode;
 import org.apache.calcite.sql.SqlOperandCountRange;
 import org.apache.calcite.sql.SqlOperator;
-import org.apache.calcite.sql.SqlUtil;
 
 import com.google.common.collect.ImmutableList;
 
@@ -34,19 +33,14 @@
 public class ArrayElementOperandTypeChecker implements SqlOperandTypeChecker {
   //~ Instance fields --------------------------------------------------------
 
-  private final boolean allowNullCheck;
-  private final boolean allowCast;
+  private final boolean arrayMayBeNull;
+  private final boolean elementMayBeNull;
 
   //~ Constructors -----------------------------------------------------------
 
-  public ArrayElementOperandTypeChecker() {
-    this.allowNullCheck = false;
-    this.allowCast = false;
-  }
-
-  public ArrayElementOperandTypeChecker(boolean allowNullCheck, boolean allowCast) {
-    this.allowNullCheck = allowNullCheck;
-    this.allowCast = allowCast;
+  public ArrayElementOperandTypeChecker(boolean arrayMayBeNull, boolean elementMayBeNull) {
+    this.arrayMayBeNull = arrayMayBeNull;
+    this.elementMayBeNull = elementMayBeNull;
   }
 
   //~ Methods ----------------------------------------------------------------
@@ -54,20 +48,19 @@
   @Override public boolean checkOperandTypes(
       SqlCallBinding callBinding,
       boolean throwOnFailure) {
-    if (allowNullCheck) {
-      // no operand can be null for type-checking to succeed
-      for (SqlNode node : callBinding.operands()) {
-        if (SqlUtil.isNullLiteral(node, allowCast)) {
-          if (throwOnFailure) {
-            throw callBinding.getValidator().newValidationError(node, RESOURCE.nullIllegal());
-          } else {
-            return false;
-          }
-        }
+    final SqlNode op0 = callBinding.operand(0);
+    RelDataType arrayType = SqlTypeUtil.deriveType(callBinding, op0);
+
+    // Check if op0 is allowed to be NULL
+    if (!this.arrayMayBeNull && arrayType.getSqlTypeName() == SqlTypeName.NULL) {
+      if (throwOnFailure) {
+        throw callBinding.getValidator().newValidationError(op0, RESOURCE.nullIllegal());
+      } else {
+        return false;
       }
     }
 
-    final SqlNode op0 = callBinding.operand(0);
+    // Check that op0 is an ARRAY type
     if (!OperandTypes.ARRAY.checkSingleOperandType(
         callBinding,
         op0,
@@ -75,20 +68,29 @@
         throwOnFailure)) {
       return false;
     }
-
     RelDataType arrayComponentType =
         getComponentTypeOrThrow(SqlTypeUtil.deriveType(callBinding, op0));
+
     final SqlNode op1 = callBinding.operand(1);
-    RelDataType aryType1 = SqlTypeUtil.deriveType(callBinding, op1);
+    RelDataType elementType = SqlTypeUtil.deriveType(callBinding, op1);
+
+    // Check if elementType is allowed to be NULL
+    if (!this.elementMayBeNull && elementType.getSqlTypeName() == SqlTypeName.NULL) {
+      if (throwOnFailure) {
+        throw callBinding.getValidator().newValidationError(op1, RESOURCE.nullIllegal());
+      } else {
+        return false;
+      }
+    }
 
     RelDataType biggest =
         callBinding.getTypeFactory().leastRestrictive(
-            ImmutableList.of(arrayComponentType, aryType1));
+            ImmutableList.of(arrayComponentType, elementType));
     if (biggest == null) {
       if (throwOnFailure) {
         throw callBinding.newError(
             RESOURCE.typeNotComparable(
-                arrayComponentType.toString(), aryType1.toString()));
+                arrayComponentType.toString(), elementType.toString()));
       }
 
       return false;
diff --git a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
index 18767f5..67367ba 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
@@ -612,10 +612,15 @@
       new ArrayFunctionOperandTypeChecker();
 
   public static final SqlOperandTypeChecker ARRAY_ELEMENT =
-      new ArrayElementOperandTypeChecker();
+      new ArrayElementOperandTypeChecker(true, true);
 
   public static final SqlOperandTypeChecker ARRAY_ELEMENT_NONNULL =
-      new ArrayElementOperandTypeChecker(true, false);
+      new ArrayElementOperandTypeChecker(false, true);
+
+  /** Type checker that accepts an ARRAY as the first argument, but not
+   * an expression with type NULL (i.e. a NULL literal). */
+  public static final SqlOperandTypeChecker ARRAY_NONNULL =
+      family(SqlTypeFamily.ARRAY).and(new NotNullOperandTypeChecker(1, false));
 
   public static final SqlOperandTypeChecker ARRAY_INSERT =
       new ArrayInsertOperandTypeChecker();
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 0a0e37e..8aaef82 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -6314,6 +6314,25 @@
     f.checkScalar("rand_integer(2, 11)", 1, "INTEGER NOT NULL");
   }
 
+  /** Test case for <a href="https://issues.apache.org/jira/browse/CALCITE-6283">
+   * [CALCITE-6283] Function array_append with a NULL array argument crashes with
+   * NullPointerException</a>. */
+  @Test void testArrayNullFunc() {
+    final String expected = "Illegal use of 'NULL'";
+    final SqlOperatorFixture f = fixture().withLibrary(SqlLibrary.SPARK);
+    f.checkFails("array_append(^null^, 2)", expected, false);
+    f.checkFails("array_prepend(^null^, 2)", expected, false);
+    f.checkFails("array_remove(^null^, 2)", expected, false);
+    f.checkFails("array_contains(^null^, 2)", expected, false);
+    f.checkFails("array_position(^null^, 2)", expected, false);
+    f.checkFails("^array_min(null)^",
+        "Cannot apply 'ARRAY_MIN' to arguments of type 'ARRAY_MIN\\(<NULL>\\)'."
+            + " Supported form\\(s\\): 'ARRAY_MIN\\(<ARRAY>\\)'", false);
+    f.checkFails("^array_max(null)^",
+        "Cannot apply 'ARRAY_MAX' to arguments of type 'ARRAY_MAX\\(<NULL>\\)'."
+        + " Supported form\\(s\\): 'ARRAY_MAX\\(<ARRAY>\\)'", false);
+  }
+
   /** Tests {@code ARRAY_APPEND} function from Spark. */
   @Test void testArrayAppendFunc() {
     final SqlOperatorFixture f0 = fixture();
@@ -6421,7 +6440,7 @@
         "INTEGER is not comparable to BOOLEAN", false);
 
     // check null without cast
-    f.checkFails("array_contains(array[1, 2], ^null^)", "Illegal use of 'NULL'", false);
+    f.checkNull("array_contains(array[1, 2], null)");
     f.checkFails("array_contains(^null^, array[1, 2])", "Illegal use of 'NULL'", false);
     f.checkFails("array_contains(^null^, null)", "Illegal use of 'NULL'", false);
   }