blob: 6e8b0f697aa2c55d92695db5a473556263af44d1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.jdbc2
import org.apache.commons.lang3.StringUtils
import org.apache.spark.TaskContext
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, DriverWrapper}
import org.apache.spark.sql.execution.datasources.jdbc2.JDBCSaveMode.JDBCSaveMode
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.NextIterator
import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
import java.util.Locale
import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal
/**
* Util functions for JDBC tables.
*/
object JdbcUtils extends Logging {
/**
* Returns a factory for creating connections to the given JDBC URL.
*
* @param options - JDBC options that contains url, table and other information.
*/
def createConnectionFactory(options: JDBCOptions): () => Connection = {
val driverClass: String = options.driverClass
() => {
DriverRegistry.register(driverClass)
val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d
case d if d.getClass.getCanonicalName == driverClass => d
}.getOrElse {
throw new IllegalStateException(
s"Did not find registered driver with class $driverClass")
}
driver.connect(options.url, options.asConnectionProperties)
}
}
/**
* Returns true if the table already exists in the JDBC database.
*/
def tableExists(conn: Connection, options: JdbcOptionsInWrite): Boolean = {
val dialect = JdbcDialects.get(options.url)
// Somewhat hacky, but there isn't a good way to identify whether a table exists for all
// SQL database systems using JDBC meta data calls, considering "table" could also include
// the database name. Query used to find table exists can be overridden by the dialects.
Try {
val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table))
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeQuery()
} finally {
statement.close()
}
}.isSuccess
}
/**
* Drops a table from the JDBC database.
*/
def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = {
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(s"DROP TABLE $table")
} finally {
statement.close()
}
}
/**
* Truncates a table from the JDBC database without side effects.
*/
def truncateTable(conn: Connection, options: JdbcOptionsInWrite): Unit = {
val dialect = JdbcDialects.get(options.url)
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
val truncateQuery =
if (options.isCascadeTruncate.isDefined) {
dialect.getTruncateQuery(options.table, options.isCascadeTruncate)
} else {
dialect.getTruncateQuery(options.table)
}
statement.executeUpdate(truncateQuery)
} finally {
statement.close()
}
}
def isCascadingTruncateTable(url: String): Option[Boolean] = {
JdbcDialects.get(url).isCascadingTruncateTable()
}
/**
* Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
*/
def getInsertStatement(
table: String,
rddSchema: StructType,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
dialect: JdbcDialect,
mode: JDBCSaveMode,
options: JDBCOptions): String = {
val columns =
if (tableSchema.isEmpty) {
rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
} else {
val columnNameEquality =
if (isCaseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
// The generated insert statement needs to follow rddSchema's column sequence and
// tableSchema's column names. When appending data into some case-sensitive DBMSs like
// PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
// RDD column names for user convenience.
val tableColumnNames = tableSchema.get.fieldNames
rddSchema.fields.map { col =>
val normalizedName =
tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
throw new AnalysisException(
s"""Column "${col.name}" not found in schema $tableSchema""")
}
dialect.quoteIdentifier(normalizedName)
}.mkString(",")
}
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
// SeaTunnel: Create insert statement when savemode is update.
mode match {
case JDBCSaveMode.Update =>
val props = options.asProperties
val duplicateIncs = props
.getProperty(JDBCOptions.JDBC_DUPLICATE_INCS, "")
.split(",")
.filter { x => StringUtils.isNotBlank(x) }
.map(x => s"`${x.trim}`")
val duplicateSetting = rddSchema
.fields
.map { x => dialect.quoteIdentifier(x.name) }
.map { name =>
if (duplicateIncs.contains(name)) s"$name=$name+VALUES($name)"
else s"$name=VALUES($name)"
}
.mkString(",")
// scalastyle:off
val sql =
s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
if (props.getProperty("showSql", "false").equals("true")) {
logInfo(s"${JDBCSaveMode.Update} => sql => $sql")
}
// scalastyle:on
sql
case _ => s"INSERT INTO $table ($columns) VALUES ($placeholders)"
}
// s"INSERT INTO $table ($columns) VALUES ($placeholders)"
}
/**
* Retrieve standard jdbc types.
*
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The default JdbcType for this DataType
*/
def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
dt match {
case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
case t: DecimalType => Option(
JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
case _ => None
}
}
private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}"))
}
/**
* Maps a JDBC type to a Catalyst type. This function is called only when
* the JdbcDialect class corresponding to your database driver returns null.
*
* @param sqlType - A field of java.sql.Types
* @return The Catalyst type corresponding to sqlType.
*/
private def getCatalystType(
sqlType: Int,
precision: Int,
scale: Int,
signed: Boolean): DataType = {
val answer = sqlType match {
// scalastyle:off
case java.sql.Types.ARRAY => null
case java.sql.Types.BIGINT =>
if (signed) {
LongType
}
else {
DecimalType(20, 0)
}
case java.sql.Types.BINARY => BinaryType
case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
case java.sql.Types.BLOB => BinaryType
case java.sql.Types.BOOLEAN => BooleanType
case java.sql.Types.CHAR => StringType
case java.sql.Types.CLOB => StringType
case java.sql.Types.DATALINK => null
case java.sql.Types.DATE => DateType
case java.sql.Types.DECIMAL if precision != 0 || scale != 0 =>
DecimalType.bounded(precision, scale)
case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.DISTINCT => null
case java.sql.Types.DOUBLE => DoubleType
case java.sql.Types.FLOAT => FloatType
case java.sql.Types.INTEGER =>
if (signed) {
IntegerType
}
else {
LongType
}
case java.sql.Types.JAVA_OBJECT => null
case java.sql.Types.LONGNVARCHAR => StringType
case java.sql.Types.LONGVARBINARY => BinaryType
case java.sql.Types.LONGVARCHAR => StringType
case java.sql.Types.NCHAR => StringType
case java.sql.Types.NCLOB => StringType
case java.sql.Types.NULL => null
case java.sql.Types.NUMERIC if precision != 0 || scale != 0 =>
DecimalType.bounded(precision, scale)
case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.NVARCHAR => StringType
case java.sql.Types.OTHER => null
case java.sql.Types.REAL => DoubleType
case java.sql.Types.REF => StringType
case java.sql.Types.REF_CURSOR => null
case java.sql.Types.ROWID => LongType
case java.sql.Types.SMALLINT => IntegerType
case java.sql.Types.SQLXML => StringType
case java.sql.Types.STRUCT => StringType
case java.sql.Types.TIME => TimestampType
case java.sql.Types.TIME_WITH_TIMEZONE => null
case java.sql.Types.TIMESTAMP => TimestampType
case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => null
case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
case _ =>
throw new SQLException("Unrecognized SQL type " + sqlType)
// scalastyle:on
}
if (answer == null) {
throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName)
}
answer
}
/**
* Returns the schema if the table already exists in the JDBC database.
*/
def getSchemaOption(conn: Connection, options: JDBCOptions): Option[StructType] = {
val dialect = JdbcDialects.get(options.url)
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(options.tableOrQuery))
try {
statement.setQueryTimeout(options.queryTimeout)
Some(getSchema(statement.executeQuery(), dialect))
} catch {
case _: SQLException => None
} finally {
statement.close()
}
} catch {
case _: SQLException => None
}
}
/**
* Takes a [[ResultSet]] and returns its Catalyst schema.
*
* @param alwaysNullable If true, all the columns are nullable.
* @return A [[StructType]] giving the Catalyst schema.
* @throws SQLException if the schema contains an unsupported type.
*/
def getSchema(
resultSet: ResultSet,
dialect: JdbcDialect,
alwaysNullable: Boolean = false): StructType = {
val rsmd = resultSet.getMetaData
val ncols = rsmd.getColumnCount
val fields = new Array[StructField](ncols)
var i = 0
while (i < ncols) {
val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
val fieldScale = rsmd.getScale(i + 1)
val isSigned = {
try {
rsmd.isSigned(i + 1)
} catch {
// Workaround for HIVE-14684:
case e: SQLException
if e.getMessage == "Method not supported" &&
rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
}
}
val nullable =
if (alwaysNullable) {
true
} else {
rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
}
val metadata = new MetadataBuilder().putLong("scale", fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
fields(i) = StructField(columnName, columnType, nullable)
i = i + 1
}
new StructType(fields)
}
/**
* Convert a [[ResultSet]] into an iterator of Catalyst Rows.
*/
def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
val inputMetrics =
Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics)
val encoder = RowEncoder(schema).resolveAndBind()
val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
internalRows.map(encoder.fromRow)
}
private[spark] def resultSetToSparkInternalRows(
resultSet: ResultSet,
schema: StructType,
inputMetrics: InputMetrics): Iterator[InternalRow] = {
new NextIterator[InternalRow] {
private[this] val rs = resultSet
private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
override protected def close(): Unit = {
try {
rs.close()
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
}
override protected def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
while (i < getters.length) {
getters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}
mutableRow
} else {
finished = true
null.asInstanceOf[InternalRow]
}
}
}
}
// A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
// for `MutableRow`. The last argument `Int` means the index for the value to be set in
// the row and also used for the value in `ResultSet`.
private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit
/**
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[InternalRow]] correctly.
*/
private def makeGetters(schema: StructType): Array[JDBCValueGetter] =
schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
case DateType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos + 1)
if (dateVal != null) {
row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
} else {
row.update(pos, null)
}
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalType.Fixed(p, s) =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
val decimal =
nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
row.update(pos, decimal)
case DoubleType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setDouble(pos, rs.getDouble(pos + 1))
case FloatType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setFloat(pos, rs.getFloat(pos + 1))
case IntegerType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setInt(pos, rs.getInt(pos + 1))
case LongType if metadata.contains("binarylong") =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
val bytes = rs.getBytes(pos + 1)
var ans = 0L
var j = 0
while (j < bytes.length) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1
}
row.setLong(pos, ans)
case LongType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setLong(pos, rs.getLong(pos + 1))
case ShortType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setShort(pos, rs.getShort(pos + 1))
case StringType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
case TimestampType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
} else {
row.update(pos, null)
}
case BinaryType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.update(pos, rs.getBytes(pos + 1))
case ArrayType(et, _) =>
val elementConversion = et match {
case TimestampType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringType =>
(array: Object) =>
// some underling types are not String such as uuid, inet, cidr, etc.
array.asInstanceOf[Array[java.lang.Object]]
.map(obj => if (obj == null) null else UTF8String.fromString(obj.toString))
case DateType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case dt: DecimalType =>
(array: Object) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](
decimal,
d => Decimal(d, dt.precision, dt.scale))
}
case LongType if metadata.contains("binarylong") =>
throw new IllegalArgumentException(s"Unsupported array element " +
s"type ${dt.catalogString} based on binary")
case ArrayType(_, _) =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => (array: Object) => array.asInstanceOf[Array[Any]]
}
(rs: ResultSet, row: InternalRow, pos: Int) =>
val array = nullSafeConvert[java.sql.Array](
input = rs.getArray(pos + 1),
array => new GenericArrayData(elementConversion.apply(array.getArray)))
row.update(pos, array)
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.catalogString}")
}
private def nullSafeConvert[T](input: T, f: T => Any): Any = {
if (input == null) {
null
} else {
f(input)
}
}
// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. The last argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.
private type JDBCValueSetter = (PreparedStatement, Row, Int) ⇒ Unit
private def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getInt(pos))
case LongType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setLong(pos + 1, row.getLong(pos))
case DoubleType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDouble(pos + 1, row.getDouble(pos))
case FloatType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setFloat(pos + 1, row.getFloat(pos))
case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getShort(pos))
case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getByte(pos))
case BooleanType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBoolean(pos + 1, row.getBoolean(pos))
case StringType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setString(pos + 1, row.getString(pos))
case BinaryType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
case TimestampType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
case DateType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
case t: DecimalType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
case ArrayType(et, _) =>
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition
.toLowerCase(Locale.ROOT).split("\\(")(0)
(stmt: PreparedStatement, row: Row, pos: Int) =>
val array = conn.createArrayOf(
typeName,
row.getSeq[AnyRef](pos).toArray)
stmt.setArray(pos + 1, array)
case _ =>
(_: PreparedStatement, _: Row, pos: Int) =>
throw new IllegalArgumentException(
s"Can't translate non-null value for field $pos")
}
/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction (unless isolation level is "NONE")
* in order to avoid repeatedly inserting data as much as possible.
*
* It is still theoretically possible for rows in a DataFrame to be
* inserted into the database more than once if a stage somehow fails after
* the commit occurs but before the stage can return successfully.
*
* This is not a closure inside saveTable() because apparently cosmetic
* implementation changes elsewhere might easily render such a closure
* non-Serializable. Instead, we explicitly close over all variables that
* are used.
*/
def savePartition(
getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int,
options: JDBCOptions): Iterator[Byte] = {
val conn = getConnection()
var committed = false
var finalIsolationLevel = Connection.TRANSACTION_NONE
if (isolationLevel != Connection.TRANSACTION_NONE) {
try {
val metadata = conn.getMetaData
if (metadata.supportsTransactions()) {
// Update to at least use the default isolation, if any transaction level
// has been chosen and transactions are supported
val defaultIsolation = metadata.getDefaultTransactionIsolation
finalIsolationLevel = defaultIsolation
if (metadata.supportsTransactionIsolationLevel(isolationLevel)) {
// Finally update to actually requested level if possible
finalIsolationLevel = isolationLevel
} else {
logWarning(s"Requested isolation level $isolationLevel is not supported; " +
s"falling back to default isolation level $defaultIsolation")
}
} else {
logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")
}
} catch {
case NonFatal(e) => logWarning("Exception while detecting transaction support", e)
}
}
val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE
try {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
conn.setTransactionIsolation(finalIsolationLevel)
}
val stmt = conn.prepareStatement(insertStmt)
val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
val numFields = rddSchema.fields.length
try {
var rowCount = 0
stmt.setQueryTimeout(options.queryTimeout)
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(stmt, row, i)
}
i = i + 1
}
stmt.addBatch()
rowCount += 1
if (rowCount % batchSize == 0) {
stmt.executeBatch()
rowCount = 0
}
}
if (rowCount > 0) {
stmt.executeBatch()
}
} finally {
stmt.close()
}
if (supportsTransactions) {
conn.commit()
}
committed = true
Iterator.empty
} catch {
case e: SQLException =>
val cause = e.getNextException
if (cause != null && e.getCause != cause) {
// If there is no cause already, set 'next exception' as cause. If cause is null,
// it *may* be because no cause was set yet
if (e.getCause == null) {
try {
e.initCause(cause)
} catch {
// Or it may be null because the cause *was* explicitly initialized, to *null*,
// in which case this fails. There is no other way to detect it.
// addSuppressed in this case as well.
case _: IllegalStateException => e.addSuppressed(cause)
}
} else {
e.addSuppressed(cause)
}
}
throw e
} finally {
if (!committed) {
// The stage must fail. We got here through an exception path, so
// let the exception through unless rollback() or close() want to
// tell the user about another problem.
if (supportsTransactions) {
conn.rollback()
}
conn.close()
} else {
// The stage must succeed. We cannot propagate any exception close() might throw.
try {
conn.close()
} catch {
case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
}
}
}
}
/**
* Compute the schema string for this RDD.
*/
def schemaString(
df: DataFrame,
url: String,
createTableColumnTypes: Option[String] = None): String = {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
val userSpecifiedColTypesMap = createTableColumnTypes
.map(parseUserSpecifiedCreateTableColumnTypes(df, _))
.getOrElse(Map.empty[String, String])
df.schema.fields.foreach { field =>
val name = dialect.quoteIdentifier(field.name)
val typ = userSpecifiedColTypesMap
.getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition)
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}
if (sb.length < 2) "" else sb.substring(2)
}
/**
* Parses the user specified createTableColumnTypes option value string specified in the same
* format as create table ddl column types, and returns Map of field name and the data type to
* use in-place of the default data type.
*/
private def parseUserSpecifiedCreateTableColumnTypes(
df: DataFrame,
createTableColumnTypes: String): Map[String, String] = {
def typeName(f: StructField): String = {
// char/varchar gets translated to string type. Real data type specified by the user
// is available in the field metadata as HIVE_TYPE_STRING
if (f.metadata.contains(HIVE_TYPE_STRING)) {
f.metadata.getString(HIVE_TYPE_STRING)
} else {
f.dataType.catalogString
}
}
val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes)
val nameEquality = df.sparkSession.sessionState.conf.resolver
// checks duplicate columns in the user specified column types.
SchemaUtils.checkColumnNameDuplication(
userSchema.map(_.name),
"in the createTableColumnTypes option value",
nameEquality)
// checks if user specified column names exist in the DataFrame schema
userSchema.fieldNames.foreach { col =>
df.schema.find(f => nameEquality(f.name, col)).getOrElse {
throw new AnalysisException(
s"createTableColumnTypes option column $col not found in schema " +
df.schema.catalogString)
}
}
val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap
val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis
if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap)
}
/**
* Parses the user specified customSchema option value to DataFrame schema, and
* returns a schema that is replaced by the custom schema's dataType if column name is matched.
*/
def getCustomSchema(
tableSchema: StructType,
customSchema: String,
nameEquality: Resolver): StructType = {
if (null != customSchema && customSchema.nonEmpty) {
val userSchema = CatalystSqlParser.parseTableSchema(customSchema)
SchemaUtils.checkColumnNameDuplication(
userSchema.map(_.name),
"in the customSchema option value",
nameEquality)
// This is resolved by names, use the custom filed dataType to replace the default dataType.
val newSchema = tableSchema.map { col =>
userSchema.find(f => nameEquality(f.name, col.name)) match {
case Some(c) => col.copy(dataType = c.dataType)
case None => col
}
}
StructType(newSchema)
} else {
tableSchema
}
}
/**
* Saves the RDD to the database in a single transaction.
*/
def saveTable(
df: DataFrame,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
options: JdbcOptionsInWrite,
saveMode: JDBCSaveMode): Unit = {
val url = options.url
val table = options.table
val dialect = JdbcDialects.get(url)
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
// SeaTunnel: If there is a customUpdateStmt parameter, use it, otherwise it will be generated by default
val customUpdateStmt = options.customUpdateStmt
val insertStmt = customUpdateStmt match {
case Some(customStmt) if StringUtils.isNotBlank(customStmt) => customStmt
case _ => getInsertStatement(
table,
rddSchema,
tableSchema,
isCaseSensitive,
dialect,
saveMode,
options)
}
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 =>
throw new IllegalArgumentException(
s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
"via JDBC. The minimum value is 1.")
case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
case _ => df
}
repartitionedDF.rdd.foreachPartition(iterator =>
savePartition(
getConnection,
table,
iterator,
rddSchema,
insertStmt,
batchSize,
dialect,
isolationLevel,
options))
}
/**
* Creates a table with a given schema.
*/
def createTable(
conn: Connection,
df: DataFrame,
options: JdbcOptionsInWrite): Unit = {
val strSchema = schemaString(
df,
options.url,
options.createTableColumnTypes)
val table = options.table
val createTableOptions = options.createTableOptions
// Create the table if the table does not exist.
// To allow certain options to append when create a new table, which can be
// table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions"
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(sql)
} finally {
statement.close()
}
}
}