blob: 31046e86909fe54ec0ed930fd2fc7b423c558628 [file] [log] [blame]
/*
* Copyright 2019 WeBank
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.webank.wedatasphere.linkis.engine.imexport
import java.io.{BufferedOutputStream, FileOutputStream}
import java.math.BigDecimal
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale
import com.webank.wedatasphere.linkis.common.utils.{Logging, Utils}
import com.webank.wedatasphere.linkis.engine.imexport.util.ImExportUtils
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{IOUtils, LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
import scala.util.control.Exception._
/**
* Created by allenlliu on 7/12/18.
*/
class CsvRelation(@transient private val source: Map[String,Any]) extends Serializable {
import LoadData._
val fieldDelimiter = getMapValue[String](source, "fieldDelimiter", "\t")
val encoding = getMapValue[String](source, "encoding", "utf-8")
//val hasHeader = getNodeValue[Boolean](source, "hasHeader", false)
val nullValue = getMapValue[String](source, "nullValue", " ")
val nanValue = getMapValue[String](source, "nanValue", "null")
val quote = getMapValue[String](source, "quote", "\"")
val escape = getMapValue[String](source, "escape", "\\")
val escapeQuotes = getMapValue[Boolean](source, "escapeQuotes", false)
val dateFormat: SimpleDateFormat = new SimpleDateFormat(getMapValue[String](source, "dateFormat", "yyyy-MM-dd"), Locale.US)
val timestampFormat: SimpleDateFormat = new SimpleDateFormat(getMapValue[String](source, "timestampFormat", "yyyy-mm-dd hh:mm:ss"), Locale.US)
def transfer(sc: SparkContext, path: String, encoding: String): RDD[String] = {
sc.hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], 1)
.map(p => new String(p._2.getBytes, 0, p._2.getLength, encoding))
}
def csvToDataFrame(spark: SparkSession, schema: StructType, hasHeader: Boolean, path: String,columns:List[Map[String, Any]]): DataFrame = {
val rdd = if ("utf-8".equalsIgnoreCase(encoding)) {
spark.sparkContext.textFile(path)
} else {
transfer(spark.sparkContext, path, encoding)
}
//header处理
val tokenRdd = if (hasHeader) {
rdd.mapPartitionsWithIndex((index, iter) => if (index == 0) iter.drop(1) else iter).map(_.split(fieldDelimiter))
} else {
rdd.map(_.split(fieldDelimiter))
}
val rowRdd = buildScan(tokenRdd, schema,columns)
spark.createDataFrame(rowRdd, schema)
}
def buildScan(tokenRdd: RDD[Array[String]], schema: StructType,columns:List[Map[String, Any]]): RDD[Row] = {
tokenRdd.map(att => {
val row = ArrayBuffer[Any]()
for (i <- schema.indices) {
val field = if ((allCatch opt att(i)).isDefined) {
att(i)
} else {
"null"
}
val data = if (nanValue.equalsIgnoreCase(field)) {
null
} else if (schema(i).dataType != StringType && (field.isEmpty || nullValue.equals(field))) {
null
} else {
val dateFormat = columns(i).getOrElse("dateFormat","yyyy-MM-dd").toString
val format: SimpleDateFormat = new SimpleDateFormat(dateFormat,Locale.US)
castTo(field, schema(i).dataType,format,null)
}
row += data
}
Row.fromSeq(row)
})
}
def castTo(field: String, dataType: DataType,dateFormatP: SimpleDateFormat,timestampFormatP:SimpleDateFormat): Any = {
val value = if (escapeQuotes) field.substring(1, field.length - 1) else field
dataType match {
case _: ByteType => value.toByte
case _: ShortType => value.toShort
case _: IntegerType => value.toInt
case _: LongType => value.toLong
case _: FloatType => value.toFloat
case _: DoubleType => value.toDouble
case _: BooleanType => value.toBoolean
case dt: DecimalType => val dataum = new BigDecimal(value.replaceAll(",", ""))
Decimal(dataum, dt.precision, dt.scale)
case _: TimestampType => new Timestamp(Try(timestampFormat.parse(value).getTime).getOrElse(DateTimeUtils.stringToTime(value).getTime * 1000L))
case _: DateType => new Date(Try(dateFormatP.parse(value).getTime).getOrElse(DateTimeUtils.stringToTime(value).getTime))
case _: StringType => value.replaceAll("\n|\t", " ")
case t => throw new RuntimeException(s"Unsupported cast from $value to $t")
}
}
def saveDFToCsv(spark: SparkSession, df: DataFrame, path: String, hasHeader: Boolean = true ,isOverwrite: Boolean = false): Boolean = {
val filesystemPath = new Path(path)
spark.sparkContext.hadoopConfiguration.setBoolean("fs.hdfs.impl.disable.cache", true)
val fs = filesystemPath.getFileSystem(spark.sparkContext.hadoopConfiguration)
fs.setVerifyChecksum(false)
fs.setWriteChecksum(false)
val out = if(fs.exists(filesystemPath)){
if(isOverwrite){
new BufferedOutputStream(fs.create(filesystemPath, isOverwrite))
} else {
val bufferedOutputStream = if(path.startsWith("file:")){
new BufferedOutputStream(new FileOutputStream(path.substring("file://".length), true))
} else {
new BufferedOutputStream(fs.append(filesystemPath))
}
bufferedOutputStream.write("\n".getBytes())
bufferedOutputStream
}
} else {
new BufferedOutputStream(fs.create(filesystemPath, isOverwrite))
}
val iterator = ImExportUtils.tryAndThrowError(df.toLocalIterator, _ => spark.sparkContext.clearJobGroup())
try{
val schema = df.schema
val header = new StringBuilder
var index = 0
for (col <- schema) {
header ++= col.name ++ fieldDelimiter
}
if (hasHeader) {
out.write(header.substring(0, header.lastIndexOf(fieldDelimiter)).getBytes)
} else {
if (iterator.hasNext) {
out.write(getLine(schema, iterator.next()).getBytes)
index += 1
}
}
while (index < Int.MaxValue && iterator.hasNext) {
val msg = "\n" + getLine(schema, iterator.next())
out.write(msg.getBytes())
index += 1
}
warn(s"Fetched ${df.columns.length} col(s) : ${index} row(s).")
} catch {
case e:Throwable =>
throw e
} finally {
spark.sparkContext.clearJobGroup()
IOUtils.closeStream(out)
fs.close()
}
true
}
def getLine(schema: StructType, row: Row): String = {
val msg = new StringBuilder
schema.indices.foreach{ i =>
val data = row(i) match {
case value: String => value.replaceAll("\n|\t", " ")
case value: Any => value.toString
case _ => "NULL"
}
msg.append(data)
msg.append(fieldDelimiter)
}
msg.substring(0, msg.lastIndexOf(fieldDelimiter))
}
}
object CsvRelation extends Logging {
def saveDFToCsv(spark: SparkSession, df: DataFrame, path: String,
hasHeader: Boolean = true ,isOverwrite: Boolean = false,
option: Map[String,Any] = Map()): Boolean ={
new CsvRelation(option).saveDFToCsv(spark, df, path, hasHeader, isOverwrite)
}
def csvToDF(spark: SparkSession, schema: StructType, hasHeader: Boolean, path: String, source: Map[String,Any] = Map(),columns:List[Map[String, Any]]): DataFrame ={
new CsvRelation(source).csvToDataFrame(spark, schema, hasHeader, path,columns)
}
}