Unit test for DefaultOperandTypeChecker (#11152)

* Less strict operand type check and implicit casting

* fix ci

* Clean up unnecessary changes

* more cleanup

* unused import
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
index c7d0780..9aba1b7 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
@@ -19,6 +19,7 @@
 
 package org.apache.druid.sql.calcite.expression;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
@@ -500,13 +501,15 @@
    * Operand type checker that is used in 'simple' situations: there are a particular number of operands, with
    * particular types, some of which may be optional or nullable, and some of which may be required to be literals.
    */
-  private static class DefaultOperandTypeChecker implements SqlOperandTypeChecker
+  @VisibleForTesting
+  static class DefaultOperandTypeChecker implements SqlOperandTypeChecker
   {
     private final List<SqlTypeFamily> operandTypes;
     private final int requiredOperands;
     private final IntSet nullableOperands;
     private final IntSet literalOperands;
 
+    @VisibleForTesting
     DefaultOperandTypeChecker(
         final List<SqlTypeFamily> operandTypes,
         final int requiredOperands,
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java
new file mode 100644
index 0000000..0268bb6
--- /dev/null
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.expression;
+
+import com.google.common.collect.ImmutableList;
+import it.unimi.dsi.fastutil.ints.IntSets;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.runtime.CalciteContextException;
+import org.apache.calcite.runtime.Resources.ExInst;
+import org.apache.calcite.sql.SqlCallBinding;
+import org.apache.calcite.sql.SqlFunction;
+import org.apache.calcite.sql.SqlLiteral;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlOperandCountRange;
+import org.apache.calcite.sql.parser.SqlParserPos;
+import org.apache.calcite.sql.type.SqlOperandTypeChecker;
+import org.apache.calcite.sql.type.SqlTypeFamily;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.validate.SqlValidator;
+import org.apache.calcite.sql.validate.SqlValidatorScope;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentMatchers;
+import org.mockito.Mockito;
+import org.mockito.stubbing.Answer;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+@RunWith(Enclosed.class)
+public class OperatorConversionsTest
+{
+  public static class DefaultOperandTypeCheckerTest
+  {
+    @Rule
+    public ExpectedException expectedException = ExpectedException.none();
+
+    @Test
+    public void testGetOperandCountRange()
+    {
+      SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
+          ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
+          2,
+          IntSets.EMPTY_SET,
+          null
+      );
+      SqlOperandCountRange countRange = typeChecker.getOperandCountRange();
+      Assert.assertEquals(2, countRange.getMin());
+      Assert.assertEquals(3, countRange.getMax());
+    }
+
+    @Test
+    public void testIsOptional()
+    {
+      SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
+          ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
+          2,
+          IntSets.EMPTY_SET,
+          null
+      );
+      Assert.assertFalse(typeChecker.isOptional(0));
+      Assert.assertFalse(typeChecker.isOptional(1));
+      Assert.assertTrue(typeChecker.isOptional(2));
+    }
+
+    @Test
+    public void testAllowFullOperands()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testAllowFullOperands")
+          .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
+          .requiredOperands(1)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      Assert.assertTrue(
+          typeChecker.checkOperandTypes(
+              mockCallBinding(
+                  function,
+                  ImmutableList.of(
+                      new OperandSpec(SqlTypeName.INTEGER, false),
+                      new OperandSpec(SqlTypeName.DATE, false)
+                  )
+              ),
+              true
+          )
+      );
+    }
+
+    @Test
+    public void testRequiredOperandsOnly()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testRequiredOperandsOnly")
+          .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
+          .requiredOperands(1)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      Assert.assertTrue(
+          typeChecker.checkOperandTypes(
+              mockCallBinding(
+                  function,
+                  ImmutableList.of(new OperandSpec(SqlTypeName.INTEGER, false))
+              ),
+              true
+          )
+      );
+    }
+
+    @Test
+    public void testLiteralOperandCheckLiteral()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testLiteralOperandCheckLiteral")
+          .operandTypes(SqlTypeFamily.INTEGER)
+          .requiredOperands(1)
+          .literalOperands(0)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      Assert.assertFalse(
+          typeChecker.checkOperandTypes(
+              mockCallBinding(
+                  function,
+                  ImmutableList.of(new OperandSpec(SqlTypeName.INTEGER, false))
+              ),
+              false
+          )
+      );
+    }
+
+    @Test
+    public void testLiteralOperandCheckLiteralThrow()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testLiteralOperandCheckLiteralThrow")
+          .operandTypes(SqlTypeFamily.INTEGER)
+          .requiredOperands(1)
+          .literalOperands(0)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      expectedException.expect(CalciteContextException.class);
+      expectedException.expectMessage("Argument to function 'testLiteralOperandCheckLiteralThrow' must be a literal");
+      typeChecker.checkOperandTypes(
+          mockCallBinding(
+              function,
+              ImmutableList.of(new OperandSpec(SqlTypeName.INTEGER, false))
+          ),
+          true
+      );
+    }
+
+    @Test
+    public void testAnyTypeOperand()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testAnyTypeOperand")
+          .operandTypes(SqlTypeFamily.ANY)
+          .requiredOperands(1)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      Assert.assertTrue(
+          typeChecker.checkOperandTypes(
+              mockCallBinding(
+                  function,
+                  ImmutableList.of(new OperandSpec(SqlTypeName.DISTINCT, false))
+              ),
+              true
+          )
+      );
+    }
+
+    @Test
+    public void testCastableFromDateTimestampToDatetimeFamily()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testCastableFromDatetimeFamilyToTimestamp")
+          .operandTypes(SqlTypeFamily.DATETIME)
+          .requiredOperands(1)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      Assert.assertTrue(
+          typeChecker.checkOperandTypes(
+              mockCallBinding(
+                  function,
+                  ImmutableList.of(new OperandSpec(SqlTypeName.DATE, false))
+              ),
+              true
+          )
+      );
+      Assert.assertTrue(
+          typeChecker.checkOperandTypes(
+              mockCallBinding(
+                  function,
+                  ImmutableList.of(new OperandSpec(SqlTypeName.TIMESTAMP, false))
+              ),
+              true
+          )
+      );
+    }
+
+    @Test
+    public void testNullForNullableOperand()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testNullForNullableOperand")
+          .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
+          .requiredOperands(1)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      Assert.assertTrue(
+          typeChecker.checkOperandTypes(
+              mockCallBinding(
+                  function,
+                  ImmutableList.of(
+                      new OperandSpec(SqlTypeName.VARCHAR, false),
+                      new OperandSpec(SqlTypeName.NULL, false)
+                  )
+              ),
+              true
+          )
+      );
+    }
+
+    @Test
+    public void testNullLiteralForNullableOperand()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testNullLiteralForNullableOperand")
+          .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
+          .requiredOperands(1)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      Assert.assertTrue(
+          typeChecker.checkOperandTypes(
+              mockCallBinding(
+                  function,
+                  ImmutableList.of(
+                      new OperandSpec(SqlTypeName.VARCHAR, false),
+                      new OperandSpec(SqlTypeName.NULL, true)
+                  )
+              ),
+              true
+          )
+      );
+    }
+
+    @Test
+    public void testNullForNonNullableOperand()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testNullForNonNullableOperand")
+          .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
+          .requiredOperands(1)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      expectedException.expect(CalciteContextException.class);
+      expectedException.expectMessage(
+          "Exception in test for operator[testNullForNonNullableOperand]: Illegal use of 'NULL'"
+      );
+      typeChecker.checkOperandTypes(
+          mockCallBinding(
+              function,
+              ImmutableList.of(
+                  new OperandSpec(SqlTypeName.NULL, false),
+                  new OperandSpec(SqlTypeName.INTERVAL_HOUR, false)
+              )
+          ),
+          true
+      );
+    }
+
+    @Test
+    public void testNullLiteralForNonNullableOperand()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testNullLiteralForNonNullableOperand")
+          .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
+          .requiredOperands(1)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      expectedException.expect(CalciteContextException.class);
+      expectedException.expectMessage(
+          "Exception in test for operator[testNullLiteralForNonNullableOperand]: Illegal use of 'NULL'"
+      );
+      typeChecker.checkOperandTypes(
+          mockCallBinding(
+              function,
+              ImmutableList.of(
+                  new OperandSpec(SqlTypeName.NULL, true),
+                  new OperandSpec(SqlTypeName.INTERVAL_HOUR, false)
+              )
+          ),
+          true
+      );
+    }
+
+    @Test
+    public void testNonCastableType()
+    {
+      SqlFunction function = OperatorConversions
+          .operatorBuilder("testNonCastableType")
+          .operandTypes(SqlTypeFamily.CURSOR, SqlTypeFamily.INTERVAL_DAY_TIME)
+          .requiredOperands(2)
+          .returnTypeNonNull(SqlTypeName.CHAR)
+          .build();
+      SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+      expectedException.expect(CalciteContextException.class);
+      expectedException.expectMessage(
+          "Exception in test for operator[testNonCastableType]: Cannot apply 'testNonCastableType' to arguments of type"
+      );
+      typeChecker.checkOperandTypes(
+          mockCallBinding(
+              function,
+              ImmutableList.of(
+                  new OperandSpec(SqlTypeName.INTEGER, true),
+                  new OperandSpec(SqlTypeName.INTERVAL_HOUR, false)
+              )
+          ),
+          true
+      );
+    }
+
+    private static SqlCallBinding mockCallBinding(
+        SqlFunction function,
+        List<OperandSpec> actualOperands
+    )
+    {
+      SqlValidator validator = Mockito.mock(SqlValidator.class);
+      List<SqlNode> operands = new ArrayList<>(actualOperands.size());
+      for (OperandSpec operand : actualOperands) {
+        final SqlNode node;
+        if (operand.isLiteral) {
+          node = Mockito.mock(SqlLiteral.class);
+        } else {
+          node = Mockito.mock(SqlNode.class);
+        }
+        RelDataType relDataType = Mockito.mock(RelDataType.class);
+        Mockito.when(validator.deriveType(ArgumentMatchers.any(), ArgumentMatchers.eq(node)))
+               .thenReturn(relDataType);
+        Mockito.when(relDataType.getSqlTypeName()).thenReturn(operand.type);
+        operands.add(node);
+      }
+      SqlParserPos pos = Mockito.mock(SqlParserPos.class);
+      Mockito.when(pos.plusAll(ArgumentMatchers.any(Collection.class)))
+             .thenReturn(pos);
+      SqlCallBinding callBinding = new SqlCallBinding(
+          validator,
+          Mockito.mock(SqlValidatorScope.class),
+          function.createCall(pos, operands)
+      );
+
+      Mockito.when(validator.newValidationError(ArgumentMatchers.any(), ArgumentMatchers.any()))
+             .thenAnswer((Answer<CalciteContextException>) invocationOnMock -> new CalciteContextException(
+                 StringUtils.format("Exception in test for operator[%s]", function.getName()),
+                 invocationOnMock.getArgument(1, ExInst.class).ex()
+             ));
+      return callBinding;
+    }
+
+    private static class OperandSpec
+    {
+      private final SqlTypeName type;
+      private final boolean isLiteral;
+
+      private OperandSpec(SqlTypeName type, boolean isLiteral)
+      {
+        this.type = type;
+        this.isLiteral = isLiteral;
+      }
+    }
+  }
+}