blob: 5a60c990b975f46339fd231c93cfa8a166c08709 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iotdb.db.utils;
import org.apache.iotdb.commons.path.MeasurementPath;
import org.apache.iotdb.db.conf.IoTDBDescriptor;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.plan.analyze.ExpressionUtils;
import org.apache.iotdb.db.queryengine.plan.expression.Expression;
import org.apache.iotdb.db.queryengine.plan.expression.binary.CompareBinaryExpression;
import org.apache.iotdb.db.queryengine.plan.expression.leaf.ConstantOperand;
import org.apache.iotdb.db.queryengine.plan.expression.leaf.TimeSeriesOperand;
import org.apache.iotdb.db.queryengine.plan.expression.multi.FunctionExpression;
import org.apache.iotdb.db.queryengine.plan.expression.multi.builtin.BuiltInScalarFunctionHelper;
import org.apache.iotdb.db.queryengine.plan.expression.multi.builtin.BuiltInScalarFunctionHelperFactory;
import org.apache.iotdb.db.utils.constant.SqlConstant;
import org.apache.commons.lang3.StringUtils;
import org.apache.tsfile.enums.TSDataType;
import java.util.Collections;
import java.util.List;
public class TypeInferenceUtils {
private static final TSDataType booleanStringInferType =
IoTDBDescriptor.getInstance().getConfig().getBooleanStringInferType();
private static final TSDataType integerStringInferType =
IoTDBDescriptor.getInstance().getConfig().getIntegerStringInferType();
private static final TSDataType floatingStringInferType =
IoTDBDescriptor.getInstance().getConfig().getFloatingStringInferType();
private static final TSDataType nanStringInferType =
IoTDBDescriptor.getInstance().getConfig().getNanStringInferType();
private TypeInferenceUtils() {}
static boolean isNumber(String s) {
if (s == null || s.equals("NaN")) {
return false;
}
try {
Double.parseDouble(s);
} catch (NumberFormatException e) {
return false;
}
return true;
}
private static boolean isBoolean(String s) {
return s.equalsIgnoreCase(SqlConstant.BOOLEAN_TRUE)
|| s.equalsIgnoreCase(SqlConstant.BOOLEAN_FALSE);
}
private static boolean isLong(String s) {
try {
Long.parseLong(s);
} catch (NumberFormatException e) {
return false;
}
return true;
}
private static boolean isConvertFloatPrecisionLack(String s) {
try {
return Long.parseLong(s) > (1 << 24);
} catch (NumberFormatException e) {
return true;
}
}
/** Get predicted DataType of the given value */
public static TSDataType getPredictedDataType(Object value, boolean inferType) {
if (value instanceof Boolean) {
return TSDataType.BOOLEAN;
} else if (value instanceof Integer) {
return TSDataType.INT32;
} else if (value instanceof Long) {
return TSDataType.INT64;
} else if (value instanceof Float) {
return TSDataType.FLOAT;
} else if (value instanceof Double) {
return TSDataType.DOUBLE;
} else if (inferType) {
String strValue = value.toString();
if (isBoolean(strValue)) {
return booleanStringInferType;
} else if (isNumber(strValue)) {
if (isLong(StringUtils.trim(strValue))) {
return integerStringInferType;
} else {
return floatingStringInferType;
}
} else if ("null".equals(strValue) || "NULL".equals(strValue)) {
return null;
// "NaN" is returned if the NaN Literal is given in Parser
} else if ("NaN".equals(strValue)) {
return nanStringInferType;
} else {
return TSDataType.TEXT;
}
}
return TSDataType.TEXT;
}
public static TSDataType getBuiltinAggregationDataType(
String aggregationFunctionName, TSDataType dataType) {
if (aggregationFunctionName == null) {
throw new IllegalArgumentException("AggregateFunction Name must not be null");
}
verifyIsAggregationDataTypeMatched(aggregationFunctionName, dataType);
switch (aggregationFunctionName.toLowerCase()) {
case SqlConstant.MIN_TIME:
case SqlConstant.MAX_TIME:
case SqlConstant.COUNT:
case SqlConstant.COUNT_TIME:
case SqlConstant.COUNT_IF:
case SqlConstant.TIME_DURATION:
return TSDataType.INT64;
case SqlConstant.MIN_VALUE:
case SqlConstant.LAST_VALUE:
case SqlConstant.FIRST_VALUE:
case SqlConstant.MAX_VALUE:
case SqlConstant.EXTREME:
case SqlConstant.MODE:
case SqlConstant.MAX_BY:
case SqlConstant.MIN_BY:
return dataType;
case SqlConstant.AVG:
case SqlConstant.SUM:
case SqlConstant.STDDEV:
case SqlConstant.STDDEV_POP:
case SqlConstant.STDDEV_SAMP:
case SqlConstant.VARIANCE:
case SqlConstant.VAR_POP:
case SqlConstant.VAR_SAMP:
return TSDataType.DOUBLE;
default:
throw new IllegalArgumentException(
"Invalid Aggregation function: " + aggregationFunctionName);
}
}
private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDataType dataType) {
// input is NullOperand, needn't check
if (dataType == null) {
return;
}
switch (aggrFuncName.toLowerCase()) {
case SqlConstant.AVG:
case SqlConstant.SUM:
case SqlConstant.EXTREME:
case SqlConstant.MIN_VALUE:
case SqlConstant.MAX_VALUE:
case SqlConstant.STDDEV:
case SqlConstant.STDDEV_POP:
case SqlConstant.STDDEV_SAMP:
case SqlConstant.VARIANCE:
case SqlConstant.VAR_POP:
case SqlConstant.VAR_SAMP:
if (dataType.isNumeric()) {
return;
}
throw new SemanticException(
"Aggregate functions [AVG, SUM, EXTREME, MIN_VALUE, MAX_VALUE, STDDEV, STDDEV_POP, STDDEV_SAMP, VARIANCE, VAR_POP, VAR_SAMP] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]");
case SqlConstant.COUNT:
case SqlConstant.COUNT_TIME:
case SqlConstant.MIN_TIME:
case SqlConstant.MAX_TIME:
case SqlConstant.FIRST_VALUE:
case SqlConstant.LAST_VALUE:
case SqlConstant.TIME_DURATION:
case SqlConstant.MODE:
case SqlConstant.MAX_BY:
case SqlConstant.MIN_BY:
return;
case SqlConstant.COUNT_IF:
if (dataType != TSDataType.BOOLEAN) {
throw new SemanticException(
String.format(
"Input series of Aggregation function [%s] only supports data type [BOOLEAN]",
aggrFuncName));
}
return;
default:
throw new IllegalArgumentException("Invalid Aggregation function: " + aggrFuncName);
}
}
/**
* Bind Type for non-series input Expressions of AggregationFunction and check Semantic
*
* <p>.e.g COUNT_IF(s1>1, keep>2, 'ignoreNull'='false'), we bind type {@link TSDataType#INT64} for
* 'keep'
*/
public static void bindTypeForBuiltinAggregationNonSeriesInputExpressions(
String functionName,
List<Expression> inputExpressions,
List<List<Expression>> outputExpressionLists) {
switch (functionName.toLowerCase()) {
case SqlConstant.AVG:
case SqlConstant.SUM:
case SqlConstant.EXTREME:
case SqlConstant.MIN_VALUE:
case SqlConstant.MAX_VALUE:
case SqlConstant.COUNT:
case SqlConstant.COUNT_TIME:
case SqlConstant.MIN_TIME:
case SqlConstant.MAX_TIME:
case SqlConstant.FIRST_VALUE:
case SqlConstant.LAST_VALUE:
case SqlConstant.TIME_DURATION:
case SqlConstant.MODE:
case SqlConstant.STDDEV:
case SqlConstant.STDDEV_POP:
case SqlConstant.STDDEV_SAMP:
case SqlConstant.VARIANCE:
case SqlConstant.VAR_POP:
case SqlConstant.VAR_SAMP:
case SqlConstant.MAX_BY:
case SqlConstant.MIN_BY:
return;
case SqlConstant.COUNT_IF:
Expression keepExpression = inputExpressions.get(1);
if (keepExpression instanceof ConstantOperand) {
outputExpressionLists.add(Collections.singletonList(keepExpression));
return;
} else if (keepExpression instanceof CompareBinaryExpression) {
Expression leftExpression =
((CompareBinaryExpression) keepExpression).getLeftExpression();
Expression rightExpression =
((CompareBinaryExpression) keepExpression).getRightExpression();
if (leftExpression instanceof TimeSeriesOperand
&& leftExpression.getExpressionString().equalsIgnoreCase("keep")
&& rightExpression.isConstantOperand()) {
outputExpressionLists.add(
Collections.singletonList(
ExpressionUtils.reconstructBinaryExpression(
keepExpression,
new TimeSeriesOperand(
new MeasurementPath(
((TimeSeriesOperand) leftExpression).getPath(), TSDataType.INT64)),
rightExpression)));
return;
} else {
throw new SemanticException(
String.format(
"Please check input keep condition of Aggregation function [%s]",
functionName));
}
} else {
throw new SemanticException(
String.format(
"Keep condition of Aggregation function [%s] need to be constant or compare expression constructed by keep and a long number",
functionName));
}
default:
throw new IllegalArgumentException("Invalid Aggregation function: " + functionName);
}
}
public static TSDataType getBuiltInScalarFunctionDataType(
FunctionExpression functionExpression, TSDataType dataType) {
String functionName = functionExpression.getFunctionName();
if (functionName == null) {
throw new IllegalArgumentException("ScalarFunction Name must not be null.");
}
BuiltInScalarFunctionHelper helper =
BuiltInScalarFunctionHelperFactory.createHelper(functionName);
// check input data type first if it is not a NullOperand
if (dataType != null) {
helper.checkBuiltInScalarFunctionInputDataType(dataType);
}
return helper.getBuiltInScalarFunctionReturnType(functionExpression);
}
public static boolean canAutoCast(TSDataType fromType, TSDataType toType) {
if (fromType.equals(toType)) {
return true;
}
switch (fromType) {
case INT32:
switch (toType) {
case INT64:
case FLOAT:
case DOUBLE:
return true;
default:
return false;
}
case INT64:
case FLOAT:
return toType.equals(TSDataType.DOUBLE);
case DOUBLE:
case BOOLEAN:
case TEXT:
return false;
default:
throw new IllegalArgumentException("Unknown data type: " + fromType);
}
}
}