blob: 5f70dc5902a361ac8cd163a861d3ce4aaffa0e2e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression;
import com.google.common.collect.ImmutableList;
import it.unimi.dsi.fastutil.ints.IntSets;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.runtime.CalciteContextException;
import org.apache.calcite.runtime.Resources.ExInst;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
@RunWith(Enclosed.class)
public class OperatorConversionsTest
{
public static class DefaultOperandTypeCheckerTest
{
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test
public void testGetOperandCountRange()
{
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
2,
IntSets.EMPTY_SET,
null
);
SqlOperandCountRange countRange = typeChecker.getOperandCountRange();
Assert.assertEquals(2, countRange.getMin());
Assert.assertEquals(3, countRange.getMax());
}
@Test
public void testIsOptional()
{
SqlOperandTypeChecker typeChecker = new DefaultOperandTypeChecker(
ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER),
2,
IntSets.EMPTY_SET,
null
);
Assert.assertFalse(typeChecker.isOptional(0));
Assert.assertFalse(typeChecker.isOptional(1));
Assert.assertTrue(typeChecker.isOptional(2));
}
@Test
public void testAllowFullOperands()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testAllowFullOperands")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertTrue(
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.INTEGER, false),
new OperandSpec(SqlTypeName.DATE, false)
)
),
true
)
);
}
@Test
public void testRequiredOperandsOnly()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testRequiredOperandsOnly")
.operandTypes(SqlTypeFamily.INTEGER, SqlTypeFamily.DATE)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertTrue(
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(new OperandSpec(SqlTypeName.INTEGER, false))
),
true
)
);
}
@Test
public void testLiteralOperandCheckLiteral()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testLiteralOperandCheckLiteral")
.operandTypes(SqlTypeFamily.INTEGER)
.requiredOperands(1)
.literalOperands(0)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertFalse(
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(new OperandSpec(SqlTypeName.INTEGER, false))
),
false
)
);
}
@Test
public void testLiteralOperandCheckLiteralThrow()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testLiteralOperandCheckLiteralThrow")
.operandTypes(SqlTypeFamily.INTEGER)
.requiredOperands(1)
.literalOperands(0)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
expectedException.expect(CalciteContextException.class);
expectedException.expectMessage("Argument to function 'testLiteralOperandCheckLiteralThrow' must be a literal");
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(new OperandSpec(SqlTypeName.INTEGER, false))
),
true
);
}
@Test
public void testAnyTypeOperand()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testAnyTypeOperand")
.operandTypes(SqlTypeFamily.ANY)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertTrue(
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(new OperandSpec(SqlTypeName.DISTINCT, false))
),
true
)
);
}
@Test
public void testCastableFromDateTimestampToDatetimeFamily()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testCastableFromDatetimeFamilyToTimestamp")
.operandTypes(SqlTypeFamily.DATETIME)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertTrue(
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(new OperandSpec(SqlTypeName.DATE, false))
),
true
)
);
Assert.assertTrue(
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(new OperandSpec(SqlTypeName.TIMESTAMP, false))
),
true
)
);
}
@Test
public void testNullForNullableOperand()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertTrue(
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.VARCHAR, false),
new OperandSpec(SqlTypeName.NULL, false)
)
),
true
)
);
}
@Test
public void testNullLiteralForNullableOperand()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullLiteralForNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
Assert.assertTrue(
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.VARCHAR, false),
new OperandSpec(SqlTypeName.NULL, true)
)
),
true
)
);
}
@Test
public void testNullForNullableOperandNonNullOutput()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableNonnull")
.operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
SqlCallBinding binding = mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.CHAR, false, true)
)
);
Assert.assertTrue(typeChecker.checkOperandTypes(binding, true));
RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding);
Assert.assertFalse(returnType.isNullable());
}
@Test
public void testNullForNullableOperandCascadeNullOutput()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableCascade")
.operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.returnTypeCascadeNullable(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
SqlCallBinding binding = mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.CHAR, false, true)
)
);
Assert.assertTrue(typeChecker.checkOperandTypes(binding, true));
RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding);
Assert.assertTrue(returnType.isNullable());
}
@Test
public void testNullForNullableOperandAlwaysNullableOutput()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNullableNonnull")
.operandTypes(SqlTypeFamily.CHARACTER)
.requiredOperands(1)
.returnTypeNullable(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
SqlCallBinding binding = mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.CHAR, false, false)
)
);
Assert.assertTrue(typeChecker.checkOperandTypes(binding, true));
RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding);
Assert.assertTrue(returnType.isNullable());
}
@Test
public void testNullForNonNullableOperand()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullForNonNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
expectedException.expect(CalciteContextException.class);
expectedException.expectMessage(
"Exception in test for operator[testNullForNonNullableOperand]: Illegal use of 'NULL'"
);
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.NULL, false),
new OperandSpec(SqlTypeName.INTERVAL_HOUR, false)
)
),
true
);
}
@Test
public void testNullLiteralForNonNullableOperand()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNullLiteralForNonNullableOperand")
.operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(1)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
expectedException.expect(CalciteContextException.class);
expectedException.expectMessage(
"Exception in test for operator[testNullLiteralForNonNullableOperand]: Illegal use of 'NULL'"
);
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.NULL, true),
new OperandSpec(SqlTypeName.INTERVAL_HOUR, false)
)
),
true
);
}
@Test
public void testNonCastableType()
{
SqlFunction function = OperatorConversions
.operatorBuilder("testNonCastableType")
.operandTypes(SqlTypeFamily.CURSOR, SqlTypeFamily.INTERVAL_DAY_TIME)
.requiredOperands(2)
.returnTypeNonNull(SqlTypeName.CHAR)
.build();
SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker();
expectedException.expect(CalciteContextException.class);
expectedException.expectMessage(
"Exception in test for operator[testNonCastableType]: Cannot apply 'testNonCastableType' to arguments of type"
);
typeChecker.checkOperandTypes(
mockCallBinding(
function,
ImmutableList.of(
new OperandSpec(SqlTypeName.INTEGER, true),
new OperandSpec(SqlTypeName.INTERVAL_HOUR, false)
)
),
true
);
}
private static SqlCallBinding mockCallBinding(
SqlFunction function,
List<OperandSpec> actualOperands
)
{
SqlValidator validator = Mockito.mock(SqlValidator.class);
Mockito.when(validator.getTypeFactory()).thenReturn(new SqlTypeFactoryImpl(DruidTypeSystem.INSTANCE));
List<SqlNode> operands = new ArrayList<>(actualOperands.size());
for (OperandSpec operand : actualOperands) {
final SqlNode node;
if (operand.isLiteral) {
node = Mockito.mock(SqlLiteral.class);
} else {
node = Mockito.mock(SqlNode.class);
}
RelDataType relDataType = Mockito.mock(RelDataType.class);
if (operand.isNullable) {
Mockito.when(relDataType.isNullable()).thenReturn(true);
} else {
Mockito.when(relDataType.isNullable()).thenReturn(false);
}
Mockito.when(validator.deriveType(ArgumentMatchers.any(), ArgumentMatchers.eq(node)))
.thenReturn(relDataType);
Mockito.when(relDataType.getSqlTypeName()).thenReturn(operand.type);
operands.add(node);
}
SqlParserPos pos = Mockito.mock(SqlParserPos.class);
Mockito.when(pos.plusAll(ArgumentMatchers.any(Collection.class)))
.thenReturn(pos);
SqlCallBinding callBinding = new SqlCallBinding(
validator,
Mockito.mock(SqlValidatorScope.class),
function.createCall(pos, operands)
);
Mockito.when(validator.newValidationError(ArgumentMatchers.any(), ArgumentMatchers.any()))
.thenAnswer((Answer<CalciteContextException>) invocationOnMock -> new CalciteContextException(
StringUtils.format("Exception in test for operator[%s]", function.getName()),
invocationOnMock.getArgument(1, ExInst.class).ex()
));
return callBinding;
}
private static class OperandSpec
{
private final SqlTypeName type;
private final boolean isLiteral;
private final boolean isNullable;
private OperandSpec(SqlTypeName type, boolean isLiteral)
{
this(type, isLiteral, type == SqlTypeName.NULL);
}
private OperandSpec(SqlTypeName type, boolean isLiteral, boolean isNullable)
{
this.type = type;
this.isLiteral = isLiteral;
this.isNullable = isNullable;
}
}
}
}