// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
import java.math.BigDecimal
import java.math.BigInteger
import java.nio.charset.StandardCharsets
import org.apache.kudu.Schema
import org.apache.kudu.Type
import org.apache.kudu.client.PartialRow
import org.apache.kudu.client.SessionConfiguration
import org.apache.kudu.spark.kudu.KuduContext
import org.apache.kudu.spark.kudu.KuduWriteOptions
import org.apache.kudu.spark.kudu.RowConverter
import org.apache.kudu.spark.kudu.SparkUtil
import org.apache.kudu.util.DataGenerator
import org.apache.kudu.util.DateUtil
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.LongAccumulator
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.yetus.audience.InterfaceAudience
import org.apache.yetus.audience.InterfaceStability
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import scopt.OptionParser
import scala.collection.JavaConverters._
case class GeneratorMetrics(rowsWritten: LongAccumulator, collisions: LongAccumulator)
object GeneratorMetrics {
def apply(sc: SparkContext): GeneratorMetrics = {
GeneratorMetrics(sc.longAccumulator("rows_written"), sc.longAccumulator("row_collisions"))
object DistributedDataGenerator {
val log: Logger = LoggerFactory.getLogger(getClass)
def run(options: DistributedDataGeneratorOptions, ss: SparkSession): GeneratorMetrics = {"Running a DistributedDataGenerator with options: $options")
val sc = ss.sparkContext
val context = new KuduContext(options.masterAddresses, sc)
val metrics = GeneratorMetrics(sc)
// Generate the Inserts.
var rdd = sc
.parallelize(0 until options.numTasks, numSlices = options.numTasks)
{ taskNumIter =>
// We know there is only 1 task per partition because numSlices = options.numTasks above.
val taskNum =
val generator = new DataGenerator.DataGeneratorBuilder()
// Add taskNum to the seed otherwise each task will try to generate the same rows.
.random(new java.util.Random(options.seed + taskNum))
val table = context.syncClient.openTable(options.tableName)
val schema = table.getSchema
val numRows = options.numRows / options.numTasks
val startRow: Long = numRows * taskNum
new GeneratedRowIterator(generator, options.generatorType, schema, startRow, numRows)
if (options.repartition) {
val table = context.syncClient.openTable(options.tableName)
val sparkSchema = SparkUtil.sparkSchema(table.getSchema)
rdd = context
.repartitionRows(rdd, options.tableName, sparkSchema, KuduWriteOptions(ignoreNull = true))
// Write the rows to Kudu.
// TODO: Use context.writeRows while still tracking inserts/collisions.
rdd.foreachPartition { rows =>
val kuduClient = context.syncClient
val table = kuduClient.openTable(options.tableName)
val kuduSchema = table.getSchema
val sparkSchema = SparkUtil.sparkSchema(kuduSchema)
val converter = new RowConverter(kuduSchema, sparkSchema, ignoreNull = true)
val session = kuduClient.newSession()
var rowsWritten = 0
rows.foreach { row =>
val insert = table.newInsert()
val partialRow = converter.toPartialRow(row)
rowsWritten += 1
// Synchronously flush after the last record is written.
// Track the collisions.
var collisions = 0
for (error <- session.getPendingErrors.getRowErrors) {
if (error.getErrorStatus.isAlreadyPresent) {
// Because we can't check for collisions every time, but instead
// only when the rows are flushed, we subtract any rows that may
// have failed from the counter.
rowsWritten -= 1
collisions += 1
} else {
throw new RuntimeException("Kudu write error: " + error.getErrorStatus.toString)
* Entry point for testing. SparkContext is a singleton,
* so tests must create and manage their own.
def testMain(args: Array[String], ss: SparkSession): GeneratorMetrics = {
DistributedDataGeneratorOptions.parse(args) match {
case None => throw new IllegalArgumentException("Could not parse arguments")
case Some(config) => run(config, ss)
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("DistributedDataGenerator")
val ss = SparkSession.builder().config(conf).getOrCreate()
val metrics = testMain(args, ss)"Rows written: ${metrics.rowsWritten.value}")"Collisions: ${metrics.collisions.value}")
private class GeneratedRowIterator(
generator: DataGenerator,
generatorType: String,
schema: Schema,
startRow: Long,
numRows: Long)
extends Iterator[Row] {
val sparkSchema = SparkUtil.sparkSchema(schema)
// ignoreNull values so unset/defaulted rows can be passed through.
val converter = new RowConverter(schema, sparkSchema, ignoreNull = true)
var currentRow: Long = startRow
var rowsGenerated: Long = 0
override def hasNext: Boolean = rowsGenerated < numRows
override def next(): Row = {
if (rowsGenerated >= numRows) {
throw new IllegalStateException("Already generated all of the rows.")
val partialRow = schema.newPartialRow()
if (generatorType == SequentialGenerator) {
setRow(partialRow, currentRow)
} else if (generatorType == RandomGenerator) {
} else {
throw new IllegalArgumentException(s"Generator type of $generatorType is unsupported")
currentRow += 1
rowsGenerated += 1
* Sets all the columns in the passed row to the passed value.
* TODO(ghenke): Consider failing when value doesn't fit into the type.
private def setRow(row: PartialRow, value: Long): Unit = {
val schema = row.getSchema
val columns = schema.getColumns.asScala
columns.indices.foreach { i =>
val col = columns(i)
col.getType match {
case Type.BOOL =>
row.addBoolean(i, value % 2 == 1)
case Type.INT8 =>
row.addByte(i, value.toByte)
case Type.INT16 =>
row.addShort(i, value.toShort)
case Type.INT32 =>
row.addInt(i, value.toInt)
case Type.INT64 =>
row.addLong(i, value)
row.addLong(i, value)
case Type.DATE =>
row.addDate(i, DateUtil.epochDaysToSqlDate(value.toInt))
case Type.FLOAT =>
row.addFloat(i, value.toFloat)
case Type.DOUBLE =>
row.addDouble(i, value.toDouble)
case Type.DECIMAL =>
new BigDecimal(BigInteger.valueOf(value), col.getTypeAttributes.getScale))
case Type.VARCHAR =>
row.addVarchar(i, String.valueOf(value))
case Type.STRING =>
row.addString(i, String.valueOf(value))
case Type.BINARY =>
val bytes: Array[Byte] = String.valueOf(value).getBytes(StandardCharsets.UTF_8)
row.addBinary(i, bytes)
case _ =>
throw new UnsupportedOperationException("Unsupported type " + col.getType)
case class DistributedDataGeneratorOptions(
tableName: String,
masterAddresses: String,
generatorType: String = DistributedDataGeneratorOptions.DefaultGeneratorType,
numRows: Long = DistributedDataGeneratorOptions.DefaultNumRows,
numTasks: Int = DistributedDataGeneratorOptions.DefaultNumTasks,
stringLength: Int = DistributedDataGeneratorOptions.DefaultStringLength,
binaryLength: Int = DistributedDataGeneratorOptions.DefaultStringLength,
seed: Long = System.currentTimeMillis(),
repartition: Boolean = DistributedDataGeneratorOptions.DefaultRepartition)
object DistributedDataGeneratorOptions {
val DefaultNumRows: Long = 10000
val DefaultNumTasks: Int = 1
val DefaultStringLength: Int = 128
val DefaultBinaryLength: Int = 128
val RandomGenerator: String = "random"
val SequentialGenerator: String = "sequential"
val DefaultGeneratorType: String = SequentialGenerator
val DefaultRepartition: Boolean = false
private val parser: OptionParser[DistributedDataGeneratorOptions] =
new OptionParser[DistributedDataGeneratorOptions]("LoadRandomData") {
.action((v, o) => o.copy(tableName = v))
.text("The table to load with random data")
.action((v, o) => o.copy(masterAddresses = v))
.text("Comma-separated addresses of Kudu masters")
.action((v, o) => o.copy(generatorType = v))
.text(s"The type of data generator. Must be one of 'random' or 'sequential'. " +
s"Default: ${DefaultGeneratorType}")
.action((v, o) => o.copy(numRows = v))
.text(s"The total number of unique rows to generate. Default: ${DefaultNumRows}")
.action((v, o) => o.copy(numTasks = v))
.text(s"The total number of Spark tasks to use when generating data. " +
s"Default: ${DefaultNumTasks}")
.action((v, o) => o.copy(stringLength = v))
.text(s"The length of generated string fields. Default: ${DefaultStringLength}")
.action((v, o) => o.copy(binaryLength = v))
.text(s"The length of generated binary fields. Default: ${DefaultBinaryLength}")
.action((v, o) => o.copy(seed = v))
.text(s"The seed to use in the random data generator. " +
s"Default: `System.currentTimeMillis()`")
.action((v, o) => o.copy(repartition = v))
.text(s"Repartition the data to ensure each spark task talks to a minimal " +
s"set of tablet servers.")
def parse(args: Seq[String]): Option[DistributedDataGeneratorOptions] = {
parser.parse(args, DistributedDataGeneratorOptions("", ""))