blob: eac483ae8df4844603114d94d9e91d8a2ab29019 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources
import java.sql.{Connection, PreparedStatement, ResultSet}
import java.util.Locale
import org.apache.spark.executor.InputMetrics
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.NextIterator
object SparkJdbcUtil {
def toRow(schema: StructType, internalRow: InternalRow) : Row = {
val encoder = RowEncoder(schema).resolveAndBind()
encoder.fromRow(internalRow)
}
// A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
// for `MutableRow`. The last argument `Int` means the index for the value to be set in
// the row and also used for the value in `ResultSet`.
private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit
private def nullSafeConvert[T](input: T, f: T => Any): Any = {
if (input == null) {
null
} else {
f(input)
}
}
/**
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[InternalRow]] correctly.
*/
private def makeGetters(schema: StructType): Array[JDBCValueGetter] =
schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
case DateType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos + 1)
if (dateVal != null) {
row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
} else {
row.update(pos, null)
}
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalType.Fixed(p, s) =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
val decimal =
nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
row.update(pos, decimal)
case DoubleType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setDouble(pos, rs.getDouble(pos + 1))
case FloatType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setFloat(pos, rs.getFloat(pos + 1))
case IntegerType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setInt(pos, rs.getInt(pos + 1))
case LongType if metadata.contains("binarylong") =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
val bytes = rs.getBytes(pos + 1)
var ans = 0L
var j = 0
while (j < bytes.length) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1
}
row.setLong(pos, ans)
case LongType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setLong(pos, rs.getLong(pos + 1))
case ShortType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.setShort(pos, rs.getShort(pos + 1))
case StringType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
case TimestampType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
} else {
row.update(pos, null)
}
case BinaryType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.update(pos, rs.getBytes(pos + 1))
case ByteType =>
(rs: ResultSet, row: InternalRow, pos: Int) =>
row.update(pos, rs.getByte(pos + 1))
case ArrayType(et, _) =>
val elementConversion = et match {
case TimestampType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringType =>
(array: Object) =>
// some underling types are not String such as uuid, inet, cidr, etc.
array.asInstanceOf[Array[java.lang.Object]]
.map(obj => if (obj == null) null else UTF8String.fromString(obj.toString))
case DateType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case dt: DecimalType =>
(array: Object) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](
decimal, d => Decimal(d, dt.precision, dt.scale))
}
case LongType if metadata.contains("binarylong") =>
throw new IllegalArgumentException(s"Unsupported array element " +
s"type ${dt.catalogString} based on binary")
case ArrayType(_, _) =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => (array: Object) => array.asInstanceOf[Array[Any]]
}
(rs: ResultSet, row: InternalRow, pos: Int) =>
val array = nullSafeConvert[java.sql.Array](
input = rs.getArray(pos + 1),
array => new GenericArrayData(elementConversion.apply(array.getArray)))
row.update(pos, array)
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.catalogString}")
}
// TODO just use JdbcUtils.resultSetToSparkInternalRows in Spark 3.0 (see SPARK-26499)
def resultSetToSparkInternalRows(
resultSet: ResultSet,
schema: StructType,
inputMetrics: InputMetrics): Iterator[InternalRow] = {
// JdbcUtils.resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
new NextIterator[InternalRow] {
private[this] val rs = resultSet
private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
override protected def close(): Unit = {
try {
rs.close()
} catch {
case e: Exception =>
}
}
override protected def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
while (i < getters.length) {
getters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}
mutableRow
} else {
finished = true
null.asInstanceOf[InternalRow]
}
}
}
}
// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. The last argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.
private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
// take from Spark JdbcUtils.scala, cannot be used directly because the method is private
def makeSetter(
conn: Connection,
dialect: JdbcDialect,
dataType: DataType): JDBCValueSetter = dataType match {
case IntegerType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getInt(pos))
case LongType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setLong(pos + 1, row.getLong(pos))
case DoubleType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDouble(pos + 1, row.getDouble(pos))
case FloatType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setFloat(pos + 1, row.getFloat(pos))
case ShortType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getShort(pos))
case ByteType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setInt(pos + 1, row.getByte(pos))
case BooleanType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBoolean(pos + 1, row.getBoolean(pos))
case StringType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setString(pos + 1, row.getString(pos))
case BinaryType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
case TimestampType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
case DateType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
case t: DecimalType =>
(stmt: PreparedStatement, row: Row, pos: Int) =>
stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
case ArrayType(et, _) =>
// remove type length parameters from end of type name
val typeName = getJdbcType(et, dialect).databaseTypeDefinition
.toLowerCase(Locale.ROOT).split("\\(")(0)
(stmt: PreparedStatement, row: Row, pos: Int) =>
val array = conn.createArrayOf(
typeName,
row.getSeq[AnyRef](pos).toArray)
stmt.setArray(pos + 1, array)
case _ =>
(_: PreparedStatement, _: Row, pos: Int) =>
throw new IllegalArgumentException(
s"Can't translate non-null value for field $pos")
}
// taken from Spark JdbcUtils
def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}"))
}
}