blob: 1542d1e9c0fa9abfa94665f5e5a165c95b8d63ee [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.asterix.bad.rules;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.asterix.bad.BADConstants;
import org.apache.asterix.bad.function.BADFunctions;
import org.apache.asterix.metadata.declared.DataSource;
import org.apache.asterix.om.base.AString;
import org.apache.asterix.om.constants.AsterixConstantValue;
import org.apache.asterix.om.functions.BuiltinFunctions;
import org.apache.commons.lang3.mutable.Mutable;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalOperator;
import org.apache.hyracks.algebricks.core.algebra.base.IOptimizationContext;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalExpressionTag;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalOperatorTag;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalVariable;
import org.apache.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.ConstantExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.ScalarFunctionCallExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.AssignOperator;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.DataSourceScanOperator;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.InnerJoinOperator;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.SelectOperator;
import org.apache.hyracks.algebricks.core.rewriter.base.IAlgebraicRewriteRule;
public class RewriteChannelTimeFunctionToLocalVarRule implements IAlgebraicRewriteRule {
@Override
public boolean rewritePre(Mutable<ILogicalOperator> opRef, IOptimizationContext context)
throws AlgebricksException {
Mutable<ILogicalExpression> exprRef;
if (opRef.getValue().getOperatorTag() == LogicalOperatorTag.INNERJOIN) {
InnerJoinOperator selectOp = (InnerJoinOperator) opRef.getValue();
exprRef = selectOp.getCondition();
} else if (opRef.getValue().getOperatorTag() == LogicalOperatorTag.SELECT) {
SelectOperator selectOp = (SelectOperator) opRef.getValue();
exprRef = selectOp.getCondition();
} else {
return false;
}
Set<Mutable<ILogicalExpression>> activeFunctionSet = new HashSet<>();
Set<LogicalVariable> needPrevDsSet = new HashSet<>();
Set<LogicalVariable> needCurrDsSet = new HashSet<>();
String channelName =
(String) context.getMetadataProvider().getConfig().getOrDefault(BADConstants.CONFIG_CHANNEL_NAME, "");
// collect active functions
boolean rewriteFunc = collectChannelTimeFunctions(exprRef, activeFunctionSet, needPrevDsSet, needCurrDsSet,
true, channelName);
if (activeFunctionSet.size() == 0) {
return rewriteFunc;
}
// add assigns for active functions
Map<LogicalVariable, LogicalVariable> prevMap = new HashMap<>();
Map<LogicalVariable, LogicalVariable> currMap = new HashMap<>();
createChannelTimeAssignOps(opRef, needPrevDsSet, needCurrDsSet, prevMap, currMap, context, channelName);
// update expressions with new vars
updateActiveFuncExprsWithVars(prevMap, currMap, activeFunctionSet);
context.computeAndSetTypeEnvironmentForOperator(opRef.getValue());
return true;
}
private void updateActiveFuncExprsWithVars(Map<LogicalVariable, LogicalVariable> prevMap,
Map<LogicalVariable, LogicalVariable> currMap, Set<Mutable<ILogicalExpression>> activeFuncExprs) {
for (Mutable<ILogicalExpression> expr : activeFuncExprs) {
AbstractFunctionCallExpression funcExpr = (AbstractFunctionCallExpression) expr.getValue();
LogicalVariable dsVar = ((VariableReferenceExpression) ((AbstractFunctionCallExpression) expr.getValue())
.getArguments().get(0).getValue()).getVariableReference();
if (funcExpr.getFunctionIdentifier() == BADFunctions.CURRENT_CHANNEL_TIME) {
expr.setValue(new VariableReferenceExpression(currMap.get(dsVar)));
} else if (funcExpr.getFunctionIdentifier() == BADFunctions.PREVIOUS_CHANNEL_TIME) {
expr.setValue(new VariableReferenceExpression(prevMap.get(dsVar)));
}
}
}
private void createChannelTimeAssignOps(Mutable<ILogicalOperator> opRef, Set<LogicalVariable> needPrevDsSet,
Set<LogicalVariable> needCurrDsSet, Map<LogicalVariable, LogicalVariable> prevMap,
Map<LogicalVariable, LogicalVariable> currMap, IOptimizationContext context, String channelName) {
ILogicalOperator currOp = opRef.getValue();
if (opRef.getValue().getOperatorTag() == LogicalOperatorTag.DATASOURCESCAN) {
DataSourceScanOperator dataScanOp = (DataSourceScanOperator) opRef.getValue();
DataSource ds = (DataSource) dataScanOp.getDataSource();
LogicalVariable dsVar = ds.getDataRecordVariable(dataScanOp.getScanVariables());
if (needPrevDsSet.contains(dsVar)) {
LogicalVariable channelTimeVar = context.newVar();
ILogicalExpression previousChannelTimeExpr = new ScalarFunctionCallExpression(
BuiltinFunctions.getBuiltinFunctionInfo(BADFunctions.PREVIOUS_CHANNEL_TIME),
new MutableObject<>(
new ConstantExpression(new AsterixConstantValue(new AString(channelName)))));
AssignOperator assignOp =
new AssignOperator(channelTimeVar, new MutableObject<>(previousChannelTimeExpr));
assignOp.getInputs().add(new MutableObject<>(opRef.getValue()));
opRef.setValue(assignOp);
prevMap.put(dsVar, channelTimeVar);
}
if (needCurrDsSet.contains(dsVar)) {
LogicalVariable channelTimeVar = context.newVar();
ILogicalExpression previousChannelTimeExpr = new ScalarFunctionCallExpression(
BuiltinFunctions.getBuiltinFunctionInfo(BADFunctions.CURRENT_CHANNEL_TIME), new MutableObject<>(
new ConstantExpression(new AsterixConstantValue(new AString(channelName)))));
AssignOperator assignOp =
new AssignOperator(channelTimeVar, new MutableObject<>(previousChannelTimeExpr));
assignOp.getInputs().add(new MutableObject<>(opRef.getValue()));
opRef.setValue(assignOp);
currMap.put(dsVar, channelTimeVar);
}
}
for (Mutable<ILogicalOperator> input : currOp.getInputs()) {
createChannelTimeAssignOps(input, needPrevDsSet, needCurrDsSet, prevMap, currMap, context, channelName);
}
}
private boolean collectChannelTimeFunctions(Mutable<ILogicalExpression> exprRef,
Set<Mutable<ILogicalExpression>> activeFunctionSet, Set<LogicalVariable> needPrevDsSet,
Set<LogicalVariable> needCurrDsSet, boolean disjunctiveFlag, String channelName) {
boolean rewriteFunc = false;
if (exprRef.getValue().getExpressionTag() == LogicalExpressionTag.FUNCTION_CALL) {
AbstractFunctionCallExpression funcExpr = (AbstractFunctionCallExpression) exprRef.getValue();
if (funcExpr.getFunctionIdentifier().equals(BuiltinFunctions.OR)) {
disjunctiveFlag = false;
}
if (funcExpr.getFunctionIdentifier() == BADFunctions.PREVIOUS_CHANNEL_TIME
|| funcExpr.getFunctionIdentifier() == BADFunctions.CURRENT_CHANNEL_TIME) {
// collect ds var to see what assign op needs to be added
LogicalExpressionTag arg0ExprTag = funcExpr.getArguments().get(0).getValue().getExpressionTag();
if (arg0ExprTag == LogicalExpressionTag.CONSTANT) {
return false;
}
if (!disjunctiveFlag) {
LogicalVariable dsVar = ((VariableReferenceExpression) funcExpr.getArguments().get(0).getValue())
.getVariableReference();
activeFunctionSet.add(exprRef);
if (funcExpr.getFunctionIdentifier() == BADFunctions.PREVIOUS_CHANNEL_TIME) {
needPrevDsSet.add(dsVar);
} else if (funcExpr.getFunctionIdentifier() == BADFunctions.CURRENT_CHANNEL_TIME) {
needCurrDsSet.add(dsVar);
}
} else {
// if disjunctive, modify to get ts and wait for introduce filter rule to work
funcExpr.getArguments().set(0, new MutableObject<>(
new ConstantExpression(new AsterixConstantValue(new AString(channelName)))));
rewriteFunc = true;
}
} else {
for (Mutable<ILogicalExpression> argExpr : funcExpr.getArguments()) {
rewriteFunc = rewriteFunc || collectChannelTimeFunctions(argExpr, activeFunctionSet, needPrevDsSet,
needCurrDsSet, disjunctiveFlag, channelName);
}
}
}
return rewriteFunc;
}
}