blob: 69a5f982c9f85e8040bc1f8d1d921ba625f1aa3c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.carbondata.common.exceptions.sql.MalformedCarbonCommandException;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.apache.spark.sql.catalyst.parser.ParserInterface;
import org.apache.spark.sql.execution.command.mutation.merge.DeleteAction;
import org.apache.spark.sql.execution.command.mutation.merge.InsertAction;
import org.apache.spark.sql.execution.command.mutation.merge.MergeAction;
import org.apache.spark.sql.execution.command.mutation.merge.UpdateAction;
import org.apache.spark.sql.merge.model.CarbonJoinExpression;
import org.apache.spark.sql.merge.model.CarbonMergeIntoModel;
import org.apache.spark.sql.merge.model.ColumnModel;
import org.apache.spark.sql.merge.model.TableModel;
import org.apache.spark.sql.parser.CarbonSqlBaseParser;
import org.apache.spark.util.SparkUtil;
public class CarbonAntlrSqlVisitor {
private final ParserInterface sparkParser;
public CarbonAntlrSqlVisitor(ParserInterface sparkParser) {
this.sparkParser = sparkParser;
}
public String visitTableAlias(CarbonSqlBaseParser.TableAliasContext ctx) {
if (null == ctx.children) {
return null;
}
String res = ctx.getChild(1).getText();
return res;
}
public MergeAction visitCarbonAssignmentList(CarbonSqlBaseParser.AssignmentListContext ctx)
throws MalformedCarbonCommandException {
// UPDATE SET assignmentList
Map<Column, Column> map = new HashMap<>();
for (int currIdx = 0; currIdx < ctx.getChildCount(); currIdx++) {
if (ctx.getChild(currIdx) instanceof CarbonSqlBaseParser.AssignmentContext) {
//Assume the actions are all use to pass value
String left = ctx.getChild(currIdx).getChild(0).getText();
if (left.split("\\.").length > 1) {
left = left.split("\\.")[1];
}
String right = ctx.getChild(currIdx).getChild(2).getText();
Column rightColumn = null;
try {
Expression expression = sparkParser.parseExpression(right);
rightColumn = new Column(expression);
} catch (Exception e) {
throw new MalformedCarbonCommandException("Parse failed: " + e.getMessage());
}
map.put(new Column(left), rightColumn);
}
}
return new UpdateAction(SparkUtil.convertMap(map), false);
}
public MergeAction visitCarbonMatchedAction(CarbonSqlBaseParser.MatchedActionContext ctx)
throws MalformedCarbonCommandException {
int childCount = ctx.getChildCount();
if (childCount == 1) {
// when matched ** delete
return new DeleteAction();
} else {
if (ctx.getChild(ctx.getChildCount() - 1) instanceof
CarbonSqlBaseParser.AssignmentListContext) {
//UPDATE SET assignmentList
return visitCarbonAssignmentList(
(CarbonSqlBaseParser.AssignmentListContext) ctx.getChild(ctx.getChildCount() - 1));
} else {
//UPDATE SET *
return new UpdateAction(null, true);
}
}
}
public InsertAction visitCarbonNotMatchedAction(CarbonSqlBaseParser.NotMatchedActionContext ctx) {
if (ctx.getChildCount() <= 2) {
//INSERT *
return InsertAction.apply(null, true);
} else {
return InsertAction.apply(null, false);
}
}
public MergeAction visitCarbonNotMatchedClause(CarbonSqlBaseParser.NotMatchedClauseContext ctx) {
int currIdx = 0;
for (; currIdx < ctx.getChildCount(); currIdx++) {
if (ctx.getChild(currIdx) instanceof CarbonSqlBaseParser.NotMatchedActionContext) {
break;
}
}
return visitCarbonNotMatchedAction(
(CarbonSqlBaseParser.NotMatchedActionContext) ctx.getChild(currIdx));
}
public MergeAction visitCarbonMatchedClause(CarbonSqlBaseParser.MatchedClauseContext ctx)
throws MalformedCarbonCommandException {
//There will be lots of childs at ctx,
// we need to find the predicate
int currIdx = 0;
for (; currIdx < ctx.getChildCount(); currIdx++) {
if (ctx.getChild(currIdx) instanceof CarbonSqlBaseParser.MatchedActionContext) {
break;
}
}
// Throw Exception in case of no Matched Action
return visitCarbonMatchedAction(
(CarbonSqlBaseParser.MatchedActionContext) ctx.getChild(currIdx));
}
public boolean containsWhenMatchedPredicateExpression(int childCount) {
return childCount > 4;
}
public boolean containsWhenNotMatchedPredicateExpression(int childCount) {
return childCount > 5;
}
public CarbonMergeIntoModel visitMergeIntoCarbonTable(CarbonSqlBaseParser.MergeIntoContext ctx)
throws MalformedCarbonCommandException {
// handle the exception msg from base parser
if (ctx.exception != null) {
throw new MalformedCarbonCommandException("Parse failed!");
}
TableModel targetTable = visitMultipartIdentifier(ctx.target);
TableModel sourceTable = visitMultipartIdentifier(ctx.source);
//Once get these two table,
//We can try to get CarbonTable
//Build a matched clause list to store the when matched and when not matched clause
int size = ctx.getChildCount();
int currIdx = 0;
Expression joinExpression = null;
List<Expression> mergeExpressions = new ArrayList<>();
List<MergeAction> mergeActions = new ArrayList<>();
// There should be two List to store the result retrieve from
// when matched / when not matched context
while (currIdx < size) {
if (ctx.getChild(currIdx) instanceof CarbonSqlBaseParser.PredicatedContext) {
//This branch will visit the Join Expression
ctx.getChild(currIdx).getChildCount();
joinExpression = this.visitCarbonPredicated(
(CarbonSqlBaseParser.PredicatedContext) ctx.getChild(currIdx));
} else if (ctx.getChild(currIdx) instanceof CarbonSqlBaseParser.MatchedClauseContext) {
//This branch will deal with the Matched Clause
Expression whenMatchedExpression = null;
//Get the whenMatched expression
try {
if (this.containsWhenMatchedPredicateExpression(ctx.getChild(currIdx).getChildCount())) {
whenMatchedExpression = sparkParser.parseExpression(
((CarbonSqlBaseParser.MatchedClauseContext) ctx.getChild(currIdx))
.booleanExpression().getText());
}
} catch (ParseException e) {
throw new MalformedCarbonCommandException("Parse failed: " + e.getMessage());
}
mergeExpressions.add(whenMatchedExpression);
mergeActions.add(visitCarbonMatchedAction(
(CarbonSqlBaseParser.MatchedActionContext) ctx.getChild(currIdx)
.getChild(ctx.getChild(currIdx).getChildCount() - 1)));
} else if (ctx.getChild(currIdx) instanceof CarbonSqlBaseParser.NotMatchedClauseContext) {
//This branch will deal with the Matched Clause
Expression whenNotMatchedExpression = null;
//Get the whenMatched expression
try {
if (this
.containsWhenNotMatchedPredicateExpression(ctx.getChild(currIdx).getChildCount())) {
whenNotMatchedExpression = sparkParser.parseExpression(
((CarbonSqlBaseParser.NotMatchedClauseContext) ctx.getChild(currIdx))
.booleanExpression().getText());
}
} catch (ParseException e) {
throw new MalformedCarbonCommandException("Parse failed: " + e.getMessage());
}
mergeExpressions.add(whenNotMatchedExpression);
CarbonSqlBaseParser.NotMatchedActionContext notMatchedActionContext =
(CarbonSqlBaseParser.NotMatchedActionContext) ctx.getChild(currIdx)
.getChild(ctx.getChild(currIdx).getChildCount() - 1);
if (notMatchedActionContext.getChildCount() <= 2) {
mergeActions.add(InsertAction.apply(null, true));
} else if (notMatchedActionContext.ASTERISK() == null) {
if (notMatchedActionContext.columns.multipartIdentifier().size() !=
notMatchedActionContext.expression().size()) {
throw new MalformedCarbonCommandException("Parse failed: size of columns " +
"is not equal to size of expression in not matched action.");
}
Map<Column, Column> insertMap = new HashMap<>();
for (int i = 0; i < notMatchedActionContext.columns.multipartIdentifier().size(); i++) {
String left = visitMultipartIdentifier(
notMatchedActionContext.columns.multipartIdentifier().get(i), "")
.getColName();
String right = notMatchedActionContext.expression().get(i).getText();
// some times the right side is literal or expression, not table column
// so we need to check the left side is a column or expression
Column rightColumn = null;
try {
Expression expression = sparkParser.parseExpression(right);
rightColumn = new Column(expression);
} catch (Exception ex) {
throw new MalformedCarbonCommandException("Parse failed: " + ex.getMessage());
}
insertMap.put(new Column(left), rightColumn);
}
mergeActions.add(InsertAction.apply(SparkUtil.convertMap(insertMap), false));
} else {
mergeActions.add(InsertAction.apply(null, false));
}
}
currIdx++;
}
return new CarbonMergeIntoModel(targetTable, sourceTable, joinExpression,
mergeExpressions, mergeActions);
}
public CarbonJoinExpression visitComparison(CarbonSqlBaseParser.ComparisonContext ctx) {
// we need to get left Expression and Right Expression
// Even get the table name and col name
ctx.getText();
String t1Name = ctx.left.getChild(0).getChild(0).getText();
String c1Name = ctx.left.getChild(0).getChild(2).getText();
String t2Name = ctx.right.getChild(0).getChild(0).getText();
String c2Name = ctx.right.getChild(0).getChild(2).getText();
return new CarbonJoinExpression(t1Name, c1Name, t2Name, c2Name);
}
public Expression visitComparison(CarbonSqlBaseParser.ComparisonContext ctx, String x) {
Expression expression = null;
try {
expression = sparkParser.parseExpression(ctx.getText());
} catch (ParseException e) {
e.printStackTrace();
}
return expression;
}
public CarbonJoinExpression visitPredicated(CarbonSqlBaseParser.PredicatedContext ctx) {
return visitComparison((CarbonSqlBaseParser.ComparisonContext) ctx.getChild(0));
}
public Expression visitCarbonPredicated(CarbonSqlBaseParser.PredicatedContext ctx) {
return visitComparison((CarbonSqlBaseParser.ComparisonContext) ctx.getChild(0), "");
}
public ColumnModel visitDereference(CarbonSqlBaseParser.DereferenceContext ctx) {
// In this part, it will return two colunm name
int count = ctx.getChildCount();
ColumnModel col = new ColumnModel();
if (count == 3) {
String tableName = ctx.getChild(0).getText();
String colName = ctx.getChild(2).getText();
col = new ColumnModel(tableName, colName);
}
return col;
}
public TableModel visitMultipartIdentifier(CarbonSqlBaseParser.MultipartIdentifierContext ctx) {
TableModel table = new TableModel();
List<CarbonSqlBaseParser.ErrorCapturingIdentifierContext> parts = ctx.parts;
if (parts.size() == 2) {
table.setDatabase(parts.get(0).getText());
table.setTable(parts.get(1).getText());
}
if (parts.size() == 1) {
table.setTable(parts.get(0).getText());
}
return table;
}
public ColumnModel visitMultipartIdentifier(CarbonSqlBaseParser.MultipartIdentifierContext ctx,
String x) {
ColumnModel column = new ColumnModel();
List<CarbonSqlBaseParser.ErrorCapturingIdentifierContext> parts = ctx.parts;
if (parts.size() == 2) {
column.setTable(parts.get(0).getText());
column.setColName(parts.get(1).getText());
}
if (parts.size() == 1) {
column.setColName(parts.get(0).getText());
}
return column;
}
public String visitUnquotedIdentifier(CarbonSqlBaseParser.UnquotedIdentifierContext ctx) {
String res = ctx.getChild(0).getText();
return res;
}
public String visitComparisonOperator(CarbonSqlBaseParser.ComparisonOperatorContext ctx) {
String res = ctx.getChild(0).getText();
return res;
}
public String visitTableIdentifier(CarbonSqlBaseParser.TableIdentifierContext ctx) {
return ctx.getChild(0).getText();
}
}