[CALCITE-6228] ELEMENT function infers incorrect return type

Signed-off-by: Mihai Budiu <mbudiu@feldera.com>
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
index 4b07e5a..1615196 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
@@ -2103,7 +2103,7 @@
    */
   public static final SqlFunction ELEMENT =
       SqlBasicFunction.create("ELEMENT",
-          ReturnTypes.MULTISET_ELEMENT_NULLABLE,
+          ReturnTypes.MULTISET_ELEMENT_FORCE_NULLABLE,
           OperandTypes.COLLECTION);
 
   /**
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index 5924f00..f869689 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -615,10 +615,11 @@
       ARG0.andThen(SqlTypeTransforms.TO_MULTISET);
 
   /**
-   * Returns the element type of a MULTISET.
+   * Returns the element type of a MULTISET, with nullability enforced.
    */
-  public static final SqlReturnTypeInference MULTISET_ELEMENT_NULLABLE =
-      MULTISET.andThen(SqlTypeTransforms.TO_COLLECTION_ELEMENT_TYPE);
+  public static final SqlReturnTypeInference MULTISET_ELEMENT_FORCE_NULLABLE =
+      MULTISET.andThen(SqlTypeTransforms.TO_COLLECTION_ELEMENT_TYPE)
+          .andThen(SqlTypeTransforms.FORCE_NULLABLE);
 
   /**
    * Same as {@link #MULTISET} but returns with nullability if any of the
diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
index 8fead2c..f888d4e 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
@@ -1925,17 +1925,17 @@
 
   @Test void testElement() {
     expr("element(multiset[1])")
-        .columnType("INTEGER NOT NULL");
+        .columnType("INTEGER");
     expr("1.0+element(multiset[1])")
-        .columnType("DECIMAL(12, 1) NOT NULL");
+        .columnType("DECIMAL(12, 1)");
     expr("element(multiset['1'])")
-        .columnType("CHAR(1) NOT NULL");
+        .columnType("CHAR(1)");
     expr("element(multiset[1e-2])")
-        .columnType("DOUBLE NOT NULL");
+        .columnType("DOUBLE");
     expr("element(multiset[multiset[cast(null as tinyint)]])")
-        .columnType("TINYINT MULTISET NOT NULL");
-    // Test case for <a href="https://issues.apache.org/jira/projects/CALCITE/issues/CALCITE-6227">
-    // ELEMENT(NULL) causes an assertion failure</a>.
+        .columnType("TINYINT MULTISET");
+    // Test case for https://issues.apache.org/jira/projects/CALCITE/issues/CALCITE-6227
+    // ELEMENT(NULL) causes an assertion failure.
     expr("element(null)")
         .columnType("NULL");
   }
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index badde6f..4b4e8fd 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -9816,7 +9816,7 @@
   @Test void testElementFunc() {
     final SqlOperatorFixture f = fixture();
     f.setFor(SqlStdOperatorTable.ELEMENT, VM_FENNEL, VM_JAVA);
-    f.checkString("element(multiset['abc'])", "abc", "CHAR(3) NOT NULL");
+    f.checkString("element(multiset['abc'])", "abc", "CHAR(3)");
     f.checkNull("element(multiset[cast(null as integer)])");
   }