blob: f93e843ad4b82c1dcb7583291d6b8d3decb45452 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.asterix.lang.sqlpp.rewrites.visitor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.asterix.common.exceptions.CompilationException;
import org.apache.asterix.common.exceptions.ErrorCode;
import org.apache.asterix.common.functions.FunctionSignature;
import org.apache.asterix.lang.common.base.AbstractClause;
import org.apache.asterix.lang.common.base.Expression;
import org.apache.asterix.lang.common.base.ILangExpression;
import org.apache.asterix.lang.common.clause.WhereClause;
import org.apache.asterix.lang.common.expression.CallExpr;
import org.apache.asterix.lang.common.expression.FieldAccessor;
import org.apache.asterix.lang.common.expression.VariableExpr;
import org.apache.asterix.lang.common.rewrites.LangRewritingContext;
import org.apache.asterix.lang.common.struct.Identifier;
import org.apache.asterix.lang.common.struct.VarIdentifier;
import org.apache.asterix.lang.sqlpp.clause.FromClause;
import org.apache.asterix.lang.sqlpp.clause.FromTerm;
import org.apache.asterix.lang.sqlpp.clause.SelectBlock;
import org.apache.asterix.lang.sqlpp.clause.SelectClause;
import org.apache.asterix.lang.sqlpp.clause.SelectElement;
import org.apache.asterix.lang.sqlpp.clause.SelectSetOperation;
import org.apache.asterix.lang.sqlpp.expression.SelectExpression;
import org.apache.asterix.lang.sqlpp.struct.SetOperationInput;
import org.apache.asterix.lang.sqlpp.util.FunctionMapUtil;
import org.apache.asterix.lang.sqlpp.util.SqlppRewriteUtil;
import org.apache.asterix.lang.sqlpp.util.SqlppVariableUtil;
import org.apache.asterix.lang.sqlpp.visitor.base.AbstractSqlppSimpleExpressionVisitor;
import org.apache.hyracks.api.exceptions.SourceLocation;
/**
* Rewrites SQL-92 aggregate function into a SQL++ core aggregate function. <br/>
* For example
* <code>SUM(e.salary + i.bonus)</code>
* is turned into
* <code>array_sum( (FROM g AS gi SELECT ELEMENT gi.e.salary + gi.i.bonus) )</code>
* where <code>g</code> is a 'group as' variable.
* <br/>
* If the SQL-92 aggregate function call contains a filter expression then that filter expression
* becomes a WHERE clause. <br/>
* For example
* <code>SUM(e.salary + i.bonus) FILTER (WHERE e.dept = 100)</code>
* is turned into
* <code>array_sum( (FROM g AS gi WHERE gi.e.dept = 100 SELECT ELEMENT gi.e.salary + gi.i.bonus) )</code>
*/
class Sql92AggregateFunctionVisitor extends AbstractSqlppSimpleExpressionVisitor {
private final LangRewritingContext context;
private final Expression groupVar;
private final Map<VariableExpr, Identifier> groupVarFieldMap;
private final Collection<VariableExpr> preGroupContextVars;
private final Collection<VariableExpr> preGroupUnmappedVars;
private final Collection<VariableExpr> outerVars;
Sql92AggregateFunctionVisitor(LangRewritingContext context, VariableExpr groupVar,
Map<VariableExpr, Identifier> groupVarFieldMap, Collection<VariableExpr> preGroupContextVars,
Collection<VariableExpr> preGroupUnmappedVars, Collection<VariableExpr> outerVars) {
this.context = context;
this.groupVar = groupVar;
this.groupVarFieldMap = groupVarFieldMap;
this.preGroupContextVars = preGroupContextVars;
this.preGroupUnmappedVars = preGroupUnmappedVars;
this.outerVars = outerVars;
}
@Override
public Expression visit(CallExpr callExpr, ILangExpression arg) throws CompilationException {
FunctionSignature signature = callExpr.getFunctionSignature();
if (FunctionMapUtil.isSql92AggregateFunction(signature)) {
rewriteSql92AggregateFunction(callExpr, arg);
return callExpr;
} else {
return super.visit(callExpr, arg);
}
}
private void rewriteSql92AggregateFunction(CallExpr callExpr, ILangExpression arg) throws CompilationException {
FunctionSignature signature = callExpr.getFunctionSignature();
List<Expression> argList = callExpr.getExprList();
if (argList.size() != 1) {
// binary SQL-92 aggregate functions are not yet supported
throw new CompilationException(ErrorCode.COMPILATION_INVALID_PARAMETER_NUMBER, callExpr.getSourceLocation(),
signature.getName(), argList.size());
}
Expression filterExpr = callExpr.getAggregateFilterExpr();
Expression expr = argList.get(0);
Expression newExpr = wrapAggregationArgument(expr, filterExpr, groupVar, groupVarFieldMap, preGroupContextVars,
preGroupUnmappedVars, outerVars, context);
List<Expression> newExprList = new ArrayList<>(1);
newExprList.add(newExpr.accept(this, arg));
// Rewrites the SQL-92 function name to core functions,
// e.g., SUM --> array_sum
callExpr.setFunctionSignature(FunctionMapUtil.sql92ToCoreAggregateFunction(signature));
callExpr.setExprList(newExprList);
callExpr.setAggregateFilterExpr(null);
}
static Expression wrapAggregationArgument(Expression expr, Expression filterExpr, Expression groupVar,
Map<VariableExpr, Identifier> groupVarFieldMap, Collection<VariableExpr> preGroupContextVars,
Collection<VariableExpr> preGroupUnmappedVars, Collection<VariableExpr> outerVars,
LangRewritingContext context) throws CompilationException {
SourceLocation sourceLoc = expr.getSourceLocation();
// From clause
VariableExpr groupItemVar = new VariableExpr(context.newVariable());
groupItemVar.setSourceLocation(sourceLoc);
FromTerm fromTerm = new FromTerm(groupVar, groupItemVar, null, null);
fromTerm.setSourceLocation(sourceLoc);
FromClause fromClause = new FromClause(Collections.singletonList(fromTerm));
fromClause.setSourceLocation(sourceLoc);
// Where clause if filter expression is present
List<AbstractClause> whereClauseList = null;
if (filterExpr != null) {
Expression newFilterExpr = rewriteAggregationArgumentExpr(filterExpr, groupItemVar, groupVarFieldMap,
preGroupContextVars, preGroupUnmappedVars, outerVars, context);
WhereClause whereClause = new WhereClause(newFilterExpr);
whereClause.setSourceLocation(sourceLoc);
whereClauseList = new ArrayList<>(1);
whereClauseList.add(whereClause);
}
// Select clause.
Expression newExpr = rewriteAggregationArgumentExpr(expr, groupItemVar, groupVarFieldMap, preGroupContextVars,
preGroupUnmappedVars, outerVars, context);
SelectElement selectElement = new SelectElement(newExpr);
selectElement.setSourceLocation(sourceLoc);
SelectClause selectClause = new SelectClause(selectElement, null, false);
selectClause.setSourceLocation(sourceLoc);
// Construct the select expression.
SelectBlock selectBlock = new SelectBlock(selectClause, fromClause, whereClauseList, null, null);
selectBlock.setSourceLocation(sourceLoc);
SelectSetOperation selectSetOperation = new SelectSetOperation(new SetOperationInput(selectBlock, null), null);
selectSetOperation.setSourceLocation(sourceLoc);
SelectExpression selectExpr = new SelectExpression(null, selectSetOperation, null, null, true);
selectExpr.setSourceLocation(sourceLoc);
return selectExpr;
}
private static Expression rewriteAggregationArgumentExpr(Expression expr, VariableExpr groupItemVar,
Map<VariableExpr, Identifier> groupVarFieldMap, Collection<VariableExpr> preGroupContextVars,
Collection<VariableExpr> preGroupUnmappedVars, Collection<VariableExpr> outerVars,
LangRewritingContext context) throws CompilationException {
// Maps field variable expressions to field accesses.
Set<VariableExpr> freeVars = SqlppRewriteUtil.getFreeVariable(expr);
Map<Expression, Expression> varExprMap = new HashMap<>();
for (VariableExpr usedVar : freeVars) {
// Reference to a field in the group variable.
if (groupVarFieldMap.containsKey(usedVar)) {
// Rewrites to a reference to a field in the group variable.
FieldAccessor fa =
new FieldAccessor(groupItemVar, new VarIdentifier(groupVarFieldMap.get(usedVar).getValue()));
fa.setSourceLocation(usedVar.getSourceLocation());
varExprMap.put(usedVar, fa);
} else if (outerVars.contains(usedVar)) {
// Do nothing
} else if (preGroupUnmappedVars != null && preGroupUnmappedVars.contains(usedVar)) {
throw new CompilationException(ErrorCode.COMPILATION_ILLEGAL_USE_OF_IDENTIFIER,
expr.getSourceLocation(),
SqlppVariableUtil.toUserDefinedVariableName(usedVar.getVar().getValue()).getValue());
} else {
// Rewrites to a reference to a single field in the group variable.
VariableExpr preGroupVar = VariableCheckAndRewriteVisitor.pickContextVar(preGroupContextVars, usedVar);
Identifier groupVarField = groupVarFieldMap.get(preGroupVar);
if (groupVarField == null) {
throw new CompilationException(ErrorCode.COMPILATION_ILLEGAL_STATE, expr.getSourceLocation());
}
FieldAccessor faInner = new FieldAccessor(groupItemVar, groupVarField);
faInner.setSourceLocation(usedVar.getSourceLocation());
Expression faOuter = VariableCheckAndRewriteVisitor.generateFieldAccess(faInner, usedVar.getVar(),
usedVar.getSourceLocation());
varExprMap.put(usedVar, faOuter);
}
}
return SqlppRewriteUtil.substituteExpression(expr, varExprMap, context);
}
}