| /* |
| * Copyright 2019 WeBank |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package com.webank.wedatasphere.linkis.engine.imexport |
| |
| import java.io.{BufferedOutputStream, FileOutputStream} |
| import java.math.BigDecimal |
| import java.sql.{Date, Timestamp} |
| import java.text.SimpleDateFormat |
| import java.util.Locale |
| |
| import com.webank.wedatasphere.linkis.common.utils.{Logging, Utils} |
| import com.webank.wedatasphere.linkis.engine.imexport.util.ImExportUtils |
| import org.apache.hadoop.fs.Path |
| import org.apache.hadoop.io.{IOUtils, LongWritable, Text} |
| import org.apache.hadoop.mapred.TextInputFormat |
| import org.apache.spark.SparkContext |
| import org.apache.spark.rdd.RDD |
| import org.apache.spark.sql.catalyst.util.DateTimeUtils |
| import org.apache.spark.sql.types._ |
| import org.apache.spark.sql.{DataFrame, Row, SparkSession} |
| |
| import scala.collection.mutable.ArrayBuffer |
| import scala.util.Try |
| import scala.util.control.Exception._ |
| |
| /** |
| * Created by allenlliu on 7/12/18. |
| */ |
| class CsvRelation(@transient private val source: Map[String,Any]) extends Serializable { |
| |
| import LoadData._ |
| |
| val fieldDelimiter = getMapValue[String](source, "fieldDelimiter", "\t") |
| val encoding = getMapValue[String](source, "encoding", "utf-8") |
| //val hasHeader = getNodeValue[Boolean](source, "hasHeader", false) |
| val nullValue = getMapValue[String](source, "nullValue", " ") |
| val nanValue = getMapValue[String](source, "nanValue", "null") |
| |
| val quote = getMapValue[String](source, "quote", "\"") |
| val escape = getMapValue[String](source, "escape", "\\") |
| val escapeQuotes = getMapValue[Boolean](source, "escapeQuotes", false) |
| val dateFormat: SimpleDateFormat = new SimpleDateFormat(getMapValue[String](source, "dateFormat", "yyyy-MM-dd"), Locale.US) |
| val timestampFormat: SimpleDateFormat = new SimpleDateFormat(getMapValue[String](source, "timestampFormat", "yyyy-mm-dd hh:mm:ss"), Locale.US) |
| |
| |
| def transfer(sc: SparkContext, path: String, encoding: String): RDD[String] = { |
| sc.hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], 1) |
| .map(p => new String(p._2.getBytes, 0, p._2.getLength, encoding)) |
| } |
| |
| def csvToDataFrame(spark: SparkSession, schema: StructType, hasHeader: Boolean, path: String,columns:List[Map[String, Any]]): DataFrame = { |
| val rdd = if ("utf-8".equalsIgnoreCase(encoding)) { |
| spark.sparkContext.textFile(path) |
| } else { |
| transfer(spark.sparkContext, path, encoding) |
| } |
| //header处理 |
| val tokenRdd = if (hasHeader) { |
| rdd.mapPartitionsWithIndex((index, iter) => if (index == 0) iter.drop(1) else iter).map(_.split(fieldDelimiter)) |
| } else { |
| rdd.map(_.split(fieldDelimiter)) |
| } |
| |
| val rowRdd = buildScan(tokenRdd, schema,columns) |
| spark.createDataFrame(rowRdd, schema) |
| } |
| |
| |
| def buildScan(tokenRdd: RDD[Array[String]], schema: StructType,columns:List[Map[String, Any]]): RDD[Row] = { |
| |
| tokenRdd.map(att => { |
| val row = ArrayBuffer[Any]() |
| for (i <- schema.indices) { |
| val field = if ((allCatch opt att(i)).isDefined) { |
| att(i) |
| } else { |
| "null" |
| } |
| val data = if (nanValue.equalsIgnoreCase(field)) { |
| null |
| } else if (schema(i).dataType != StringType && (field.isEmpty || nullValue.equals(field))) { |
| null |
| } else { |
| val dateFormat = columns(i).getOrElse("dateFormat","yyyy-MM-dd").toString |
| val format: SimpleDateFormat = new SimpleDateFormat(dateFormat,Locale.US) |
| castTo(field, schema(i).dataType,format,null) |
| } |
| row += data |
| } |
| Row.fromSeq(row) |
| }) |
| } |
| |
| def castTo(field: String, dataType: DataType,dateFormatP: SimpleDateFormat,timestampFormatP:SimpleDateFormat): Any = { |
| val value = if (escapeQuotes) field.substring(1, field.length - 1) else field |
| dataType match { |
| case _: ByteType => value.toByte |
| case _: ShortType => value.toShort |
| case _: IntegerType => value.toInt |
| case _: LongType => value.toLong |
| case _: FloatType => value.toFloat |
| case _: DoubleType => value.toDouble |
| case _: BooleanType => value.toBoolean |
| case dt: DecimalType => val dataum = new BigDecimal(value.replaceAll(",", "")) |
| Decimal(dataum, dt.precision, dt.scale) |
| case _: TimestampType => new Timestamp(Try(timestampFormat.parse(value).getTime).getOrElse(DateTimeUtils.stringToTime(value).getTime * 1000L)) |
| case _: DateType => new Date(Try(dateFormatP.parse(value).getTime).getOrElse(DateTimeUtils.stringToTime(value).getTime)) |
| case _: StringType => value.replaceAll("\n|\t", " ") |
| case t => throw new RuntimeException(s"Unsupported cast from $value to $t") |
| } |
| } |
| |
| def saveDFToCsv(spark: SparkSession, df: DataFrame, path: String, hasHeader: Boolean = true ,isOverwrite: Boolean = false): Boolean = { |
| val filesystemPath = new Path(path) |
| spark.sparkContext.hadoopConfiguration.setBoolean("fs.hdfs.impl.disable.cache", true) |
| val fs = filesystemPath.getFileSystem(spark.sparkContext.hadoopConfiguration) |
| fs.setVerifyChecksum(false) |
| fs.setWriteChecksum(false) |
| val out = if(fs.exists(filesystemPath)){ |
| if(isOverwrite){ |
| new BufferedOutputStream(fs.create(filesystemPath, isOverwrite)) |
| } else { |
| val bufferedOutputStream = if(path.startsWith("file:")){ |
| new BufferedOutputStream(new FileOutputStream(path.substring("file://".length), true)) |
| } else { |
| new BufferedOutputStream(fs.append(filesystemPath)) |
| } |
| bufferedOutputStream.write("\n".getBytes()) |
| bufferedOutputStream |
| } |
| } else { |
| new BufferedOutputStream(fs.create(filesystemPath, isOverwrite)) |
| } |
| val iterator = ImExportUtils.tryAndThrowError(df.toLocalIterator, _ => spark.sparkContext.clearJobGroup()) |
| try{ |
| val schema = df.schema |
| val header = new StringBuilder |
| var index = 0 |
| for (col <- schema) { |
| header ++= col.name ++ fieldDelimiter |
| } |
| if (hasHeader) { |
| out.write(header.substring(0, header.lastIndexOf(fieldDelimiter)).getBytes) |
| } else { |
| if (iterator.hasNext) { |
| out.write(getLine(schema, iterator.next()).getBytes) |
| index += 1 |
| } |
| } |
| |
| while (index < Int.MaxValue && iterator.hasNext) { |
| val msg = "\n" + getLine(schema, iterator.next()) |
| out.write(msg.getBytes()) |
| index += 1 |
| } |
| warn(s"Fetched ${df.columns.length} col(s) : ${index} row(s).") |
| } catch { |
| case e:Throwable => |
| throw e |
| } finally { |
| spark.sparkContext.clearJobGroup() |
| IOUtils.closeStream(out) |
| fs.close() |
| } |
| true |
| } |
| |
| def getLine(schema: StructType, row: Row): String = { |
| val msg = new StringBuilder |
| schema.indices.foreach{ i => |
| val data = row(i) match { |
| case value: String => value.replaceAll("\n|\t", " ") |
| case value: Any => value.toString |
| case _ => "NULL" |
| } |
| msg.append(data) |
| msg.append(fieldDelimiter) |
| } |
| msg.substring(0, msg.lastIndexOf(fieldDelimiter)) |
| } |
| } |
| |
| object CsvRelation extends Logging { |
| |
| def saveDFToCsv(spark: SparkSession, df: DataFrame, path: String, |
| hasHeader: Boolean = true ,isOverwrite: Boolean = false, |
| option: Map[String,Any] = Map()): Boolean ={ |
| new CsvRelation(option).saveDFToCsv(spark, df, path, hasHeader, isOverwrite) |
| } |
| |
| def csvToDF(spark: SparkSession, schema: StructType, hasHeader: Boolean, path: String, source: Map[String,Any] = Map(),columns:List[Map[String, Any]]): DataFrame ={ |
| new CsvRelation(source).csvToDataFrame(spark, schema, hasHeader, path,columns) |
| } |
| } |