blob: 443460a66a7d5a7163f55f87a33a53c2d66393d9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.seatunnel.spark.fake.source
import org.apache.seatunnel.common.config.CheckResult
import org.apache.seatunnel.shade.com.typesafe.config.{Config, ConfigFactory}
import org.apache.seatunnel.spark.SparkEnvironment
import org.apache.seatunnel.spark.stream.SparkStreamingSource
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row, RowFactory, SparkSession}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
import java.security.SecureRandom
import scala.collection.JavaConversions._
class FakeStream extends SparkStreamingSource[String] {
override def prepare(env: SparkEnvironment): Unit = {
val defaultConfig = ConfigFactory.parseMap(
Map(
"rate" -> 1 // rate per second, X records/sec
))
config = config.withFallback(defaultConfig)
}
override def getData(env: SparkEnvironment): DStream[String] = {
val receiverInputDStream = env.getStreamingContext.receiverStream(new FakeReceiver(config))
receiverInputDStream
}
override def rdd2dataset(sparkSession: SparkSession, rdd: RDD[String]): Dataset[Row] = {
val rowsRDD = rdd.map(element => {
RowFactory.create(element)
})
val schema = StructType(Array(StructField("raw_message", DataTypes.StringType)))
sparkSession.createDataFrame(rowsRDD, schema)
}
override def checkConfig(): CheckResult = {
if (config.hasPath("content") && config.getStringList("content").nonEmpty) {
CheckResult.success()
} else {
CheckResult.error("please make sure [content] is of type string array")
}
}
override def getPluginName: String = "FakeStream"
}
private class FakeReceiver(config: Config)
extends Receiver[String](StorageLevel.MEMORY_AND_DISK_2) {
val secRandom = new SecureRandom()
def generateData(): String = {
val contentList = config.getStringList("content")
val n = secRandom.nextInt(contentList.length)
contentList.get(n)
}
def onStart(): Unit = {
// Start the thread that receives data over a connection
new Thread("FakeReceiver Source") {
override def run(): Unit = {
receive()
}
}.start()
}
def onStop(): Unit = {
// There is nothing much to do as the thread calling receive()
// is designed to stop by itself isStopped() returns false
}
/** Create a socket connection and receive data until receiver is stopped */
private def receive(): Unit = {
while (!isStopped()) {
store(generateData())
Thread.sleep((1000.toDouble / config.getInt("rate")).toInt)
}
}
}