Unit test for DefaultOperandTypeChecker (#11152)
* Less strict operand type check and implicit casting
* fix ci
* Clean up unnecessary changes
* more cleanup
* unused import
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
index c7d0780..9aba1b7 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java
@@ -19,6 +19,7 @@
package org.apache.druid.sql.calcite.expression;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
@@ -500,13 +501,15 @@
* Operand type checker that is used in 'simple' situations: there are a particular number of operands, with
* particular types, some of which may be optional or nullable, and some of which may be required to be literals.
*/
- private static class DefaultOperandTypeChecker implements SqlOperandTypeChecker
+ @VisibleForTesting
+ static class DefaultOperandTypeChecker implements SqlOperandTypeChecker
{
private final List<SqlTypeFamily> operandTypes;
private final int requiredOperands;
private final IntSet nullableOperands;
private final IntSet literalOperands;
+ @VisibleForTesting
DefaultOperandTypeChecker(
final List<SqlTypeFamily> operandTypes,
final int requiredOperands,
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java
new file mode 100644
index 0000000..0268bb6
--- /dev/null
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.expression;
+
+import com.google.common.collect.ImmutableList;
+import it.unimi.dsi.fastutil.ints.IntSets;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.runtime.CalciteContextException;
+import org.apache.calcite.runtime.Resources.ExInst;
+import org.apache.calcite.sql.SqlCallBinding;
+import org.apache.calcite.sql.SqlFunction;
+import org.apache.calcite.sql.SqlLiteral;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlOperandCountRange;
+import org.apache.calcite.sql.parser.SqlParserPos;
+import org.apache.calcite.sql.type.SqlOperandTypeChecker;
+import org.apache.calcite.sql.type.SqlTypeFamily;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.sql.validate.SqlValidator;
+import org.apache.calcite.sql.validate.SqlValidatorScope;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentMatchers;
+import org.mockito.Mockito;
+import org.mockito.stubbing.Answer;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+@RunWith(Enclosed.class)
+public class OperatorConversionsTest
+{
+ public static class DefaultOperandTypeCheckerTest
+ {
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @Test
+ public void testGetOperandCountRange()
+ {
+ SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
+ ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
+ 2,
+ IntSets.EMPTY_SET,
+ null
+ );
+ SqlOperandCountRange countRange = typeChecker.getOperandCountRange();
+ Assert.assertEquals(2, countRange.getMin());
+ Assert.assertEquals(3, countRange.getMax());
+ }
+
+ @Test
+ public void testIsOptional()
+ {
+ SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
+ ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
+ 2,
+ IntSets.EMPTY_SET,
+ null
+ );
+ Assert.assertFalse(typeChecker.isOptional(0));
+ Assert.assertFalse(typeChecker.isOptional(1));
+ Assert.assertTrue(typeChecker.isOptional(2));
+ }
+
+ @Test
+ public void testAllowFullOperands()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testAllowFullOperands")
+ .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
+ .requiredOperands(1)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ Assert.assertTrue(
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(
+ new OperandSpec(SqlTypeName.INTEGER, false),
+ new OperandSpec(SqlTypeName.DATE, false)
+ )
+ ),
+ true
+ )
+ );
+ }
+
+ @Test
+ public void testRequiredOperandsOnly()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testRequiredOperandsOnly")
+ .operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
+ .requiredOperands(1)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ Assert.assertTrue(
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(new OperandSpec(SqlTypeName.INTEGER, false))
+ ),
+ true
+ )
+ );
+ }
+
+ @Test
+ public void testLiteralOperandCheckLiteral()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testLiteralOperandCheckLiteral")
+ .operandTypes(SqlTypeFamily.INTEGER)
+ .requiredOperands(1)
+ .literalOperands(0)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ Assert.assertFalse(
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(new OperandSpec(SqlTypeName.INTEGER, false))
+ ),
+ false
+ )
+ );
+ }
+
+ @Test
+ public void testLiteralOperandCheckLiteralThrow()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testLiteralOperandCheckLiteralThrow")
+ .operandTypes(SqlTypeFamily.INTEGER)
+ .requiredOperands(1)
+ .literalOperands(0)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ expectedException.expect(CalciteContextException.class);
+ expectedException.expectMessage("Argument to function 'testLiteralOperandCheckLiteralThrow' must be a literal");
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(new OperandSpec(SqlTypeName.INTEGER, false))
+ ),
+ true
+ );
+ }
+
+ @Test
+ public void testAnyTypeOperand()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testAnyTypeOperand")
+ .operandTypes(SqlTypeFamily.ANY)
+ .requiredOperands(1)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ Assert.assertTrue(
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(new OperandSpec(SqlTypeName.DISTINCT, false))
+ ),
+ true
+ )
+ );
+ }
+
+ @Test
+ public void testCastableFromDateTimestampToDatetimeFamily()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testCastableFromDatetimeFamilyToTimestamp")
+ .operandTypes(SqlTypeFamily.DATETIME)
+ .requiredOperands(1)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ Assert.assertTrue(
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(new OperandSpec(SqlTypeName.DATE, false))
+ ),
+ true
+ )
+ );
+ Assert.assertTrue(
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(new OperandSpec(SqlTypeName.TIMESTAMP, false))
+ ),
+ true
+ )
+ );
+ }
+
+ @Test
+ public void testNullForNullableOperand()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testNullForNullableOperand")
+ .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
+ .requiredOperands(1)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ Assert.assertTrue(
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(
+ new OperandSpec(SqlTypeName.VARCHAR, false),
+ new OperandSpec(SqlTypeName.NULL, false)
+ )
+ ),
+ true
+ )
+ );
+ }
+
+ @Test
+ public void testNullLiteralForNullableOperand()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testNullLiteralForNullableOperand")
+ .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
+ .requiredOperands(1)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ Assert.assertTrue(
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(
+ new OperandSpec(SqlTypeName.VARCHAR, false),
+ new OperandSpec(SqlTypeName.NULL, true)
+ )
+ ),
+ true
+ )
+ );
+ }
+
+ @Test
+ public void testNullForNonNullableOperand()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testNullForNonNullableOperand")
+ .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
+ .requiredOperands(1)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ expectedException.expect(CalciteContextException.class);
+ expectedException.expectMessage(
+ "Exception in test for operator[testNullForNonNullableOperand]: Illegal use of 'NULL'"
+ );
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(
+ new OperandSpec(SqlTypeName.NULL, false),
+ new OperandSpec(SqlTypeName.INTERVAL_HOUR, false)
+ )
+ ),
+ true
+ );
+ }
+
+ @Test
+ public void testNullLiteralForNonNullableOperand()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testNullLiteralForNonNullableOperand")
+ .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
+ .requiredOperands(1)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ expectedException.expect(CalciteContextException.class);
+ expectedException.expectMessage(
+ "Exception in test for operator[testNullLiteralForNonNullableOperand]: Illegal use of 'NULL'"
+ );
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(
+ new OperandSpec(SqlTypeName.NULL, true),
+ new OperandSpec(SqlTypeName.INTERVAL_HOUR, false)
+ )
+ ),
+ true
+ );
+ }
+
+ @Test
+ public void testNonCastableType()
+ {
+ SqlFunction function = OperatorConversions
+ .operatorBuilder("testNonCastableType")
+ .operandTypes(SqlTypeFamily.CURSOR, SqlTypeFamily.INTERVAL_DAY_TIME)
+ .requiredOperands(2)
+ .returnTypeNonNull(SqlTypeName.CHAR)
+ .build();
+ SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
+ expectedException.expect(CalciteContextException.class);
+ expectedException.expectMessage(
+ "Exception in test for operator[testNonCastableType]: Cannot apply 'testNonCastableType' to arguments of type"
+ );
+ typeChecker.checkOperandTypes(
+ mockCallBinding(
+ function,
+ ImmutableList.of(
+ new OperandSpec(SqlTypeName.INTEGER, true),
+ new OperandSpec(SqlTypeName.INTERVAL_HOUR, false)
+ )
+ ),
+ true
+ );
+ }
+
+ private static SqlCallBinding mockCallBinding(
+ SqlFunction function,
+ List<OperandSpec> actualOperands
+ )
+ {
+ SqlValidator validator = Mockito.mock(SqlValidator.class);
+ List<SqlNode> operands = new ArrayList<>(actualOperands.size());
+ for (OperandSpec operand : actualOperands) {
+ final SqlNode node;
+ if (operand.isLiteral) {
+ node = Mockito.mock(SqlLiteral.class);
+ } else {
+ node = Mockito.mock(SqlNode.class);
+ }
+ RelDataType relDataType = Mockito.mock(RelDataType.class);
+ Mockito.when(validator.deriveType(ArgumentMatchers.any(), ArgumentMatchers.eq(node)))
+ .thenReturn(relDataType);
+ Mockito.when(relDataType.getSqlTypeName()).thenReturn(operand.type);
+ operands.add(node);
+ }
+ SqlParserPos pos = Mockito.mock(SqlParserPos.class);
+ Mockito.when(pos.plusAll(ArgumentMatchers.any(Collection.class)))
+ .thenReturn(pos);
+ SqlCallBinding callBinding = new SqlCallBinding(
+ validator,
+ Mockito.mock(SqlValidatorScope.class),
+ function.createCall(pos, operands)
+ );
+
+ Mockito.when(validator.newValidationError(ArgumentMatchers.any(), ArgumentMatchers.any()))
+ .thenAnswer((Answer<CalciteContextException>) invocationOnMock -> new CalciteContextException(
+ StringUtils.format("Exception in test for operator[%s]", function.getName()),
+ invocationOnMock.getArgument(1, ExInst.class).ex()
+ ));
+ return callBinding;
+ }
+
+ private static class OperandSpec
+ {
+ private final SqlTypeName type;
+ private final boolean isLiteral;
+
+ private OperandSpec(SqlTypeName type, boolean isLiteral)
+ {
+ this.type = type;
+ this.isLiteral = isLiteral;
+ }
+ }
+ }
+}