| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.spark.sql.execution.datasources.jdbc2 |
| |
| import org.apache.spark.Partition |
| import org.apache.spark.internal.Logging |
| import org.apache.spark.rdd.RDD |
| import org.apache.spark.sql.catalyst.analysis._ |
| import org.apache.spark.sql.catalyst.util.DateTimeUtils |
| import org.apache.spark.sql.jdbc.JdbcDialects |
| import org.apache.spark.sql.sources._ |
| import org.apache.spark.sql.types.{DataType, DateType, NumericType, StructType, TimestampType} |
| import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode, SparkSession} |
| import org.apache.spark.util.Utils |
| |
| import java.sql.{Date, Timestamp} |
| import scala.collection.mutable.ArrayBuffer |
| |
| /** |
| * Instructions on how to partition the table among workers. |
| */ |
| private[sql] case class JDBCPartitioningInfo( |
| column: String, |
| columnType: DataType, |
| lowerBound: Long, |
| upperBound: Long, |
| numPartitions: Int) |
| |
| private[sql] object JDBCRelation extends Logging { |
| |
| /** |
| * Given a partitioning schematic (a column of integral type, a number of |
| * partitions, and upper and lower bounds on the column's value), generate |
| * WHERE clauses for each partition so that each row in the table appears |
| * exactly once. The parameters minValue and maxValue are advisory in that |
| * incorrect values may cause the partitioning to be poor, but no data |
| * will fail to be represented. |
| * |
| * Null value predicate is added to the first partition where clause to include |
| * the rows with null value for the partitions column. |
| * |
| * @param schema resolved schema of a JDBC table |
| * @param resolver function used to determine if two identifiers are equal |
| * @param timeZoneId timezone ID to be used if a partition column type is date or timestamp |
| * @param jdbcOptions JDBC options that contains url |
| * @return an array of partitions with where clause for each partition |
| */ |
| def columnPartition( |
| schema: StructType, |
| resolver: Resolver, |
| timeZoneId: String, |
| jdbcOptions: JDBCOptions): Array[Partition] = { |
| val partitioning = { |
| import JDBCOptions._ |
| |
| val partitionColumn = jdbcOptions.partitionColumn |
| val lowerBound = jdbcOptions.lowerBound |
| val upperBound = jdbcOptions.upperBound |
| val numPartitions = jdbcOptions.numPartitions |
| |
| if (partitionColumn.isEmpty) { |
| assert( |
| lowerBound.isEmpty && upperBound.isEmpty, |
| "When 'partitionColumn' is not " + |
| s"specified, '$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty") |
| null |
| } else { |
| assert( |
| lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, |
| s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + |
| s"'$JDBC_NUM_PARTITIONS' are also required") |
| |
| val (column, columnType) = verifyAndGetNormalizedPartitionColumn( |
| schema, |
| partitionColumn.get, |
| resolver, |
| jdbcOptions) |
| |
| val lowerBoundValue = toInternalBoundValue(lowerBound.get, columnType) |
| val upperBoundValue = toInternalBoundValue(upperBound.get, columnType) |
| JDBCPartitioningInfo( |
| column, |
| columnType, |
| lowerBoundValue, |
| upperBoundValue, |
| numPartitions.get) |
| } |
| } |
| |
| if (partitioning == null || partitioning.numPartitions <= 1 || |
| partitioning.lowerBound == partitioning.upperBound) { |
| Array[Partition](JDBCPartition(null, 0)) |
| } else { |
| val lowerBound = partitioning.lowerBound |
| val upperBound = partitioning.upperBound |
| require( |
| lowerBound <= upperBound, |
| "Operation not allowed: the lower bound of partitioning column is larger than the upper " + |
| s"bound. Lower bound: $lowerBound; Upper bound: $upperBound") |
| val boundValueToString: Long => String = |
| toBoundValueInWhereClause(_, partitioning.columnType, timeZoneId) |
| val numPartitions = |
| if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */ |
| (upperBound - lowerBound) < 0) { |
| partitioning.numPartitions |
| } else { |
| logWarning("The number of partitions is reduced because the specified number of " + |
| "partitions is less than the difference between upper bound and lower bound. " + |
| s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " + |
| s"partitions: ${partitioning.numPartitions}; " + |
| s"Lower bound: ${boundValueToString(lowerBound)}; " + |
| s"Upper bound: ${boundValueToString(upperBound)}.") |
| upperBound - lowerBound |
| } |
| // Overflow and silliness can happen if you subtract then divide. |
| // Here we get a little roundoff, but that's (hopefully) OK. |
| val stride: Long = upperBound / numPartitions - lowerBound / numPartitions |
| |
| var i: Int = 0 |
| val column = partitioning.column |
| var currentValue = lowerBound |
| val ans = new ArrayBuffer[Partition]() |
| while (i < numPartitions) { |
| val lBoundValue = boundValueToString(currentValue) |
| val lBound = if (i != 0) s"$column >= $lBoundValue" else null |
| currentValue += stride |
| val uBoundValue = boundValueToString(currentValue) |
| val uBound = if (i != numPartitions - 1) s"$column < $uBoundValue" else null |
| val whereClause = |
| if (uBound == null) { |
| lBound |
| } else if (lBound == null) { |
| s"$uBound or $column is null" |
| } else { |
| s"$lBound AND $uBound" |
| } |
| ans += JDBCPartition(whereClause, i) |
| i = i + 1 |
| } |
| val partitions = ans.toArray |
| logInfo(s"Number of partitions: $numPartitions, WHERE clauses of these partitions: " + |
| partitions.map(_.asInstanceOf[JDBCPartition].whereClause).mkString(", ")) |
| partitions |
| } |
| } |
| |
| // Verify column name and type based on the JDBC resolved schema |
| private def verifyAndGetNormalizedPartitionColumn( |
| schema: StructType, |
| columnName: String, |
| resolver: Resolver, |
| jdbcOptions: JDBCOptions): (String, DataType) = { |
| val dialect = JdbcDialects.get(jdbcOptions.url) |
| val column = schema.find { f => |
| resolver(f.name, columnName) || resolver(dialect.quoteIdentifier(f.name), columnName) |
| }.getOrElse { |
| throw new AnalysisException(s"User-defined partition column $columnName not " + |
| s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}") |
| } |
| column.dataType match { |
| case _: NumericType | DateType | TimestampType => |
| case _ => |
| throw new AnalysisException( |
| s"Partition column type should be ${NumericType.simpleString}, " + |
| s"${DateType.catalogString}, or ${TimestampType.catalogString}, but " + |
| s"${column.dataType.catalogString} found.") |
| } |
| (dialect.quoteIdentifier(column.name), column.dataType) |
| } |
| |
| private def toInternalBoundValue(value: String, columnType: DataType): Long = columnType match { |
| case _: NumericType => value.toLong |
| case DateType => DateTimeUtils.fromJavaDate(Date.valueOf(value)).toLong |
| case TimestampType => DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(value)) |
| } |
| |
| private def toBoundValueInWhereClause( |
| value: Long, |
| columnType: DataType, |
| timeZoneId: String): String = { |
| def dateTimeToString(): String = { |
| val timeZone = DateTimeUtils.getTimeZone(timeZoneId) |
| val dateTimeStr = columnType match { |
| case DateType => DateTimeUtils.dateToString(value.toInt, timeZone) |
| case TimestampType => DateTimeUtils.timestampToString(value, timeZone) |
| } |
| s"'$dateTimeStr'" |
| } |
| |
| columnType match { |
| case _: NumericType => value.toString |
| case DateType | TimestampType => dateTimeToString() |
| } |
| } |
| |
| /** |
| * Takes a (schema, table) specification and returns the table's Catalyst schema. |
| * If `customSchema` defined in the JDBC options, replaces the schema's dataType with the |
| * custom schema's type. |
| * |
| * @param resolver function used to determine if two identifiers are equal |
| * @param jdbcOptions JDBC options that contains url, table and other information. |
| * @return resolved Catalyst schema of a JDBC table |
| */ |
| def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = { |
| val tableSchema = JDBCRDD.resolveTable(jdbcOptions) |
| jdbcOptions.customSchema match { |
| case Some(customSchema) => JdbcUtils.getCustomSchema( |
| tableSchema, |
| customSchema, |
| resolver) |
| case None => tableSchema |
| } |
| } |
| |
| /** |
| * Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema. |
| */ |
| def apply( |
| parts: Array[Partition], |
| jdbcOptions: JDBCOptions)( |
| sparkSession: SparkSession): JDBCRelation = { |
| val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions) |
| JDBCRelation(schema, parts, jdbcOptions)(sparkSession) |
| } |
| } |
| |
| private[sql] case class JDBCRelation( |
| override val schema: StructType, |
| parts: Array[Partition], |
| jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) |
| extends BaseRelation |
| with PrunedFilteredScan |
| with InsertableRelation { |
| |
| override def sqlContext: SQLContext = sparkSession.sqlContext |
| |
| override val needConversion: Boolean = false |
| |
| // Check if JDBCRDD.compileFilter can accept input filters |
| override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { |
| if (jdbcOptions.pushDownPredicate) { |
| filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) |
| } else { |
| filters |
| } |
| } |
| |
| override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { |
| // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] |
| JDBCRDD.scanTable( |
| sparkSession.sparkContext, |
| schema, |
| requiredColumns, |
| filters, |
| parts, |
| jdbcOptions).asInstanceOf[RDD[Row]] |
| } |
| |
| override def insert(data: DataFrame, overwrite: Boolean): Unit = { |
| data.write |
| .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) |
| .jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties) |
| } |
| |
| override def toString: String = { |
| val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else "" |
| // credentials should not be included in the plan output, table information is sufficient. |
| s"JDBCRelation(${jdbcOptions.tableOrQuery})" + partitioningInfo |
| } |
| } |