blob: 326f1e7684b9c58636cd4af849d2c447162b15e0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.parser
import java.util.Locale
import java.util.concurrent.TimeUnit
import scala.collection.mutable.{ArrayBuffer, Set}
import scala.jdk.CollectionConverters._
import scala.util.{Left, Right}
import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
import org.apache.commons.codec.DecoderException
import org.apache.commons.codec.binary.Hex
import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkThrowable}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.PARTITION_SPECIFICATION
import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, ClusterBySpec}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyValue, First, Last}
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLStmt
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LEGACY_BANG_EQUALS_NOT
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.random.RandomSampler
/**
* The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
* TableIdentifier.
*/
class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import ParserUtils._
protected def withIdentClause(
ctx: IdentifierReferenceContext,
builder: Seq[String] => LogicalPlan): LogicalPlan = {
val exprCtx = ctx.expression
if (exprCtx != null) {
PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) }, builder)
} else {
builder.apply(visitMultipartIdentifier(ctx.multipartIdentifier))
}
}
protected def withFuncIdentClause(
ctx: FunctionNameContext,
builder: Seq[String] => LogicalPlan): LogicalPlan = {
val exprCtx = ctx.expression
if (exprCtx != null) {
PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) }, builder)
} else {
builder.apply(getFunctionMultiparts(ctx))
}
}
protected def withFuncIdentClause(
ctx: FunctionNameContext,
otherExprs: Seq[Expression],
builder: (Seq[String], Seq[Expression]) => Expression): Expression = {
val exprCtx = ctx.expression
if (exprCtx != null) {
ExpressionWithUnresolvedIdentifier(
withOrigin(exprCtx) { expression(exprCtx) },
otherExprs,
builder)
} else {
builder.apply(getFunctionMultiparts(ctx), otherExprs)
}
}
/**
* Override the default behavior for all visit methods. This will only return a non-null result
* when the context has only one child. This is done because there is no generic method to
* combine the results of the context children. In all other cases null is returned.
*/
override def visitChildren(node: RuleNode): AnyRef = {
if (node.getChildCount == 1) {
node.getChild(0).accept(this)
} else {
null
}
}
override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
visit(ctx.statement).asInstanceOf[LogicalPlan]
}
override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) {
visitNamedExpression(ctx.namedExpression)
}
override def visitSingleTableIdentifier(
ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) {
visitTableIdentifier(ctx.tableIdentifier)
}
override def visitSingleFunctionIdentifier(
ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) {
visitFunctionIdentifier(ctx.functionIdentifier)
}
override def visitSingleMultipartIdentifier(
ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) {
visitMultipartIdentifier(ctx.multipartIdentifier)
}
override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
typedVisit[DataType](ctx.dataType)
}
override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = {
val schema = StructType(visitColTypeList(ctx.colTypeList))
withOrigin(ctx)(schema)
}
/* ********************************************************************************************
* Plan parsing
* ******************************************************************************************** */
protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree)
/**
* Create a top-level plan with Common Table Expressions.
*/
override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) {
val query = plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)
// Apply CTEs
query.optionalMap(ctx.ctes)(withCTE)
}
override def visitDmlStatement(ctx: DmlStatementContext): AnyRef = withOrigin(ctx) {
val dmlStmt = plan(ctx.dmlStatementNoWith)
// Apply CTEs
dmlStmt.optionalMap(ctx.ctes)(withCTE)
}
private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = {
val ctes = ctx.namedQuery.asScala.map { nCtx =>
val namedQuery = visitNamedQuery(nCtx)
(namedQuery.alias, namedQuery)
}
// Check for duplicate names.
val duplicates = ctes.groupBy(_._1).filter(_._2.size > 1).keys
if (duplicates.nonEmpty) {
throw QueryParsingErrors.duplicateCteDefinitionNamesError(
duplicates.mkString("'", "', '", "'"), ctx)
}
UnresolvedWith(plan, ctes.toSeq)
}
/**
* Create a logical query plan for a hive-style FROM statement body.
*/
private def withFromStatementBody(
ctx: FromStatementBodyContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) {
// two cases for transforms and selects
if (ctx.transformClause != null) {
withTransformQuerySpecification(
ctx,
ctx.transformClause,
ctx.lateralView,
ctx.whereClause,
ctx.aggregationClause,
ctx.havingClause,
ctx.windowClause,
plan
)
} else {
withSelectQuerySpecification(
ctx,
ctx.selectClause,
ctx.lateralView,
ctx.whereClause,
ctx.aggregationClause,
ctx.havingClause,
ctx.windowClause,
plan
)
}
}
override def visitFromStatement(ctx: FromStatementContext): LogicalPlan = withOrigin(ctx) {
val from = visitFromClause(ctx.fromClause)
val selects = ctx.fromStatementBody.asScala.map { body =>
withFromStatementBody(body, from).
// Add organization statements.
optionalMap(body.queryOrganization)(withQueryResultClauses)
}
// If there are multiple SELECT just UNION them together into one query.
if (selects.length == 1) {
selects.head
} else {
Union(selects.toSeq)
}
}
/**
* Create a named logical plan.
*
* This is only used for Common Table Expressions.
*/
override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
val subQuery: LogicalPlan = plan(ctx.query).optionalMap(ctx.columnAliases)(
(columnAliases, plan) =>
UnresolvedSubqueryColumnAliases(visitIdentifierList(columnAliases), plan)
)
SubqueryAlias(ctx.name.getText, subQuery)
}
/**
* Create a logical plan which allows for multiple inserts using one 'from' statement. These
* queries have the following SQL form:
* {{{
* [WITH cte...]?
* FROM src
* [INSERT INTO tbl1 SELECT *]+
* }}}
* For example:
* {{{
* FROM db.tbl1 A
* INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5
* INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12
* }}}
* This (Hive) feature cannot be combined with set-operators.
*/
override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) {
val from = visitFromClause(ctx.fromClause)
// Build the insert clauses.
val inserts = ctx.multiInsertQueryBody.asScala.map { body =>
withInsertInto(body.insertInto,
withFromStatementBody(body.fromStatementBody, from).
optionalMap(body.fromStatementBody.queryOrganization)(withQueryResultClauses))
}
// If there are multiple INSERTS just UNION them together into one query.
if (inserts.length == 1) {
inserts.head
} else {
Union(inserts.toSeq)
}
}
/**
* Create a logical plan for a regular (single-insert) query.
*/
override def visitSingleInsertQuery(
ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) {
withInsertInto(ctx.insertInto(), visitQuery(ctx.query))
}
/**
* Parameters used for writing query to a table:
* (table ident, tableColumnList, partitionKeys, ifPartitionNotExists, byName).
*/
type InsertTableParams =
(IdentifierReferenceContext, Seq[String], Map[String, Option[String]], Boolean, Boolean)
/**
* Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
*/
type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String])
/**
* Add an
* {{{
* INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList]
* INSERT INTO [TABLE] tableIdentifier [partitionSpec] ([BY NAME] | [identifierList])
* INSERT INTO [TABLE] tableIdentifier REPLACE whereClause
* INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat]
* INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList]
* }}}
* operation to logical plan
*/
private def withInsertInto(
ctx: InsertIntoContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
ctx match {
// We cannot push withIdentClause() into the write command because:
// 1. `PlanWithUnresolvedIdentifier` is not a NamedRelation
// 2. Write commands do not hold the table logical plan as a child, and we need to add
// additional resolution code to resolve identifiers inside the write commands.
case table: InsertIntoTableContext =>
val (relationCtx, cols, partition, ifPartitionNotExists, byName)
= visitInsertIntoTable(table)
withIdentClause(relationCtx, ident => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident),
partition,
cols,
query,
overwrite = false,
ifPartitionNotExists,
byName)
})
case table: InsertOverwriteTableContext =>
val (relationCtx, cols, partition, ifPartitionNotExists, byName)
= visitInsertOverwriteTable(table)
withIdentClause(relationCtx, ident => {
InsertIntoStatement(
createUnresolvedRelation(relationCtx, ident),
partition,
cols,
query,
overwrite = true,
ifPartitionNotExists,
byName)
})
case ctx: InsertIntoReplaceWhereContext =>
withIdentClause(ctx.identifierReference, ident => {
OverwriteByExpression.byPosition(
createUnresolvedRelation(ctx.identifierReference, ident),
query,
expression(ctx.whereClause().booleanExpression()))
})
case dir: InsertOverwriteDirContext =>
val (isLocal, storage, provider) = visitInsertOverwriteDir(dir)
InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
case hiveDir: InsertOverwriteHiveDirContext =>
val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir)
InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
case _ =>
throw QueryParsingErrors.invalidInsertIntoError(ctx)
}
}
/**
* Add an INSERT INTO TABLE operation to the logical plan.
*/
override def visitInsertIntoTable(
ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
blockBang(ctx.errorCapturingNot())
if (ctx.EXISTS != null) {
invalidStatement("INSERT INTO ... IF NOT EXISTS", ctx)
}
(ctx.identifierReference, cols, partitionKeys, false, ctx.NAME() != null)
}
/**
* Add an INSERT OVERWRITE TABLE operation to the logical plan.
*/
override def visitInsertOverwriteTable(
ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
assert(ctx.OVERWRITE() != null)
val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
blockBang(ctx.errorCapturingNot())
val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) {
operationNotAllowed("IF NOT EXISTS with dynamic partitions: " +
dynamicPartitionKeys.keys.mkString(", "), ctx)
}
(ctx.identifierReference, cols, partitionKeys, ctx.EXISTS() != null, ctx.NAME() != null)
}
/**
* Write to a directory, returning a [[InsertIntoDir]] logical plan.
*/
override def visitInsertOverwriteDir(
ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) {
throw QueryParsingErrors.insertOverwriteDirectoryUnsupportedError()
}
/**
* Write to a directory, returning a [[InsertIntoDir]] logical plan.
*/
override def visitInsertOverwriteHiveDir(
ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) {
throw QueryParsingErrors.insertOverwriteDirectoryUnsupportedError()
}
private def getTableAliasWithoutColumnAlias(
ctx: TableAliasContext, op: String): Option[String] = {
if (ctx == null) {
None
} else {
val ident = ctx.strictIdentifier()
if (ctx.identifierList() != null) {
throw QueryParsingErrors.columnAliasInOperationNotAllowedError(op, ctx)
}
if (ident != null) Some(ident.getText) else None
}
}
override def visitDeleteFromTable(
ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference)
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val predicate = if (ctx.whereClause() != null) {
expression(ctx.whereClause().booleanExpression())
} else {
Literal.TrueLiteral
}
DeleteFromTable(aliasedTable, predicate)
}
override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedRelation(ctx.identifierReference)
val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE")
val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table)
val assignments = withAssignments(ctx.setClause().assignmentList())
val predicate = if (ctx.whereClause() != null) {
Some(expression(ctx.whereClause().booleanExpression()))
} else {
None
}
UpdateTable(aliasedTable, assignments, predicate)
}
protected def withAssignments(assignCtx: SqlBaseParser.AssignmentListContext): Seq[Assignment] =
withOrigin(assignCtx) {
assignCtx.assignment().asScala.map { assign =>
Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)),
expression(assign.value))
}.toSeq
}
override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) {
val withSchemaEvolution = ctx.EVOLUTION() != null
val targetTable = createUnresolvedRelation(ctx.target)
val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE")
val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable)
val sourceTableOrQuery = if (ctx.source != null) {
createUnresolvedRelation(ctx.source)
} else if (ctx.sourceQuery != null) {
visitQuery(ctx.sourceQuery)
} else {
throw QueryParsingErrors.emptySourceForMergeError(ctx)
}
val sourceTableAlias = getTableAliasWithoutColumnAlias(ctx.sourceAlias, "MERGE")
val aliasedSource =
sourceTableAlias.map(SubqueryAlias(_, sourceTableOrQuery)).getOrElse(sourceTableOrQuery)
val mergeCondition = expression(ctx.mergeCondition)
val matchedActions = ctx.matchedClause().asScala.map {
clause => {
if (clause.matchedAction().DELETE() != null) {
DeleteAction(Option(clause.matchedCond).map(expression))
} else if (clause.matchedAction().UPDATE() != null) {
val condition = Option(clause.matchedCond).map(expression)
if (clause.matchedAction().ASTERISK() != null) {
UpdateStarAction(condition)
} else {
UpdateAction(condition, withAssignments(clause.matchedAction().assignmentList()))
}
} else {
throw SparkException.internalError(
s"Unrecognized matched action: ${clause.matchedAction().getText}")
}
}
}
val notMatchedActions = ctx.notMatchedClause().asScala.map {
clause => {
if (clause.notMatchedAction().INSERT() != null) {
val condition = Option(clause.notMatchedCond).map(expression)
if (clause.notMatchedAction().ASTERISK() != null) {
InsertStarAction(condition)
} else {
val columns = clause.notMatchedAction().columns.multipartIdentifier()
.asScala.map(attr => UnresolvedAttribute(visitMultipartIdentifier(attr)))
val values = clause.notMatchedAction().expression().asScala.map(expression)
if (columns.size != values.size) {
throw QueryParsingErrors.insertedValueNumberNotMatchFieldNumberError(clause)
}
InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2)).toSeq)
}
} else {
throw SparkException.internalError(
s"Unrecognized matched action: ${clause.notMatchedAction().getText}")
}
}
}
val notMatchedBySourceActions = ctx.notMatchedBySourceClause().asScala.map {
clause => {
val notMatchedBySourceAction = clause.notMatchedBySourceAction()
if (notMatchedBySourceAction.DELETE() != null) {
DeleteAction(Option(clause.notMatchedBySourceCond).map(expression))
} else if (notMatchedBySourceAction.UPDATE() != null) {
val condition = Option(clause.notMatchedBySourceCond).map(expression)
UpdateAction(condition,
withAssignments(clause.notMatchedBySourceAction().assignmentList()))
} else {
throw SparkException.internalError(
s"Unrecognized matched action: ${clause.notMatchedBySourceAction().getText}")
}
}
}
if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) {
throw QueryParsingErrors.mergeStatementWithoutWhenClauseError(ctx)
}
// children being empty means that the condition is not set
val matchedActionSize = matchedActions.length
if (matchedActionSize >= 2 && !matchedActions.init.forall(_.condition.nonEmpty)) {
throw QueryParsingErrors.nonLastMatchedClauseOmitConditionError(ctx)
}
val notMatchedActionSize = notMatchedActions.length
if (notMatchedActionSize >= 2 && !notMatchedActions.init.forall(_.condition.nonEmpty)) {
throw QueryParsingErrors.nonLastNotMatchedClauseOmitConditionError(ctx)
}
val notMatchedBySourceActionSize = notMatchedBySourceActions.length
if (notMatchedBySourceActionSize >= 2 &&
!notMatchedBySourceActions.init.forall(_.condition.nonEmpty)) {
throw QueryParsingErrors.nonLastNotMatchedBySourceClauseOmitConditionError(ctx)
}
MergeIntoTable(
aliasedTarget,
aliasedSource,
mergeCondition,
matchedActions.toSeq,
notMatchedActions.toSeq,
notMatchedBySourceActions.toSeq,
withSchemaEvolution)
}
/**
* Returns the parameters for [[ExecuteImmediateQuery]] logical plan.
* Expected format:
* {{{
* EXECUTE IMMEDIATE {query_string|string_literal}
* [INTO target1, target2] [USING param1, param2, ...]
* }}}
*/
override def visitExecuteImmediate(ctx: ExecuteImmediateContext): LogicalPlan = withOrigin(ctx) {
// Because of how parsing rules are written, we know that either
// queryParam or targetVariable is non null - hence use Either to represent this.
val queryString = Option(ctx.queryParam.stringLit()).map(sl => Left(string(visitStringLit(sl))))
val queryVariable = Option(ctx.queryParam.multipartIdentifier)
.map(mpi => Right(UnresolvedAttribute(visitMultipartIdentifier(mpi))))
val targetVars = Option(ctx.targetVariable).toSeq
.flatMap(v => visitMultipartIdentifierList(v))
val exprs = Option(ctx.executeImmediateUsing).map {
visitExecuteImmediateUsing(_)
}.getOrElse{ Seq.empty }
ExecuteImmediateQuery(exprs, queryString.getOrElse(queryVariable.get), targetVars)
}
override def visitExecuteImmediateUsing(
ctx: ExecuteImmediateUsingContext): Seq[Expression] = withOrigin(ctx) {
val expressions = Option(ctx).toSeq
.flatMap(ctx => visitNamedExpressionSeq(ctx.params))
val resultExpr = expressions.map(e => e._1)
validateExecImmediateArguments(resultExpr, ctx)
resultExpr
}
/**
* Performs validation on the arguments to EXECUTE IMMEDIATE.
*/
private def validateExecImmediateArguments(
expressions: Seq[Expression],
ctx : ExecuteImmediateUsingContext) : Unit = {
val duplicateAliases = expressions
.filter(_.isInstanceOf[Alias])
.groupBy {
case Alias(arg, name) => name
}.filter(group => group._2.size > 1)
if (duplicateAliases.nonEmpty) {
throw QueryParsingErrors.duplicateArgumentNamesError(duplicateAliases.keys.toSeq, ctx)
}
}
override def visitMultipartIdentifierList(
ctx: MultipartIdentifierListContext): Seq[UnresolvedAttribute] = withOrigin(ctx) {
ctx.multipartIdentifier.asScala.map(typedVisit[Seq[String]]).map(new UnresolvedAttribute(_))
.toSeq
}
/**
* Create a partition specification map.
*/
override def visitPartitionSpec(
ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) {
val legacyNullAsString =
conf.getConf(SQLConf.LEGACY_PARSE_NULL_PARTITION_SPEC_AS_STRING_LITERAL)
val keepPartitionSpecAsString =
conf.getConf(SQLConf.LEGACY_KEEP_PARTITION_SPEC_AS_STRING_LITERAL)
val parts = ctx.partitionVal.asScala.map { pVal =>
// Check if the query attempted to refer to a DEFAULT column value within the PARTITION clause
// and return a specific error to help guide the user, since this is not allowed.
if (pVal.DEFAULT != null) {
throw QueryParsingErrors.defaultColumnReferencesNotAllowedInPartitionSpec(ctx)
}
val name = pVal.identifier.getText
val value = Option(pVal.constant).map(v => {
visitStringConstant(v, legacyNullAsString, keepPartitionSpecAsString)
})
name -> value
}
// Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values
// in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for
// partition columns will be done in analyzer.
if (conf.caseSensitiveAnalysis) {
checkDuplicateKeys(parts.toSeq, ctx)
} else {
checkDuplicateKeys(parts.map(kv => kv._1.toLowerCase(Locale.ROOT) -> kv._2).toSeq, ctx)
}
parts.toMap
}
/**
* Create a partition specification map without optional values.
*/
protected def visitNonOptionalPartitionSpec(
ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
visitPartitionSpec(ctx).map {
case (key, None) => throw QueryParsingErrors.emptyPartitionKeyError(key, ctx)
case (key, Some(value)) => key -> value
}
}
/**
* Convert a constant of any type into a string. This is typically used in DDL commands, and its
* main purpose is to prevent slight differences due to back to back conversions i.e.:
* String -> Literal -> String.
*/
protected def visitStringConstant(
ctx: ConstantContext,
legacyNullAsString: Boolean = false,
keepPartitionSpecAsString: Boolean = false): String = withOrigin(ctx) {
expression(ctx) match {
case Literal(null, _) if !legacyNullAsString => null
case l @ Literal(null, _) => l.toString
case l: Literal =>
if (keepPartitionSpecAsString && !ctx.isInstanceOf[StringLiteralContext]) {
ctx.getText
} else {
// TODO For v2 commands, we will cast the string back to its actual value,
// which is a waste and can be improved in the future.
Cast(l, conf.defaultStringType, Some(conf.sessionLocalTimeZone)).eval().toString
}
case other =>
throw new SparkIllegalArgumentException(
errorClass = "_LEGACY_ERROR_TEMP_3222",
messageParameters = Map("expr" -> other.sql)
)
}
}
/**
* Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These
* clauses determine the shape (ordering/partitioning/rows) of the query result.
*/
private def withQueryResultClauses(
ctx: QueryOrganizationContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
import ctx._
// Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
val withOrder = if (
!order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
// ORDER BY ...
Sort(order.asScala.map(visitSortItem).toSeq, global = true, query)
} else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
// SORT BY ...
Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query)
} else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
// DISTRIBUTE BY ...
withRepartitionByExpression(ctx, expressionList(distributeBy), query)
} else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
// SORT BY ... DISTRIBUTE BY ...
Sort(
sort.asScala.map(visitSortItem).toSeq,
global = false,
withRepartitionByExpression(ctx, expressionList(distributeBy), query))
} else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
// CLUSTER BY ...
val expressions = expressionList(clusterBy)
Sort(
expressions.map(SortOrder(_, Ascending)),
global = false,
withRepartitionByExpression(ctx, expressions, query))
} else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
// [EMPTY]
query
} else {
throw QueryParsingErrors.combinationQueryResultClausesUnsupportedError(ctx)
}
// WINDOWS
val withWindow = withOrder.optionalMap(windowClause)(withWindowClause)
// OFFSET
// - OFFSET 0 is the same as omitting the OFFSET clause
val withOffset = withWindow.optional(offset) {
Offset(typedVisit(offset), withWindow)
}
// LIMIT
// - LIMIT ALL is the same as omitting the LIMIT clause
withOffset.optional(limit) {
Limit(typedVisit(limit), withOffset)
}
}
/**
* Create a clause for DISTRIBUTE BY.
*/
protected def withRepartitionByExpression(
ctx: QueryOrganizationContext,
expressions: Seq[Expression],
query: LogicalPlan): LogicalPlan = {
throw QueryParsingErrors.distributeByUnsupportedError(ctx)
}
override def visitTransformQuerySpecification(
ctx: TransformQuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
val from = OneRowRelation().optional(ctx.fromClause) {
visitFromClause(ctx.fromClause)
}
withTransformQuerySpecification(
ctx,
ctx.transformClause,
ctx.lateralView,
ctx.whereClause,
ctx.aggregationClause,
ctx.havingClause,
ctx.windowClause,
from
)
}
override def visitRegularQuerySpecification(
ctx: RegularQuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
val from = OneRowRelation().optional(ctx.fromClause) {
visitFromClause(ctx.fromClause)
}
withSelectQuerySpecification(
ctx,
ctx.selectClause,
ctx.lateralView,
ctx.whereClause,
ctx.aggregationClause,
ctx.havingClause,
ctx.windowClause,
from
)
}
private def getAliasFunc(ctx: ParseTree): Option[Expression => String] = {
if (conf.getConf(SQLConf.STABLE_DERIVED_COLUMN_ALIAS_ENABLED)) {
Some(_ => toExprAlias(ctx))
} else {
None
}
}
override def visitNamedExpressionSeq(
ctx: NamedExpressionSeqContext): Seq[(Expression, Option[Expression => String])] = {
Option(ctx).toSeq
.flatMap(_.namedExpression.asScala)
.map(ctx => (typedVisit[Expression](ctx), getAliasFunc(ctx)))
}
override def visitExpressionSeq(
ctx: ExpressionSeqContext): Seq[(Expression, Option[Expression => String])] = {
Option(ctx).toSeq
.flatMap(_.expression.asScala)
.map(ctx => (typedVisit[Expression](ctx), getAliasFunc(ctx)))
}
/**
* Create a logical plan using a having clause.
*/
private def withHavingClause(
ctx: HavingClauseContext, plan: LogicalPlan): LogicalPlan = {
// Note that we add a cast to non-predicate expressions. If the expression itself is
// already boolean, the optimizer will get rid of the unnecessary cast.
val predicate = expression(ctx.booleanExpression) match {
case p: Predicate => p
case e => Cast(e, BooleanType)
}
UnresolvedHaving(predicate, plan)
}
/**
* Create a logical plan using a where clause.
*/
private def withWhereClause(ctx: WhereClauseContext, plan: LogicalPlan): LogicalPlan = {
Filter(expression(ctx.booleanExpression), plan)
}
/**
* Add a hive-style transform (SELECT TRANSFORM/MAP/REDUCE) query specification to a logical plan.
*/
private def withTransformQuerySpecification(
ctx: ParserRuleContext,
transformClause: TransformClauseContext,
lateralView: java.util.List[LateralViewContext],
whereClause: WhereClauseContext,
aggregationClause: AggregationClauseContext,
havingClause: HavingClauseContext,
windowClause: WindowClauseContext,
relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
if (transformClause.setQuantifier != null) {
throw QueryParsingErrors.transformNotSupportQuantifierError(transformClause.setQuantifier)
}
// Create the attributes.
val (attributes, schemaLess) = if (transformClause.colTypeList != null) {
// Typed return columns.
(DataTypeUtils.toAttributes(createSchema(transformClause.colTypeList)), false)
} else if (transformClause.identifierSeq != null) {
// Untyped return columns.
val attrs = visitIdentifierSeq(transformClause.identifierSeq).map { name =>
AttributeReference(name, StringType, nullable = true)()
}
(attrs, false)
} else {
(Seq(AttributeReference("key", StringType)(),
AttributeReference("value", StringType)()), true)
}
val plan = visitCommonSelectQueryClausePlan(
relation,
visitExpressionSeq(transformClause.expressionSeq),
lateralView,
whereClause,
aggregationClause,
havingClause,
windowClause,
isDistinct = false)
ScriptTransformation(
string(visitStringLit(transformClause.script)),
attributes,
plan,
withScriptIOSchema(
ctx,
transformClause.inRowFormat,
visitStringLit(transformClause.recordWriter),
transformClause.outRowFormat,
visitStringLit(transformClause.recordReader),
schemaLess
)
)
}
/**
* Add a regular (SELECT) query specification to a logical plan. The query specification
* is the core of the logical plan, this is where sourcing (FROM clause), projection (SELECT),
* aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place.
*
* Note that query hints are ignored (both by the parser and the builder).
*/
private def withSelectQuerySpecification(
ctx: ParserRuleContext,
selectClause: SelectClauseContext,
lateralView: java.util.List[LateralViewContext],
whereClause: WhereClauseContext,
aggregationClause: AggregationClauseContext,
havingClause: HavingClauseContext,
windowClause: WindowClauseContext,
relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
val isDistinct = selectClause.setQuantifier() != null &&
selectClause.setQuantifier().DISTINCT() != null
val plan = visitCommonSelectQueryClausePlan(
relation,
visitNamedExpressionSeq(selectClause.namedExpressionSeq),
lateralView,
whereClause,
aggregationClause,
havingClause,
windowClause,
isDistinct)
// Hint
selectClause.hints.asScala.foldRight(plan)(withHints)
}
def visitCommonSelectQueryClausePlan(
relation: LogicalPlan,
expressions: Seq[(Expression, Option[Expression => String])],
lateralView: java.util.List[LateralViewContext],
whereClause: WhereClauseContext,
aggregationClause: AggregationClauseContext,
havingClause: HavingClauseContext,
windowClause: WindowClauseContext,
isDistinct: Boolean): LogicalPlan = {
// Add lateral views.
val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate)
// Add where.
val withFilter = withLateralView.optionalMap(whereClause)(withWhereClause)
// Add aggregation or a project.
val namedExpressions = expressions.map {
case (e: NamedExpression, _) => e
case (e: Expression, aliasFunc) => UnresolvedAlias(e, aliasFunc)
}
def createProject() = if (namedExpressions.nonEmpty) {
Project(namedExpressions, withFilter)
} else {
withFilter
}
val withProject = if (aggregationClause == null && havingClause != null) {
if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) {
// If the legacy conf is set, treat HAVING without GROUP BY as WHERE.
val predicate = expression(havingClause.booleanExpression) match {
case p: Predicate => p
case e => Cast(e, BooleanType)
}
Filter(predicate, createProject())
} else {
// According to SQL standard, HAVING without GROUP BY means global aggregate.
withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter))
}
} else if (aggregationClause != null) {
val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter)
aggregate.optionalMap(havingClause)(withHavingClause)
} else {
// When hitting this branch, `having` must be null.
createProject()
}
// Distinct
val withDistinct = if (isDistinct) {
Distinct(withProject)
} else {
withProject
}
// Window
val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause)
withWindow
}
// Script Transform's input/output format.
type ScriptIOFormat =
(Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])
protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): ScriptIOFormat = {
def entry(key: String, value: StringLitContext): Seq[(String, String)] = {
Option(value).toSeq.map(x => key -> string(visitStringLit(x)))
}
// TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema
// expects a seq of pairs in which the old parsers' token names are used as keys.
// Transforming the result of visitRowFormatDelimited would be quite a bit messier than
// retrieving the key value pairs ourselves.
val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++
entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++
entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++
entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs) ++
Option(ctx.linesSeparatedBy).toSeq.map { stringLitCtx =>
val value = string(visitStringLit(stringLitCtx))
validate(
value == "\n",
s"LINES TERMINATED BY only supports newline '\\n' right now: $value",
ctx)
"TOK_TABLEROWFORMATLINES" -> value
}
(entries, None, Seq.empty, None)
}
/**
* Create a [[ScriptInputOutputSchema]].
*/
protected def withScriptIOSchema(
ctx: ParserRuleContext,
inRowFormat: RowFormatContext,
recordWriter: Token,
outRowFormat: RowFormatContext,
recordReader: Token,
schemaLess: Boolean): ScriptInputOutputSchema = {
def format(fmt: RowFormatContext): ScriptIOFormat = fmt match {
case c: RowFormatDelimitedContext =>
getRowFormatDelimited(c)
case c: RowFormatSerdeContext =>
throw QueryParsingErrors.transformWithSerdeUnsupportedError(ctx)
// SPARK-32106: When there is no definition about format, we return empty result
// to use a built-in default Serde in SparkScriptTransformationExec.
case null =>
(Nil, None, Seq.empty, None)
}
val (inFormat, inSerdeClass, inSerdeProps, reader) = format(inRowFormat)
val (outFormat, outSerdeClass, outSerdeProps, writer) = format(outRowFormat)
ScriptInputOutputSchema(
inFormat, outFormat,
inSerdeClass, outSerdeClass,
inSerdeProps, outSerdeProps,
reader, writer,
schemaLess)
}
/**
* Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma
* separated) relations here, these get converted into a single plan by condition-less inner join.
*/
override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
val relationPrimary = relation.relationPrimary()
val right = if (conf.ansiRelationPrecedence) {
visitRelation(relation)
} else {
plan(relationPrimary)
}
val join = right.optionalMap(left) { (left, right) =>
if (relation.LATERAL != null) {
relationPrimary match {
case _: AliasedQueryContext =>
case _: TableValuedFunctionContext =>
case other =>
throw QueryParsingErrors.invalidLateralJoinRelationError(other)
}
LateralJoin(left, LateralSubquery(right), Inner, None)
} else {
Join(left, right, Inner, None, JoinHint.NONE)
}
}
if (conf.ansiRelationPrecedence) join else withRelationExtensions(relation, join)
}
if (ctx.pivotClause() != null) {
if (ctx.unpivotClause() != null) {
throw QueryParsingErrors.unpivotWithPivotInFromClauseNotAllowedError(ctx)
}
if (!ctx.lateralView.isEmpty) {
throw QueryParsingErrors.lateralWithPivotInFromClauseNotAllowedError(ctx)
}
withPivot(ctx.pivotClause, from)
} else if (ctx.unpivotClause() != null) {
if (!ctx.lateralView.isEmpty) {
throw QueryParsingErrors.lateralWithUnpivotInFromClauseNotAllowedError(ctx)
}
withUnpivot(ctx.unpivotClause, from)
} else {
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
}
}
/**
* Connect two queries by a Set operator.
*
* Supported Set operators are:
* - UNION [ DISTINCT | ALL ]
* - EXCEPT [ DISTINCT | ALL ]
* - MINUS [ DISTINCT | ALL ]
* - INTERSECT [DISTINCT | ALL]
*/
override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) {
val left = plan(ctx.left)
val right = plan(ctx.right)
val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
ctx.operator.getType match {
case SqlBaseParser.UNION if all =>
Union(left, right)
case SqlBaseParser.UNION =>
Distinct(Union(left, right))
case SqlBaseParser.INTERSECT if all =>
Intersect(left, right, isAll = true)
case SqlBaseParser.INTERSECT =>
Intersect(left, right, isAll = false)
case SqlBaseParser.EXCEPT if all =>
Except(left, right, isAll = true)
case SqlBaseParser.EXCEPT =>
Except(left, right, isAll = false)
case SqlBaseParser.SETMINUS if all =>
Except(left, right, isAll = true)
case SqlBaseParser.SETMINUS =>
Except(left, right, isAll = false)
}
}
/**
* Add a [[WithWindowDefinition]] operator to a logical plan.
*/
private def withWindowClause(
ctx: WindowClauseContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
// Collect all window specifications defined in the WINDOW clause.
val baseWindowTuples = ctx.namedWindow.asScala.map {
wCtx =>
(wCtx.name.getText, typedVisit[WindowSpec](wCtx.windowSpec))
}
baseWindowTuples.groupBy(_._1).foreach { kv =>
if (kv._2.size > 1) {
throw QueryParsingErrors.repetitiveWindowDefinitionError(kv._1, ctx)
}
}
val baseWindowMap = baseWindowTuples.toMap
// Handle cases like
// window w1 as (partition by p_mfgr order by p_name
// range between 2 preceding and 2 following),
// w2 as w1
val windowMapView = baseWindowMap.transform {
case (_, WindowSpecReference(name)) =>
baseWindowMap.get(name) match {
case Some(spec: WindowSpecDefinition) =>
spec
case Some(ref) =>
throw QueryParsingErrors.invalidWindowReferenceError(name, ctx)
case None =>
throw QueryParsingErrors.cannotResolveWindowReferenceError(name, ctx)
}
case (_, spec: WindowSpecDefinition) => spec
}
// Note that mapValues creates a view instead of materialized map. We force materialization by
// mapping over identity.
WithWindowDefinition(windowMapView.map(identity), query)
}
/**
* Add an [[Aggregate]] to a logical plan.
*/
private def withAggregationClause(
ctx: AggregationClauseContext,
selectExpressions: Seq[NamedExpression],
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
if (ctx.groupingExpressionsWithGroupingAnalytics.isEmpty) {
val groupByExpressions = expressionList(ctx.groupingExpressions)
if (ctx.GROUPING != null) {
// GROUP BY ... GROUPING SETS (...)
// `groupByExpressions` can be non-empty for Hive compatibility. It may add extra grouping
// expressions that do not exist in GROUPING SETS (...), and the value is always null.
// For example, `SELECT a, b, c FROM ... GROUP BY a, b, c GROUPING SETS (a, b)`, the output
// of column `c` is always null.
val groupingSets =
ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq)
Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)),
selectExpressions, query)
} else {
// GROUP BY .... (WITH CUBE | WITH ROLLUP)?
val mappedGroupByExpressions = if (ctx.CUBE != null) {
Seq(Cube(groupByExpressions.map(Seq(_))))
} else if (ctx.ROLLUP != null) {
Seq(Rollup(groupByExpressions.map(Seq(_))))
} else {
groupByExpressions
}
Aggregate(mappedGroupByExpressions, selectExpressions, query)
}
} else {
val groupByExpressions =
ctx.groupingExpressionsWithGroupingAnalytics.asScala
.map(groupByExpr => {
val groupingAnalytics = groupByExpr.groupingAnalytics
if (groupingAnalytics != null) {
visitGroupingAnalytics(groupingAnalytics)
} else {
expression(groupByExpr.expression)
}
})
Aggregate(groupByExpressions.toSeq, selectExpressions, query)
}
}
override def visitGroupingAnalytics(
groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = {
val groupingSets = groupingAnalytics.groupingSet.asScala
.map(_.expression.asScala.map(e => expression(e)).toSeq)
if (groupingAnalytics.CUBE != null) {
// CUBE(A, B, (A, B), ()) is not supported.
if (groupingSets.exists(_.isEmpty)) {
throw QueryParsingErrors.invalidGroupingSetError("CUBE", groupingAnalytics)
}
Cube(groupingSets.toSeq)
} else if (groupingAnalytics.ROLLUP != null) {
// ROLLUP(A, B, (A, B), ()) is not supported.
if (groupingSets.exists(_.isEmpty)) {
throw QueryParsingErrors.invalidGroupingSetError("ROLLUP", groupingAnalytics)
}
Rollup(groupingSets.toSeq)
} else {
assert(groupingAnalytics.GROUPING != null && groupingAnalytics.SETS != null)
val groupingSets = groupingAnalytics.groupingElement.asScala.flatMap { expr =>
val groupingAnalytics = expr.groupingAnalytics()
if (groupingAnalytics != null) {
visitGroupingAnalytics(groupingAnalytics).selectedGroupByExprs
} else {
Seq(expr.groupingSet().expression().asScala.map(e => expression(e)).toSeq)
}
}
GroupingSets(groupingSets.toSeq)
}
}
/**
* Add [[UnresolvedHint]]s to a logical plan.
*/
private def withHints(
ctx: HintContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
var plan = query
ctx.hintStatements.asScala.reverse.foreach { stmt =>
plan = UnresolvedHint(stmt.hintName.getText,
stmt.parameters.asScala.map(expression).toSeq, plan)
}
plan
}
/**
* Add a [[Pivot]] to a logical plan.
*/
private def withPivot(
ctx: PivotClauseContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
val aggregates = Option(ctx.aggregates).toSeq
.flatMap(_.namedExpression.asScala)
.map(typedVisit[Expression])
val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) {
UnresolvedAttribute.quoted(ctx.pivotColumn.errorCapturingIdentifier.getText)
} else {
CreateStruct(
ctx.pivotColumn.identifiers.asScala.map(
identifier => UnresolvedAttribute.quoted(identifier.getText)).toSeq)
}
val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue)
Pivot(None, pivotColumn, pivotValues.toSeq, aggregates, query)
}
/**
* Create a Pivot column value with or without an alias.
*/
override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) {
val e = expression(ctx.expression)
if (ctx.errorCapturingIdentifier != null) {
Alias(e, ctx.errorCapturingIdentifier.getText)()
} else {
e
}
}
/**
* Add an [[Unpivot]] to a logical plan.
*/
private def withUnpivot(
ctx: UnpivotClauseContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
// this is needed to create unpivot and to filter unpivot for nulls further down
val valueColumnNames =
Option(ctx.unpivotOperator().unpivotSingleValueColumnClause())
.map(_.unpivotValueColumn().identifier().getText)
.map(Seq(_))
.getOrElse(
Option(ctx.unpivotOperator().unpivotMultiValueColumnClause())
.map(_.unpivotValueColumns.asScala.map(_.identifier().getText).toSeq)
.get
)
val unpivot = if (ctx.unpivotOperator().unpivotSingleValueColumnClause() != null) {
val unpivotClause = ctx.unpivotOperator().unpivotSingleValueColumnClause()
val variableColumnName = unpivotClause.unpivotNameColumn().identifier().getText
val (unpivotColumns, unpivotAliases) =
unpivotClause.unpivotColumns.asScala.map(visitUnpivotColumnAndAlias).toSeq.unzip
Unpivot(
None,
Some(unpivotColumns.map(Seq(_))),
// None when all elements are None
Some(unpivotAliases).filter(_.exists(_.isDefined)),
variableColumnName,
valueColumnNames,
query
)
} else {
val unpivotClause = ctx.unpivotOperator().unpivotMultiValueColumnClause()
val variableColumnName = unpivotClause.unpivotNameColumn().identifier().getText
val (unpivotColumns, unpivotAliases) =
unpivotClause.unpivotColumnSets.asScala.map(visitUnpivotColumnSet).toSeq.unzip
Unpivot(
None,
Some(unpivotColumns),
// None when all elements are None
Some(unpivotAliases).filter(_.exists(_.isDefined)),
variableColumnName,
valueColumnNames,
query
)
}
// exclude null values by default
val filtered = if (ctx.nullOperator == null || ctx.nullOperator.EXCLUDE() != null) {
Filter(IsNotNull(Coalesce(valueColumnNames.map(UnresolvedAttribute(_)))), unpivot)
} else {
unpivot
}
// alias unpivot result
if (ctx.errorCapturingIdentifier() != null) {
val alias = ctx.errorCapturingIdentifier().getText
SubqueryAlias(alias, filtered)
} else {
filtered
}
}
/**
* Create an Unpivot column.
*/
override def visitUnpivotColumn(ctx: UnpivotColumnContext): NamedExpression = withOrigin(ctx) {
UnresolvedAttribute(visitMultipartIdentifier(ctx.multipartIdentifier))
}
/**
* Create an Unpivot column.
*/
override def visitUnpivotColumnAndAlias(ctx: UnpivotColumnAndAliasContext):
(NamedExpression, Option[String]) = withOrigin(ctx) {
val attr = visitUnpivotColumn(ctx.unpivotColumn())
val alias = Option(ctx.unpivotAlias()).map(_.errorCapturingIdentifier().getText)
(attr, alias)
}
/**
* Create an Unpivot struct column with or without an alias.
* Each struct field is renamed to the respective value column name.
*/
override def visitUnpivotColumnSet(ctx: UnpivotColumnSetContext):
(Seq[NamedExpression], Option[String]) =
withOrigin(ctx) {
val exprs = ctx.unpivotColumns.asScala.map(visitUnpivotColumn).toSeq
val alias = Option(ctx.unpivotAlias()).map(_.errorCapturingIdentifier().getText)
(exprs, alias)
}
/**
* Add a [[Generate]] (Lateral View) to a logical plan.
*/
private def withGenerate(
query: LogicalPlan,
ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) {
val expressions = expressionList(ctx.expression)
Generate(
UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions),
unrequiredChildIndex = Nil,
outer = ctx.OUTER != null,
// scalastyle:off caselocale
Some(ctx.tblName.getText.toLowerCase),
// scalastyle:on caselocale
ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.quoted).toSeq,
query)
}
/**
* Create a single relation referenced in a FROM clause. This method is used when a part of the
* join condition is nested, for example:
* {{{
* select * from t1 join (t2 cross join t3) on col1 = col2
* }}}
*/
override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
withRelationExtensions(ctx, plan(ctx.relationPrimary))
}
private def withRelationExtensions(ctx: RelationContext, query: LogicalPlan): LogicalPlan = {
ctx.relationExtension().asScala.foldLeft(query) { (left, extension) =>
if (extension.joinRelation() != null) {
withJoinRelation(extension.joinRelation(), left)
} else if (extension.pivotClause() != null) {
withPivot(extension.pivotClause(), left)
} else {
assert(extension.unpivotClause() != null)
withUnpivot(extension.unpivotClause(), left)
}
}
}
/**
* Join one more [[LogicalPlan]] to the current logical plan.
*/
private def withJoinRelation(ctx: JoinRelationContext, base: LogicalPlan): LogicalPlan = {
withOrigin(ctx) {
val baseJoinType = ctx.joinType match {
case null => Inner
case jt if jt.CROSS != null => Cross
case jt if jt.FULL != null => FullOuter
case jt if jt.SEMI != null => LeftSemi
case jt if jt.ANTI != null => LeftAnti
case jt if jt.LEFT != null => LeftOuter
case jt if jt.RIGHT != null => RightOuter
case _ => Inner
}
if (ctx.LATERAL != null) {
ctx.right match {
case _: AliasedQueryContext =>
case _: TableValuedFunctionContext =>
case other =>
throw QueryParsingErrors.invalidLateralJoinRelationError(other)
}
}
// Resolve the join type and join condition
val (joinType, condition) = Option(ctx.joinCriteria) match {
case Some(c) if c.USING != null =>
if (ctx.LATERAL != null) {
throw QueryParsingErrors.lateralJoinWithUsingJoinUnsupportedError(ctx)
}
(UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None)
case Some(c) if c.booleanExpression != null =>
(baseJoinType, Option(expression(c.booleanExpression)))
case Some(c) =>
throw SparkException.internalError(s"Unimplemented joinCriteria: $c")
case None if ctx.NATURAL != null =>
if (ctx.LATERAL != null) {
throw QueryParsingErrors.incompatibleJoinTypesError(
joinType1 = ctx.LATERAL.toString, joinType2 = ctx.NATURAL.toString, ctx = ctx
)
}
if (baseJoinType == Cross) {
throw QueryParsingErrors.incompatibleJoinTypesError(
joinType1 = ctx.NATURAL.toString, joinType2 = baseJoinType.toString, ctx = ctx
)
}
(NaturalJoin(baseJoinType), None)
case None =>
(baseJoinType, None)
}
if (ctx.LATERAL != null) {
if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) {
throw QueryParsingErrors.unsupportedLateralJoinTypeError(ctx, joinType.sql)
}
LateralJoin(base, LateralSubquery(plan(ctx.right)), joinType, condition)
} else {
Join(base, plan(ctx.right), joinType, condition, JoinHint.NONE)
}
}
}
/**
* Add a [[Sample]] to a logical plan.
*
* This currently supports the following sampling methods:
* - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
* - TABLESAMPLE(x PERCENT) [REPEATABLE (y)]: Sample the table down to the given percentage with
* seed 'y'. Note that percentages are defined as a number between 0 and 100.
* - TABLESAMPLE(BUCKET x OUT OF y) [REPEATABLE (z)]: Sample the table down to a 'x' divided by
* 'y' fraction with seed 'z'.
*/
private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
// Create a sampled plan if we need one.
def sample(fraction: Double, seed: Long): Sample = {
// The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
// function takes X PERCENT as the input and the range of X is [0, 100], we need to
// adjust the fraction.
val eps = RandomSampler.roundingEpsilon
validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
s"Sampling fraction ($fraction) must be on interval [0, 1]",
ctx)
Sample(0.0, fraction, withReplacement = false, seed, query)
}
if (ctx.sampleMethod() == null) {
throw QueryParsingErrors.emptyInputForTableSampleError(ctx)
}
val seed = if (ctx.seed != null) {
ctx.seed.getText.toLong
} else {
(math.random() * 1000).toLong
}
ctx.sampleMethod() match {
case ctx: SampleByRowsContext =>
Limit(expression(ctx.expression), query)
case ctx: SampleByPercentileContext =>
val fraction = ctx.percentage.getText.toDouble
val sign = if (ctx.negativeSign == null) 1 else -1
sample(sign * fraction / 100.0d, seed)
case ctx: SampleByBytesContext =>
val bytesStr = ctx.bytes.getText
if (bytesStr.matches("[0-9]+[bBkKmMgG]")) {
throw QueryParsingErrors.tableSampleByBytesUnsupportedError("byteLengthLiteral", ctx)
} else {
throw QueryParsingErrors.invalidByteLengthLiteralError(bytesStr, ctx)
}
case ctx: SampleByBucketContext if ctx.ON() != null =>
if (ctx.identifier != null) {
throw QueryParsingErrors.tableSampleByBytesUnsupportedError(
"BUCKET x OUT OF y ON colname", ctx)
} else {
throw QueryParsingErrors.tableSampleByBytesUnsupportedError(
"BUCKET x OUT OF y ON function", ctx)
}
case ctx: SampleByBucketContext =>
sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble, seed)
}
}
/**
* Create a logical plan for a sub-query.
*/
override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) {
plan(ctx.query)
}
/**
* Create an un-aliased table reference. This is typically used for top-level table references,
* for example:
* {{{
* INSERT INTO db.tbl2
* TABLE db.tbl1
* }}}
*/
override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) {
createUnresolvedRelation(ctx.identifierReference)
}
/**
* Create an aliased table reference. This is typically used in FROM clauses.
*/
override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) {
val relation = createUnresolvedRelation(ctx.identifierReference)
val table = mayApplyAliasPlan(
ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel))
table.optionalMap(ctx.sample)(withSample)
}
override def visitVersion(ctx: VersionContext): Option[String] = {
if (ctx != null) {
if (ctx.INTEGER_VALUE != null) {
Some(ctx.INTEGER_VALUE().getText)
} else {
Option(string(visitStringLit(ctx.stringLit())))
}
} else {
None
}
}
private def extractNamedArgument(expr: FunctionArgumentContext, funcName: String) : Expression = {
Option(expr.namedArgumentExpression).map { n =>
if (conf.getConf(SQLConf.ALLOW_NAMED_FUNCTION_ARGUMENTS)) {
NamedArgumentExpression(n.key.getText, expression(n.value))
} else {
throw QueryCompilationErrors.namedArgumentsNotEnabledError(funcName, n.key.getText)
}
}.getOrElse {
expression(expr)
}
}
private def withTimeTravel(
ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) {
val v = ctx.version
val version = visitVersion(ctx.version)
val timestamp = Option(ctx.timestamp).map(expression)
if (timestamp.exists(_.references.nonEmpty)) {
throw QueryParsingErrors.invalidTimeTravelSpec(
"timestamp expression cannot refer to any columns", ctx.timestamp)
}
RelationTimeTravel(plan, timestamp, version)
}
/**
* Create a relation argument for a table-valued function argument.
*/
override def visitFunctionTableSubqueryArgument(
ctx: FunctionTableSubqueryArgumentContext): Expression = withOrigin(ctx) {
val p = Option(ctx.identifierReference).map { r =>
// Make sure that the identifier after the TABLE keyword is surrounded by parentheses, as
// required by the SQL standard. If not, return an informative error message.
if (ctx.LEFT_PAREN() == null) {
throw QueryParsingErrors.invalidTableFunctionIdentifierArgumentMissingParentheses(
ctx, argumentName = ctx.identifierReference().getText)
}
createUnresolvedRelation(r)
}.getOrElse {
plan(ctx.query)
}
var withSinglePartition = false
var partitionByExpressions = Seq.empty[Expression]
var orderByExpressions = Seq.empty[SortOrder]
Option(ctx.tableArgumentPartitioning).foreach { p =>
if (p.SINGLE != null) {
withSinglePartition = true
}
partitionByExpressions = p.partition.asScala.map(expression).toSeq
orderByExpressions = p.sortItem.asScala.map(visitSortItem).toSeq
def invalidPartitionOrOrderingExpression(clause: String): String = {
"The table function call includes a table argument with an invalid " +
s"partitioning/ordering specification: the $clause clause included multiple " +
"expressions without parentheses surrounding them; please add parentheses around " +
"these expressions and then retry the query again"
}
validate(
Option(p.invalidMultiPartitionExpression).isEmpty,
message = invalidPartitionOrOrderingExpression("PARTITION BY"),
ctx = p.invalidMultiPartitionExpression)
validate(
Option(p.invalidMultiSortItem).isEmpty,
message = invalidPartitionOrOrderingExpression("ORDER BY"),
ctx = p.invalidMultiSortItem)
}
validate(
!(withSinglePartition && partitionByExpressions.nonEmpty),
message = "WITH SINGLE PARTITION cannot be specified if PARTITION BY is also present",
ctx = ctx.tableArgumentPartitioning)
validate(
!(orderByExpressions.nonEmpty && partitionByExpressions.isEmpty && !withSinglePartition),
message = "ORDER BY cannot be specified unless either " +
"PARTITION BY or WITH SINGLE PARTITION is also present",
ctx = ctx.tableArgumentPartitioning)
FunctionTableSubqueryArgumentExpression(
plan = p,
partitionByExpressions = partitionByExpressions,
withSinglePartition = withSinglePartition,
orderByExpressions = orderByExpressions)
}
private def extractFunctionTableNamedArgument(
expr: FunctionTableReferenceArgumentContext, funcName: String) : Expression = {
Option(expr.functionTableNamedArgumentExpression).map { n =>
if (conf.getConf(SQLConf.ALLOW_NAMED_FUNCTION_ARGUMENTS)) {
NamedArgumentExpression(
n.key.getText, visitFunctionTableSubqueryArgument(n.functionTableSubqueryArgument))
} else {
throw QueryCompilationErrors.namedArgumentsNotEnabledError(funcName, n.key.getText)
}
}.getOrElse {
visitFunctionTableSubqueryArgument(expr.functionTableSubqueryArgument)
}
}
/**
* Create a table-valued function call with arguments, e.g. range(1000)
*/
override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
: LogicalPlan = withOrigin(ctx) {
val func = ctx.functionTable
val aliases = if (func.tableAlias.identifierList != null) {
visitIdentifierList(func.tableAlias.identifierList)
} else {
Seq.empty
}
withFuncIdentClause(
func.functionName,
ident => {
if (ident.length > 1) {
throw QueryParsingErrors.invalidTableValuedFunctionNameError(ident, ctx)
}
val funcName = func.functionName.getText
val args = func.functionTableArgument.asScala.map { e =>
Option(e.functionArgument).map(extractNamedArgument(_, funcName))
.getOrElse {
extractFunctionTableNamedArgument(e.functionTableReferenceArgument, funcName)
}
}.toSeq
val tvf = UnresolvedTableValuedFunction(ident, args)
val tvfAliases = if (aliases.nonEmpty) UnresolvedTVFAliases(ident, tvf, aliases) else tvf
tvfAliases.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan)
})
}
/**
* Create an inline table (a virtual table in Hive parlance).
*/
override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
// Get the backing expressions.
val rows = ctx.expression.asScala.map { e =>
expression(e) match {
// inline table comes in two styles:
// style 1: values (1), (2), (3) -- multiple columns are supported
// style 2: values 1, 2, 3 -- only a single column is supported here
case struct: CreateNamedStruct => struct.valExprs // style 1
case child => Seq(child) // style 2
}
}
val aliases = if (ctx.tableAlias.identifierList != null) {
visitIdentifierList(ctx.tableAlias.identifierList)
} else {
Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
}
val table = UnresolvedInlineTable(aliases, rows.toSeq)
table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
}
/**
* Create an alias (SubqueryAlias) for a join relation. This is practically the same as
* visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different
* hooks. We could add alias names for output columns, for example:
* {{{
* SELECT a, b, c, d FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d)
* }}}
*/
override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) {
val relation = plan(ctx.relation).optionalMap(ctx.sample)(withSample)
mayApplyAliasPlan(ctx.tableAlias, relation)
}
/**
* Create an alias (SubqueryAlias) for a sub-query. This is practically the same as
* visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different
* hooks. We could add alias names for output columns, for example:
* {{{
* SELECT col1, col2 FROM testData AS t(col1, col2)
* }}}
*/
override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) {
val relation = plan(ctx.query).optionalMap(ctx.sample)(withSample)
if (ctx.tableAlias.strictIdentifier == null) {
// For un-aliased subqueries, use a default alias name that is not likely to conflict with
// normal subquery names, so that parent operators can only access the columns in subquery by
// unqualified names. Users can still use this special qualifier to access columns if they
// know it, but that's not recommended.
SubqueryAlias(SubqueryAlias.generateSubqueryName(), relation)
} else {
mayApplyAliasPlan(ctx.tableAlias, relation)
}
}
/**
* Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]].
*/
private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = {
SubqueryAlias(alias.getText, plan)
}
/**
* If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and
* column aliases for a [[LogicalPlan]].
*/
private def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = {
if (tableAlias.strictIdentifier != null) {
val alias = tableAlias.strictIdentifier.getText
if (tableAlias.identifierList != null) {
val columnNames = visitIdentifierList(tableAlias.identifierList)
SubqueryAlias(alias, UnresolvedSubqueryColumnAliases(columnNames, plan))
} else {
SubqueryAlias(alias, plan)
}
} else {
plan
}
}
/**
* Create a Sequence of Strings for a parenthesis enclosed alias list.
*/
override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
visitIdentifierSeq(ctx.identifierSeq)
}
/**
* Create a Sequence of Strings for an identifier list.
*/
override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
ctx.ident.asScala.map(_.getText).toSeq
}
/* ********************************************************************************************
* Table Identifier parsing
* ******************************************************************************************** */
/**
* Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
*/
override def visitTableIdentifier(
ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
}
/**
* Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern.
*/
override def visitFunctionIdentifier(
ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) {
FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText))
}
/**
* Create a multi-part identifier.
*/
override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
withOrigin(ctx) {
ctx.parts.asScala.map(_.getText).toSeq
}
/* ********************************************************************************************
* Expression parsing
* ******************************************************************************************** */
/**
* Create an expression from the given context. This method just passes the context on to the
* visitor and only takes care of typing (We assume that the visitor returns an Expression here).
*/
protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx)
/**
* Create sequence of expressions from the given sequence of contexts.
*/
private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = {
trees.asScala.map(expression).toSeq
}
/**
* Create a star (i.e. all) expression; this selects all elements (in the specified object).
* Both un-targeted (global) and targeted aliases are supported.
*/
override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) {
val target = Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText).toSeq)
if (ctx.exceptClause != null) {
visitStarExcept(ctx, target)
}
else {
UnresolvedStar(target)
}
}
/**
* Create a star-except (i.e. all - except list) expression; this selects all elements in the
* specified object except those in the except list.
* Both un-targeted (global) and targeted aliases are supported.
*/
def visitStarExcept(ctx: StarContext, target: Option[Seq[String]]): Expression = withOrigin(ctx) {
val exceptCols = ctx.exceptClause
.exceptCols.multipartIdentifier.asScala.map(typedVisit[Seq[String]])
UnresolvedStarExcept(
target,
exceptCols.toSeq)
}
/**
* Check for the inappropriate usage of the '!' token.
* '!' used to be a synonym for 'NOT' in the lexer, but that was too general.
* '!' should only be a synonym for 'NOT' when used as a prefix in a logical operation.
* We do that now explicitly.
*/
def blockBang(ctx: ErrorCapturingNotContext): ErrorCapturingNotContext = {
val tolerateBang = conf.getConf(LEGACY_BANG_EQUALS_NOT)
if (ctx != null && ctx.BANG() != null && !tolerateBang) {
withOrigin(ctx) {
throw new ParseException(
errorClass = "SYNTAX_DISCONTINUED.BANG_EQUALS_NOT",
messageParameters = Map("clause" -> toSQLStmt("!")),
ctx)
}
}
ctx
}
/**
* Create an aliased expression if an alias is specified. Both single and multi-aliases are
* supported.
*/
override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) {
val e = expression(ctx.expression)
if (ctx.name != null) {
Alias(e, ctx.name.getText)()
} else if (ctx.identifierList != null) {
MultiAlias(e, visitIdentifierList(ctx.identifierList))
} else {
e
}
}
/**
* Combine a number of boolean expressions into a balanced expression tree. These expressions are
* either combined by a logical [[And]] or a logical [[Or]].
*
* A balanced binary tree is created because regular left recursive trees cause considerable
* performance degradations and can cause stack overflows.
*/
override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) {
val expressionType = ctx.operator.getType
val expressionCombiner = expressionType match {
case SqlBaseParser.AND => And.apply _
case SqlBaseParser.OR => Or.apply _
}
// Collect all similar left hand contexts.
val contexts = ArrayBuffer(ctx.right)
var current = ctx.left
def collectContexts: Boolean = current match {
case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType =>
contexts += lbc.right
current = lbc.left
true
case _ =>
contexts += current
false
}
while (collectContexts) {
// No body - all updates take place in the collectContexts.
}
// Reverse the contexts to have them in the same sequence as in the SQL statement & turn them
// into expressions.
val expressions = contexts.reverseIterator.map(expression).to(ArrayBuffer)
// Create a balanced tree.
def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match {
case 0 =>
expressions(low)
case 1 =>
expressionCombiner(expressions(low), expressions(high))
case x =>
val mid = low + x / 2
expressionCombiner(
reduceToExpressionTree(low, mid),
reduceToExpressionTree(mid + 1, high))
}
reduceToExpressionTree(0, expressions.size - 1)
}
/**
* Invert a boolean expression.
*/
override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) {
Not(expression(ctx.booleanExpression()))
}
/**
* Create a filtering correlated sub-query (EXISTS).
*/
override def visitExists(ctx: ExistsContext): Expression = {
Exists(plan(ctx.query))
}
/**
* Create a comparison expression. This compares two expressions. The following comparison
* operators are supported:
* - Equal: '=' or '=='
* - Null-safe Equal: '<=>'
* - Not Equal: '<>' or '!='
* - Less than: '<'
* - Less than or Equal: '<='
* - Greater than: '>'
* - Greater than or Equal: '>='
*/
override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) {
val left = expression(ctx.left)
val right = expression(ctx.right)
val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
operator.getSymbol.getType match {
case SqlBaseParser.EQ =>
EqualTo(left, right)
case SqlBaseParser.NSEQ =>
EqualNullSafe(left, right)
case SqlBaseParser.NEQ | SqlBaseParser.NEQJ =>
Not(EqualTo(left, right))
case SqlBaseParser.LT =>
LessThan(left, right)
case SqlBaseParser.LTE =>
LessThanOrEqual(left, right)
case SqlBaseParser.GT =>
GreaterThan(left, right)
case SqlBaseParser.GTE =>
GreaterThanOrEqual(left, right)
}
}
/**
* Create a predicated expression. A predicated expression is a normal expression with a
* predicate attached to it, for example:
* {{{
* a + 1 IS NULL
* }}}
*/
override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) {
val e = expression(ctx.valueExpression)
if (ctx.predicate != null) {
withPredicate(e, ctx.predicate)
} else {
e
}
}
/**
* Add a predicate to the given expression. Supported expressions are:
* - (NOT) BETWEEN
* - (NOT) IN
* - (NOT) (LIKE | ILIKE) (ANY | SOME | ALL)
* - (NOT) RLIKE
* - IS (NOT) NULL.
* - IS (NOT) (TRUE | FALSE | UNKNOWN)
* - IS (NOT) DISTINCT FROM
*/
private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) {
// Invert a predicate if it has a valid NOT clause.
def invertIfNotDefined(e: Expression): Expression = {
val withNot = blockBang(ctx.errorCapturingNot)
withNot match {
case null => e
case _ => Not(e)
}
}
def getValueExpressions(e: Expression): Seq[Expression] = e match {
case c: CreateNamedStruct => c.valExprs
case other => Seq(other)
}
def lowerLikeArgsIfNeeded(
expr: Expression,
patterns: Seq[UTF8String]): (Expression, Seq[UTF8String]) = ctx.kind.getType match {
// scalastyle:off caselocale
case SqlBaseParser.ILIKE => (Lower(expr), patterns.map(_.toLowerCase))
// scalastyle:on caselocale
case _ => (expr, patterns)
}
def getLike(expr: Expression, pattern: Expression): Expression = ctx.kind.getType match {
case SqlBaseParser.ILIKE => new ILike(expr, pattern)
case _ => new Like(expr, pattern)
}
val withNot = blockBang(ctx.errorCapturingNot)
// Create the predicate.
ctx.kind.getType match {
case SqlBaseParser.BETWEEN =>
invertIfNotDefined(UnresolvedFunction(
"between", Seq(e, expression(ctx.lower), expression(ctx.upper)), isDistinct = false))
case SqlBaseParser.IN if ctx.query != null =>
invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query))))
case SqlBaseParser.IN =>
invertIfNotDefined(In(e, ctx.expression.asScala.map(expression).toSeq))
case SqlBaseParser.LIKE | SqlBaseParser.ILIKE =>
Option(ctx.quantifier).map(_.getType) match {
case Some(SqlBaseParser.ANY) | Some(SqlBaseParser.SOME) =>
validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx)
val expressions = expressionList(ctx.expression)
if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) {
// If there are many pattern expressions, will throw StackOverflowError.
// So we use LikeAny or NotLikeAny instead.
val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String])
val (expr, pat) = lowerLikeArgsIfNeeded(e, patterns)
withNot match {
case null => LikeAny(expr, pat)
case _ => NotLikeAny(expr, pat)
}
} else {
ctx.expression.asScala.map(expression)
.map(p => invertIfNotDefined(getLike(e, p))).toSeq.reduceLeft(Or)
}
case Some(SqlBaseParser.ALL) =>
validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx)
val expressions = expressionList(ctx.expression)
if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) {
// If there are many pattern expressions, will throw StackOverflowError.
// So we use LikeAll or NotLikeAll instead.
val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String])
val (expr, pat) = lowerLikeArgsIfNeeded(e, patterns)
withNot match {
case null => LikeAll(expr, pat)
case _ => NotLikeAll(expr, pat)
}
} else {
ctx.expression.asScala.map(expression)
.map(p => invertIfNotDefined(getLike(e, p))).toSeq.reduceLeft(And)
}
case _ =>
val escapeChar = Option(ctx.escapeChar)
.map(stringLitCtx => string(visitStringLit(stringLitCtx))).map { str =>
if (str.length != 1) {
throw QueryParsingErrors.invalidEscapeStringError(str, ctx)
}
str.charAt(0)
}.getOrElse('\\')
val likeExpr = ctx.kind.getType match {
case SqlBaseParser.ILIKE => ILike(e, expression(ctx.pattern), escapeChar)
case _ => Like(e, expression(ctx.pattern), escapeChar)
}
invertIfNotDefined(likeExpr)
}
case SqlBaseParser.RLIKE =>
invertIfNotDefined(RLike(e, expression(ctx.pattern)))
case SqlBaseParser.NULL if withNot != null =>
IsNotNull(e)
case SqlBaseParser.NULL =>
IsNull(e)
case SqlBaseParser.TRUE => withNot match {
case null => EqualNullSafe(e, Literal(true))
case _ => Not(EqualNullSafe(e, Literal(true)))
}
case SqlBaseParser.FALSE => withNot match {
case null => EqualNullSafe(e, Literal(false))
case _ => Not(EqualNullSafe(e, Literal(false)))
}
case SqlBaseParser.UNKNOWN => withNot match {
case null => IsUnknown(e)
case _ => IsNotUnknown(e)
}
case SqlBaseParser.DISTINCT if withNot != null =>
EqualNullSafe(e, expression(ctx.right))
case SqlBaseParser.DISTINCT =>
Not(EqualNullSafe(e, expression(ctx.right)))
}
}
/**
* Create a binary arithmetic expression. The following arithmetic operators are supported:
* - Multiplication: '*'
* - Division: '/'
* - Hive Long Division: 'DIV'
* - Modulo: '%'
* - Addition: '+'
* - Subtraction: '-'
* - Binary AND: '&'
* - Binary XOR
* - Binary OR: '|'
*/
override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) {
val left = expression(ctx.left)
val right = expression(ctx.right)
ctx.operator.getType match {
case SqlBaseParser.ASTERISK =>
Multiply(left, right)
case SqlBaseParser.SLASH =>
Divide(left, right)
case SqlBaseParser.PERCENT =>
Remainder(left, right)
case SqlBaseParser.DIV =>
IntegralDivide(left, right)
case SqlBaseParser.PLUS =>
Add(left, right)
case SqlBaseParser.MINUS =>
Subtract(left, right)
case SqlBaseParser.CONCAT_PIPE =>
Concat(left :: right :: Nil)
case SqlBaseParser.AMPERSAND =>
BitwiseAnd(left, right)
case SqlBaseParser.HAT =>
BitwiseXor(left, right)
case SqlBaseParser.PIPE =>
BitwiseOr(left, right)
}
}
/**
* Create a unary arithmetic expression. The following arithmetic operators are supported:
* - Plus: '+'
* - Minus: '-'
* - Bitwise Not: '~'
*/
override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) {
val value = expression(ctx.valueExpression)
ctx.operator.getType match {
case SqlBaseParser.PLUS =>
UnaryPositive(value)
case SqlBaseParser.MINUS =>
UnaryMinus(value)
case SqlBaseParser.TILDE =>
BitwiseNot(value)
}
}
override def visitCurrentLike(ctx: CurrentLikeContext): Expression = withOrigin(ctx) {
if (conf.enforceReservedKeywords) {
ctx.name.getType match {
case SqlBaseParser.CURRENT_DATE =>
CurrentDate()
case SqlBaseParser.CURRENT_TIMESTAMP =>
CurrentTimestamp()
case SqlBaseParser.CURRENT_USER | SqlBaseParser.USER | SqlBaseParser.SESSION_USER =>
CurrentUser()
}
} else {
// If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there
// are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP`.
UnresolvedAttribute.quoted(ctx.name.getText)
}
}
/**
* Create a [[Collate]] expression.
*/
override def visitCollate(ctx: CollateContext): Expression = withOrigin(ctx) {
val collationName = visitCollateClause(ctx.collateClause())
Collate(expression(ctx.primaryExpression), collationName)
}
override def visitCollateClause(ctx: CollateClauseContext): String = withOrigin(ctx) {
if (!SQLConf.get.collationEnabled) {
throw QueryCompilationErrors.collationNotEnabledError()
}
ctx.identifier.getText
}
/**
* Create a [[Cast]] expression.
*/
override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
val rawDataType = typedVisit[DataType](ctx.dataType())
val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
ctx.name.getType match {
case SqlBaseParser.CAST =>
val cast = Cast(expression(ctx.expression), dataType)
cast.setTagValue(Cast.USER_SPECIFIED_CAST, ())
cast
case SqlBaseParser.TRY_CAST =>
val cast = Cast(expression(ctx.expression), dataType, evalMode = EvalMode.TRY)
cast.setTagValue(Cast.USER_SPECIFIED_CAST, ())
cast
}
}
/**
* Create a [[Cast]] expression for '::' syntax.
*/
override def visitCastByColon(ctx: CastByColonContext): Expression = withOrigin(ctx) {
val rawDataType = typedVisit[DataType](ctx.dataType())
val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
val cast = Cast(expression(ctx.primaryExpression), dataType)
cast.setTagValue(Cast.USER_SPECIFIED_CAST, ())
cast
}
/**
* Create a [[CreateStruct]] expression.
*/
override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) {
CreateStruct.create(ctx.argument.asScala.map(expression).toSeq)
}
/**
* Create a [[First]] expression.
*/
override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) {
val ignoreNullsExpr = ctx.IGNORE != null
First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
}
/**
* Create an [[AnyValue]] expression.
*/
override def visitAny_value(ctx: Any_valueContext): Expression = withOrigin(ctx) {
val ignoreNullsExpr = ctx.IGNORE != null
AnyValue(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
}
/**
* Create a [[Last]] expression.
*/
override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) {
val ignoreNullsExpr = ctx.IGNORE != null
Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
}
/**
* Create a Position expression.
*/
override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) {
new StringLocate(expression(ctx.substr), expression(ctx.str))
}
/**
* Create a Extract expression.
*/
override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) {
val arguments = Seq(Literal(ctx.field.getText), expression(ctx.source))
UnresolvedFunction("extract", arguments, isDistinct = false)
}
/**
* Create a Substring/Substr expression.
*/
override def visitSubstring(ctx: SubstringContext): Expression = withOrigin(ctx) {
if (ctx.len != null) {
Substring(expression(ctx.str), expression(ctx.pos), expression(ctx.len))
} else {
new Substring(expression(ctx.str), expression(ctx.pos))
}
}
/**
* Create a Trim expression.
*/
override def visitTrim(ctx: TrimContext): Expression = withOrigin(ctx) {
val srcStr = expression(ctx.srcStr)
val trimStr = Option(ctx.trimStr).map(expression)
Option(ctx.trimOption).map(_.getType).getOrElse(SqlBaseParser.BOTH) match {
case SqlBaseParser.BOTH =>
StringTrim(srcStr, trimStr)
case SqlBaseParser.LEADING =>
StringTrimLeft(srcStr, trimStr)
case SqlBaseParser.TRAILING =>
StringTrimRight(srcStr, trimStr)
case other =>
throw QueryParsingErrors.trimOptionUnsupportedError(other, ctx)
}
}
/**
* Create a Overlay expression.
*/
override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) {
val input = expression(ctx.input)
val replace = expression(ctx.replace)
val position = expression(ctx.position)
val lengthOpt = Option(ctx.length).map(expression)
lengthOpt match {
case Some(length) => Overlay(input, replace, position, length)
case None => new Overlay(input, replace, position)
}
}
/**
* Create a (windowed) Function expression.
*/
override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
// Create the function call.
val name = ctx.functionName.getText
val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
// Call `toSeq`, otherwise `ctx.argument.asScala.map(expression)` is `Buffer` in Scala 2.13
val arguments = ctx.argument.asScala.map { e =>
extractNamedArgument(e, name)
}.toSeq match {
case Seq(UnresolvedStar(None))
if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct =>
// Transform COUNT(*) into COUNT(1).
Seq(Literal(1))
case expressions =>
expressions
}
val order = ctx.sortItem.asScala.map(visitSortItem)
val filter = Option(ctx.where).map(expression(_))
val ignoreNulls =
Option(ctx.nullsOption).map(_.getType == SqlBaseParser.IGNORE).getOrElse(false)
// Is this an IDENTIFIER clause instead of a function call?
if (ctx.functionName.identFunc != null &&
arguments.length == 1 && // One argument
ctx.setQuantifier == null && // No other clause
ctx.where == null &&
ctx.nullsOption == null &&
ctx.windowSpec == null) {
new ExpressionWithUnresolvedIdentifier(arguments.head, UnresolvedAttribute(_))
} else {
// It's a function call
val funcCtx = ctx.functionName
val func = withFuncIdentClause(
funcCtx,
arguments ++ filter ++ order.toSeq,
(ident, otherExprs) => {
val orderings = otherExprs.takeRight(order.size).asInstanceOf[Seq[SortOrder]]
val args = otherExprs.take(arguments.length)
val filterExpr = if (filter.isDefined) {
Some(otherExprs(args.length))
} else {
None
}
UnresolvedFunction(ident, args, isDistinct, filterExpr, ignoreNulls, orderings)
}
)
// Check if the function is evaluated in a windowed context.
ctx.windowSpec match {
case spec: WindowRefContext =>
UnresolvedWindowExpression(func, visitWindowRef(spec))
case spec: WindowDefContext =>
WindowExpression(func, visitWindowDef(spec))
case _ => func
}
}
}
/**
* Create a function database (optional) and name pair.
*/
protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = {
visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq)
}
/**
* Create a function database (optional) and name pair.
*/
private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = {
texts match {
case Seq(db, fn) => FunctionIdentifier(fn, Option(db))
case Seq(fn) => FunctionIdentifier(fn, None)
case other =>
throw QueryParsingErrors.functionNameUnsupportedError(texts.mkString("."), ctx)
}
}
protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = {
if (ctx.qualifiedName != null) {
ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq
} else {
Seq(ctx.getText)
}
}
/**
* Create an [[LambdaFunction]].
*/
override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
val arguments = ctx.identifier().asScala.map { name =>
UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
}
val function = expression(ctx.expression).transformUp {
case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
}
LambdaFunction(function, arguments.toSeq)
}
/**
* Create a reference to a window frame, i.e. [[WindowSpecReference]].
*/
override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) {
WindowSpecReference(ctx.name.getText)
}
/**
* Create a window definition, i.e. [[WindowSpecDefinition]].
*/
override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) {
// CLUSTER BY ... | PARTITION BY ... ORDER BY ...
val partition = ctx.partition.asScala.map(expression)
val order = ctx.sortItem.asScala.map(visitSortItem)
// RANGE/ROWS BETWEEN ...
val frameSpecOption = Option(ctx.windowFrame).map { frame =>
val frameType = frame.frameType.getType match {
case SqlBaseParser.RANGE => RangeFrame
case SqlBaseParser.ROWS => RowFrame
}
SpecifiedWindowFrame(
frameType,
visitFrameBound(frame.start),
Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow))
}
WindowSpecDefinition(
partition.toSeq,
order.toSeq,
frameSpecOption.getOrElse(UnspecifiedFrame))
}
/**
* Create or resolve a frame boundary expressions.
*/
override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) {
def value: Expression = {
val e = expression(ctx.expression)
validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx)
e
}
ctx.boundType.getType match {
case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null =>
UnboundedPreceding
case SqlBaseParser.PRECEDING =>
UnaryMinus(value)
case SqlBaseParser.CURRENT =>
CurrentRow
case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null =>
UnboundedFollowing
case SqlBaseParser.FOLLOWING =>
value
}
}
/**
* Create a [[CreateStruct]] expression.
*/
override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) {
CreateStruct(ctx.namedExpression().asScala.map(expression).toSeq)
}
/**
* Create a [[ScalarSubquery]] expression.
*/
override def visitSubqueryExpression(
ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) {
ScalarSubquery(plan(ctx.query))
}
/**
* Create a value based [[CaseWhen]] expression. This has the following SQL form:
* {{{
* CASE [expression]
* WHEN [value] THEN [expression]
* ...
* ELSE [expression]
* END
* }}}
*/
override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
val e = expression(ctx.value)
val branches = ctx.whenClause.asScala.map { wCtx =>
(EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
}
CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression))
}
/**
* Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
* {{{
* CASE
* WHEN [predicate] THEN [expression]
* ...
* ELSE [expression]
* END
* }}}
*
* @param ctx the parse tree
* */
override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) {
val branches = ctx.whenClause.asScala.map { wCtx =>
(expression(wCtx.condition), expression(wCtx.result))
}
CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression))
}
/**
* Currently only regex in expressions of SELECT statements are supported; in other
* places, e.g., where `(a)?+.+` = 2, regex are not meaningful.
*/
private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) {
var parent = ctx.getParent
while (parent != null) {
if (parent.isInstanceOf[NamedExpressionContext]) return true
parent = parent.getParent
}
return false
}
/**
* Returns whether the pattern is a regex expression (instead of a normal
* string). Normal string is a string with all alphabets/digits and "_".
*/
private def isRegex(pattern: String): Boolean = {
pattern.exists(p => !Character.isLetterOrDigit(p) && p != '_')
}
/**
* Create a dereference expression. The return type depends on the type of the parent.
* If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or
* a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression,
* it can be [[UnresolvedExtractValue]].
*/
override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
val attr = ctx.fieldName.getText
expression(ctx.base) match {
case unresolved_attr @ UnresolvedAttribute(nameParts) =>
ctx.fieldName.getStart.getText match {
case escapedIdentifier(columnNameRegex)
if conf.supportQuotedRegexColumnName &&
isRegex(columnNameRegex) && canApplyRegex(ctx) =>
UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name),
conf.caseSensitiveAnalysis)
case _ =>
UnresolvedAttribute(nameParts :+ attr)
}
case e =>
UnresolvedExtractValue(e, Literal(attr))
}
}
/**
* Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex
* quoted in ``
*/
override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
ctx.getStart.getText match {
case escapedIdentifier(columnNameRegex)
if conf.supportQuotedRegexColumnName &&
isRegex(columnNameRegex) && canApplyRegex(ctx) =>
UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis)
case _ =>
UnresolvedAttribute.quoted(ctx.getText)
}
}
/**
* Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array.
*/
override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) {
UnresolvedExtractValue(expression(ctx.value), expression(ctx.index))
}
/**
* Create an expression for an expression between parentheses. This is need because the ANTLR
* visitor cannot automatically convert the nested context into an expression.
*/
override def visitParenthesizedExpression(
ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) {
expression(ctx.expression)
}
/**
* Create a [[SortOrder]] expression.
*/
override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) {
val direction = if (ctx.DESC != null) {
Descending
} else {
Ascending
}
val nullOrdering = if (ctx.FIRST != null) {
NullsFirst
} else if (ctx.LAST != null) {
NullsLast
} else {
direction.defaultNullOrdering
}
SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty)
}
/**
* Create a typed Literal expression. A typed literal has the following SQL syntax:
* {{{
* [TYPE] '[VALUE]'
* }}}
* Currently Date, Timestamp, Interval and Binary typed literals are supported.
*/
override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
val value = string(visitStringLit(ctx.stringLit))
val valueType = ctx.literalType.start.getType
def toLiteral[T](f: UTF8String => Option[T], t: DataType): Literal = {
f(UTF8String.fromString(value)).map(Literal(_, t)).getOrElse {
throw QueryParsingErrors.cannotParseValueTypeError(ctx.literalType.getText, value, ctx)
}
}
def constructTimestampLTZLiteral(value: String): Literal = {
val zoneId = getZoneId(conf.sessionLocalTimeZone)
val specialTs = convertSpecialTimestamp(value, zoneId).map(Literal(_, TimestampType))
specialTs.getOrElse(toLiteral(stringToTimestamp(_, zoneId), TimestampType))
}
valueType match {
case DATE =>
val zoneId = getZoneId(conf.sessionLocalTimeZone)
val specialDate = convertSpecialDate(value, zoneId).map(Literal(_, DateType))
specialDate.getOrElse(toLiteral(stringToDate, DateType))
case TIMESTAMP_NTZ =>
convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone))
.map(Literal(_, TimestampNTZType))
.getOrElse(toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType))
case TIMESTAMP_LTZ =>
constructTimestampLTZLiteral(value)
case TIMESTAMP =>
SQLConf.get.timestampType match {
case TimestampNTZType =>
convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone))
.map(Literal(_, TimestampNTZType))
.getOrElse {
val containsTimeZonePart =
DateTimeUtils.parseTimestampString(UTF8String.fromString(value))._2.isDefined
// If the input string contains time zone part, return a timestamp with local time
// zone literal.
if (containsTimeZonePart) {
constructTimestampLTZLiteral(value)
} else {
toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType)
}
}
case TimestampType =>
constructTimestampLTZLiteral(value)
}
case INTERVAL =>
val interval = try {
IntervalUtils.stringToInterval(UTF8String.fromString(value))
} catch {
case e: IllegalArgumentException =>
val ex = QueryParsingErrors.cannotParseValueTypeError(
ctx.literalType.getText, value, ctx)
ex.setStackTrace(e.getStackTrace)
throw ex
}
if (!conf.legacyIntervalEnabled) {
val units = value
.split("\\s")
.map(_.toLowerCase(Locale.ROOT).stripSuffix("s"))
.filter(s => s != "interval" && s.matches("[a-z]+"))
constructMultiUnitsIntervalLiteral(ctx, interval, units.toImmutableArraySeq)
} else {
Literal(interval, CalendarIntervalType)
}
case BINARY_HEX =>
val padding = if (value.length % 2 != 0) "0" else ""
try {
Literal(Hex.decodeHex(padding + value))
} catch {
case e: DecoderException =>
val ex = QueryParsingErrors.cannotParseValueTypeError("X", value, ctx)
ex.setStackTrace(e.getStackTrace)
throw ex
}
case _ =>
throw QueryParsingErrors.literalValueTypeUnsupportedError(
unsupportedType = ctx.literalType.getText,
supportedTypes =
Seq("DATE", "TIMESTAMP_NTZ", "TIMESTAMP_LTZ", "TIMESTAMP", "INTERVAL", "X"),
ctx)
}
}
/**
* Create a NULL literal expression.
*/
override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) {
Literal(null)
}
/**
* Create a Boolean literal expression.
*/
override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) {
if (ctx.getText.toBoolean) {
Literal.TrueLiteral
} else {
Literal.FalseLiteral
}
}
/**
* Create an integral literal expression. The code selects the most narrow integral type
* possible, either a BigDecimal, a Long or an Integer is returned.
*/
override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) {
BigDecimal(ctx.getText) match {
case v if v.isValidInt =>
Literal(v.intValue)
case v if v.isValidLong =>
Literal(v.longValue)
case v => Literal(v.underlying())
}
}
/**
* Create a decimal literal for a regular decimal number.
*/
override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) {
Literal(BigDecimal(ctx.getText).underlying())
}
/**
* Create a decimal literal for a regular decimal number or a scientific decimal number.
*/
override def visitLegacyDecimalLiteral(
ctx: LegacyDecimalLiteralContext): Literal = withOrigin(ctx) {
Literal(BigDecimal(ctx.getText).underlying())
}
/**
* Create a double literal for number with an exponent, e.g. 1E-30
*/
override def visitExponentLiteral(ctx: ExponentLiteralContext): Literal = {
numericLiteral(ctx, ctx.getText, /* exponent values don't have a suffix */
Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble)
}
/** Create a numeric literal expression. */
private def numericLiteral(
ctx: NumberContext,
rawStrippedQualifier: String,
minValue: BigDecimal,
maxValue: BigDecimal,
typeName: String)(converter: String => Any): Literal = withOrigin(ctx) {
try {
val rawBigDecimal = BigDecimal(rawStrippedQualifier)
if (rawBigDecimal < minValue || rawBigDecimal > maxValue) {
throw QueryParsingErrors.invalidNumericLiteralRangeError(
rawStrippedQualifier, minValue, maxValue, typeName, ctx)
}
Literal(converter(rawStrippedQualifier))
} catch {
case e: NumberFormatException =>
throw new ParseException(
errorClass = "_LEGACY_ERROR_TEMP_0060",
messageParameters = Map("msg" -> e.getMessage),
ctx)
}
}
/**
* Create a Byte Literal expression.
*/
override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = {
val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
numericLiteral(ctx, rawStrippedQualifier,
Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte)
}
/**
* Create a Short Literal expression.
*/
override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = {
val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
numericLiteral(ctx, rawStrippedQualifier,
Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort)
}
/**
* Create a Long Literal expression.
*/
override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = {
val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
numericLiteral(ctx, rawStrippedQualifier,
Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong)
}
/**
* Create a Float Literal expression.
*/
override def visitFloatLiteral(ctx: FloatLiteralContext): Literal = {
val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
numericLiteral(ctx, rawStrippedQualifier,
Float.MinValue, Float.MaxValue, FloatType.simpleString)(_.toFloat)
}
/**
* Create a Double Literal expression.
*/
override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = {
val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1)
numericLiteral(ctx, rawStrippedQualifier,
Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble)
}
/**
* Create a BigDecimal Literal expression.
*/
override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = {
val raw = ctx.getText.substring(0, ctx.getText.length - 2)
try {
Literal(BigDecimal(raw).underlying())
} catch {
case e: SparkArithmeticException =>
throw new ParseException(
errorClass = e.getErrorClass,
messageParameters = e.getMessageParameters.asScala.toMap,
ctx)
}
}
/**
* Create a String literal expression.
*/
override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) {
Literal.create(createString(ctx), conf.defaultStringType)
}
/**
* Create a String from a string literal context. This supports multiple consecutive string
* literals, these are concatenated, for example this expression "'hello' 'world'" will be
* converted into "helloworld".
*
* Special characters can be escaped by using Hive/C-style escaping.
*/
private def createString(ctx: StringLiteralContext): String = {
if (conf.escapedStringLiterals) {
ctx.stringLit.asScala.map(x => stringWithoutUnescape(visitStringLit(x))).mkString
} else {
ctx.stringLit.asScala.map(x => string(visitStringLit(x))).mkString
}
}
/**
* Create an [[UnresolvedRelation]] from an identifier reference.
*/
private def createUnresolvedRelation(
ctx: IdentifierReferenceContext): LogicalPlan = withOrigin(ctx) {
withIdentClause(ctx, UnresolvedRelation(_))
}
/**
* Create an [[UnresolvedRelation]] from a multi-part identifier.
*/
private def createUnresolvedRelation(
ctx: ParserRuleContext, ident: Seq[String]): UnresolvedRelation = withOrigin(ctx) {
UnresolvedRelation(ident)
}
/**
* Create an [[UnresolvedTable]] from an identifier reference.
*/
private def createUnresolvedTable(
ctx: IdentifierReferenceContext,
commandName: String,
suggestAlternative: Boolean = false): LogicalPlan = withOrigin(ctx) {
withIdentClause(ctx, UnresolvedTable(_, commandName, suggestAlternative))
}
/**
* Create an [[UnresolvedView]] from a multi-part identifier.
*/
private def createUnresolvedView(
ctx: IdentifierReferenceContext,
commandName: String,
allowTemp: Boolean = true,
suggestAlternative: Boolean = false): LogicalPlan = withOrigin(ctx) {
withIdentClause(ctx, UnresolvedView(_, commandName, allowTemp, suggestAlternative))
}
/**
* Create an [[UnresolvedTableOrView]] from a multi-part identifier.
*/
private def createUnresolvedTableOrView(
ctx: IdentifierReferenceContext,
commandName: String,
allowTempView: Boolean = true): LogicalPlan = withOrigin(ctx) {
withIdentClause(ctx, UnresolvedTableOrView(_, commandName, allowTempView))
}
private def createUnresolvedTableOrView(
ctx: ParserRuleContext,
ident: Seq[String],
commandName: String,
allowTempView: Boolean): UnresolvedTableOrView = withOrigin(ctx) {
UnresolvedTableOrView(ident, commandName, allowTempView)
}
/**
* Create an [[UnresolvedFunction]] from a multi-part identifier.
*/
private def createUnresolvedFunctionName(
ctx: ParserRuleContext,
ident: Seq[String],
commandName: String,
requirePersistent: Boolean = false,
funcTypeMismatchHint: Option[String] = None,
possibleQualifiedName: Option[Seq[String]] = None): UnresolvedFunctionName = withOrigin(ctx) {
UnresolvedFunctionName(
ident,
commandName,
requirePersistent,
funcTypeMismatchHint,
possibleQualifiedName)
}
/**
* Construct an [[Literal]] from [[CalendarInterval]] and
* units represented as a [[Seq]] of [[String]].
*/
private def constructMultiUnitsIntervalLiteral(
ctx: ParserRuleContext,
calendarInterval: CalendarInterval,
units: Seq[String]): Literal = {
val yearMonthFields = Set.empty[Byte]
val dayTimeFields = Set.empty[Byte]
for (unit <- units) {
if (YearMonthIntervalType.stringToField.contains(unit)) {
yearMonthFields += YearMonthIntervalType.stringToField(unit)
} else if (DayTimeIntervalType.stringToField.contains(unit)) {
dayTimeFields += DayTimeIntervalType.stringToField(unit)
} else if (unit == "week") {
dayTimeFields += DayTimeIntervalType.DAY
} else {
assert(unit == "millisecond" || unit == "microsecond")
dayTimeFields += DayTimeIntervalType.SECOND
}
}
if (yearMonthFields.nonEmpty) {
if (dayTimeFields.nonEmpty) {
val literalStr = source(ctx)
throw QueryParsingErrors.mixedIntervalUnitsError(literalStr, ctx)
}
Literal(
calendarInterval.months,
YearMonthIntervalType(yearMonthFields.min, yearMonthFields.max)
)
} else {
Literal(
IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS),
DayTimeIntervalType(dayTimeFields.min, dayTimeFields.max))
}
}
/**
* Create a [[CalendarInterval]] or ANSI interval literal expression.
* Two syntaxes are supported:
* - multiple unit value pairs, for instance: interval 2 months 2 days.
* - from-to unit, for instance: interval '1-2' year to month.
*/
override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
val calendarInterval = parseIntervalLiteral(ctx)
if (ctx.errorCapturingUnitToUnitInterval != null && !conf.legacyIntervalEnabled) {
// Check the `to` unit to distinguish year-month and day-time intervals because
// `CalendarInterval` doesn't have enough info. For instance, new CalendarInterval(0, 0, 0)
// can be derived from INTERVAL '0-0' YEAR TO MONTH as well as from
// INTERVAL '0 00:00:00' DAY TO SECOND.
val fromUnit =
ctx.errorCapturingUnitToUnitInterval.body.from.getText.toLowerCase(Locale.ROOT)
val toUnit = ctx.errorCapturingUnitToUnitInterval.body.to.getText.toLowerCase(Locale.ROOT)
if (toUnit == "month") {
assert(calendarInterval.days == 0 && calendarInterval.microseconds == 0)
val start = YearMonthIntervalType.stringToField(fromUnit)
Literal(calendarInterval.months, YearMonthIntervalType(start, YearMonthIntervalType.MONTH))
} else {
assert(calendarInterval.months == 0)
val micros = IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS)
val start = DayTimeIntervalType.stringToField(fromUnit)
val end = DayTimeIntervalType.stringToField(toUnit)
Literal(micros, DayTimeIntervalType(start, end))
}
} else if (ctx.errorCapturingMultiUnitsInterval != null && !conf.legacyIntervalEnabled) {
val units =
ctx.errorCapturingMultiUnitsInterval.body.unit.asScala.map(
_.getText.toLowerCase(Locale.ROOT).stripSuffix("s")).toSeq
constructMultiUnitsIntervalLiteral(ctx, calendarInterval, units)
} else {
Literal(calendarInterval, CalendarIntervalType)
}
}
/**
* Create a [[CalendarInterval]] object
*/
protected def parseIntervalLiteral(ctx: IntervalContext): CalendarInterval = withOrigin(ctx) {
if (ctx.errorCapturingMultiUnitsInterval != null) {
val innerCtx = ctx.errorCapturingMultiUnitsInterval
if (innerCtx.unitToUnitInterval != null) {
throw QueryParsingErrors.moreThanOneFromToUnitInIntervalLiteralError(
innerCtx.unitToUnitInterval)
}
visitMultiUnitsInterval(innerCtx.multiUnitsInterval)
} else {
assert(ctx.errorCapturingUnitToUnitInterval != null)
val innerCtx = ctx.errorCapturingUnitToUnitInterval
if (innerCtx.error1 != null || innerCtx.error2 != null) {
val errorCtx = if (innerCtx.error1 != null) innerCtx.error1 else innerCtx.error2
throw QueryParsingErrors.moreThanOneFromToUnitInIntervalLiteralError(errorCtx)
}
visitUnitToUnitInterval(innerCtx.body)
}
}
/**
* Creates a [[CalendarInterval]] with multiple unit value pairs, e.g. 1 YEAR 2 DAYS.
*/
override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = {
withOrigin(ctx) {
val units = ctx.unit.asScala
val values = ctx.intervalValue().asScala
try {
assert(units.length == values.length)
val kvs = units.indices.map { i =>
val u = units(i).getText
val v = if (values(i).stringLit() != null) {
val value = string(visitStringLit(values(i).stringLit()))
// SPARK-32840: For invalid cases, e.g. INTERVAL '1 day 2' hour,
// INTERVAL 'interval 1' day, we need to check ahead before they are concatenated with
// units and become valid ones, e.g. '1 day 2 hour'.
// Ideally, we only ensure the value parts don't contain any units here.
if (value.exists(Character.isLetter)) {
throw QueryParsingErrors.invalidIntervalFormError(value, ctx)
}
if (values(i).MINUS() == null) {
value
} else if (value.startsWith("-")) {
value.replaceFirst("-", "")
} else {
s"-$value"
}
} else {
values(i).getText
}
UTF8String.fromString(" " + v + " " + u)
}
IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*))
} catch {
case st: SparkThrowable => throw st
case i: IllegalArgumentException =>
val e = new ParseException(
errorClass = "_LEGACY_ERROR_TEMP_0062",
messageParameters = Map("msg" -> i.getMessage),
ctx)
e.setStackTrace(i.getStackTrace)
throw e
}
}
}
/**
* Creates a [[CalendarInterval]] with from-to unit, e.g. '2-1' YEAR TO MONTH.
*/
override def visitUnitToUnitInterval(ctx: UnitToUnitIntervalContext): CalendarInterval = {
withOrigin(ctx) {
val value = Option(ctx.intervalValue().stringLit()).map(s => string(visitStringLit(s)))
.map { interval =>
if (ctx.intervalValue().MINUS() == null) {
interval
} else if (interval.startsWith("-")) {
interval.replaceFirst("-", "")
} else {
s"-$interval"
}
}.getOrElse {
throw QueryParsingErrors.invalidFromToUnitValueError(ctx.intervalValue)
}
try {
val from = ctx.from.getText.toLowerCase(Locale.ROOT)
val to = ctx.to.getText.toLowerCase(Locale.ROOT)
(from, to) match {
case ("year", "month") =>
IntervalUtils.fromYearMonthString(value)
case ("day", "hour") | ("day", "minute") | ("day", "second") | ("hour", "minute") |
("hour", "second") | ("minute", "second") =>
IntervalUtils.fromDayTimeString(value,
DayTimeIntervalType.stringToField(from), DayTimeIntervalType.stringToField(to))
case _ =>
throw QueryParsingErrors.fromToIntervalUnsupportedError(from, to, ctx)
}
} catch {
// Handle Exceptions thrown by CalendarInterval
case e: IllegalArgumentException =>
val pe = new ParseException(
errorClass = "_LEGACY_ERROR_TEMP_0063",
messageParameters = Map("msg" -> e.getMessage),
ctx)
pe.setStackTrace(e.getStackTrace)
throw pe
}
}
}
/* ********************************************************************************************
* DataType parsing
* ******************************************************************************************** */
/**
* Create top level table schema.
*/
protected def createSchema(ctx: CreateOrReplaceTableColTypeListContext): StructType = {
val columns = Option(ctx).toArray.flatMap(visitCreateOrReplaceTableColTypeList)
StructType(columns.map(_.toV1Column))
}
/**
* Get CREATE TABLE column definitions.
*/
override def visitCreateOrReplaceTableColTypeList(
ctx: CreateOrReplaceTableColTypeListContext): Seq[ColumnDefinition] = withOrigin(ctx) {
ctx.createOrReplaceTableColType().asScala.map(visitCreateOrReplaceTableColType).toSeq
}
/**
* Get a CREATE TABLE column definition.
*/
override def visitCreateOrReplaceTableColType(
ctx: CreateOrReplaceTableColTypeContext): ColumnDefinition = withOrigin(ctx) {
import ctx._
val name: String = colName.getText
// Check that no duplicates exist among any CREATE TABLE column options specified.
var nullable = true
var defaultExpression: Option[DefaultExpressionContext] = None
var generationExpression: Option[GenerationExpressionContext] = None
var commentSpec: Option[CommentSpecContext] = None
ctx.colDefinitionOption().asScala.foreach { option =>
if (option.NULL != null) {
blockBang(option.errorCapturingNot)
if (!nullable) {
throw QueryParsingErrors.duplicateTableColumnDescriptor(
option, name, "NOT NULL")
}
nullable = false
}
Option(option.defaultExpression()).foreach { expr =>
if (!conf.getConf(SQLConf.ENABLE_DEFAULT_COLUMNS)) {
throw QueryParsingErrors.defaultColumnNotEnabledError(ctx)
}
if (defaultExpression.isDefined) {
throw QueryParsingErrors.duplicateTableColumnDescriptor(
option, name, "DEFAULT")
}
defaultExpression = Some(expr)
}
Option(option.generationExpression()).foreach { expr =>
if (generationExpression.isDefined) {
throw QueryParsingErrors.duplicateTableColumnDescriptor(
option, name, "GENERATED ALWAYS AS")
}
generationExpression = Some(expr)
}
Option(option.commentSpec()).foreach { spec =>
if (commentSpec.isDefined) {
throw QueryParsingErrors.duplicateTableColumnDescriptor(
option, name, "COMMENT")
}
commentSpec = Some(spec)
}
}
ColumnDefinition(
name = name,
dataType = typedVisit[DataType](ctx.dataType),
nullable = nullable,
comment = commentSpec.map(visitCommentSpec),
defaultValue = defaultExpression.map(visitDefaultExpression),
generationExpression = generationExpression.map(visitGenerationExpression)
)
}
/**
* Create a location string.
*/
override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) {
string(visitStringLit(ctx.stringLit))
}
/**
* Create an optional location string.
*/
protected def visitLocationSpecList(ctx: java.util.List[LocationSpecContext]): Option[String] = {
ctx.asScala.headOption.map(visitLocationSpec)
}
private def getDefaultExpression(
exprCtx: ExpressionContext,
place: String): DefaultValueExpression = {
// Make sure it can be converted to Catalyst expressions.
val expr = expression(exprCtx)
if (expr.containsPattern(PARAMETER)) {
throw QueryParsingErrors.parameterMarkerNotAllowed(place, expr.origin)
}
// Extract the raw expression text so that we can save the user provided text. We don't
// use `Expression.sql` to avoid storing incorrect text caused by bugs in any expression's
// `sql` method. Note: `exprCtx.getText` returns a string without spaces, so we need to
// get the text from the underlying char stream instead.
val start = exprCtx.getStart.getStartIndex
val end = exprCtx.getStop.getStopIndex
val originalSQL = exprCtx.getStart.getInputStream.getText(new Interval(start, end))
DefaultValueExpression(expr, originalSQL)
}
/**
* Create `DefaultValueExpression` for a column.
*/
override def visitDefaultExpression(ctx: DefaultExpressionContext): DefaultValueExpression =
withOrigin(ctx) {
getDefaultExpression(ctx.expression(), "DEFAULT")
}
/**
* Create `DefaultValueExpression` for a SQL variable.
*/
override def visitVariableDefaultExpression(
ctx: VariableDefaultExpressionContext): DefaultValueExpression =
withOrigin(ctx) {
getDefaultExpression(ctx.expression(), "DEFAULT")
}
/**
* Create a generation expression string.
*/
override def visitGenerationExpression(ctx: GenerationExpressionContext): String =
withOrigin(ctx) {
getDefaultExpression(ctx.expression(), "GENERATED").originalSQL
}
/**
* Create an optional comment string.
*/
protected def visitCommentSpecList(ctx: java.util.List[CommentSpecContext]): Option[String] = {
ctx.asScala.headOption.map(visitCommentSpec)
}
/**
* Create a [[BucketSpec]].
*/
override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) {
BucketSpec(
ctx.INTEGER_VALUE.getText.toInt,
visitIdentifierList(ctx.identifierList),
Option(ctx.orderedIdentifierList)
.toSeq
.flatMap(_.orderedIdentifier.asScala)
.map { orderedIdCtx =>
Option(orderedIdCtx.ordering).map(_.getText).foreach { dir =>
if (dir.toLowerCase(Locale.ROOT) != "asc") {
operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx)
}
}
orderedIdCtx.ident.getText
})
}
/**
* Create a [[ClusterBySpec]].
*/
override def visitClusterBySpec(ctx: ClusterBySpecContext): ClusterBySpec = withOrigin(ctx) {
val columnNames = ctx.multipartIdentifierList.multipartIdentifier.asScala
.map(typedVisit[Seq[String]]).map(FieldReference(_)).toSeq
ClusterBySpec(columnNames)
}
/**
* Convert a property list into a key-value map.
* This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]].
*/
override def visitPropertyList(
ctx: PropertyListContext): Map[String, String] = withOrigin(ctx) {
val properties = ctx.property.asScala.map { property =>
val key = visitPropertyKey(property.key)
val value = visitPropertyValue(property.value)
key -> value
}
// Check for duplicate property names.
checkDuplicateKeys(properties.toSeq, ctx)
properties.toMap
}
/**
* Parse a key-value map from a [[PropertyListContext]], assuming all values are specified.
*/
def visitPropertyKeyValues(ctx: PropertyListContext): Map[String, String] = {
val props = visitPropertyList(ctx)
val badKeys = props.collect { case (key, null) => key }
if (badKeys.nonEmpty) {
operationNotAllowed(
s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
}
props
}
/**
* Parse a list of keys from a [[PropertyListContext]], assuming no values are specified.
*/
def visitPropertyKeys(ctx: PropertyListContext): Seq[String] = {
val props = visitPropertyList(ctx)
val badKeys = props.filter { case (_, v) => v != null }.keys
if (badKeys.nonEmpty) {
operationNotAllowed(
s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
}
props.keys.toSeq
}
/**
* A property key can either be String or a collection of dot separated elements. This
* function extracts the property key based on whether its a string literal or a property
* identifier.
*/
override def visitPropertyKey(key: PropertyKeyContext): String = {
if (key.stringLit() != null) {
string(visitStringLit(key.stringLit()))
} else {
key.getText
}
}
/**
* A property value can be String, Integer, Boolean or Decimal. This function extracts
* the property value based on whether its a string, integer, boolean or decimal literal.
*/
override def visitPropertyValue(value: PropertyValueContext): String = {
if (value == null) {
null
} else if (value.stringLit() != null) {
string(visitStringLit(value.stringLit()))
} else if (value.booleanValue != null) {
value.getText.toLowerCase(Locale.ROOT)
} else {
value.getText
}
}
/**
* Parse a key-value map from an [[ExpressionPropertyListContext]], assuming all values are
* specified.
*/
override def visitExpressionPropertyList(
ctx: ExpressionPropertyListContext): OptionList = {
val options = ctx.expressionProperty.asScala.map { property =>
val key: String = visitPropertyKey(property.key)
val value: Expression = Option(property.value).map(expression).getOrElse {
operationNotAllowed(s"A value must be specified for the key: $key.", ctx)
}
key -> value
}.toSeq
OptionList(options)
}
/**
* Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal).
*/
type TableHeader = (IdentifierReferenceContext, Boolean, Boolean, Boolean)
/**
* Type to keep track of table clauses:
* - partition transforms
* - partition columns
* - bucketSpec
* - properties
* - options
* - location
* - comment
* - serde
* - clusterBySpec
*
* Note: Partition transforms are based on existing table schema definition. It can be simple
* column names, or functions like `year(date_col)`. Partition columns are column names with data
* types like `i INT`, which should be appended to the existing table schema.
*/
type TableClauses = (
Seq[Transform], Seq[ColumnDefinition], Option[BucketSpec], Map[String, String],
OptionList, Option[String], Option[String], Option[SerdeInfo], Option[ClusterBySpec])
/**
* Validate a create table statement and return the [[TableIdentifier]].
*/
override def visitCreateTableHeader(
ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
blockBang(ctx.errorCapturingNot)
val temporary = ctx.TEMPORARY != null
val ifNotExists = ctx.EXISTS != null
if (temporary && ifNotExists) {
invalidStatement("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx)
}
(ctx.identifierReference(), temporary, ifNotExists, ctx.EXTERNAL != null)
}
/**
* Parse a qualified name to a multipart name.
*/
override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) {
ctx.identifier.asScala.map(_.getText).toSeq
}
/**
* Parse a list of transforms or columns.
*/
override def visitPartitionFieldList(
ctx: PartitionFieldListContext): (Seq[Transform], Seq[ColumnDefinition]) = withOrigin(ctx) {
val (transforms, columns) = ctx.fields.asScala.map {
case transform: PartitionTransformContext =>
(Some(visitPartitionTransform(transform)), None)
case field: PartitionColumnContext =>
val f = visitColType(field.colType)
// The parser rule of `visitColType` only supports basic column info with comment.
val col = ColumnDefinition(f.name, f.dataType, f.nullable, f.getComment())
(None, Some(col))
}.unzip
(transforms.flatten.toSeq, columns.flatten.toSeq)
}
override def visitPartitionTransform(
ctx: PartitionTransformContext): Transform = withOrigin(ctx) {
def getFieldReference(
ctx: ApplyTransformContext,
arg: V2Expression): FieldReference = {
lazy val name: String = ctx.identifier.getText
arg match {
case ref: FieldReference =>
ref
case nonRef =>
throw QueryParsingErrors.partitionTransformNotExpectedError(name, nonRef.describe, ctx)
}
}
def getSingleFieldReference(
ctx: ApplyTransformContext,
arguments: Seq[V2Expression]): FieldReference = {
lazy val name: String = ctx.identifier.getText
if (arguments.size > 1) {
throw QueryParsingErrors.wrongNumberArgumentsForTransformError(name, arguments.size, ctx)
} else if (arguments.isEmpty) {
throw SparkException.internalError(s"Not enough arguments for transform $name")
} else {
getFieldReference(ctx, arguments.head)
}
}
ctx.transform match {
case identityCtx: IdentityTransformContext =>
IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName)))
case applyCtx: ApplyTransformContext =>
val arguments = applyCtx.argument.asScala.map(visitTransformArgument).toSeq
applyCtx.identifier.getText match {
case "bucket" =>
val numBuckets: Int = arguments.head match {
case LiteralValue(shortValue, ShortType) =>
shortValue.asInstanceOf[Short].toInt
case LiteralValue(intValue, IntegerType) =>
intValue.asInstanceOf[Int]
case LiteralValue(longValue, LongType) =>
longValue.asInstanceOf[Long].toInt
case lit =>
throw QueryParsingErrors.invalidBucketsNumberError(lit.describe, applyCtx)
}
val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg))
BucketTransform(LiteralValue(numBuckets, IntegerType), fields)
case "years" =>
YearsTransform(getSingleFieldReference(applyCtx, arguments))
case "months" =>
MonthsTransform(getSingleFieldReference(applyCtx, arguments))
case "days" =>
DaysTransform(getSingleFieldReference(applyCtx, arguments))
case "hours" =>
HoursTransform(getSingleFieldReference(applyCtx, arguments))
case name =>
ApplyTransform(name, arguments)
}
}
}
/**
* Parse an argument to a transform. An argument may be a field reference (qualified name) or
* a value literal.
*/
override def visitTransformArgument(ctx: TransformArgumentContext): V2Expression = {
withOrigin(ctx) {
val reference = Option(ctx.qualifiedName)
.map(typedVisit[Seq[String]])
.map(FieldReference(_))
val literal = Option(ctx.constant)
.map(typedVisit[Literal])
.map(lit => LiteralValue(lit.value, lit.dataType))
reference.orElse(literal)
.getOrElse(throw SparkException.internalError("Invalid transform argument"))
}
}
private def cleanNamespaceProperties(
properties: Map[String, String],
ctx: ParserRuleContext): Map[String, String] = withOrigin(ctx) {
import SupportsNamespaces._
val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED)
properties.filter {
case (PROP_LOCATION, _) if !legacyOn =>
throw QueryParsingErrors.cannotCleanReservedNamespacePropertyError(
PROP_LOCATION, ctx, "please use the LOCATION clause to specify it")
case (PROP_LOCATION, _) => false
case (PROP_OWNER, _) if !legacyOn =>
throw QueryParsingErrors.cannotCleanReservedNamespacePropertyError(
PROP_OWNER, ctx, "it will be set to the current user")
case (PROP_OWNER, _) => false
case _ => true
}
}
/**
* Create a [[CreateNamespace]] command.
*
* For example:
* {{{
* CREATE NAMESPACE [IF NOT EXISTS] ns1.ns2.ns3
* create_namespace_clauses;
*
* create_namespace_clauses (order insensitive):
* [COMMENT namespace_comment]
* [LOCATION path]
* [WITH PROPERTIES (key1=val1, key2=val2, ...)]
* }}}
*/
override def visitCreateNamespace(ctx: CreateNamespaceContext): LogicalPlan = withOrigin(ctx) {
import SupportsNamespaces._
checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx)
checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
checkDuplicateClauses(ctx.PROPERTIES, "WITH PROPERTIES", ctx)
checkDuplicateClauses(ctx.DBPROPERTIES, "WITH DBPROPERTIES", ctx)
if (!ctx.PROPERTIES.isEmpty && !ctx.DBPROPERTIES.isEmpty) {
throw QueryParsingErrors.propertiesAndDbPropertiesBothSpecifiedError(ctx)
}
var properties = ctx.propertyList.asScala.headOption
.map(visitPropertyKeyValues)
.getOrElse(Map.empty)
properties = cleanNamespaceProperties(properties, ctx)
visitCommentSpecList(ctx.commentSpec()).foreach {
properties += PROP_COMMENT -> _
}
visitLocationSpecList(ctx.locationSpec()).foreach {
properties += PROP_LOCATION -> _
}
blockBang(ctx.errorCapturingNot)
CreateNamespace(
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_)),
ctx.EXISTS != null,
properties)
}
/**
* Create a [[DropNamespace]] command.
*
* For example:
* {{{
* DROP (DATABASE|SCHEMA|NAMESPACE) [IF EXISTS] ns1.ns2 [RESTRICT|CASCADE];
* }}}
*/
override def visitDropNamespace(ctx: DropNamespaceContext): LogicalPlan = withOrigin(ctx) {
DropNamespace(
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_)),
ctx.EXISTS != null,
ctx.CASCADE != null)
}
/**
* Create an [[SetNamespaceProperties]] logical plan.
*
* For example:
* {{{
* ALTER (DATABASE|SCHEMA|NAMESPACE) database
* SET (DBPROPERTIES|PROPERTIES) (property_name=property_value, ...);
* }}}
*/
override def visitSetNamespaceProperties(ctx: SetNamespacePropertiesContext): LogicalPlan = {
withOrigin(ctx) {
val properties = cleanNamespaceProperties(visitPropertyKeyValues(ctx.propertyList), ctx)
SetNamespaceProperties(
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_)),
properties)
}
}
/**
* Create an [[SetNamespaceLocation]] logical plan.
*
* For example:
* {{{
* ALTER (DATABASE|SCHEMA|NAMESPACE) namespace SET LOCATION path;
* }}}
*/
override def visitSetNamespaceLocation(ctx: SetNamespaceLocationContext): LogicalPlan = {
withOrigin(ctx) {
SetNamespaceLocation(
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_)),
visitLocationSpec(ctx.locationSpec))
}
}
/**
* Create a [[ShowNamespaces]] command.
*/
override def visitShowNamespaces(ctx: ShowNamespacesContext): LogicalPlan = withOrigin(ctx) {
val multiPart = Option(ctx.multipartIdentifier).map(visitMultipartIdentifier)
ShowNamespaces(
UnresolvedNamespace(multiPart.getOrElse(Seq.empty[String])),
Option(ctx.pattern).map(x => string(visitStringLit(x))))
}
/**
* Create a [[DescribeNamespace]].
*
* For example:
* {{{
* DESCRIBE (DATABASE|SCHEMA|NAMESPACE) [EXTENDED] database;
* }}}
*/
override def visitDescribeNamespace(ctx: DescribeNamespaceContext): LogicalPlan =
withOrigin(ctx) {
DescribeNamespace(
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_)),
ctx.EXTENDED != null)
}
def cleanTableProperties[ValueType](
ctx: ParserRuleContext, properties: Map[String, ValueType]): Map[String, ValueType] = {
import TableCatalog._
val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED)
properties.filter {
case (PROP_PROVIDER, _) if !legacyOn =>
throw QueryParsingErrors.cannotCleanReservedTablePropertyError(
PROP_PROVIDER, ctx, "please use the USING clause to specify it")
case (PROP_PROVIDER, _) => false
case (PROP_LOCATION, _) if !legacyOn =>
throw QueryParsingErrors.cannotCleanReservedTablePropertyError(
PROP_LOCATION, ctx, "please use the LOCATION clause to specify it")
case (PROP_LOCATION, _) => false
case (PROP_OWNER, _) if !legacyOn =>
throw QueryParsingErrors.cannotCleanReservedTablePropertyError(
PROP_OWNER, ctx, "it will be set to the current user")
case (PROP_OWNER, _) => false
case (PROP_EXTERNAL, _) if !legacyOn =>
throw QueryParsingErrors.cannotCleanReservedTablePropertyError(
PROP_EXTERNAL, ctx, "please use CREATE EXTERNAL TABLE")
case (PROP_EXTERNAL, _) => false
// It's safe to set whatever table comment, so we don't make it a reserved table property.
case (PROP_COMMENT, _) => true
case (k, _) =>
val isReserved = CatalogV2Util.TABLE_RESERVED_PROPERTIES.contains(k)
if (!legacyOn && isReserved) {
throw QueryParsingErrors.cannotCleanReservedTablePropertyError(
k, ctx, "please remove it from the TBLPROPERTIES list.")
}
!isReserved
}
}
def cleanTableOptions(
ctx: ParserRuleContext,
options: OptionList,
location: Option[String]): (OptionList, Option[String]) = {
var path = location
val filtered = cleanTableProperties(ctx, options.options.toMap).filter {
case (key, value) if key.equalsIgnoreCase("path") =>
val newValue: String =
if (value == null) {
""
} else value match {
case Literal(_, _: StringType) => value.toString
case _ => throw QueryCompilationErrors.optionMustBeLiteralString(key)
}
if (path.nonEmpty) {
throw QueryParsingErrors.duplicatedTablePathsFoundError(path.get, newValue, ctx)
}
path = Some(newValue)
false
case _ => true
}
(OptionList(filtered.toSeq), path)
}
/**
* Create a [[SerdeInfo]] for creating tables.
*
* Format: STORED AS (name | INPUTFORMAT input_format OUTPUTFORMAT output_format)
*/
override def visitCreateFileFormat(ctx: CreateFileFormatContext): SerdeInfo = withOrigin(ctx) {
(ctx.fileFormat, ctx.storageHandler) match {
// Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format
case (c: TableFileFormatContext, null) =>
SerdeInfo(formatClasses = Some(FormatClasses(string(visitStringLit(c.inFmt)),
string(visitStringLit(c.outFmt)))))
// Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO
case (c: GenericFileFormatContext, null) =>
SerdeInfo(storedAs = Some(c.identifier.getText))
case (null, storageHandler) =>
invalidStatement("STORED BY", ctx)
case _ =>
throw QueryParsingErrors.storedAsAndStoredByBothSpecifiedError(ctx)
}
}
/**
* Create a [[SerdeInfo]] used for creating tables.
*
* Example format:
* {{{
* SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)]
* }}}
*
* OR
*
* {{{
* DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]]
* [COLLECTION ITEMS TERMINATED BY char]
* [MAP KEYS TERMINATED BY char]
* [LINES TERMINATED BY char]
* [NULL DEFINED AS char]
* }}}
*/
def visitRowFormat(ctx: RowFormatContext): SerdeInfo = withOrigin(ctx) {
ctx match {
case serde: RowFormatSerdeContext => visitRowFormatSerde(serde)
case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited)
}
}
/**
* Create SERDE row format name and properties pair.
*/
override def visitRowFormatSerde(ctx: RowFormatSerdeContext): SerdeInfo = withOrigin(ctx) {
import ctx._
SerdeInfo(
serde = Some(string(visitStringLit(name))),
serdeProperties = Option(propertyList).map(visitPropertyKeyValues).getOrElse(Map.empty))
}
/**
* Create a delimited row format properties object.
*/
override def visitRowFormatDelimited(
ctx: RowFormatDelimitedContext): SerdeInfo = withOrigin(ctx) {
// Collect the entries if any.
def entry(key: String, value: StringLitContext): Seq[(String, String)] = {
Option(value).toSeq.map(x => key -> string(visitStringLit(x)))
}
// TODO we need proper support for the NULL format.
val entries =
entry("field.delim", ctx.fieldsTerminatedBy) ++
entry("serialization.format", ctx.fieldsTerminatedBy) ++
entry("escape.delim", ctx.escapedBy) ++
// The following typo is inherited from Hive...
entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++
entry("mapkey.delim", ctx.keysTerminatedBy) ++
Option(ctx.linesSeparatedBy).toSeq.map { token =>
val value = string(visitStringLit(token))
validate(
value == "\n",
s"LINES TERMINATED BY only supports newline '\\n' right now: $value",
ctx)
"line.delim" -> value
}
SerdeInfo(serdeProperties = entries.toMap)
}
/**
* Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT
* and STORED AS.
*
* The following are allowed. Anything else is not:
* ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE]
* ROW FORMAT DELIMITED ... STORED AS TEXTFILE
* ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ...
*/
protected def validateRowFormatFileFormat(
rowFormatCtx: RowFormatContext,
createFileFormatCtx: CreateFileFormatContext,
parentCtx: ParserRuleContext): Unit = {
if (rowFormatCtx == null || createFileFormatCtx == null) {
return
}
(rowFormatCtx, createFileFormatCtx.fileFormat) match {
case (_, ffTable: TableFileFormatContext) => // OK
case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) =>
ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match {
case ("sequencefile" | "textfile" | "rcfile") => // OK
case fmt =>
operationNotAllowed(
s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde",
parentCtx)
}
case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) =>
ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match {
case "textfile" => // OK
case fmt => operationNotAllowed(
s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx)
}
case _ =>
// should never happen
def str(ctx: ParserRuleContext): String = {
(0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ")
}
operationNotAllowed(
s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}",
parentCtx)
}
}
protected def validateRowFormatFileFormat(
rowFormatCtx: Seq[RowFormatContext],
createFileFormatCtx: Seq[CreateFileFormatContext],
parentCtx: ParserRuleContext): Unit = {
if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) {
validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx)
}
}
override def visitCreateTableClauses(ctx: CreateTableClausesContext): TableClauses = {
checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx)
checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx)
checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx)
checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx)
checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx)
checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx)
checkDuplicateClauses(ctx.clusterBySpec(), "CLUSTER BY", ctx)
checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
if (ctx.skewSpec.size > 0) {
invalidStatement("CREATE TABLE ... SKEWED BY", ctx)
}
val (partTransforms, partCols) =
Option(ctx.partitioning).map(visitPartitionFieldList).getOrElse((Nil, Nil))
val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec)
val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
val cleanedProperties = cleanTableProperties(ctx, properties)
val options = Option(ctx.options).map(visitExpressionPropertyList)
.getOrElse(OptionList(Seq.empty))
val location = visitLocationSpecList(ctx.locationSpec())
val (cleanedOptions, newLocation) = cleanTableOptions(ctx, options, location)
val comment = visitCommentSpecList(ctx.commentSpec())
val serdeInfo =
getSerdeInfo(ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx)
val clusterBySpec = ctx.clusterBySpec().asScala.headOption.map(visitClusterBySpec)
if (clusterBySpec.isDefined) {
if (partCols.nonEmpty || partTransforms.nonEmpty) {
throw QueryParsingErrors.clusterByWithPartitionedBy(ctx)
}
if (bucketSpec.isDefined) {
throw QueryParsingErrors.clusterByWithBucketing(ctx)
}
}
(partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment,
serdeInfo, clusterBySpec)
}
protected def getSerdeInfo(
rowFormatCtx: Seq[RowFormatContext],
createFileFormatCtx: Seq[CreateFileFormatContext],
ctx: ParserRuleContext): Option[SerdeInfo] = {
validateRowFormatFileFormat(rowFormatCtx, createFileFormatCtx, ctx)
val rowFormatSerdeInfo = rowFormatCtx.map(visitRowFormat)
val fileFormatSerdeInfo = createFileFormatCtx.map(visitCreateFileFormat)
(fileFormatSerdeInfo ++ rowFormatSerdeInfo).reduceLeftOption((l, r) => l.merge(r))
}
private def partitionExpressions(
partTransforms: Seq[Transform],
partCols: Seq[ColumnDefinition],
ctx: ParserRuleContext): Seq[Transform] = {
if (partTransforms.nonEmpty) {
if (partCols.nonEmpty) {
val references = partTransforms.map(_.describe()).mkString(", ")
val columns = partCols
.map(column => s"${column.name} ${column.dataType.simpleString}")
.mkString(", ")
operationNotAllowed(
s"""PARTITION BY: Cannot mix partition expressions and partition columns:
|Expressions: $references
|Columns: $columns""".stripMargin, ctx)
}
partTransforms
} else {
// columns were added to create the schema. convert to column references
partCols.map { column =>
IdentityTransform(FieldReference(Seq(column.name)))
}
}
}
/**
* Create a table, returning a [[CreateTable]] or [[CreateTableAsSelect]] logical plan.
*
* Expected format:
* {{{
* CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name
* [USING table_provider]
* create_table_clauses
* [[AS] select_statement];
*
* create_table_clauses (order insensitive):
* [PARTITIONED BY (partition_fields)]
* [OPTIONS table_property_list]
* [ROW FORMAT row_format]
* [STORED AS file_format]
* [CLUSTER BY (col_name, col_name, ...)]
* [CLUSTERED BY (col_name, col_name, ...)
* [SORTED BY (col_name [ASC|DESC], ...)]
* INTO num_buckets BUCKETS
* ]
* [LOCATION path]
* [COMMENT table_comment]
* [TBLPROPERTIES (property_name=property_value, ...)]
*
* partition_fields:
* col_name, transform(col_name), transform(constant, col_name), ... |
* col_name data_type [NOT NULL] [COMMENT col_comment], ...
* }}}
*/
override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) {
val (identifierContext, temp, ifNotExists, external) =
visitCreateTableHeader(ctx.createTableHeader)
val columns = Option(ctx.createOrReplaceTableColTypeList())
.map(visitCreateOrReplaceTableColTypeList).getOrElse(Nil)
val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText)
val (partTransforms, partCols, bucketSpec, properties, options, location,
comment, serdeInfo, clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses())
if (provider.isDefined && serdeInfo.isDefined) {
invalidStatement(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx)
}
if (temp) {
val asSelect = if (ctx.query == null) "" else " AS ..."
operationNotAllowed(
s"CREATE TEMPORARY TABLE ...$asSelect, use CREATE TEMPORARY VIEW instead", ctx)
}
val partitioning =
partitionExpressions(partTransforms, partCols, ctx) ++
bucketSpec.map(_.asTransform) ++
clusterBySpec.map(_.asTransform)
val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment,
serdeInfo, external)
Option(ctx.query).map(plan) match {
case Some(_) if columns.nonEmpty =>
operationNotAllowed(
"Schema may not be specified in a Create Table As Select (CTAS) statement",
ctx)
case Some(_) if partCols.nonEmpty =>
// non-reference partition columns are not allowed because schema can't be specified
operationNotAllowed(
"Partition column types may not be specified in Create Table As Select (CTAS)",
ctx)
case Some(query) =>
CreateTableAsSelect(withIdentClause(identifierContext, UnresolvedIdentifier(_)),
partitioning, query, tableSpec, Map.empty, ifNotExists)
case _ =>
// Note: table schema includes both the table columns list and the partition columns
// with data type.
val allColumns = columns ++ partCols
CreateTable(
withIdentClause(identifierContext, UnresolvedIdentifier(_)),
allColumns, partitioning, tableSpec, ignoreIfExists = ifNotExists)
}
}
/**
* Replace a table, returning a [[ReplaceTable]] or [[ReplaceTableAsSelect]]
* logical plan.
*
* Expected format:
* {{{
* [CREATE OR] REPLACE TABLE [db_name.]table_name
* [USING table_provider]
* replace_table_clauses
* [[AS] select_statement];
*
* replace_table_clauses (order insensitive):
* [OPTIONS table_property_list]
* [PARTITIONED BY (partition_fields)]
* [CLUSTER BY (col_name, col_name, ...)]
* [CLUSTERED BY (col_name, col_name, ...)
* [SORTED BY (col_name [ASC|DESC], ...)]
* INTO num_buckets BUCKETS
* ]
* [LOCATION path]
* [COMMENT table_comment]
* [TBLPROPERTIES (property_name=property_value, ...)]
*
* partition_fields:
* col_name, transform(col_name), transform(constant, col_name), ... |
* col_name data_type [NOT NULL] [COMMENT col_comment], ...
* }}}
*/
override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) {
val orCreate = ctx.replaceTableHeader().CREATE() != null
val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo,
clusterBySpec) = visitCreateTableClauses(ctx.createTableClauses())
val columns = Option(ctx.createOrReplaceTableColTypeList())
.map(visitCreateOrReplaceTableColTypeList).getOrElse(Nil)
val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText)
if (provider.isDefined && serdeInfo.isDefined) {
invalidStatement(s"REPLACE TABLE ... USING ... ${serdeInfo.get.describe}", ctx)
}
val partitioning =
partitionExpressions(partTransforms, partCols, ctx) ++
bucketSpec.map(_.asTransform) ++
clusterBySpec.map(_.asTransform)
val tableSpec = UnresolvedTableSpec(properties, provider, options, location, comment,
serdeInfo, external = false)
Option(ctx.query).map(plan) match {
case Some(_) if columns.nonEmpty =>
operationNotAllowed(
"Schema may not be specified in a Replace Table As Select (RTAS) statement",
ctx)
case Some(_) if partCols.nonEmpty =>
// non-reference partition columns are not allowed because schema can't be specified
operationNotAllowed(
"Partition column types may not be specified in Replace Table As Select (RTAS)",
ctx)
case Some(query) =>
ReplaceTableAsSelect(
withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)),
partitioning, query, tableSpec, writeOptions = Map.empty, orCreate = orCreate)
case _ =>
// Note: table schema includes both the table columns list and the partition columns
// with data type.
val allColumns = columns ++ partCols
ReplaceTable(
withIdentClause(ctx.replaceTableHeader.identifierReference(), UnresolvedIdentifier(_)),
allColumns, partitioning, tableSpec, orCreate = orCreate)
}
}
/**
* Create a [[DropTable]] command.
*/
override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) {
// DROP TABLE works with either a table or a temporary view.
DropTable(
withIdentClause(ctx.identifierReference, UnresolvedIdentifier(_, allowTemp = true)),
ctx.EXISTS != null,
ctx.PURGE != null)
}
/**
* Create a [[DropView]] command.
*/
override def visitDropView(ctx: DropViewContext): AnyRef = withOrigin(ctx) {
DropView(
withIdentClause(ctx.identifierReference, UnresolvedIdentifier(_, allowTemp = true)),
ctx.EXISTS != null)
}
/**
* Create a [[SetCatalogAndNamespace]] command.
*/
override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) {
SetCatalogAndNamespace(withIdentClause(ctx.identifierReference, UnresolvedNamespace(_)))
}
/**
* Create a [[ShowTables]] command.
*/
override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) {
val ns = if (ctx.identifierReference() != null) {
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_))
} else {
CurrentNamespace
}
ShowTables(ns, Option(ctx.pattern).map(x => string(visitStringLit(x))))
}
/**
* Create a [[ShowTablesExtended]] or [[ShowTablePartition]] command.
*/
override def visitShowTableExtended(
ctx: ShowTableExtendedContext): LogicalPlan = withOrigin(ctx) {
Option(ctx.partitionSpec).map { spec =>
val table = withOrigin(ctx.pattern) {
if (ctx.identifierReference() != null) {
withIdentClause(ctx.identifierReference(), ns => {
val names = ns :+ string(visitStringLit(ctx.pattern))
UnresolvedTable(names, "SHOW TABLE EXTENDED ... PARTITION ...")
})
} else {
val names = Seq.empty[String] :+ string(visitStringLit(ctx.pattern))
UnresolvedTable(names, "SHOW TABLE EXTENDED ... PARTITION ...")
}
}
ShowTablePartition(table, UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(spec)))
}.getOrElse {
val ns = if (ctx.identifierReference() != null) {
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_))
} else {
CurrentNamespace
}
ShowTablesExtended(ns, string(visitStringLit(ctx.pattern)))
}
}
/**
* Create a [[ShowViews]] command.
*/
override def visitShowViews(ctx: ShowViewsContext): LogicalPlan = withOrigin(ctx) {
val ns = if (ctx.identifierReference() != null) {
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_))
} else {
CurrentNamespace
}
ShowViews(ns, Option(ctx.pattern).map(x => string(visitStringLit(x))))
}
override def visitColPosition(ctx: ColPositionContext): ColumnPosition = {
ctx.position.getType match {
case SqlBaseParser.FIRST => ColumnPosition.first()
case SqlBaseParser.AFTER => ColumnPosition.after(ctx.afterCol.getText)
}
}
/**
* Parse new column info from ADD COLUMN into a QualifiedColType.
*/
override def visitQualifiedColTypeWithPosition(
ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) {
val name = typedVisit[Seq[String]](ctx.name)
// Check that no duplicates exist among any ALTER TABLE ADD|REPLACE column options specified.
var nullable = true
var defaultExpression: Option[DefaultExpressionContext] = None
var commentSpec: Option[CommentSpecContext] = None
var colPosition: Option[ColPositionContext] = None
val columnName = name.last
ctx.colDefinitionDescriptorWithPosition.asScala.foreach { option =>
blockBang(option.errorCapturingNot)
if (option.NULL != null) {
blockBang(option.errorCapturingNot)
if (!nullable) {
throw QueryParsingErrors.duplicateTableColumnDescriptor(
option, columnName, "NOT NULL", isCreate = false)
}
nullable = false
}
Option(option.defaultExpression()).foreach { expr =>
if (defaultExpression.isDefined) {
throw QueryParsingErrors.duplicateTableColumnDescriptor(
option, columnName, "DEFAULT", isCreate = false)
}
defaultExpression = Some(expr)
}
Option(option.commentSpec()).foreach { spec =>
if (commentSpec.isDefined) {
throw QueryParsingErrors.duplicateTableColumnDescriptor(
option, columnName, "COMMENT", isCreate = false)
}
commentSpec = Some(spec)
}
Option(option.colPosition()).foreach { spec =>
if (colPosition.isDefined) {
throw QueryParsingErrors.duplicateTableColumnDescriptor(
option, columnName, "FIRST|AFTER", isCreate = false)
}
colPosition = Some(spec)
}
}
// Add the 'DEFAULT expression' clause in the column definition, if any, to the column metadata.
val defaultExpr = defaultExpression.map(visitDefaultExpression).map { field =>
if (conf.getConf(SQLConf.ENABLE_DEFAULT_COLUMNS)) {
field.originalSQL
} else {
throw QueryParsingErrors.defaultColumnNotEnabledError(ctx)
}
}
QualifiedColType(
path = if (name.length > 1) Some(UnresolvedFieldName(name.init)) else None,
colName = name.last,
dataType = typedVisit[DataType](ctx.dataType),
nullable = nullable,
comment = commentSpec.map(visitCommentSpec),
position = colPosition.map( pos =>
UnresolvedFieldPosition(typedVisit[ColumnPosition](pos))),
default = defaultExpr)
}
/**
* Parse a [[AlterTableAddColumns]] command.
*
* For example:
* {{{
* ALTER TABLE table1
* ADD COLUMNS (col_name data_type [COMMENT col_comment], ...);
* }}}
*/
override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) {
val colToken = if (ctx.COLUMN() != null) "COLUMN" else "COLUMNS"
AddColumns(
createUnresolvedTable(ctx.identifierReference, s"ALTER TABLE ... ADD $colToken"),
ctx.columns.qualifiedColTypeWithPosition.asScala.map(typedVisit[QualifiedColType]).toSeq
)
}
/**
* Parse a [[AlterTableRenameColumn]] command.
*
* For example:
* {{{
* ALTER TABLE table1 RENAME COLUMN a.b.c TO x
* }}}
*/
override def visitRenameTableColumn(
ctx: RenameTableColumnContext): LogicalPlan = withOrigin(ctx) {
RenameColumn(
createUnresolvedTable(ctx.table, "ALTER TABLE ... RENAME COLUMN"),
UnresolvedFieldName(typedVisit[Seq[String]](ctx.from)),
ctx.to.getText)
}
/**
* Parse a [[AlterTableAlterColumn]] command to alter a column's property.
*
* For example:
* {{{
* ALTER TABLE table1 ALTER COLUMN a.b.c TYPE bigint
* ALTER TABLE table1 ALTER COLUMN a.b.c SET NOT NULL
* ALTER TABLE table1 ALTER COLUMN a.b.c DROP NOT NULL
* ALTER TABLE table1 ALTER COLUMN a.b.c COMMENT 'new comment'
* ALTER TABLE table1 ALTER COLUMN a.b.c FIRST
* ALTER TABLE table1 ALTER COLUMN a.b.c AFTER x
* }}}
*/
override def visitAlterTableAlterColumn(
ctx: AlterTableAlterColumnContext): LogicalPlan = withOrigin(ctx) {
val action = ctx.alterColumnAction
val verb = if (ctx.CHANGE != null) "CHANGE" else "ALTER"
if (action == null) {
operationNotAllowed(
s"ALTER TABLE table $verb COLUMN requires a TYPE, a SET/DROP, a COMMENT, or a FIRST/AFTER",
ctx)
}
val dataType = if (action.dataType != null) {
Some(typedVisit[DataType](action.dataType))
} else {
None
}
val nullable = if (action.setOrDrop != null) {
action.setOrDrop.getType match {
case SqlBaseParser.SET => Some(false)
case SqlBaseParser.DROP => Some(true)
}
} else {
None
}
val comment = if (action.commentSpec != null) {
Some(visitCommentSpec(action.commentSpec()))
} else {
None
}
val position = if (action.colPosition != null) {
Some(UnresolvedFieldPosition(typedVisit[ColumnPosition](action.colPosition)))
} else {
None
}
val setDefaultExpression: Option[String] =
if (action.defaultExpression != null) {
Option(action.defaultExpression()).map(visitDefaultExpression).map(_.originalSQL)
} else if (action.dropDefault != null) {
Some("")
} else {
None
}
if (setDefaultExpression.isDefined && !conf.getConf(SQLConf.ENABLE_DEFAULT_COLUMNS)) {
throw QueryParsingErrors.defaultColumnNotEnabledError(ctx)
}
assert(Seq(dataType, nullable, comment, position, setDefaultExpression)
.count(_.nonEmpty) == 1)
AlterColumn(
createUnresolvedTable(ctx.table, s"ALTER TABLE ... $verb COLUMN"),
UnresolvedFieldName(typedVisit[Seq[String]](ctx.column)),
dataType = dataType,
nullable = nullable,
comment = comment,
position = position,
setDefaultExpression = setDefaultExpression)
}
/**
* Parse a [[AlterTableAlterColumn]] command. This is Hive SQL syntax.
*
* For example:
* {{{
* ALTER TABLE table [PARTITION partition_spec]
* CHANGE [COLUMN] column_old_name column_new_name column_dataType [COMMENT column_comment]
* [FIRST | AFTER column_name];
* }}}
*/
override def visitHiveChangeColumn(ctx: HiveChangeColumnContext): LogicalPlan = withOrigin(ctx) {
if (ctx.partitionSpec != null) {
invalidStatement("ALTER TABLE ... PARTITION ... CHANGE COLUMN", ctx)
}
val columnNameParts = typedVisit[Seq[String]](ctx.colName)
if (!conf.resolver(columnNameParts.last, ctx.colType().colName.getText)) {
throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError("Renaming column",
"ALTER COLUMN", ctx, Some("please run RENAME COLUMN instead"))
}
if (ctx.colType.NULL != null) {
throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError(
"NOT NULL", "ALTER COLUMN", ctx,
Some("please run ALTER COLUMN ... SET/DROP NOT NULL instead"))
}
AlterColumn(
createUnresolvedTable(ctx.table, "ALTER TABLE ... CHANGE COLUMN"),
UnresolvedFieldName(columnNameParts),
dataType = Option(ctx.colType().dataType()).map(typedVisit[DataType]),
nullable = None,
comment = Option(ctx.colType().commentSpec()).map(visitCommentSpec),
position = Option(ctx.colPosition).map(
pos => UnresolvedFieldPosition(typedVisit[ColumnPosition](pos))),
setDefaultExpression = None)
}
override def visitHiveReplaceColumns(
ctx: HiveReplaceColumnsContext): LogicalPlan = withOrigin(ctx) {
if (ctx.partitionSpec != null) {
invalidStatement("ALTER TABLE ... PARTITION ... REPLACE COLUMNS", ctx)
}
ReplaceColumns(
createUnresolvedTable(ctx.table, "ALTER TABLE ... REPLACE COLUMNS"),
ctx.columns.qualifiedColTypeWithPosition.asScala.map { colType =>
val name = typedVisit[Seq[String]](colType.name)
if (name.length > 1) {
throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError(
"Replacing with a nested column", "REPLACE COLUMNS", ctx)
}
var commentSpec: Option[CommentSpecContext] = None
colType.colDefinitionDescriptorWithPosition.asScala.foreach { opt =>
blockBang(opt.errorCapturingNot)
if (opt.NULL != null) {
throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError(
"NOT NULL", "REPLACE COLUMNS", ctx)
}
if (opt.colPosition != null) {
throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError(
"Column position", "REPLACE COLUMNS", ctx)
}
if (Option(opt.defaultExpression()).map(visitDefaultExpression).isDefined) {
throw QueryParsingErrors.defaultColumnNotImplementedYetError(ctx)
}
Option(opt.commentSpec()).foreach { spec =>
if (commentSpec.isDefined) {
throw QueryParsingErrors.duplicateTableColumnDescriptor(
opt, name.last, "COMMENT", isCreate = false, alterType = "REPLACE")
}
commentSpec = Some(spec)
}
}
QualifiedColType(
path = None,
colName = name.last,
dataType = typedVisit[DataType](colType.dataType),
nullable = true,
comment = commentSpec.map(visitCommentSpec),
position = None,
default = None)
}.toSeq
)
}
/**
* Parse a [[AlterTableDropColumns]] command.
*
* For example:
* {{{
* ALTER TABLE table1 DROP COLUMN a.b.c
* ALTER TABLE table1 DROP COLUMNS a.b.c, x, y
* }}}
*/
override def visitDropTableColumns(
ctx: DropTableColumnsContext): LogicalPlan = withOrigin(ctx) {
val ifExists = ctx.EXISTS() != null
val columnsToDrop = ctx.columns.multipartIdentifier.asScala.map(typedVisit[Seq[String]])
DropColumns(
createUnresolvedTable(ctx.identifierReference, "ALTER TABLE ... DROP COLUMNS"),
columnsToDrop.map(UnresolvedFieldName(_)).toSeq,
ifExists)
}
/**
* Parse [[SetViewProperties]] or [[SetTableProperties]] commands.
*
* For example:
* {{{
* ALTER TABLE table SET TBLPROPERTIES ('table_property' = 'property_value');
* ALTER VIEW view SET TBLPROPERTIES ('table_property' = 'property_value');
* }}}
*/
override def visitSetTableProperties(
ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) {
val properties = visitPropertyKeyValues(ctx.propertyList)
val cleanedTableProperties = cleanTableProperties(ctx, properties)
if (ctx.VIEW != null) {
SetViewProperties(
createUnresolvedView(
ctx.identifierReference,
commandName = "ALTER VIEW ... SET TBLPROPERTIES",
allowTemp = false,
suggestAlternative = true),
cleanedTableProperties)
} else {
SetTableProperties(
createUnresolvedTable(
ctx.identifierReference,
"ALTER TABLE ... SET TBLPROPERTIES",
true),
cleanedTableProperties)
}
}
/**
* Parse [[UnsetViewProperties]] or [[UnsetTableProperties]] commands.
*
* For example:
* {{{
* ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key');
* ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key');
* }}}
*/
override def visitUnsetTableProperties(
ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) {
val properties = visitPropertyKeys(ctx.propertyList)
val cleanedProperties = cleanTableProperties(ctx, properties.map(_ -> "").toMap).keys.toSeq
val ifExists = ctx.EXISTS != null
if (ctx.VIEW != null) {
UnsetViewProperties(
createUnresolvedView(
ctx.identifierReference,
commandName = "ALTER VIEW ... UNSET TBLPROPERTIES",
allowTemp = false,
suggestAlternative = true),
cleanedProperties,
ifExists)
} else {
UnsetTableProperties(
createUnresolvedTable(
ctx.identifierReference,
"ALTER TABLE ... UNSET TBLPROPERTIES",
true),
cleanedProperties,
ifExists)
}
}
/**
* Create an [[SetTableLocation]] command.
*
* For example:
* {{{
* ALTER TABLE table_name [PARTITION partition_spec] SET LOCATION "loc";
* }}}
*/
override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) {
SetTableLocation(
createUnresolvedTable(
ctx.identifierReference,
"ALTER TABLE ... SET LOCATION ..."),
Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
visitLocationSpec(ctx.locationSpec))
}
/**
* Create a [[DescribeColumn]] or [[DescribeRelation]] commands.
*/
override def visitDescribeRelation(ctx: DescribeRelationContext): LogicalPlan = withOrigin(ctx) {
val isExtended = ctx.EXTENDED != null || ctx.FORMATTED != null
val relation = createUnresolvedTableOrView(ctx.identifierReference, "DESCRIBE TABLE")
if (ctx.describeColName != null) {
if (ctx.partitionSpec != null) {
throw QueryParsingErrors.descColumnForPartitionUnsupportedError(ctx)
} else {
DescribeColumn(
relation,
UnresolvedAttribute(ctx.describeColName.nameParts.asScala.map(_.getText).toSeq),
isExtended)
}
} else {
val partitionSpec = if (ctx.partitionSpec != null) {
// According to the syntax, visitPartitionSpec returns `Map[String, Option[String]]`.
visitPartitionSpec(ctx.partitionSpec).map {
case (key, Some(value)) => key -> value
case (key, _) =>
throw QueryParsingErrors.emptyPartitionKeyError(key, ctx.partitionSpec)
}
} else {
Map.empty[String, String]
}
DescribeRelation(relation, partitionSpec, isExtended)
}
}
/**
* Create an [[AnalyzeTable]], or an [[AnalyzeColumn]].
* Example SQL for analyzing a table or a set of partitions :
* {{{
* ANALYZE TABLE multi_part_name [PARTITION (partcol1[=val1], partcol2[=val2], ...)]
* COMPUTE STATISTICS [NOSCAN];
* }}}
*
* Example SQL for analyzing columns :
* {{{
* ANALYZE TABLE multi_part_name COMPUTE STATISTICS FOR COLUMNS column1, column2;
* }}}
*
* Example SQL for analyzing all columns of a table:
* {{{
* ANALYZE TABLE multi_part_name COMPUTE STATISTICS FOR ALL COLUMNS;
* }}}
*/
override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) {
def checkPartitionSpec(): Unit = {
if (ctx.partitionSpec != null) {
logWarning(
log"Partition specification is ignored when collecting column statistics: " +
log"${MDC(PARTITION_SPECIFICATION, ctx.partitionSpec.getText)}")
}
}
if (ctx.identifier != null &&
ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") {
throw QueryParsingErrors.computeStatisticsNotExpectedError(ctx.identifier())
}
if (ctx.ALL() != null) {
checkPartitionSpec()
AnalyzeColumn(
createUnresolvedTableOrView(ctx.identifierReference, "ANALYZE TABLE ... FOR ALL COLUMNS"),
None,
allColumns = true)
} else if (ctx.identifierSeq() == null) {
val partitionSpec = if (ctx.partitionSpec != null) {
visitPartitionSpec(ctx.partitionSpec)
} else {
Map.empty[String, Option[String]]
}
AnalyzeTable(
createUnresolvedTableOrView(
ctx.identifierReference,
"ANALYZE TABLE",
allowTempView = false),
partitionSpec,
noScan = ctx.identifier != null)
} else {
checkPartitionSpec()
AnalyzeColumn(
createUnresolvedTableOrView(ctx.identifierReference, "ANALYZE TABLE ... FOR COLUMNS ..."),
Option(visitIdentifierSeq(ctx.identifierSeq())),
allColumns = false)
}
}
/**
* Create an [[AnalyzeTables]].
* Example SQL for analyzing all tables in default database:
* {{{
* ANALYZE TABLES IN default COMPUTE STATISTICS;
* }}}
*/
override def visitAnalyzeTables(ctx: AnalyzeTablesContext): LogicalPlan = withOrigin(ctx) {
if (ctx.identifier != null &&
ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") {
throw QueryParsingErrors.computeStatisticsNotExpectedError(ctx.identifier())
}
val ns = if (ctx.identifierReference() != null) {
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_))
} else {
CurrentNamespace
}
AnalyzeTables(ns, noScan = ctx.identifier != null)
}
/**
* Create a [[RepairTable]].
*
* For example:
* {{{
* [MSCK] REPAIR TABLE multi_part_name [{ADD|DROP|SYNC} PARTITIONS]
* }}}
*/
override def visitRepairTable(ctx: RepairTableContext): LogicalPlan = withOrigin(ctx) {
val (enableAddPartitions, enableDropPartitions, option) =
if (ctx.SYNC() != null) {
(true, true, " ... SYNC PARTITIONS")
} else if (ctx.DROP() != null) {
(false, true, " ... DROP PARTITIONS")
} else if (ctx.ADD() != null) {
(true, false, " ... ADD PARTITIONS")
} else {
(true, false, "")
}
RepairTable(
createUnresolvedTable(ctx.identifierReference, s"MSCK REPAIR TABLE$option"),
enableAddPartitions,
enableDropPartitions)
}
/**
* Create a [[LoadData]].
*
* For example:
* {{{
* LOAD DATA [LOCAL] INPATH 'filepath' [OVERWRITE] INTO TABLE multi_part_name
* [PARTITION (partcol1=val1, partcol2=val2 ...)]
* }}}
*/
override def visitLoadData(ctx: LoadDataContext): LogicalPlan = withOrigin(ctx) {
LoadData(
child = createUnresolvedTable(ctx.identifierReference, "LOAD DATA"),
path = string(visitStringLit(ctx.path)),
isLocal = ctx.LOCAL != null,
isOverwrite = ctx.OVERWRITE != null,
partition = Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)
)
}
/**
* Creates a [[ShowCreateTable]]
*/
override def visitShowCreateTable(ctx: ShowCreateTableContext): LogicalPlan = withOrigin(ctx) {
ShowCreateTable(
createUnresolvedTableOrView(
ctx.identifierReference,
"SHOW CREATE TABLE",
allowTempView = false),
ctx.SERDE != null)
}
/**
* Create a [[CacheTable]] or [[CacheTableAsSelect]].
*
* For example:
* {{{
* CACHE [LAZY] TABLE multi_part_name
* [OPTIONS tablePropertyList] [[AS] query]
* }}}
*/
override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
val query = Option(ctx.query).map(plan)
withIdentClause(ctx.identifierReference, ident => {
if (query.isDefined && ident.length > 1) {
val catalogAndNamespace = ident.init
throw QueryParsingErrors.addCatalogInCacheTableAsSelectNotAllowedError(
catalogAndNamespace.quoted, ctx)
}
val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
val isLazy = ctx.LAZY != null
if (query.isDefined) {
CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options)
} else {
CacheTable(createUnresolvedRelation(ctx.identifierReference, ident), ident, isLazy, options)
}
})
}
/**
* Create an [[UncacheTable]] logical plan.
*/
override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) {
UncacheTable(createUnresolvedRelation(ctx.identifierReference), ctx.EXISTS != null)
}
/**
* Create a [[TruncateTable]] command.
*
* For example:
* {{{
* TRUNCATE TABLE multi_part_name [PARTITION (partcol1=val1, partcol2=val2 ...)]
* }}}
*/
override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) {
val table = createUnresolvedTable(ctx.identifierReference, "TRUNCATE TABLE")
Option(ctx.partitionSpec).map { spec =>
TruncatePartition(table, UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(spec)))
}.getOrElse(TruncateTable(table))
}
/**
* A command for users to list the partition names of a table. If partition spec is specified,
* partitions that match the spec are returned. Otherwise an empty result set is returned.
*
* This function creates a [[ShowPartitionsStatement]] logical plan
*
* The syntax of using this command in SQL is:
* {{{
* SHOW PARTITIONS multi_part_name [partition_spec];
* }}}
*/
override def visitShowPartitions(ctx: ShowPartitionsContext): LogicalPlan = withOrigin(ctx) {
val partitionKeys = Option(ctx.partitionSpec).map { specCtx =>
UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(specCtx), None)
}
ShowPartitions(
createUnresolvedTable(ctx.identifierReference, "SHOW PARTITIONS"),
partitionKeys)
}
/**
* Create a [[RefreshTable]].
*
* For example:
* {{{
* REFRESH TABLE multi_part_name
* }}}
*/
override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) {
RefreshTable(createUnresolvedTableOrView(ctx.identifierReference, "REFRESH TABLE"))
}
/**
* A command for users to list the column names for a table.
* This function creates a [[ShowColumns]] logical plan.
*
* The syntax of using this command in SQL is:
* {{{
* SHOW COLUMNS (FROM | IN) tableName=multipartIdentifier
* ((FROM | IN) namespace=multipartIdentifier)?
* }}}
*/
override def visitShowColumns(ctx: ShowColumnsContext): LogicalPlan = withOrigin(ctx) {
withIdentClause(ctx.table, ident => {
val table = createUnresolvedTableOrView(
ctx.table,
ident,
"SHOW COLUMNS",
allowTempView = true)
val namespace = Option(ctx.ns).map(visitMultipartIdentifier)
// Use namespace only if table name doesn't specify it. If namespace is already specified
// in the table name, it's checked against the given namespace after table/view is resolved.
val tableWithNamespace = if (namespace.isDefined && table.multipartIdentifier.length == 1) {
CurrentOrigin.withOrigin(table.origin) {
table.copy(multipartIdentifier = namespace.get ++ table.multipartIdentifier)
}
} else {
table
}
ShowColumns(tableWithNamespace, namespace)
})
}
/**
* Create an [[RecoverPartitions]]
*
* For example:
* {{{
* ALTER TABLE multi_part_name RECOVER PARTITIONS;
* }}}
*/
override def visitRecoverPartitions(
ctx: RecoverPartitionsContext): LogicalPlan = withOrigin(ctx) {
RecoverPartitions(
createUnresolvedTable(
ctx.identifierReference,
"ALTER TABLE ... RECOVER PARTITIONS"))
}
/**
* Create an [[AddPartitions]].
*
* For example:
* {{{
* ALTER TABLE multi_part_name ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1']
* ALTER VIEW multi_part_name ADD [IF NOT EXISTS] PARTITION spec
* }}}
*
* ALTER VIEW ... ADD PARTITION ... is not supported because the concept of partitioning
* is associated with physical tables
*/
override def visitAddTablePartition(
ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) {
if (ctx.VIEW != null) {
invalidStatement("ALTER VIEW ... ADD PARTITION", ctx)
}
// Create partition spec to location mapping.
val specsAndLocs = ctx.partitionSpecLocation.asScala.map { splCtx =>
val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec)
val location = Option(splCtx.locationSpec).map(visitLocationSpec)
UnresolvedPartitionSpec(spec, location)
}
blockBang(ctx.errorCapturingNot)
AddPartitions(
createUnresolvedTable(
ctx.identifierReference,
"ALTER TABLE ... ADD PARTITION ..."),
specsAndLocs.toSeq,
ctx.EXISTS != null)
}
/**
* Create an [[RenamePartitions]]
*
* For example:
* {{{
* ALTER TABLE multi_part_name PARTITION spec1 RENAME TO PARTITION spec2;
* }}}
*/
override def visitRenameTablePartition(
ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) {
RenamePartitions(
createUnresolvedTable(
ctx.identifierReference,
"ALTER TABLE ... RENAME TO PARTITION"),
UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(ctx.from)),
UnresolvedPartitionSpec(visitNonOptionalPartitionSpec(ctx.to)))
}
/**
* Create an [[DropPartitions]]
*
* For example:
* {{{
* ALTER TABLE multi_part_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...]
* [PURGE];
* ALTER VIEW view DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...];
* }}}
*
* ALTER VIEW ... DROP PARTITION ... is not supported because the concept of partitioning
* is associated with physical tables
*/
override def visitDropTablePartitions(
ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) {
if (ctx.VIEW != null) {
invalidStatement("ALTER VIEW ... DROP PARTITION", ctx)
}
val partSpecs = ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec)
.map(spec => UnresolvedPartitionSpec(spec))
DropPartitions(
createUnresolvedTable(
ctx.identifierReference,
"ALTER TABLE ... DROP PARTITION ..."),
partSpecs.toSeq,
ifExists = ctx.EXISTS != null,
purge = ctx.PURGE != null)
}
/**
* Create an [[SetTableSerDeProperties]]
*
* For example:
* {{{
* ALTER TABLE multi_part_name [PARTITION spec] SET SERDE serde_name
* [WITH SERDEPROPERTIES props];
* ALTER TABLE multi_part_name [PARTITION spec] SET SERDEPROPERTIES serde_properties;
* }}}
*/
override def visitSetTableSerDe(ctx: SetTableSerDeContext): LogicalPlan = withOrigin(ctx) {
SetTableSerDeProperties(
createUnresolvedTable(
ctx.identifierReference,
"ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]",
true),
Option(ctx.stringLit).map(x => string(visitStringLit(x))),
Option(ctx.propertyList).map(visitPropertyKeyValues),
// TODO a partition spec is allowed to have optional values. This is currently violated.
Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))
}
/**
* Alter the query of a view. This creates a [[AlterViewAs]]
*
* For example:
* {{{
* ALTER VIEW multi_part_name AS SELECT ...;
* }}}
*/
override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) {
AlterViewAs(
createUnresolvedView(ctx.identifierReference, "ALTER VIEW ... AS"),
originalText = source(ctx.query),
query = plan(ctx.query))
}
/**
* Create a [[RenameTable]] command.
*
* For example:
* {{{
* ALTER TABLE multi_part_name1 RENAME TO multi_part_name2;
* ALTER VIEW multi_part_name1 RENAME TO multi_part_name2;
* }}}
*/
override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) {
val isView = ctx.VIEW != null
val relationStr = if (isView) "VIEW" else "TABLE"
RenameTable(
createUnresolvedTableOrView(ctx.from, s"ALTER $relationStr ... RENAME TO"),
visitMultipartIdentifier(ctx.to),
isView)
}
/**
* A command for users to list the properties for a table. If propertyKey is specified, the value
* for the propertyKey is returned. If propertyKey is not specified, all the keys and their
* corresponding values are returned.
* The syntax of using this command in SQL is:
* {{{
* SHOW TBLPROPERTIES multi_part_name[('propertyKey')];
* }}}
*/
override def visitShowTblProperties(
ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) {
ShowTableProperties(
createUnresolvedTableOrView(ctx.table, "SHOW TBLPROPERTIES"),
Option(ctx.key).map(visitPropertyKey))
}
/**
* Create a plan for a DESCRIBE FUNCTION statement.
*/
override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) {
import ctx._
if (describeFuncName.identifierReference() == null) {
val functionName =
if (describeFuncName.stringLit() != null) {
Seq(string(visitStringLit(describeFuncName.stringLit())))
} else {
Seq(describeFuncName.getText)
}
DescribeFunction(
createUnresolvedFunctionName(
ctx.describeFuncName(),
functionName,
"DESCRIBE FUNCTION",
requirePersistent = false,
funcTypeMismatchHint = None),
EXTENDED != null)
} else {
DescribeFunction(
withIdentClause(
describeFuncName.identifierReference(),
createUnresolvedFunctionName(
describeFuncName.identifierReference,
_,
"DESCRIBE FUNCTION",
requirePersistent = false,
funcTypeMismatchHint = None)),
EXTENDED != null)
}
}
/**
* Create a plan for a SHOW FUNCTIONS command.
*/
override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) {
val (userScope, systemScope) = Option(ctx.identifier)
.map(_.getText.toLowerCase(Locale.ROOT)) match {
case None | Some("all") => (true, true)
case Some("system") => (false, true)
case Some("user") => (true, false)
case Some(x) => throw QueryParsingErrors.showFunctionsUnsupportedError(x, ctx.identifier())
}
val legacy = Option(ctx.legacy).map(visitMultipartIdentifier)
val pattern = Option(ctx.pattern).map(x => string(visitStringLit(x))).orElse(legacy.map(_.last))
if (ctx.ns != null) {
if (legacy.isDefined) {
throw QueryParsingErrors.showFunctionsInvalidPatternError(ctx.legacy.getText, ctx.legacy)
}
ShowFunctions(
withIdentClause(ctx.ns, UnresolvedNamespace(_)),
userScope, systemScope, pattern)
} else if (legacy.isDefined) {
val ns = if (legacy.get.length > 1) {
UnresolvedNamespace(legacy.get.dropRight(1))
} else {
CurrentNamespace
}
ShowFunctions(ns, userScope, systemScope, pattern)
} else {
ShowFunctions(CurrentNamespace, userScope, systemScope, pattern)
}
}
override def visitRefreshFunction(ctx: RefreshFunctionContext): LogicalPlan = withOrigin(ctx) {
RefreshFunction(
withIdentClause(
ctx.identifierReference,
createUnresolvedFunctionName(
ctx.identifierReference,
_,
"REFRESH FUNCTION",
requirePersistent = true,
funcTypeMismatchHint = None)))
}
override def visitCommentNamespace(ctx: CommentNamespaceContext): LogicalPlan = withOrigin(ctx) {
val comment = visitComment(ctx.comment)
CommentOnNamespace(withIdentClause(ctx.identifierReference, UnresolvedNamespace(_)), comment)
}
override def visitCommentTable(ctx: CommentTableContext): LogicalPlan = withOrigin(ctx) {
val comment = visitComment(ctx.comment)
CommentOnTable(createUnresolvedTable(ctx.identifierReference, "COMMENT ON TABLE"), comment)
}
override def visitComment (ctx: CommentContext): String = {
Option(ctx.stringLit()).map(s => string(visitStringLit(s))).getOrElse("")
}
/**
* Create an index, returning a [[CreateIndex]] logical plan.
* For example:
* {{{
* CREATE INDEX index_name ON [TABLE] table_name [USING index_type] (column_index_property_list)
* [OPTIONS indexPropertyList]
* column_index_property_list: column_name [OPTIONS(indexPropertyList)] [ , . . . ]
* indexPropertyList: index_property_name [= index_property_value] [ , . . . ]
* }}}
*/
override def visitCreateIndex(ctx: CreateIndexContext): LogicalPlan = withOrigin(ctx) {
val (indexName, indexType) = if (ctx.identifier.size() == 1) {
(ctx.identifier(0).getText, "")
} else {
(ctx.identifier(0).getText, ctx.identifier(1).getText)
}
val columns = ctx.columns.multipartIdentifierProperty.asScala
.map(_.multipartIdentifier).map(typedVisit[Seq[String]]).toSeq
val columnsProperties = ctx.columns.multipartIdentifierProperty.asScala
.map(x => (Option(x.options).map(visitPropertyKeyValues).getOrElse(Map.empty))).toSeq
val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
blockBang(ctx.errorCapturingNot)
CreateIndex(
createUnresolvedTable(ctx.identifierReference, "CREATE INDEX"),
indexName,
indexType,
ctx.EXISTS != null,
columns.map(UnresolvedFieldName(_)).zip(columnsProperties),
options)
}
/**
* Drop an index, returning a [[DropIndex]] logical plan.
* For example:
* {{{
* DROP INDEX [IF EXISTS] index_name ON [TABLE] table_name
* }}}
*/
override def visitDropIndex(ctx: DropIndexContext): LogicalPlan = withOrigin(ctx) {
val indexName = ctx.identifier.getText
DropIndex(
createUnresolvedTable(ctx.identifierReference, "DROP INDEX"),
indexName,
ctx.EXISTS != null)
}
/**
* Create a TimestampAdd expression.
*/
override def visitTimestampadd(ctx: TimestampaddContext): Expression = withOrigin(ctx) {
if (ctx.invalidUnit != null) {
throw QueryParsingErrors.invalidDatetimeUnitError(
ctx,
ctx.name.getText,
ctx.invalidUnit.getText)
} else {
TimestampAdd(ctx.unit.getText, expression(ctx.unitsAmount), expression(ctx.timestamp))
}
}
/**
* Create a TimestampDiff expression.
*/
override def visitTimestampdiff(ctx: TimestampdiffContext): Expression = withOrigin(ctx) {
if (ctx.invalidUnit != null) {
throw QueryParsingErrors.invalidDatetimeUnitError(
ctx,
ctx.name.getText,
ctx.invalidUnit.getText)
} else {
TimestampDiff(ctx.unit.getText, expression(ctx.startTimestamp), expression(ctx.endTimestamp))
}
}
/**
* Create a named parameter which represents a literal with a non-bound value and unknown type.
* */
override def visitNamedParameterLiteral(
ctx: NamedParameterLiteralContext): Expression = withOrigin(ctx) {
NamedParameter(ctx.identifier().getText)
}
/**
* Create a positional parameter which represents a literal
* with a non-bound value and unknown type.
* */
override def visitPosParameterLiteral(
ctx: PosParameterLiteralContext): Expression = withOrigin(ctx) {
PosParameter(ctx.QUESTION().getSymbol.getStartIndex)
}
/**
* Create a [[CreateVariable]] command.
*
* For example:
* {{{
* DECLARE [OR REPLACE] [VARIABLE] [db_name.]variable_name
* [dataType] [defaultExpression];
* }}}
*
* We will add CREATE VARIABLE for persisted variable definitions to this, hence the name.
*/
override def visitCreateVariable(ctx: CreateVariableContext): LogicalPlan = withOrigin(ctx) {
val dataTypeOpt = Option(ctx.dataType()).map(typedVisit[DataType])
val defaultExpression = if (ctx.variableDefaultExpression() == null) {
if (dataTypeOpt.isEmpty) {
throw new ParseException(
errorClass = "INVALID_SQL_SYNTAX.VARIABLE_TYPE_OR_DEFAULT_REQUIRED",
messageParameters = Map.empty,
ctx.identifierReference)
}
DefaultValueExpression(Literal(null, dataTypeOpt.get), "null")
} else {
val default = visitVariableDefaultExpression(ctx.variableDefaultExpression())
dataTypeOpt.map { dt => default.copy(child = Cast(default.child, dt)) }.getOrElse(default)
}
CreateVariable(
withIdentClause(ctx.identifierReference(), UnresolvedIdentifier(_)),
defaultExpression,
ctx.REPLACE() != null
)
}
/**
* Create a [[DropVariable]] command.
*
* For example:
* {{{
* DROP TEMPORARY VARIABLE [IF EXISTS] variable;
* }}}
*/
override def visitDropVariable(ctx: DropVariableContext): LogicalPlan = withOrigin(ctx) {
DropVariable(
withIdentClause(ctx.identifierReference(), UnresolvedIdentifier(_)),
ctx.EXISTS() != null
)
}
/**
* Create a [[SetVariable]] command.
*
* For example:
* {{{
* SET VARIABLE var1 = v1, var2 = v2, ...
* SET VARIABLE (var1, var2, ...) = (SELECT ...)
* }}}
*/
override def visitSetVariable(ctx: SetVariableContext): LogicalPlan = withOrigin(ctx) {
if (ctx.query() != null) {
// The SET variable source is a query
val variables = ctx.multipartIdentifierList.multipartIdentifier.asScala.map { variableIdent =>
val varName = visitMultipartIdentifier(variableIdent)
UnresolvedAttribute(varName)
}.toSeq
SetVariable(variables, visitQuery(ctx.query()))
} else {
// The SET variable source is list of expressions.
val (variables, values) = ctx.assignmentList().assignment().asScala.map { assign =>
val varIdent = visitMultipartIdentifier(assign.key)
val varExpr = expression(assign.value)
val varNamedExpr = varExpr match {
case n: NamedExpression => n
case e => Alias(e, varIdent.last)()
}
(UnresolvedAttribute(varIdent), varNamedExpr)
}.toSeq.unzip
SetVariable(variables, Project(values, OneRowRelation()))
}
}
}