blob: 60d88b6690587cd1a03f66982c26377bd0602e25 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.jdbc
import scala.collection.mutable.ArrayBuffer
import scala.math.BigDecimal.RoundingMode
import org.apache.spark.Partition
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
/**
* Instructions on how to partition the table among workers.
*/
private[sql] case class JDBCPartitioningInfo(
column: String,
columnType: DataType,
lowerBound: Long,
upperBound: Long,
numPartitions: Int)
private[sql] object JDBCRelation extends Logging {
/**
* Given a partitioning schematic (a column of integral type, a number of
* partitions, and upper and lower bounds on the column's value), generate
* WHERE clauses for each partition so that each row in the table appears
* exactly once. The parameters minValue and maxValue are advisory in that
* incorrect values may cause the partitioning to be poor, but no data
* will fail to be represented.
*
* Null value predicate is added to the first partition where clause to include
* the rows with null value for the partitions column.
*
* @param schema resolved schema of a JDBC table
* @param resolver function used to determine if two identifiers are equal
* @param timeZoneId timezone ID to be used if a partition column type is date or timestamp
* @param jdbcOptions JDBC options that contains url
* @return an array of partitions with where clause for each partition
*/
def columnPartition(
schema: StructType,
resolver: Resolver,
timeZoneId: String,
jdbcOptions: JDBCOptions): Array[Partition] = {
val partitioning = {
import JDBCOptions._
val partitionColumn = jdbcOptions.partitionColumn
val lowerBound = jdbcOptions.lowerBound
val upperBound = jdbcOptions.upperBound
val numPartitions = jdbcOptions.numPartitions
if (partitionColumn.isEmpty) {
assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not " +
s"specified, '$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty")
null
} else {
assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty,
s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " +
s"'$JDBC_NUM_PARTITIONS' are also required")
val (column, columnType) = verifyAndGetNormalizedPartitionColumn(
schema, partitionColumn.get, resolver, jdbcOptions)
val lowerBoundValue = toInternalBoundValue(lowerBound.get, columnType, timeZoneId)
val upperBoundValue = toInternalBoundValue(upperBound.get, columnType, timeZoneId)
JDBCPartitioningInfo(
column, columnType, lowerBoundValue, upperBoundValue, numPartitions.get)
}
}
if (partitioning == null || partitioning.numPartitions <= 1 ||
partitioning.lowerBound == partitioning.upperBound) {
return Array[Partition](JDBCPartition(null, 0))
}
val lowerBound = partitioning.lowerBound
val upperBound = partitioning.upperBound
require (lowerBound <= upperBound,
"Operation not allowed: the lower bound of partitioning column is larger than the upper " +
s"bound. Lower bound: $lowerBound; Upper bound: $upperBound")
val boundValueToString: Long => String =
toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId)
val numPartitions =
if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */
(upperBound - lowerBound) < 0) {
partitioning.numPartitions
} else {
logWarning("The number of partitions is reduced because the specified number of " +
"partitions is less than the difference between upper bound and lower bound. " +
s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " +
s"partitions: ${partitioning.numPartitions}; " +
s"Lower bound: ${boundValueToString(lowerBound)}; " +
s"Upper bound: ${boundValueToString(upperBound)}.")
upperBound - lowerBound
}
// Overflow can happen if you subtract then divide. For example:
// (Long.MaxValue - Long.MinValue) / (numPartitions - 2).
// Also, using fixed-point decimals here to avoid possible inaccuracy from floating point.
val upperStride = (upperBound / BigDecimal(numPartitions))
.setScale(18, RoundingMode.HALF_EVEN)
val lowerStride = (lowerBound / BigDecimal(numPartitions))
.setScale(18, RoundingMode.HALF_EVEN)
val preciseStride = upperStride - lowerStride
val stride = preciseStride.toLong
// Determine the number of strides the last partition will fall short of compared to the
// supplied upper bound. Take half of those strides, and then add them to the lower bound
// for better distribution of the first and last partitions.
val lostNumOfStrides = (preciseStride - stride) * numPartitions / stride
val lowerBoundWithStrideAlignment = lowerBound +
((lostNumOfStrides / 2) * stride).setScale(0, RoundingMode.HALF_UP).toLong
var i: Int = 0
val column = partitioning.column
var currentValue = lowerBoundWithStrideAlignment
val ans = new ArrayBuffer[Partition]()
while (i < numPartitions) {
val lBoundValue = boundValueToString(currentValue)
val lBound = if (i != 0) s"$column >= $lBoundValue" else null
currentValue += stride
val uBoundValue = boundValueToString(currentValue)
val uBound = if (i != numPartitions - 1) s"$column < $uBoundValue" else null
val whereClause =
if (uBound == null) {
lBound
} else if (lBound == null) {
s"$uBound or $column is null"
} else {
s"$lBound AND $uBound"
}
ans += JDBCPartition(whereClause, i)
i = i + 1
}
val partitions = ans.toArray
logInfo(s"Number of partitions: $numPartitions, WHERE clauses of these partitions: " +
partitions.map(_.asInstanceOf[JDBCPartition].whereClause).mkString(", "))
partitions
}
// Verify column name and type based on the JDBC resolved schema
private def verifyAndGetNormalizedPartitionColumn(
schema: StructType,
columnName: String,
resolver: Resolver,
jdbcOptions: JDBCOptions): (String, DataType) = {
val dialect = JdbcDialects.get(jdbcOptions.url)
val column = schema.find { f =>
resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName)
}.getOrElse {
val maxNumToStringFields = SQLConf.get.maxToStringFields
throw QueryCompilationErrors.userDefinedPartitionNotFoundInJDBCRelationError(
columnName, schema.simpleString(maxNumToStringFields))
}
column.dataType match {
case _: NumericType | DateType | TimestampType =>
case _ =>
throw QueryCompilationErrors.invalidPartitionColumnTypeError(column)
}
(dialect.quoteIdentifier(column.name), column.dataType)
}
private def toInternalBoundValue(
value: String,
columnType: DataType,
timeZoneId: String): Long = {
def parse[T](f: UTF8String => Option[T]): T = {
f(UTF8String.fromString(value)).getOrElse {
throw new IllegalArgumentException(
s"Cannot parse the bound value $value as ${columnType.catalogString}")
}
}
columnType match {
case _: NumericType => value.toLong
case DateType => parse(stringToDate).toLong
case TimestampType => parse(stringToTimestamp(_, getZoneId(timeZoneId)))
}
}
private def toBoundValueInWhereClause(
value: Long,
columnType: DataType,
timeZoneId: String): String = {
def dateTimeToString(): String = {
val dateTimeStr = columnType match {
case DateType =>
DateFormatter().format(value.toInt)
case TimestampType =>
val timestampFormatter = TimestampFormatter.getFractionFormatter(
DateTimeUtils.getZoneId(timeZoneId))
timestampFormatter.format(value)
}
s"'$dateTimeStr'"
}
columnType match {
case _: NumericType => value.toString
case DateType | TimestampType => dateTimeToString()
}
}
/**
* Takes a (schema, table) specification and returns the table's Catalyst schema.
* If `customSchema` defined in the JDBC options, replaces the schema's dataType with the
* custom schema's type.
*
* @param resolver function used to determine if two identifiers are equal
* @param jdbcOptions JDBC options that contains url, table and other information.
* @return resolved Catalyst schema of a JDBC table
*/
def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = {
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
jdbcOptions.customSchema match {
case Some(customSchema) => JdbcUtils.getCustomSchema(
tableSchema, customSchema, resolver)
case None => tableSchema
}
}
/**
* Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema.
*/
def apply(
parts: Array[Partition],
jdbcOptions: JDBCOptions)(
sparkSession: SparkSession): JDBCRelation = {
val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions)
JDBCRelation(schema, parts, jdbcOptions)(sparkSession)
}
}
private[sql] case class JDBCRelation(
override val schema: StructType,
parts: Array[Partition],
jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
extends BaseRelation
with PrunedFilteredScan
with InsertableRelation {
override def sqlContext: SQLContext = sparkSession.sqlContext
override val needConversion: Boolean = false
// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
if (jdbcOptions.pushDownPredicate) {
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)
} else {
filters
}
}
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
schema,
requiredColumns,
filters,
parts,
jdbcOptions).asInstanceOf[RDD[Row]]
}
def buildScan(
requiredColumns: Array[String],
finalSchema: StructType,
filters: Array[Filter],
groupByColumns: Option[Array[String]]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
schema,
requiredColumns,
filters,
parts,
jdbcOptions,
Some(finalSchema),
groupByColumns).asInstanceOf[RDD[Row]]
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties)
}
override def toString: String = {
val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else ""
// credentials should not be included in the plan output, table information is sufficient.
s"JDBCRelation(${jdbcOptions.tableOrQuery})" + partitioningInfo
}
}