| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.submarine.spark.security.parser |
| |
| import org.antlr.v4.runtime.{CharStreams, CommonTokenStream} |
| import org.antlr.v4.runtime.atn.PredictionMode |
| import org.antlr.v4.runtime.misc.ParseCancellationException |
| import org.apache.spark.sql.AnalysisException |
| import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} |
| import org.apache.spark.sql.catalyst.expressions.Expression |
| import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface, PostProcessor} |
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan |
| import org.apache.spark.sql.catalyst.trees.Origin |
| import org.apache.spark.sql.types.{DataType, StructType} |
| |
| class SubmarineSqlParser(val delegate: ParserInterface) extends ParserInterface { |
| |
| private val astBuilder = new SubmarineSqlAstBuilder |
| |
| override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => |
| astBuilder.visit(parser.singleStatement()) match { |
| case plan: LogicalPlan => plan |
| case _ => delegate.parsePlan(sqlText) |
| } |
| } |
| |
| // scalastyle:off line.size.limit |
| /** |
| * Fork from `org.apache.spark.sql.catalyst.parser.AbstractSqlParser#parse(java.lang.String, scala.Function1)`. |
| * |
| * @see https://github.com/apache/spark/blob/v2.4.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala#L81 |
| */ |
| // scalastyle:on |
| private def parse[T](command: String)(toResult: SubmarineSqlBaseParser => T): T = { |
| val lexer = new SubmarineSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) |
| lexer.removeErrorListeners() |
| lexer.addErrorListener(ParseErrorListener) |
| |
| val tokenStream = new CommonTokenStream(lexer) |
| val parser = new SubmarineSqlBaseParser(tokenStream) |
| parser.addParseListener(PostProcessor) |
| parser.removeErrorListeners() |
| parser.addErrorListener(ParseErrorListener) |
| |
| try { |
| try { |
| // first, try parsing with potentially faster SLL mode |
| parser.getInterpreter.setPredictionMode(PredictionMode.SLL) |
| toResult(parser) |
| } catch { |
| case e: ParseCancellationException => |
| // if we fail, parse with LL mode |
| tokenStream.seek(0) // rewind input stream |
| parser.reset() |
| |
| // Try Again. |
| parser.getInterpreter.setPredictionMode(PredictionMode.LL) |
| toResult(parser) |
| } |
| } catch { |
| case e: ParseException if e.command.isDefined => |
| throw e |
| case e: ParseException => |
| throw e.withCommand(command) |
| case e: AnalysisException => |
| val position = Origin(e.line, e.startPosition) |
| throw new ParseException(Option(command), e.message, position, position) |
| } |
| } |
| |
| override def parseExpression(sqlText: String): Expression = { |
| delegate.parseExpression(sqlText) |
| } |
| |
| override def parseTableIdentifier(sqlText: String): TableIdentifier = { |
| delegate.parseTableIdentifier(sqlText) |
| } |
| |
| override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { |
| delegate.parseFunctionIdentifier(sqlText) |
| } |
| |
| override def parseTableSchema(sqlText: String): StructType = { |
| delegate.parseTableSchema(sqlText) |
| } |
| |
| override def parseDataType(sqlText: String): DataType = { |
| delegate.parseDataType(sqlText) |
| } |
| |
| override def parseMultipartIdentifier(sqlText: String): Seq[String] = { |
| delegate.parseMultipartIdentifier(sqlText) |
| } |
| |
| override def parseRawDataType(sqlText: String): DataType = { |
| delegate.parseRawDataType(sqlText) |
| } |
| |
| } |