blob: 4a0ff5458cf9140f7352ad940eeedef0df363b5e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iotdb.db.queryengine.plan.expression.multi;
import org.apache.iotdb.commons.conf.IoTDBConstant;
import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
import org.apache.iotdb.commons.udf.service.UDFManagementService;
import org.apache.iotdb.db.queryengine.common.NodeRef;
import org.apache.iotdb.db.queryengine.plan.expression.Expression;
import org.apache.iotdb.db.queryengine.plan.expression.ExpressionType;
import org.apache.iotdb.db.queryengine.plan.expression.leaf.TimeSeriesOperand;
import org.apache.iotdb.db.queryengine.plan.expression.multi.builtin.BuiltInScalarFunctionHelperFactory;
import org.apache.iotdb.db.queryengine.plan.expression.visitor.ExpressionVisitor;
import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.InputLocation;
import org.apache.iotdb.db.queryengine.transformation.dag.memory.LayerMemoryAssigner;
import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDTFExecutor;
import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDTFInformationInferrer;
import org.apache.iotdb.udf.api.customizer.strategy.AccessStrategy;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.utils.ReadWriteIOUtils;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;
public class FunctionExpression extends Expression {
private FunctionType functionType;
private final String functionName;
private final LinkedHashMap<String, String> functionAttributes;
/**
* example: select udf(a, b, udf(c)) from root.sg.d;
*
* <p>3 expressions [root.sg.d.a, root.sg.d.b, udf(root.sg.d.c)] will be in this field.
*/
private List<Expression> expressions;
private List<PartialPath> paths;
private List<Expression> countTimeExpressions;
private String parametersString;
public FunctionExpression(String functionName) {
this.functionName = functionName;
functionAttributes = new LinkedHashMap<>();
expressions = new ArrayList<>();
}
public FunctionExpression(
String functionName,
LinkedHashMap<String, String> functionAttributes,
List<Expression> expressions) {
this.functionName = functionName;
this.functionAttributes = functionAttributes;
this.expressions = expressions;
this.countTimeExpressions = null;
}
public FunctionExpression(
String functionName,
LinkedHashMap<String, String> functionAttributes,
List<Expression> expressions,
List<Expression> countTimeExpressions) {
this.functionName = functionName;
this.functionAttributes = functionAttributes;
this.expressions = expressions;
this.countTimeExpressions = countTimeExpressions;
}
public FunctionExpression(ByteBuffer byteBuffer) {
functionName = ReadWriteIOUtils.readString(byteBuffer);
functionAttributes = ReadWriteIOUtils.readLinkedHashMap(byteBuffer);
int expressionSize = ReadWriteIOUtils.readInt(byteBuffer);
expressions = new ArrayList<>();
for (int i = 0; i < expressionSize; i++) {
expressions.add(Expression.deserialize(byteBuffer));
}
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitFunctionExpression(this, context);
}
private void initializeFunctionType() {
final String lowerCaseFunctionName = this.functionName.toLowerCase();
if (BuiltinAggregationFunction.getNativeFunctionNames().contains(lowerCaseFunctionName)) {
functionType = FunctionType.BUILT_IN_AGGREGATION_FUNCTION;
} else if (BuiltinScalarFunction.getNativeFunctionNames().contains(lowerCaseFunctionName)) {
functionType = FunctionType.BUILT_IN_SCALAR_FUNCTION;
} else if (UDFManagementService.getInstance().isUDAF(functionName)) {
functionType = FunctionType.UDAF;
} else {
functionType = FunctionType.UDTF;
}
}
@Override
public boolean isBuiltInAggregationFunctionExpression() {
if (functionType == null) {
initializeFunctionType();
}
return functionType == FunctionType.BUILT_IN_AGGREGATION_FUNCTION;
}
public boolean isBuiltInScalarFunctionExpression() {
if (functionType == null) {
initializeFunctionType();
}
return functionType == FunctionType.BUILT_IN_SCALAR_FUNCTION;
}
@Override
public boolean isExternalAggregationFunctionExpression() {
if (functionType == null) {
initializeFunctionType();
}
return functionType == FunctionType.UDAF;
}
@Override
public boolean isConstantOperandInternal() {
if (isConstantOperandCache == null) {
isConstantOperandCache = true;
for (Expression inputExpression : expressions) {
if (!inputExpression.isConstantOperand()) {
isConstantOperandCache = false;
break;
}
}
}
return isConstantOperandCache;
}
public boolean isCountStar() {
if (!isAggregationFunctionExpression()) {
return false;
}
return getPaths().size() == 1
&& paths.get(0) != null
&& (paths.get(0).getTailNode().equals(IoTDBConstant.ONE_LEVEL_PATH_WILDCARD)
|| paths.get(0).getTailNode().equals(IoTDBConstant.MULTI_LEVEL_PATH_WILDCARD))
&& functionName.equals(IoTDBConstant.COLUMN_COUNT);
}
public void addAttribute(String key, String value) {
functionAttributes.put(key, value);
}
public void addExpression(Expression expression) {
expressions.add(expression);
}
public void setExpressions(List<Expression> expressions) {
this.expressions = expressions;
}
public String getFunctionName() {
return functionName;
}
public LinkedHashMap<String, String> getFunctionAttributes() {
return functionAttributes;
}
@Override
public List<Expression> getExpressions() {
return expressions;
}
@Override
public String getOutputSymbolInternal() {
StringBuilder builder = new StringBuilder();
if (!expressions.isEmpty()) {
builder.append(expressions.get(0).getOutputSymbol());
for (int i = 1; i < expressions.size(); ++i) {
builder.append(", ").append(expressions.get(i).getOutputSymbol());
}
}
if (!functionAttributes.isEmpty()) {
// Some built-in scalar functions may have different header.
if (BuiltinScalarFunction.contains(functionName)) {
BuiltInScalarFunctionHelperFactory.createHelper(functionName)
.appendFunctionAttributes(!expressions.isEmpty(), builder, functionAttributes);
} else {
appendAttributes(!expressions.isEmpty(), builder, functionAttributes);
}
}
return functionName + "(" + builder + ")";
}
@Override
public void constructUdfExecutors(
Map<String, UDTFExecutor> expressionName2Executor, ZoneId zoneId) {
String expressionString = getExpressionString();
if (expressionName2Executor.containsKey(expressionString)) {
return;
}
for (Expression expression : expressions) {
expression.constructUdfExecutors(expressionName2Executor, zoneId);
}
expressionName2Executor.put(expressionString, new UDTFExecutor(functionName, zoneId));
}
@Override
public void bindInputLayerColumnIndexWithExpression(
Map<String, List<InputLocation>> inputLocations) {
for (Expression expression : expressions) {
expression.bindInputLayerColumnIndexWithExpression(inputLocations);
}
final String digest = getExpressionString();
if (inputLocations.containsKey(digest)) {
inputColumnIndex = inputLocations.get(digest).get(0).getValueColumnIndex();
}
}
@Override
public void updateStatisticsForMemoryAssigner(LayerMemoryAssigner memoryAssigner) {
for (Expression expression : expressions) {
expression.updateStatisticsForMemoryAssigner(memoryAssigner);
}
memoryAssigner.increaseExpressionReference(this);
}
@Override
public boolean isMappable(Map<NodeRef<Expression>, TSDataType> expressionTypes) {
if (!isAggregationFunctionExpression() && !isBuiltInScalarFunctionExpression()) {
// this is a UDF function
boolean isCurrentMappable =
new UDTFInformationInferrer(functionName)
.getAccessStrategy(
expressions.stream()
.map(Expression::getExpressionString)
.collect(Collectors.toList()),
expressions.stream()
.map(f -> expressionTypes.get(NodeRef.of(f)))
.collect(Collectors.toList()),
functionAttributes)
.getAccessStrategyType()
.equals(AccessStrategy.AccessStrategyType.MAPPABLE_ROW_BY_ROW);
if (!isCurrentMappable) {
return false;
}
}
// Function expression is mappable only when all its child expressions are mappable
boolean hasNonMappableChild = false;
for (Expression child : expressions) {
if (!child.isMappable(expressionTypes)) {
hasNonMappableChild = true;
break;
}
}
return !hasNonMappableChild;
}
public List<PartialPath> getPaths() {
if (paths == null) {
paths = new ArrayList<>();
for (Expression expression : expressions) {
paths.add(
expression instanceof TimeSeriesOperand
? ((TimeSeriesOperand) expression).getPath()
: null);
}
}
return paths;
}
public List<Expression> getCountTimeExpressions() {
return this.countTimeExpressions;
}
@Override
public String getExpressionStringInternal() {
return functionName + "(" + getParametersString() + ")";
}
/**
* Generates the parameter part of the function column name.
*
* <p>Example:
*
* <p>Full column name -> udf(root.sg.d.s1, sin(root.sg.d.s1), 'key1'='value1', 'key2'='value2')
*
* <p>The parameter part -> root.sg.d.s1, sin(root.sg.d.s1), 'key1'='value1', 'key2'='value2'
*/
private String getParametersString() {
if (parametersString == null) {
StringBuilder builder = new StringBuilder();
if (!expressions.isEmpty()) {
builder.append(expressions.get(0).getExpressionString());
for (int i = 1; i < expressions.size(); ++i) {
builder.append(", ").append(expressions.get(i).getExpressionString());
}
}
if (!functionAttributes.isEmpty()) {
// Some built-in scalar functions may have different header.
if (BuiltinScalarFunction.contains(functionName)) {
BuiltInScalarFunctionHelperFactory.createHelper(functionName)
.appendFunctionAttributes(!expressions.isEmpty(), builder, functionAttributes);
} else {
appendAttributes(!expressions.isEmpty(), builder, functionAttributes);
}
}
parametersString = builder.toString();
}
return parametersString;
}
public static void appendAttributes(
boolean hasExpression, StringBuilder builder, Map<String, String> functionAttributes) {
if (hasExpression) {
builder.append(", ");
}
Iterator<Entry<String, String>> iterator = functionAttributes.entrySet().iterator();
Entry<String, String> entry = iterator.next();
builder
.append("\"")
.append(entry.getKey())
.append("\"=\"")
.append(entry.getValue())
.append("\"");
while (iterator.hasNext()) {
entry = iterator.next();
builder
.append(", ")
.append("\"")
.append(entry.getKey())
.append("\"=\"")
.append(entry.getValue())
.append("\"");
}
}
@Override
public ExpressionType getExpressionType() {
return ExpressionType.FUNCTION;
}
@Override
protected void serialize(ByteBuffer byteBuffer) {
ReadWriteIOUtils.write(functionName, byteBuffer);
ReadWriteIOUtils.write(functionAttributes, byteBuffer);
ReadWriteIOUtils.write(expressions.size(), byteBuffer);
for (Expression expression : expressions) {
Expression.serialize(expression, byteBuffer);
}
}
@Override
protected void serialize(DataOutputStream stream) throws IOException {
ReadWriteIOUtils.write(functionName, stream);
ReadWriteIOUtils.write(functionAttributes, stream);
ReadWriteIOUtils.write(expressions.size(), stream);
for (Expression expression : expressions) {
Expression.serialize(expression, stream);
}
}
}