| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.sql.execution.datasources.jdbc |
| |
| import java.sql.{Connection, PreparedStatement, ResultSet} |
| |
| import scala.util.control.NonFatal |
| |
| import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} |
| import org.apache.spark.internal.Logging |
| import org.apache.spark.rdd.RDD |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, Max, Min, Sum} |
| import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} |
| import org.apache.spark.sql.sources._ |
| import org.apache.spark.sql.types._ |
| import org.apache.spark.util.CompletionIterator |
| |
| /** |
| * Data corresponding to one partition of a JDBCRDD. |
| */ |
| case class JDBCPartition(whereClause: String, idx: Int) extends Partition { |
| override def index: Int = idx |
| } |
| |
| object JDBCRDD extends Logging { |
| |
| /** |
| * Takes a (schema, table) specification and returns the table's Catalyst |
| * schema. |
| * |
| * @param options - JDBC options that contains url, table and other information. |
| * |
| * @return A StructType giving the table's Catalyst schema. |
| * @throws java.sql.SQLException if the table specification is garbage. |
| * @throws java.sql.SQLException if the table contains an unsupported type. |
| */ |
| def resolveTable(options: JDBCOptions): StructType = { |
| val url = options.url |
| val table = options.tableOrQuery |
| val dialect = JdbcDialects.get(url) |
| getQueryOutputSchema(dialect.getSchemaQuery(table), options, dialect) |
| } |
| |
| def getQueryOutputSchema( |
| query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = { |
| val conn: Connection = JdbcUtils.createConnectionFactory(options)() |
| try { |
| val statement = conn.prepareStatement(query) |
| try { |
| statement.setQueryTimeout(options.queryTimeout) |
| val rs = statement.executeQuery() |
| try { |
| JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) |
| } finally { |
| rs.close() |
| } |
| } finally { |
| statement.close() |
| } |
| } finally { |
| conn.close() |
| } |
| } |
| |
| /** |
| * Prune all but the specified columns from the specified Catalyst schema. |
| * |
| * @param schema - The Catalyst schema of the master table |
| * @param columns - The list of desired columns |
| * |
| * @return A Catalyst schema corresponding to columns in the given order. |
| */ |
| private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { |
| val fieldMap = Map(schema.fields.map(x => x.name -> x): _*) |
| new StructType(columns.map(name => fieldMap(name))) |
| } |
| |
| /** |
| * Turns a single Filter into a String representing a SQL expression. |
| * Returns None for an unhandled filter. |
| */ |
| def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = { |
| def quote(colName: String): String = dialect.quoteIdentifier(colName) |
| |
| Option(f match { |
| case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}" |
| case EqualNullSafe(attr, value) => |
| val col = quote(attr) |
| s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " + |
| s"${dialect.compileValue(value)} IS NULL) OR " + |
| s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))" |
| case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}" |
| case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}" |
| case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}" |
| case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}" |
| case IsNull(attr) => s"${quote(attr)} IS NULL" |
| case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL" |
| case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'" |
| case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'" |
| case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'" |
| case In(attr, value) if value.isEmpty => |
| s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END" |
| case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})" |
| case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null) |
| case Or(f1, f2) => |
| // We can't compile Or filter unless both sub-filters are compiled successfully. |
| // It applies too for the following And filter. |
| // If we can make sure compileFilter supports all filters, we can remove this check. |
| val or = Seq(f1, f2).flatMap(compileFilter(_, dialect)) |
| if (or.size == 2) { |
| or.map(p => s"($p)").mkString(" OR ") |
| } else { |
| null |
| } |
| case And(f1, f2) => |
| val and = Seq(f1, f2).flatMap(compileFilter(_, dialect)) |
| if (and.size == 2) { |
| and.map(p => s"($p)").mkString(" AND ") |
| } else { |
| null |
| } |
| case _ => null |
| }) |
| } |
| |
| def compileAggregates( |
| aggregates: Seq[AggregateFunc], |
| dialect: JdbcDialect): Option[Seq[String]] = { |
| def quote(colName: String): String = dialect.quoteIdentifier(colName) |
| |
| Some(aggregates.map { |
| case min: Min => |
| if (min.column.fieldNames.length != 1) return None |
| s"MIN(${quote(min.column.fieldNames.head)})" |
| case max: Max => |
| if (max.column.fieldNames.length != 1) return None |
| s"MAX(${quote(max.column.fieldNames.head)})" |
| case count: Count => |
| if (count.column.fieldNames.length != 1) return None |
| val distinct = if (count.isDistinct) "DISTINCT " else "" |
| val column = quote(count.column.fieldNames.head) |
| s"COUNT($distinct$column)" |
| case sum: Sum => |
| if (sum.column.fieldNames.length != 1) return None |
| val distinct = if (sum.isDistinct) "DISTINCT " else "" |
| val column = quote(sum.column.fieldNames.head) |
| s"SUM($distinct$column)" |
| case _: CountStar => |
| s"COUNT(*)" |
| case _ => return None |
| }) |
| } |
| |
| /** |
| * Build and return JDBCRDD from the given information. |
| * |
| * @param sc - Your SparkContext. |
| * @param schema - The Catalyst schema of the underlying database table. |
| * @param requiredColumns - The names of the columns to SELECT. |
| * @param filters - The filters to include in all WHERE clauses. |
| * @param parts - An array of JDBCPartitions specifying partition ids and |
| * per-partition WHERE clauses. |
| * @param options - JDBC options that contains url, table and other information. |
| * @param outputSchema - The schema of the columns to SELECT. |
| * @param groupByColumns - The pushed down group by columns. |
| * |
| * @return An RDD representing "SELECT requiredColumns FROM fqTable". |
| */ |
| def scanTable( |
| sc: SparkContext, |
| schema: StructType, |
| requiredColumns: Array[String], |
| filters: Array[Filter], |
| parts: Array[Partition], |
| options: JDBCOptions, |
| outputSchema: Option[StructType] = None, |
| groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = { |
| val url = options.url |
| val dialect = JdbcDialects.get(url) |
| val quotedColumns = if (groupByColumns.isEmpty) { |
| requiredColumns.map(colName => dialect.quoteIdentifier(colName)) |
| } else { |
| // these are already quoted in JDBCScanBuilder |
| requiredColumns |
| } |
| new JDBCRDD( |
| sc, |
| JdbcUtils.createConnectionFactory(options), |
| outputSchema.getOrElse(pruneSchema(schema, requiredColumns)), |
| quotedColumns, |
| filters, |
| parts, |
| url, |
| options, |
| groupByColumns) |
| } |
| } |
| |
| /** |
| * An RDD representing a table in a database accessed via JDBC. Both the |
| * driver code and the workers must be able to access the database; the driver |
| * needs to fetch the schema while the workers need to fetch the data. |
| */ |
| private[jdbc] class JDBCRDD( |
| sc: SparkContext, |
| getConnection: () => Connection, |
| schema: StructType, |
| columns: Array[String], |
| filters: Array[Filter], |
| partitions: Array[Partition], |
| url: String, |
| options: JDBCOptions, |
| groupByColumns: Option[Array[String]]) |
| extends RDD[InternalRow](sc, Nil) { |
| |
| /** |
| * Retrieve the list of partitions corresponding to this RDD. |
| */ |
| override def getPartitions: Array[Partition] = partitions |
| |
| /** |
| * `columns`, but as a String suitable for injection into a SQL query. |
| */ |
| private val columnList: String = { |
| val sb = new StringBuilder() |
| columns.foreach(x => sb.append(",").append(x)) |
| if (sb.isEmpty) "1" else sb.substring(1) |
| } |
| |
| /** |
| * `filters`, but as a WHERE clause suitable for injection into a SQL query. |
| */ |
| private val filterWhereClause: String = |
| filters |
| .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url))) |
| .map(p => s"($p)").mkString(" AND ") |
| |
| /** |
| * A WHERE clause representing both `filters`, if any, and the current partition. |
| */ |
| private def getWhereClause(part: JDBCPartition): String = { |
| if (part.whereClause != null && filterWhereClause.length > 0) { |
| "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" |
| } else if (part.whereClause != null) { |
| "WHERE " + part.whereClause |
| } else if (filterWhereClause.length > 0) { |
| "WHERE " + filterWhereClause |
| } else { |
| "" |
| } |
| } |
| |
| /** |
| * A GROUP BY clause representing pushed-down grouping columns. |
| */ |
| private def getGroupByClause: String = { |
| if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) { |
| // The GROUP BY columns should already be quoted by the caller side. |
| s"GROUP BY ${groupByColumns.get.mkString(", ")}" |
| } else { |
| "" |
| } |
| } |
| |
| /** |
| * Runs the SQL query against the JDBC driver. |
| * |
| */ |
| override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = { |
| var closed = false |
| var rs: ResultSet = null |
| var stmt: PreparedStatement = null |
| var conn: Connection = null |
| |
| def close(): Unit = { |
| if (closed) return |
| try { |
| if (null != rs) { |
| rs.close() |
| } |
| } catch { |
| case e: Exception => logWarning("Exception closing resultset", e) |
| } |
| try { |
| if (null != stmt) { |
| stmt.close() |
| } |
| } catch { |
| case e: Exception => logWarning("Exception closing statement", e) |
| } |
| try { |
| if (null != conn) { |
| if (!conn.isClosed && !conn.getAutoCommit) { |
| try { |
| conn.commit() |
| } catch { |
| case NonFatal(e) => logWarning("Exception committing transaction", e) |
| } |
| } |
| conn.close() |
| } |
| logInfo("closed connection") |
| } catch { |
| case e: Exception => logWarning("Exception closing connection", e) |
| } |
| closed = true |
| } |
| |
| context.addTaskCompletionListener[Unit]{ context => close() } |
| |
| val inputMetrics = context.taskMetrics().inputMetrics |
| val part = thePart.asInstanceOf[JDBCPartition] |
| conn = getConnection() |
| val dialect = JdbcDialects.get(url) |
| import scala.collection.JavaConverters._ |
| dialect.beforeFetch(conn, options.asProperties.asScala.toMap) |
| |
| // This executes a generic SQL statement (or PL/SQL block) before reading |
| // the table/query via JDBC. Use this feature to initialize the database |
| // session environment, e.g. for optimizations and/or troubleshooting. |
| options.sessionInitStatement match { |
| case Some(sql) => |
| val statement = conn.prepareStatement(sql) |
| logInfo(s"Executing sessionInitStatement: $sql") |
| try { |
| statement.setQueryTimeout(options.queryTimeout) |
| statement.execute() |
| } finally { |
| statement.close() |
| } |
| case None => |
| } |
| |
| // H2's JDBC driver does not support the setSchema() method. We pass a |
| // fully-qualified table name in the SELECT statement. I don't know how to |
| // talk about a table in a completely portable way. |
| |
| val myWhereClause = getWhereClause(part) |
| |
| val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + |
| s" $getGroupByClause" |
| stmt = conn.prepareStatement(sqlText, |
| ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) |
| stmt.setFetchSize(options.fetchSize) |
| stmt.setQueryTimeout(options.queryTimeout) |
| rs = stmt.executeQuery() |
| val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) |
| |
| CompletionIterator[InternalRow, Iterator[InternalRow]]( |
| new InterruptibleIterator(context, rowsIterator), close()) |
| } |
| } |