blob: 294a6a7160e3bbf930a25d8040fa9bfd5231db83 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.kudu.spark.tools
import java.math.BigDecimal
import java.math.BigInteger
import java.nio.charset.StandardCharsets
import org.apache.kudu.Schema
import org.apache.kudu.Type
import org.apache.kudu.client.PartialRow
import org.apache.kudu.client.SessionConfiguration
import org.apache.kudu.spark.kudu.KuduContext
import org.apache.kudu.spark.kudu.KuduWriteOptions
import org.apache.kudu.spark.kudu.RowConverter
import org.apache.kudu.spark.kudu.SparkUtil
import org.apache.kudu.spark.tools.DistributedDataGeneratorOptions._
import org.apache.kudu.util.DataGenerator
import org.apache.kudu.util.DateUtil
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.util.LongAccumulator
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.yetus.audience.InterfaceAudience
import org.apache.yetus.audience.InterfaceStability
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import scopt.OptionParser
import scala.collection.JavaConverters._
case class GeneratorMetrics(rowsWritten: LongAccumulator, collisions: LongAccumulator)
object GeneratorMetrics {
def apply(sc: SparkContext): GeneratorMetrics = {
GeneratorMetrics(sc.longAccumulator("rows_written"), sc.longAccumulator("row_collisions"))
}
}
object DistributedDataGenerator {
val log: Logger = LoggerFactory.getLogger(getClass)
def run(options: DistributedDataGeneratorOptions, ss: SparkSession): GeneratorMetrics = {
log.info(s"Running a DistributedDataGenerator with options: $options")
val sc = ss.sparkContext
val context = new KuduContext(options.masterAddresses, sc)
val metrics = GeneratorMetrics(sc)
// Generate the Inserts.
var rdd = sc
.parallelize(0 until options.numTasks, numSlices = options.numTasks)
.mapPartitions(
{ taskNumIter =>
// We know there is only 1 task per partition because numSlices = options.numTasks above.
val taskNum = taskNumIter.next()
val generator = new DataGenerator.DataGeneratorBuilder()
// Add taskNum to the seed otherwise each task will try to generate the same rows.
.random(new java.util.Random(options.seed + taskNum))
.stringLength(options.stringLength)
.binaryLength(options.binaryLength)
.build()
val table = context.syncClient.openTable(options.tableName)
val schema = table.getSchema
val numRows = options.numRows / options.numTasks
val startRow: Long = numRows * taskNum
new GeneratedRowIterator(generator, options.generatorType, schema, startRow, numRows)
},
true
)
if (options.repartition) {
val table = context.syncClient.openTable(options.tableName)
val sparkSchema = SparkUtil.sparkSchema(table.getSchema)
rdd = context
.repartitionRows(rdd, options.tableName, sparkSchema, KuduWriteOptions(ignoreNull = true))
}
// Write the rows to Kudu.
// TODO: Use context.writeRows while still tracking inserts/collisions.
rdd.foreachPartition { rows =>
val kuduClient = context.syncClient
val table = kuduClient.openTable(options.tableName)
val kuduSchema = table.getSchema
val sparkSchema = SparkUtil.sparkSchema(kuduSchema)
val converter = new RowConverter(kuduSchema, sparkSchema, ignoreNull = true)
val session = kuduClient.newSession()
session.setFlushMode(SessionConfiguration.FlushMode.AUTO_FLUSH_BACKGROUND)
var rowsWritten = 0
rows.foreach { row =>
val insert = table.newInsert()
val partialRow = converter.toPartialRow(row)
insert.setRow(partialRow)
session.apply(insert)
rowsWritten += 1
}
// Synchronously flush after the last record is written.
session.flush()
// Track the collisions.
var collisions = 0
for (error <- session.getPendingErrors.getRowErrors) {
if (error.getErrorStatus.isAlreadyPresent) {
// Because we can't check for collisions every time, but instead
// only when the rows are flushed, we subtract any rows that may
// have failed from the counter.
rowsWritten -= 1
collisions += 1
} else {
throw new RuntimeException("Kudu write error: " + error.getErrorStatus.toString)
}
}
metrics.rowsWritten.add(rowsWritten)
metrics.collisions.add(collisions)
session.close()
}
metrics
}
/**
* Entry point for testing. SparkContext is a singleton,
* so tests must create and manage their own.
*/
@InterfaceAudience.LimitedPrivate(Array("Test"))
def testMain(args: Array[String], ss: SparkSession): GeneratorMetrics = {
DistributedDataGeneratorOptions.parse(args) match {
case None => throw new IllegalArgumentException("Could not parse arguments")
case Some(config) => run(config, ss)
}
}
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("DistributedDataGenerator")
val ss = SparkSession.builder().config(conf).getOrCreate()
val metrics = testMain(args, ss)
log.info(s"Rows written: ${metrics.rowsWritten.value}")
log.info(s"Collisions: ${metrics.collisions.value}")
}
}
private class GeneratedRowIterator(
generator: DataGenerator,
generatorType: String,
schema: Schema,
startRow: Long,
numRows: Long)
extends Iterator[Row] {
val sparkSchema = SparkUtil.sparkSchema(schema)
// ignoreNull values so unset/defaulted rows can be passed through.
val converter = new RowConverter(schema, sparkSchema, ignoreNull = true)
var currentRow: Long = startRow
var rowsGenerated: Long = 0
override def hasNext: Boolean = rowsGenerated < numRows
override def next(): Row = {
if (rowsGenerated >= numRows) {
throw new IllegalStateException("Already generated all of the rows.")
}
val partialRow = schema.newPartialRow()
if (generatorType == SequentialGenerator) {
setRow(partialRow, currentRow)
} else if (generatorType == RandomGenerator) {
generator.randomizeRow(partialRow)
} else {
throw new IllegalArgumentException(s"Generator type of $generatorType is unsupported")
}
currentRow += 1
rowsGenerated += 1
converter.toRow(partialRow)
}
/**
* Sets all the columns in the passed row to the passed value.
* TODO(ghenke): Consider failing when value doesn't fit into the type.
*/
private def setRow(row: PartialRow, value: Long): Unit = {
val schema = row.getSchema
val columns = schema.getColumns.asScala
columns.indices.foreach { i =>
val col = columns(i)
col.getType match {
case Type.BOOL =>
row.addBoolean(i, value % 2 == 1)
case Type.INT8 =>
row.addByte(i, value.toByte)
case Type.INT16 =>
row.addShort(i, value.toShort)
case Type.INT32 =>
row.addInt(i, value.toInt)
case Type.INT64 =>
row.addLong(i, value)
case Type.UNIXTIME_MICROS =>
row.addLong(i, value)
case Type.DATE =>
row.addDate(i, DateUtil.epochDaysToSqlDate(value.toInt))
case Type.FLOAT =>
row.addFloat(i, value.toFloat)
case Type.DOUBLE =>
row.addDouble(i, value.toDouble)
case Type.DECIMAL =>
row.addDecimal(
i,
new BigDecimal(BigInteger.valueOf(value), col.getTypeAttributes.getScale))
case Type.VARCHAR =>
row.addVarchar(i, String.valueOf(value))
case Type.STRING =>
row.addString(i, String.valueOf(value))
case Type.BINARY =>
val bytes: Array[Byte] = String.valueOf(value).getBytes(StandardCharsets.UTF_8)
row.addBinary(i, bytes)
case _ =>
throw new UnsupportedOperationException("Unsupported type " + col.getType)
}
}
}
}
@InterfaceAudience.Private
@InterfaceStability.Unstable
case class DistributedDataGeneratorOptions(
tableName: String,
masterAddresses: String,
generatorType: String = DistributedDataGeneratorOptions.DefaultGeneratorType,
numRows: Long = DistributedDataGeneratorOptions.DefaultNumRows,
numTasks: Int = DistributedDataGeneratorOptions.DefaultNumTasks,
stringLength: Int = DistributedDataGeneratorOptions.DefaultStringLength,
binaryLength: Int = DistributedDataGeneratorOptions.DefaultStringLength,
seed: Long = System.currentTimeMillis(),
repartition: Boolean = DistributedDataGeneratorOptions.DefaultRepartition)
@InterfaceAudience.Private
@InterfaceStability.Unstable
object DistributedDataGeneratorOptions {
val DefaultNumRows: Long = 10000
val DefaultNumTasks: Int = 1
val DefaultStringLength: Int = 128
val DefaultBinaryLength: Int = 128
val RandomGenerator: String = "random"
val SequentialGenerator: String = "sequential"
val DefaultGeneratorType: String = SequentialGenerator
val DefaultRepartition: Boolean = false
private val parser: OptionParser[DistributedDataGeneratorOptions] =
new OptionParser[DistributedDataGeneratorOptions]("LoadRandomData") {
arg[String]("table-name")
.action((v, o) => o.copy(tableName = v))
.text("The table to load with random data")
arg[String]("master-addresses")
.action((v, o) => o.copy(masterAddresses = v))
.text("Comma-separated addresses of Kudu masters")
opt[String]("type")
.action((v, o) => o.copy(generatorType = v))
.text(s"The type of data generator. Must be one of 'random' or 'sequential'. " +
s"Default: ${DefaultGeneratorType}")
.optional()
opt[Long]("num-rows")
.action((v, o) => o.copy(numRows = v))
.text(s"The total number of unique rows to generate. Default: ${DefaultNumRows}")
.optional()
opt[Int]("num-tasks")
.action((v, o) => o.copy(numTasks = v))
.text(s"The total number of Spark tasks to use when generating data. " +
s"Default: ${DefaultNumTasks}")
.optional()
opt[Int]("string-length")
.action((v, o) => o.copy(stringLength = v))
.text(s"The length of generated string fields. Default: ${DefaultStringLength}")
.optional()
opt[Int]("binary-length")
.action((v, o) => o.copy(binaryLength = v))
.text(s"The length of generated binary fields. Default: ${DefaultBinaryLength}")
.optional()
opt[Long]("seed")
.action((v, o) => o.copy(seed = v))
.text(s"The seed to use in the random data generator. " +
s"Default: `System.currentTimeMillis()`")
opt[Boolean]("repartition")
.action((v, o) => o.copy(repartition = v))
.text(s"Repartition the data to ensure each spark task talks to a minimal " +
s"set of tablet servers.")
}
def parse(args: Seq[String]): Option[DistributedDataGeneratorOptions] = {
parser.parse(args, DistributedDataGeneratorOptions("", ""))
}
}