blob: c575e95485cead1b14d49ab9ef02db8c28b40160 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.jdbc
import java.sql.{Connection, PreparedStatement, ResultSet}
import scala.util.control.NonFatal
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.CompletionIterator
/**
* Data corresponding to one partition of a JDBCRDD.
*/
case class JDBCPartition(whereClause: String, idx: Int) extends Partition {
override def index: Int = idx
}
object JDBCRDD extends Logging {
/**
* Takes a (schema, table) specification and returns the table's Catalyst
* schema.
*
* @param options - JDBC options that contains url, table and other information.
*
* @return A StructType giving the table's Catalyst schema.
* @throws java.sql.SQLException if the table specification is garbage.
* @throws java.sql.SQLException if the table contains an unsupported type.
*/
def resolveTable(options: JDBCOptions): StructType = {
val url = options.url
val table = options.tableOrQuery
val dialect = JdbcDialects.get(url)
getQueryOutputSchema(dialect.getSchemaQuery(table), options, dialect)
}
def getQueryOutputSchema(
query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = {
val conn: Connection = JdbcUtils.createConnectionFactory(options)()
try {
val statement = conn.prepareStatement(query)
try {
statement.setQueryTimeout(options.queryTimeout)
val rs = statement.executeQuery()
try {
JdbcUtils.getSchema(rs, dialect, alwaysNullable = true)
} finally {
rs.close()
}
} finally {
statement.close()
}
} finally {
conn.close()
}
}
/**
* Prune all but the specified columns from the specified Catalyst schema.
*
* @param schema - The Catalyst schema of the master table
* @param columns - The list of desired columns
*
* @return A Catalyst schema corresponding to columns in the given order.
*/
private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
new StructType(columns.map(name => fieldMap(name)))
}
/**
* Turns a single Filter into a String representing a SQL expression.
* Returns None for an unhandled filter.
*/
def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = {
def quote(colName: String): String = dialect.quoteIdentifier(colName)
Option(f match {
case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}"
case EqualNullSafe(attr, value) =>
val col = quote(attr)
s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " +
s"${dialect.compileValue(value)} IS NULL) OR " +
s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))"
case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}"
case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}"
case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}"
case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}"
case IsNull(attr) => s"${quote(attr)} IS NULL"
case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL"
case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'"
case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'"
case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'"
case In(attr, value) if value.isEmpty =>
s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END"
case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})"
case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null)
case Or(f1, f2) =>
// We can't compile Or filter unless both sub-filters are compiled successfully.
// It applies too for the following And filter.
// If we can make sure compileFilter supports all filters, we can remove this check.
val or = Seq(f1, f2).flatMap(compileFilter(_, dialect))
if (or.size == 2) {
or.map(p => s"($p)").mkString(" OR ")
} else {
null
}
case And(f1, f2) =>
val and = Seq(f1, f2).flatMap(compileFilter(_, dialect))
if (and.size == 2) {
and.map(p => s"($p)").mkString(" AND ")
} else {
null
}
case _ => null
})
}
def compileAggregates(
aggregates: Seq[AggregateFunc],
dialect: JdbcDialect): Option[Seq[String]] = {
def quote(colName: String): String = dialect.quoteIdentifier(colName)
Some(aggregates.map {
case min: Min =>
if (min.column.fieldNames.length != 1) return None
s"MIN(${quote(min.column.fieldNames.head)})"
case max: Max =>
if (max.column.fieldNames.length != 1) return None
s"MAX(${quote(max.column.fieldNames.head)})"
case count: Count =>
if (count.column.fieldNames.length != 1) return None
val distinct = if (count.isDistinct) "DISTINCT " else ""
val column = quote(count.column.fieldNames.head)
s"COUNT($distinct$column)"
case sum: Sum =>
if (sum.column.fieldNames.length != 1) return None
val distinct = if (sum.isDistinct) "DISTINCT " else ""
val column = quote(sum.column.fieldNames.head)
s"SUM($distinct$column)"
case _: CountStar =>
s"COUNT(*)"
case _ => return None
})
}
/**
* Build and return JDBCRDD from the given information.
*
* @param sc - Your SparkContext.
* @param schema - The Catalyst schema of the underlying database table.
* @param requiredColumns - The names of the columns to SELECT.
* @param filters - The filters to include in all WHERE clauses.
* @param parts - An array of JDBCPartitions specifying partition ids and
* per-partition WHERE clauses.
* @param options - JDBC options that contains url, table and other information.
* @param outputSchema - The schema of the columns to SELECT.
* @param groupByColumns - The pushed down group by columns.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
def scanTable(
sc: SparkContext,
schema: StructType,
requiredColumns: Array[String],
filters: Array[Filter],
parts: Array[Partition],
options: JDBCOptions,
outputSchema: Option[StructType] = None,
groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
requiredColumns.map(colName => dialect.quoteIdentifier(colName))
} else {
// these are already quoted in JDBCScanBuilder
requiredColumns
}
new JDBCRDD(
sc,
JdbcUtils.createConnectionFactory(options),
outputSchema.getOrElse(pruneSchema(schema, requiredColumns)),
quotedColumns,
filters,
parts,
url,
options,
groupByColumns)
}
}
/**
* An RDD representing a table in a database accessed via JDBC. Both the
* driver code and the workers must be able to access the database; the driver
* needs to fetch the schema while the workers need to fetch the data.
*/
private[jdbc] class JDBCRDD(
sc: SparkContext,
getConnection: () => Connection,
schema: StructType,
columns: Array[String],
filters: Array[Filter],
partitions: Array[Partition],
url: String,
options: JDBCOptions,
groupByColumns: Option[Array[String]])
extends RDD[InternalRow](sc, Nil) {
/**
* Retrieve the list of partitions corresponding to this RDD.
*/
override def getPartitions: Array[Partition] = partitions
/**
* `columns`, but as a String suitable for injection into a SQL query.
*/
private val columnList: String = {
val sb = new StringBuilder()
columns.foreach(x => sb.append(",").append(x))
if (sb.isEmpty) "1" else sb.substring(1)
}
/**
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
private val filterWhereClause: String =
filters
.flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url)))
.map(p => s"($p)").mkString(" AND ")
/**
* A WHERE clause representing both `filters`, if any, and the current partition.
*/
private def getWhereClause(part: JDBCPartition): String = {
if (part.whereClause != null && filterWhereClause.length > 0) {
"WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
} else if (part.whereClause != null) {
"WHERE " + part.whereClause
} else if (filterWhereClause.length > 0) {
"WHERE " + filterWhereClause
} else {
""
}
}
/**
* A GROUP BY clause representing pushed-down grouping columns.
*/
private def getGroupByClause: String = {
if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
// The GROUP BY columns should already be quoted by the caller side.
s"GROUP BY ${groupByColumns.get.mkString(", ")}"
} else {
""
}
}
/**
* Runs the SQL query against the JDBC driver.
*
*/
override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = {
var closed = false
var rs: ResultSet = null
var stmt: PreparedStatement = null
var conn: Connection = null
def close(): Unit = {
if (closed) return
try {
if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
if (null != conn) {
if (!conn.isClosed && !conn.getAutoCommit) {
try {
conn.commit()
} catch {
case NonFatal(e) => logWarning("Exception committing transaction", e)
}
}
conn.close()
}
logInfo("closed connection")
} catch {
case e: Exception => logWarning("Exception closing connection", e)
}
closed = true
}
context.addTaskCompletionListener[Unit]{ context => close() }
val inputMetrics = context.taskMetrics().inputMetrics
val part = thePart.asInstanceOf[JDBCPartition]
conn = getConnection()
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, options.asProperties.asScala.toMap)
// This executes a generic SQL statement (or PL/SQL block) before reading
// the table/query via JDBC. Use this feature to initialize the database
// session environment, e.g. for optimizations and/or troubleshooting.
options.sessionInitStatement match {
case Some(sql) =>
val statement = conn.prepareStatement(sql)
logInfo(s"Executing sessionInitStatement: $sql")
try {
statement.setQueryTimeout(options.queryTimeout)
statement.execute()
} finally {
statement.close()
}
case None =>
}
// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
// talk about a table in a completely portable way.
val myWhereClause = getWhereClause(part)
val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
s" $getGroupByClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
stmt.setQueryTimeout(options.queryTimeout)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
CompletionIterator[InternalRow, Iterator[InternalRow]](
new InterruptibleIterator(context, rowsIterator), close())
}
}