blob: 49904d6532cc2dbfe58454f1ba5d3105eabf6103 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iotdb.db.queryengine.plan.analyze;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.common.NodeRef;
import org.apache.iotdb.db.queryengine.plan.expression.Expression;
import org.apache.iotdb.db.queryengine.plan.expression.binary.ArithmeticBinaryExpression;
import org.apache.iotdb.db.queryengine.plan.expression.binary.CompareBinaryExpression;
import org.apache.iotdb.db.queryengine.plan.expression.binary.LogicBinaryExpression;
import org.apache.iotdb.db.queryengine.plan.expression.binary.WhenThenExpression;
import org.apache.iotdb.db.queryengine.plan.expression.leaf.ConstantOperand;
import org.apache.iotdb.db.queryengine.plan.expression.leaf.NullOperand;
import org.apache.iotdb.db.queryengine.plan.expression.leaf.TimeSeriesOperand;
import org.apache.iotdb.db.queryengine.plan.expression.leaf.TimestampOperand;
import org.apache.iotdb.db.queryengine.plan.expression.multi.FunctionExpression;
import org.apache.iotdb.db.queryengine.plan.expression.other.CaseWhenThenExpression;
import org.apache.iotdb.db.queryengine.plan.expression.ternary.BetweenExpression;
import org.apache.iotdb.db.queryengine.plan.expression.unary.InExpression;
import org.apache.iotdb.db.queryengine.plan.expression.unary.IsNullExpression;
import org.apache.iotdb.db.queryengine.plan.expression.unary.LikeExpression;
import org.apache.iotdb.db.queryengine.plan.expression.unary.LogicNotExpression;
import org.apache.iotdb.db.queryengine.plan.expression.unary.NegationExpression;
import org.apache.iotdb.db.queryengine.plan.expression.unary.RegularExpression;
import org.apache.iotdb.db.queryengine.plan.expression.visitor.ExpressionVisitor;
import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDAFInformationInferrer;
import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDTFInformationInferrer;
import org.apache.iotdb.db.utils.TypeInferenceUtils;
import org.apache.iotdb.db.utils.constant.SqlConstant;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.write.schema.IMeasurementSchema;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class ExpressionTypeAnalyzer {
private final Map<NodeRef<Expression>, TSDataType> expressionTypes = new LinkedHashMap<>();
private ExpressionTypeAnalyzer() {}
public static TSDataType analyzeExpression(Analysis analysis, Expression expression) {
if (!analysis.getExpressionTypes().containsKey(NodeRef.of(expression))) {
ExpressionTypeAnalyzer analyzer = new ExpressionTypeAnalyzer();
analyzer.analyze(expression, null);
addExpressionTypes(analysis, analyzer);
}
return analysis.getType(expression);
}
public static TSDataType analyzeExpressionForTemplatedQuery(
Analysis analysis, Expression expression) {
if (!analysis.getExpressionTypes().containsKey(NodeRef.of(expression))) {
ExpressionTypeAnalyzer analyzer = new ExpressionTypeAnalyzer();
analyzer.analyze(expression, analysis.getDeviceTemplate().getSchemaMap());
addExpressionTypes(analysis, analyzer);
}
return analysis.getType(expression);
}
public static void analyzeExpression(
Map<NodeRef<Expression>, TSDataType> types, Expression expression) {
ExpressionTypeAnalyzer analyzer = new ExpressionTypeAnalyzer();
analyzer.analyze(expression, null);
types.putAll(analyzer.getExpressionTypes());
}
public static void analyzeExpressionUsingTemplatedInfo(
Map<NodeRef<Expression>, TSDataType> types,
Expression expression,
TemplatedInfo templatedInfo) {
ExpressionTypeAnalyzer analyzer = new ExpressionTypeAnalyzer();
analyzer.analyze(expression, templatedInfo.getSchemaMap());
types.putAll(analyzer.getExpressionTypes());
}
private static void addExpressionTypes(Analysis analysis, ExpressionTypeAnalyzer analyzer) {
analysis.addTypes(analyzer.getExpressionTypes());
}
public TSDataType analyze(Expression expression, Map<String, IMeasurementSchema> context) {
Visitor visitor = new Visitor();
return visitor.process(expression, context);
}
public Map<NodeRef<Expression>, TSDataType> getExpressionTypes() {
return expressionTypes;
}
private class Visitor extends ExpressionVisitor<TSDataType, Map<String, IMeasurementSchema>> {
@Override
public TSDataType process(Expression expression, Map<String, IMeasurementSchema> context) {
// don't double process a expression
TSDataType dataType = expressionTypes.get(NodeRef.of(expression));
if (dataType != null) {
return dataType;
}
return super.process(expression, context);
}
@Override
public TSDataType visitExpression(
Expression expression, Map<String, IMeasurementSchema> context) {
throw new UnsupportedOperationException(
"Unsupported expression type: " + expression.getClass().getName());
}
@Override
public TSDataType visitInExpression(
InExpression inExpression, Map<String, IMeasurementSchema> context) {
process(inExpression.getExpression(), context);
return setExpressionType(inExpression, TSDataType.BOOLEAN);
}
@Override
public TSDataType visitIsNullExpression(
IsNullExpression isNullExpression, Map<String, IMeasurementSchema> context) {
process(isNullExpression.getExpression(), context);
return setExpressionType(isNullExpression, TSDataType.BOOLEAN);
}
@Override
public TSDataType visitLikeExpression(
LikeExpression likeExpression, Map<String, IMeasurementSchema> context) {
checkInputExpressionDataType(
likeExpression.getExpression().getExpressionString(),
process(likeExpression.getExpression(), context),
TSDataType.TEXT);
return setExpressionType(likeExpression, TSDataType.BOOLEAN);
}
@Override
public TSDataType visitRegularExpression(
RegularExpression regularExpression, Map<String, IMeasurementSchema> context) {
checkInputExpressionDataType(
regularExpression.getExpression().getExpressionString(),
process(regularExpression.getExpression(), context),
TSDataType.TEXT);
return setExpressionType(regularExpression, TSDataType.BOOLEAN);
}
@Override
public TSDataType visitLogicNotExpression(
LogicNotExpression logicNotExpression, Map<String, IMeasurementSchema> context) {
checkInputExpressionDataType(
logicNotExpression.getExpression().getExpressionString(),
process(logicNotExpression.getExpression(), context),
TSDataType.BOOLEAN);
return setExpressionType(logicNotExpression, TSDataType.BOOLEAN);
}
@Override
public TSDataType visitNegationExpression(
NegationExpression negationExpression, Map<String, IMeasurementSchema> context) {
TSDataType inputExpressionType = process(negationExpression.getExpression(), context);
checkInputExpressionDataType(
negationExpression.getExpression().getExpressionString(),
inputExpressionType,
TSDataType.INT32,
TSDataType.INT64,
TSDataType.FLOAT,
TSDataType.DOUBLE);
return setExpressionType(negationExpression, inputExpressionType);
}
@Override
public TSDataType visitArithmeticBinaryExpression(
ArithmeticBinaryExpression arithmeticBinaryExpression,
Map<String, IMeasurementSchema> context) {
checkInputExpressionDataType(
arithmeticBinaryExpression.getLeftExpression().getExpressionString(),
process(arithmeticBinaryExpression.getLeftExpression(), context),
TSDataType.INT32,
TSDataType.INT64,
TSDataType.FLOAT,
TSDataType.DOUBLE);
checkInputExpressionDataType(
arithmeticBinaryExpression.getRightExpression().getExpressionString(),
process(arithmeticBinaryExpression.getRightExpression(), context),
TSDataType.INT32,
TSDataType.INT64,
TSDataType.FLOAT,
TSDataType.DOUBLE);
return setExpressionType(arithmeticBinaryExpression, TSDataType.DOUBLE);
}
@Override
public TSDataType visitLogicBinaryExpression(
LogicBinaryExpression logicBinaryExpression, Map<String, IMeasurementSchema> context) {
checkInputExpressionDataType(
logicBinaryExpression.getLeftExpression().getExpressionString(),
process(logicBinaryExpression.getLeftExpression(), context),
TSDataType.BOOLEAN);
checkInputExpressionDataType(
logicBinaryExpression.getRightExpression().getExpressionString(),
process(logicBinaryExpression.getRightExpression(), context),
TSDataType.BOOLEAN);
return setExpressionType(logicBinaryExpression, TSDataType.BOOLEAN);
}
@Override
public TSDataType visitCompareBinaryExpression(
CompareBinaryExpression compareBinaryExpression, Map<String, IMeasurementSchema> context) {
final TSDataType leftExpressionDataType =
process(compareBinaryExpression.getLeftExpression(), context);
final TSDataType rightExpressionDataType =
process(compareBinaryExpression.getRightExpression(), context);
if (leftExpressionDataType != null
&& rightExpressionDataType != null
&& !leftExpressionDataType.equals(rightExpressionDataType)) {
final String leftExpressionString =
compareBinaryExpression.getLeftExpression().getExpressionString();
final String rightExpressionString =
compareBinaryExpression.getRightExpression().getExpressionString();
if (TSDataType.BOOLEAN.equals(leftExpressionDataType)
|| TSDataType.BOOLEAN.equals(rightExpressionDataType)) {
checkInputExpressionDataType(
leftExpressionString, leftExpressionDataType, TSDataType.BOOLEAN);
checkInputExpressionDataType(
rightExpressionString, rightExpressionDataType, TSDataType.BOOLEAN);
} else if (TSDataType.TEXT.equals(leftExpressionDataType)
|| TSDataType.TEXT.equals(rightExpressionDataType)) {
checkInputExpressionDataType(
leftExpressionString, leftExpressionDataType, TSDataType.TEXT);
checkInputExpressionDataType(
rightExpressionString, rightExpressionDataType, TSDataType.TEXT);
} else {
checkInputExpressionDataType(
leftExpressionString,
leftExpressionDataType,
TSDataType.INT32,
TSDataType.INT64,
TSDataType.FLOAT,
TSDataType.DOUBLE);
checkInputExpressionDataType(
rightExpressionString,
rightExpressionDataType,
TSDataType.INT32,
TSDataType.INT64,
TSDataType.FLOAT,
TSDataType.DOUBLE);
}
}
return setExpressionType(compareBinaryExpression, TSDataType.BOOLEAN);
}
@Override
public TSDataType visitBetweenExpression(
BetweenExpression betweenExpression, Map<String, IMeasurementSchema> context) {
process(betweenExpression.getFirstExpression(), context);
process(betweenExpression.getSecondExpression(), context);
process(betweenExpression.getThirdExpression(), context);
return setExpressionType(betweenExpression, TSDataType.BOOLEAN);
}
@Override
public TSDataType visitFunctionExpression(
FunctionExpression functionExpression, Map<String, IMeasurementSchema> context) {
List<Expression> inputExpressions = functionExpression.getExpressions();
for (Expression expression : inputExpressions) {
process(expression, context);
}
if (functionExpression.isBuiltInAggregationFunctionExpression()) {
return setExpressionType(
functionExpression,
TypeInferenceUtils.getBuiltinAggregationDataType(
functionExpression.getFunctionName(),
getInputExpressionTypeForAggregation(
inputExpressions, functionExpression.getFunctionName())));
}
if (functionExpression.isBuiltInScalarFunctionExpression()) {
return setExpressionType(
functionExpression,
TypeInferenceUtils.getBuiltInScalarFunctionDataType(
functionExpression, expressionTypes.get(NodeRef.of(inputExpressions.get(0)))));
}
if (functionExpression.isExternalAggregationFunctionExpression()) {
return setExpressionType(
functionExpression,
new UDAFInformationInferrer(functionExpression.getFunctionName())
.inferOutputType(
inputExpressions.stream()
.map(Expression::getExpressionString)
.collect(Collectors.toList()),
inputExpressions.stream()
.map(f -> expressionTypes.get(NodeRef.of(f)))
.collect(Collectors.toList()),
functionExpression.getFunctionAttributes()));
}
return setExpressionType(
functionExpression,
new UDTFInformationInferrer(functionExpression.getFunctionName())
.inferOutputType(
inputExpressions.stream()
.map(Expression::getExpressionString)
.collect(Collectors.toList()),
inputExpressions.stream()
.map(f -> expressionTypes.get(NodeRef.of(f)))
.collect(Collectors.toList()),
functionExpression.getFunctionAttributes()));
}
@Override
public TSDataType visitTimeStampOperand(
TimestampOperand timestampOperand, Map<String, IMeasurementSchema> context) {
return setExpressionType(timestampOperand, TSDataType.INT64);
}
@Override
public TSDataType visitTimeSeriesOperand(
TimeSeriesOperand timeSeriesOperand, Map<String, IMeasurementSchema> context) {
if (context != null && (context.containsKey(timeSeriesOperand.getOutputSymbol()))) {
return setExpressionType(
timeSeriesOperand, context.get(timeSeriesOperand.getOutputSymbol()).getType());
}
return setExpressionType(timeSeriesOperand, timeSeriesOperand.getPath().getSeriesType());
}
@Override
public TSDataType visitConstantOperand(
ConstantOperand constantOperand, Map<String, IMeasurementSchema> context) {
return setExpressionType(constantOperand, constantOperand.getDataType());
}
@Override
public TSDataType visitNullOperand(
NullOperand nullOperand, Map<String, IMeasurementSchema> context) {
return null;
}
@Override
public TSDataType visitCaseWhenThenExpression(
CaseWhenThenExpression caseWhenThenExpression, Map<String, IMeasurementSchema> context) {
Set<TSDataType> typeSet = new HashSet<>();
for (WhenThenExpression whenThenExpression :
caseWhenThenExpression.getWhenThenExpressions()) {
typeSet.add(process(whenThenExpression, context));
}
if (!(caseWhenThenExpression.getElseExpression() instanceof NullOperand)) {
typeSet.add(process(caseWhenThenExpression.getElseExpression(), context));
}
// if TEXT exists, every branch need to be TEXT
if (typeSet.contains(TSDataType.TEXT)) {
if (typeSet.stream().anyMatch(tsDataType -> tsDataType != TSDataType.TEXT)) {
throw new SemanticException(
"CASE expression: TEXT and other types cannot exist at the same time");
}
return setExpressionType(caseWhenThenExpression, TSDataType.TEXT);
}
// if BOOLEAN exists, every branch need to be BOOLEAN
if (typeSet.contains(TSDataType.BOOLEAN)) {
if (typeSet.stream().anyMatch(tsDataType -> tsDataType != TSDataType.BOOLEAN)) {
throw new SemanticException(
"CASE expression: BOOLEAN and other types cannot exist at the same time");
}
return setExpressionType(caseWhenThenExpression, TSDataType.BOOLEAN);
}
// other 4 TSDataType can exist at the same time
// because they can be transformed by Type, finally treated as DOUBLE
return setExpressionType(caseWhenThenExpression, TSDataType.DOUBLE);
}
@Override
public TSDataType visitWhenThenExpression(
WhenThenExpression whenThenExpression, Map<String, IMeasurementSchema> context) {
TSDataType whenType = process(whenThenExpression.getWhen(), context);
if (!whenType.equals(TSDataType.BOOLEAN)) {
throw new SemanticException(
String.format(
"The expression in the WHEN clause must return BOOLEAN. expression: %s, actual data type: %s.",
whenThenExpression.getWhen().getExpressionString(), whenType.name()));
}
TSDataType thenType = process(whenThenExpression.getThen(), context);
return setExpressionType(whenThenExpression, thenType);
}
private TSDataType setExpressionType(Expression expression, TSDataType type) {
expressionTypes.put(NodeRef.of(expression), type);
return type;
}
private void checkInputExpressionDataType(
String expressionString, TSDataType actual, TSDataType... expected) {
for (TSDataType type : expected) {
if (actual == null || actual.equals(type)) {
return;
}
}
throw new SemanticException(
String.format(
"Invalid input expression data type. expression: %s, actual data type: %s, expected data type(s): %s.",
expressionString, actual.name(), Arrays.toString(expected)));
}
}
private TSDataType getInputExpressionTypeForAggregation(
List<Expression> inputExpressions, String aggregateFunctionName) {
// Some aggregate functions have a fixed output type, while others determine their output type
// based on the data type of their input.
// Currently, for all aggregate functions without a fixed output type, the output type is
// determined by the first input.
switch (aggregateFunctionName.toLowerCase()) {
case SqlConstant.MIN_TIME:
case SqlConstant.MAX_TIME:
case SqlConstant.MIN_VALUE:
case SqlConstant.MAX_VALUE:
case SqlConstant.EXTREME:
case SqlConstant.LAST_VALUE:
case SqlConstant.FIRST_VALUE:
case SqlConstant.COUNT:
case SqlConstant.AVG:
case SqlConstant.SUM:
case SqlConstant.COUNT_IF:
case SqlConstant.TIME_DURATION:
case SqlConstant.MODE:
case SqlConstant.COUNT_TIME:
case SqlConstant.STDDEV:
case SqlConstant.STDDEV_POP:
case SqlConstant.STDDEV_SAMP:
case SqlConstant.VARIANCE:
case SqlConstant.VAR_POP:
case SqlConstant.VAR_SAMP:
case SqlConstant.MAX_BY:
case SqlConstant.MIN_BY:
return expressionTypes.get(NodeRef.of(inputExpressions.get(0)));
default:
throw new IllegalArgumentException(
"Invalid Aggregation function: " + aggregateFunctionName);
}
}
}