blob: e60b6e27a71fcd59695a3dbcf839cc17f6629746 [file] [log] [blame]
/*
* Copyright 2019 WeBank
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.webank.wedatasphere.linkis.engine.execute
import com.webank.wedatasphere.linkis.common.utils.Logging
import com.webank.wedatasphere.linkis.engine.conf.EngineConfiguration
import com.webank.wedatasphere.linkis.engine.execute.CodeType.CodeType
import org.apache.commons.lang.StringUtils
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
/**
* Created by enjoyyin on 2018/9/14.
*/
trait CodeParser {
def parse(code: String, engineExecutorContext: EngineExecutorContext): Array[String]
}
abstract class SingleCodeParser extends CodeParser{
val codeType: CodeType
def canParse(codeType: String): Boolean = {
CodeType.getType(codeType) == this.codeType
}
}
abstract class CombinedEngineCodeParser extends CodeParser{
val parsers: Array[SingleCodeParser]
def getCodeType(code: String, engineExecutorContext: EngineExecutorContext): String
override def parse(code: String, engineExecutorContext: EngineExecutorContext): Array[String] = {
val codeType = getCodeType(code, engineExecutorContext)
parsers.find(_.canParse(codeType)) match {
case Some(parser) => parser.parse(code, engineExecutorContext)
case None => Array(code)
}
}
}
class ScalaCodeParser extends SingleCodeParser with Logging{
override val codeType: CodeType = CodeType.Scala
override def parse(code: String, engineExecutorContext: EngineExecutorContext): Array[String] = {
//val realCode = StringUtils.substringAfter(code, "\n")
val codeBuffer = new ArrayBuffer[String]()
val statementBuffer = new ArrayBuffer[String]()
code.split("\n").foreach{
case "" =>
case l if l.startsWith(" ") || l.startsWith("\t") => if(!l.trim.startsWith("//")) statementBuffer.append(l)
case l if l.startsWith("@") => statementBuffer.append(l)
case l if StringUtils.isNotBlank(l) =>
if(l.trim.startsWith("}")){
statementBuffer.append(l)
}else{
if(statementBuffer.nonEmpty) codeBuffer.append(statementBuffer.mkString("\n"))
statementBuffer.clear()
//statementBuffer.append("%scala")
statementBuffer.append(l)
}
case _ =>
}
if(statementBuffer.nonEmpty) codeBuffer.append(statementBuffer.mkString("\n"))
codeBuffer.toArray
}
}
class PythonCodeParser extends SingleCodeParser {
override val codeType: CodeType = CodeType.Python
val openBrackets = Array("{","(","[")
val closeBrackets = Array("}",")","]")
val LOG:Logger = LoggerFactory.getLogger(getClass)
override def parse(code: String, engineExecutorContext: EngineExecutorContext): Array[String] = {
//val realCode = StringUtils.substringAfter(code, "\n")
val bracketStack = new mutable.Stack[String]
val codeBuffer = new ArrayBuffer[String]()
val statementBuffer = new ArrayBuffer[String]()
var notDoc = true
//quotationMarks is used to optimize the three quotes problem(quotationMarks用来优化三引号问题)
var quotationMarks:Boolean = false
code.split("\n").foreach {
case "" =>
case l if l.trim.contains("\"\"\"")||l.trim.contains("""'''""") => quotationMarks = !quotationMarks
statementBuffer.append(l)
recordBrackets(bracketStack, l)
case l if quotationMarks => statementBuffer.append(l)
//shanhuang 用于修复python的引号问题
//recordBrackets(bracketStack, l)
case l if notDoc && l.startsWith("#") =>
case l if StringUtils.isNotBlank(statementBuffer.last) && statementBuffer.last.endsWith("""\""") =>
statementBuffer.append(l)
case l if notDoc && l.startsWith(" ") =>
statementBuffer.append(l)
recordBrackets(bracketStack, l.trim)
case l if notDoc && l.startsWith("\t") =>
statementBuffer.append(l)
recordBrackets(bracketStack, l.trim)
case l if notDoc && l.startsWith("@") =>
statementBuffer.append(l)
recordBrackets(bracketStack, l.trim)
case l if notDoc && l.startsWith("else") => //LOG.info("I am else")
statementBuffer.append(l)
recordBrackets(bracketStack, l.trim)
case l if notDoc && l.startsWith("elif") => //LOG.info("I am elif")
statementBuffer.append(l)
recordBrackets(bracketStack, l.trim)
case l if notDoc && StringUtils.isNotBlank(l) =>
if(statementBuffer.nonEmpty && bracketStack.isEmpty){
codeBuffer.append(statementBuffer.mkString("\n"))
statementBuffer.clear()
}
// statementBuffer.append("%python")
statementBuffer.append(l)
recordBrackets(bracketStack, l.trim)
case _ =>
}
if(statementBuffer.nonEmpty) codeBuffer.append(statementBuffer.mkString("\n"))
codeBuffer.toArray
}
def recordBrackets(bracketStack: mutable.Stack[String], l: String): Unit ={
val real = l.replace("\"\"\"", "").replace("'''", "").trim
if(StringUtils.endsWithAny(real, openBrackets)){
for(i <- (0 to real.length -1).reverse){
val token = real.substring(i, i + 1)
if(openBrackets.contains(token)){
bracketStack.push(token)
}
}
}
if(StringUtils.startsWithAny(real, closeBrackets)){
for(i <- 0 to real.length -1){
val token = real.substring(i, i + 1)
if(closeBrackets.contains(token)){
bracketStack.pop()
}
}
}
}
}
object Main{
def main(args: Array[String]): Unit = {
val codeParser = new PythonCodeParser
val code = "if True: \n print 1 \nelif N=123: \n print 456 \nelse: \n print 789"
println(code)
val arrCodes = codeParser.parse(code, null)
print(arrCodes.mkString("||\n"))
}
}
class SQLCodeParser extends SingleCodeParser {
override val codeType: CodeType = CodeType.SQL
val separator = ";"
val defaultLimit:Int = EngineConfiguration.ENGINE_DEFAULT_LIMIT.getValue
override def parse(code: String, engineExecutorContext: EngineExecutorContext): Array[String] = {
//val realCode = StringUtils.substringAfter(code, "\n")
val codeBuffer = new ArrayBuffer[String]()
def appendStatement(sqlStatement: String): Unit ={
codeBuffer.append(sqlStatement)
}
if (StringUtils.contains(code, separator)) {
StringUtils.split(code, ";").foreach{
case s if StringUtils.isBlank(s) =>
case s if isSelectCmdNoLimit(s) => appendStatement(s);
case s => appendStatement(s);
}
} else {
code match {
case s if StringUtils.isBlank(s) =>
case s if isSelectCmdNoLimit(s) => appendStatement(s);
case s => appendStatement(s);
}
}
codeBuffer.toArray
}
def isSelectCmdNoLimit(cmd: String): Boolean = {
var code = cmd.trim
if(!cmd.split("\\s+")(0).equalsIgnoreCase("select")) return false
if (code.contains("limit")) code = code.substring(code.lastIndexOf("limit")).trim
else if (code.contains("LIMIT")) code = code.substring(code.lastIndexOf("LIMIT")).trim.toLowerCase
else return true
val hasLimit = code.matches("limit\\s+\\d+\\s*;?")
if (hasLimit) {
if (code.indexOf(";") > 0) code = code.substring(5, code.length - 1).trim
else code = code.substring(5).trim
val limitNum = code.toInt
if (limitNum > defaultLimit) throw new IllegalArgumentException("We at most allowed to limit " + defaultLimit + ", but your SQL has been over the max rows.")
}
!hasLimit
}
}
object CodeType extends Enumeration {
type CodeType = Value
val Python, SQL, Scala, Shell, Other = Value
def getType(codeType: String): CodeType = codeType.toLowerCase() match {
case "python" | "pyspark" | "py" => Python
case "sql" | "hql" => SQL
case "scala" => Scala
case "shell" => Shell
case _ => Other
}
}