blob: ccbee6af2ea48c62d450ac0fe65b451e253ad3ca [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samza.sql.planner;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.samza.SamzaException;
import org.apache.samza.sql.interfaces.UdfMetadata;
import org.apache.samza.sql.schema.SamzaSqlFieldType;
import org.apache.samza.sql.udfs.SamzaSqlUdfMethod;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class Checker implements SqlOperandTypeChecker {
private static final Logger LOG = LoggerFactory.getLogger(Checker.class);
private static final List<SqlTypeName> ANY_SQL_TYPE_NAMES = ImmutableList.of(SqlTypeName.ANY, SqlTypeName.OTHER);
static final Checker ANY_CHECKER = new Checker();
private final Optional<UdfMetadata> udfMetadataOptional;
private final SqlOperandCountRange range;
public static Checker getChecker(int min, int max, UdfMetadata udfMetadata) {
if (min == max) {
return new Checker(min, udfMetadata);
} else {
return new Checker(min, max, udfMetadata);
}
}
private Checker(int size, UdfMetadata udfMetadata) {
this.range = SqlOperandCountRanges.of(size);
this.udfMetadataOptional = Optional.of(udfMetadata);
}
private Checker(int min, int max, UdfMetadata udfMetadata) {
this.range = SqlOperandCountRanges.between(min, max);
this.udfMetadataOptional = Optional.of(udfMetadata);
}
private Checker() {
this.range = SqlOperandCountRanges.any();
this.udfMetadataOptional = Optional.empty();
}
@Override
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
if (!udfMetadataOptional.isPresent() || udfMetadataOptional.get().isDisableArgCheck() || !throwOnFailure) {
return true;
} else {
// 1. Generate a mapping from argument index to parsed calcite-type for the sql UDF.
Map<Integer, RelDataType> argumentIndexToCalciteType = IntStream.range(0, callBinding.getOperandCount())
.boxed()
.collect(Collectors.toMap(operandIndex -> operandIndex, callBinding::getOperandType, (a, b) -> b));
UdfMetadata udfMetadata = udfMetadataOptional.get();
List<SamzaSqlFieldType> udfArguments = udfMetadata.getArguments();
// 2. Compare the argument type in samza-sql UDF against the RelType generated by the
// calcite parser engine.
for (int udfArgumentIndex = 0; udfArgumentIndex < udfArguments.size(); ++udfArgumentIndex) {
SamzaSqlFieldType udfArgumentType = udfArguments.get(udfArgumentIndex);
SqlTypeName udfArgumentAsSqlType = toCalciteSqlType(udfArgumentType);
RelDataType parsedSqlArgType = argumentIndexToCalciteType.get(udfArgumentIndex);
// 3(a). Special-case, where static strings used as method-arguments in udf-methods during invocation are parsed as the Char type by calcite.
if (parsedSqlArgType.getSqlTypeName() == SqlTypeName.CHAR && udfArgumentAsSqlType == SqlTypeName.VARCHAR) {
return true;
} else if (!Objects.equals(parsedSqlArgType.getSqlTypeName(), udfArgumentAsSqlType)
&& !ANY_SQL_TYPE_NAMES.contains(parsedSqlArgType.getSqlTypeName()) && hasOneUdfMethod(udfMetadata)) {
// 3(b). Throw up and fail on mismatch between the SamzaSqlType and CalciteType for any argument.
String msg = String.format("Type mismatch in udf class: %s at argument index: %d." +
"Expected type: %s, actual type: %s.", udfMetadata.getName(),
udfArgumentIndex, parsedSqlArgType.getSqlTypeName(), udfArgumentAsSqlType);
LOG.error(msg);
throw new SamzaSqlValidatorException(msg);
}
}
}
// 4. The SamzaSqlFieldType and CalciteType has matched for all the arguments in the UDF.
return true;
}
/**
* Checks if there is only one UdfMethod in the input {@link UdfMetadata}.
* @param udfMetadata the metadata for a UDF.
* @return true if there is only one udf method defined in the UdfMetadata.
* false otherwise.
*/
@VisibleForTesting
boolean hasOneUdfMethod(UdfMetadata udfMetadata) {
Class<?> udfClass = udfMetadata.getUdfMethod().getDeclaringClass();
int numAnnotatedUdfMethods = 0;
for (Method method : udfClass.getMethods()) {
if (method.isAnnotationPresent(SamzaSqlUdfMethod.class)) {
numAnnotatedUdfMethods += 1;
}
}
return numAnnotatedUdfMethods == 1;
}
@Override
public SqlOperandCountRange getOperandCountRange() {
return range;
}
@Override
public String getAllowedSignatures(SqlOperator op, String opName) {
return opName + "(Drill - Opaque)";
}
@Override
public Consistency getConsistency() {
return Consistency.NONE;
}
@Override
public boolean isOptional(int i) {
return false;
}
/**
* Converts the {@link SamzaSqlFieldType} to the calcite {@link SqlTypeName}.
* @param samzaSqlFieldType the samza sql field type.
* @return the converted calcite SqlTypeName.
*/
@VisibleForTesting
static SqlTypeName toCalciteSqlType(SamzaSqlFieldType samzaSqlFieldType) {
switch (samzaSqlFieldType) {
case ANY:
case ROW:
return SqlTypeName.ANY;
case MAP:
return SqlTypeName.MAP;
case ARRAY:
return SqlTypeName.ARRAY;
case REAL:
return SqlTypeName.REAL;
case DOUBLE:
return SqlTypeName.DOUBLE;
case STRING:
return SqlTypeName.VARCHAR;
case INT16:
case INT32:
return SqlTypeName.INTEGER;
case FLOAT:
return SqlTypeName.FLOAT;
case INT64:
return SqlTypeName.BIGINT;
case BOOLEAN:
return SqlTypeName.BOOLEAN;
case BYTES:
return SqlTypeName.VARBINARY;
default:
String msg = String.format("Field Type %s is not supported", samzaSqlFieldType);
LOG.error(msg);
throw new SamzaException(msg);
}
}
}