| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.spark.sql.execution.datasources |
| |
| import java.sql.{Connection, PreparedStatement, ResultSet} |
| import java.util.Locale |
| |
| import org.apache.spark.executor.InputMetrics |
| import org.apache.spark.sql.Row |
| import org.apache.spark.sql.catalyst.InternalRow |
| import org.apache.spark.sql.catalyst.encoders.RowEncoder |
| import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow |
| import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} |
| import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils |
| import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._ |
| import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType} |
| import org.apache.spark.sql.types._ |
| import org.apache.spark.unsafe.types.UTF8String |
| import org.apache.spark.util.NextIterator |
| |
| object SparkJdbcUtil { |
| |
| def toRow(schema: StructType, internalRow: InternalRow) : Row = { |
| val encoder = RowEncoder(schema).resolveAndBind() |
| encoder.fromRow(internalRow) |
| } |
| |
| // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field |
| // for `MutableRow`. The last argument `Int` means the index for the value to be set in |
| // the row and also used for the value in `ResultSet`. |
| private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit |
| |
| private def nullSafeConvert[T](input: T, f: T => Any): Any = { |
| if (input == null) { |
| null |
| } else { |
| f(input) |
| } |
| } |
| |
| /** |
| * Creates `JDBCValueGetter`s according to [[StructType]], which can set |
| * each value from `ResultSet` to each field of [[InternalRow]] correctly. |
| */ |
| private def makeGetters(schema: StructType): Array[JDBCValueGetter] = |
| schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) |
| |
| private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { |
| case BooleanType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| row.setBoolean(pos, rs.getBoolean(pos + 1)) |
| |
| case DateType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. |
| val dateVal = rs.getDate(pos + 1) |
| if (dateVal != null) { |
| row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal)) |
| } else { |
| row.update(pos, null) |
| } |
| |
| // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal |
| // object returned by ResultSet.getBigDecimal is not correctly matched to the table |
| // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. |
| // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through |
| // a BigDecimal object with scale as 0. But the dataframe schema has correct type as |
| // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then |
| // retrieve it, you will get wrong result 199.99. |
| // So it is needed to set precision and scale for Decimal based on JDBC metadata. |
| case DecimalType.Fixed(p, s) => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| val decimal = |
| nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s)) |
| row.update(pos, decimal) |
| |
| case DoubleType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| row.setDouble(pos, rs.getDouble(pos + 1)) |
| |
| case FloatType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| row.setFloat(pos, rs.getFloat(pos + 1)) |
| |
| case IntegerType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| row.setInt(pos, rs.getInt(pos + 1)) |
| |
| case LongType if metadata.contains("binarylong") => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| val bytes = rs.getBytes(pos + 1) |
| var ans = 0L |
| var j = 0 |
| while (j < bytes.length) { |
| ans = 256 * ans + (255 & bytes(j)) |
| j = j + 1 |
| } |
| row.setLong(pos, ans) |
| |
| case LongType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| row.setLong(pos, rs.getLong(pos + 1)) |
| |
| case ShortType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| row.setShort(pos, rs.getShort(pos + 1)) |
| |
| case StringType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 |
| row.update(pos, UTF8String.fromString(rs.getString(pos + 1))) |
| |
| case TimestampType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| val t = rs.getTimestamp(pos + 1) |
| if (t != null) { |
| row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t)) |
| } else { |
| row.update(pos, null) |
| } |
| |
| case BinaryType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| row.update(pos, rs.getBytes(pos + 1)) |
| |
| case ByteType => |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| row.update(pos, rs.getByte(pos + 1)) |
| |
| case ArrayType(et, _) => |
| val elementConversion = et match { |
| case TimestampType => |
| (array: Object) => |
| array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => |
| nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) |
| } |
| |
| case StringType => |
| (array: Object) => |
| // some underling types are not String such as uuid, inet, cidr, etc. |
| array.asInstanceOf[Array[java.lang.Object]] |
| .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString)) |
| |
| case DateType => |
| (array: Object) => |
| array.asInstanceOf[Array[java.sql.Date]].map { date => |
| nullSafeConvert(date, DateTimeUtils.fromJavaDate) |
| } |
| |
| case dt: DecimalType => |
| (array: Object) => |
| array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => |
| nullSafeConvert[java.math.BigDecimal]( |
| decimal, d => Decimal(d, dt.precision, dt.scale)) |
| } |
| |
| case LongType if metadata.contains("binarylong") => |
| throw new IllegalArgumentException(s"Unsupported array element " + |
| s"type ${dt.catalogString} based on binary") |
| |
| case ArrayType(_, _) => |
| throw new IllegalArgumentException("Nested arrays unsupported") |
| |
| case _ => (array: Object) => array.asInstanceOf[Array[Any]] |
| } |
| |
| (rs: ResultSet, row: InternalRow, pos: Int) => |
| val array = nullSafeConvert[java.sql.Array]( |
| input = rs.getArray(pos + 1), |
| array => new GenericArrayData(elementConversion.apply(array.getArray))) |
| row.update(pos, array) |
| |
| case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.catalogString}") |
| } |
| |
| // TODO just use JdbcUtils.resultSetToSparkInternalRows in Spark 3.0 (see SPARK-26499) |
| def resultSetToSparkInternalRows( |
| resultSet: ResultSet, |
| schema: StructType, |
| inputMetrics: InputMetrics): Iterator[InternalRow] = { |
| // JdbcUtils.resultSetToSparkInternalRows(resultSet, schema, inputMetrics) |
| new NextIterator[InternalRow] { |
| private[this] val rs = resultSet |
| private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema) |
| private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) |
| |
| override protected def close(): Unit = { |
| try { |
| rs.close() |
| } catch { |
| case e: Exception => |
| } |
| } |
| |
| override protected def getNext(): InternalRow = { |
| if (rs.next()) { |
| inputMetrics.incRecordsRead(1) |
| var i = 0 |
| while (i < getters.length) { |
| getters(i).apply(rs, mutableRow, i) |
| if (rs.wasNull) mutableRow.setNullAt(i) |
| i = i + 1 |
| } |
| mutableRow |
| } else { |
| finished = true |
| null.asInstanceOf[InternalRow] |
| } |
| } |
| } |
| } |
| |
| // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for |
| // `PreparedStatement`. The last argument `Int` means the index for the value to be set |
| // in the SQL statement and also used for the value in `Row`. |
| private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit |
| |
| // take from Spark JdbcUtils.scala, cannot be used directly because the method is private |
| def makeSetter( |
| conn: Connection, |
| dialect: JdbcDialect, |
| dataType: DataType): JDBCValueSetter = dataType match { |
| case IntegerType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setInt(pos + 1, row.getInt(pos)) |
| |
| case LongType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setLong(pos + 1, row.getLong(pos)) |
| |
| case DoubleType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setDouble(pos + 1, row.getDouble(pos)) |
| |
| case FloatType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setFloat(pos + 1, row.getFloat(pos)) |
| |
| case ShortType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setInt(pos + 1, row.getShort(pos)) |
| |
| case ByteType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setInt(pos + 1, row.getByte(pos)) |
| |
| case BooleanType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setBoolean(pos + 1, row.getBoolean(pos)) |
| |
| case StringType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setString(pos + 1, row.getString(pos)) |
| |
| case BinaryType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos)) |
| |
| case TimestampType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos)) |
| |
| case DateType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos)) |
| |
| case t: DecimalType => |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| stmt.setBigDecimal(pos + 1, row.getDecimal(pos)) |
| |
| case ArrayType(et, _) => |
| // remove type length parameters from end of type name |
| val typeName = getJdbcType(et, dialect).databaseTypeDefinition |
| .toLowerCase(Locale.ROOT).split("\\(")(0) |
| (stmt: PreparedStatement, row: Row, pos: Int) => |
| val array = conn.createArrayOf( |
| typeName, |
| row.getSeq[AnyRef](pos).toArray) |
| stmt.setArray(pos + 1, array) |
| |
| case _ => |
| (_: PreparedStatement, _: Row, pos: Int) => |
| throw new IllegalArgumentException( |
| s"Can't translate non-null value for field $pos") |
| } |
| |
| // taken from Spark JdbcUtils |
| def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { |
| dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( |
| throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}")) |
| } |
| |
| } |