package org.apache.druid.spark.v2.reader
import{FileUtils, IAE, ISE, StringUtils}
import org.apache.druid.query.dimension.DefaultDimensionSpec
import org.apache.druid.query.filter.DimFilter
import org.apache.druid.segment.column.ValueType
import org.apache.druid.segment.vector.{VectorColumnSelectorFactory, VectorCursor}
import org.apache.druid.segment.VirtualColumns
import org.apache.druid.spark.configuration.{Configuration, SerializableHadoopConfiguration}
import org.apache.druid.spark.mixins.Logging
import org.apache.druid.spark.registries.ComplexTypeRegistry
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, WritableColumnVector}
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader
import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType, LongType, StringType,
StructField, StructType, TimestampType}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
class DruidColumnarInputPartitionReader(
segmentStr: String,
schema: StructType,
filter: Option[DimFilter],
columnTypes: Option[Set[String]],
broadcastedHadoopConf: Broadcast[SerializableHadoopConfiguration],
conf: Configuration,
useSparkConfForDeepStorage: Boolean,
useCompactSketches: Boolean,
useDefaultNullHandling: Boolean,
batchSize: Int
extends DruidBaseInputPartitionReader(
) with InputPartitionReader[ColumnarBatch] with Logging {
private val cursor: VectorCursor = segment.asStorageAdapter().makeVectorCursor(,
null) // scalastyle:ignore null
private val columnVectors: Array[OnHeapColumnVector] = OnHeapColumnVector.allocateColumns(batchSize, schema)
private val resultBatch: ColumnarBatch = new ColumnarBatch([ColumnVector]))
override def next(): Boolean = {
if (!cursor.isDone) {
} else {
override def get(): ColumnarBatch = {
override def close(): Unit = {
try {
if (Option(cursor).nonEmpty) {
if (Option(segment).nonEmpty) {
if (Option(tmpDir).nonEmpty) {
} catch {
case e: Exception =>
// Since we're just going to rethrow e and tearing down the JVM will clean up the result batch, column vectors,
// cursor, and segment even if we can't, the only leak we have to worry about is the temp file. Spark should
// clean up temp files as well, but rather than rely on that we'll try to take care of it ourselves.
logWarn("Encountered exception attempting to close a DruidColumnarInputPartitionReader!")
if (Option(tmpDir).nonEmpty && tmpDir.exists()) {
throw e
// TODO: Maybe ColumnProcessors can help here? Need to investigate
private[reader] def fillVectors(): Unit = {
val selectorFactory = cursor.getColumnSelectorFactory
schema.fields.zipWithIndex.foreach{case(col, i) =>
val capabilities = selectorFactory.getColumnCapabilities(
val columnVector = columnVectors(i)
if (capabilities == null) { // scalastyle:ignore null
fillNullVector(columnVector, col)
} else {
capabilities.getType match {
case ValueType.FLOAT | ValueType.LONG | ValueType.DOUBLE =>
fillNumericVector(capabilities.getType, selectorFactory, columnVector,
case ValueType.STRING =>
fillStringVector(selectorFactory, columnVector, col, capabilities.hasMultipleValues.isMaybeTrue)
case ValueType.COMPLEX =>
fillComplexVector(selectorFactory, columnVector, col)
case _ => throw new IAE(s"Unrecognized ValueType ${capabilities.getType}!")
* Fill a Spark ColumnVector with the values from a Druid VectorSelector containing numeric rows.
* The general pattern is:
* 1) If there are no null values in the Druid data, just copy the backing array over
* 2) If there are nulls (the null vector is not itself null), for each index in the Druid vector check
* if the source vector is null at that index and if so insert the appropriate null value into the
* Spark vector. Otherwise, copy over the value at that index from the Druid vector.
* @param valueType The ValueType of the Druid column to fill COLUMNVECTOR from.
* @param selectorFactory The Druid SelectorFactory backed by the data read in from segment files.
* @param columnVector The Spark ColumnVector to fill with the data from SELECTORFACTORY.
* @param name The name of the column in Druid to source data from.
private[reader] def fillNumericVector(
valueType: ValueType,
selectorFactory: VectorColumnSelectorFactory,
columnVector: WritableColumnVector,
name: String
): Unit = {
val selector = selectorFactory.makeValueSelector(name)
val vectorLength = selector.getCurrentVectorSize
val nulls = selector.getNullVector
valueType match {
case ValueType.FLOAT =>
val vector = selector.getFloatVector
if (nulls == null) { // scalastyle:ignore null
columnVector.putFloats(0, vectorLength, vector, 0)
} else {
(0 until vectorLength).foreach { i =>
if (nulls(i)) {
if (useDefaultNullHandling) {
columnVector.putFloat(i, 0)
} else {
} else {
columnVector.putFloat(i, vector(i))
case ValueType.LONG =>
val vector = selector.getLongVector
if (nulls == null) { // scalastyle:ignore null
columnVector.putLongs(0, vectorLength, vector, 0)
} else {
(0 until vectorLength).foreach { i =>
if (nulls(i)) {
if (useDefaultNullHandling) {
columnVector.putLong(i, 0)
} else {
} else {
columnVector.putLong(i, vector(i))
case ValueType.DOUBLE =>
val vector = selector.getDoubleVector
if (nulls == null) { // scalastyle:ignore null
columnVector.putDoubles(0, vectorLength, vector, 0)
} else {
(0 until vectorLength).foreach { i =>
if (nulls(i)) {
if (useDefaultNullHandling) {
columnVector.putDouble(i, 0)
} else {
} else {
columnVector.putDouble(i, vector(i))
case _ => throw new IAE(s"Must call fillNumericVector will a numeric value type; called with $valueType!")
* Fill a Spark ColumnVector with the values from a Druid VectorSelector containing string rows.
* In theory, we could define a ColumnVector implementation that handled single- and multi-valued strings
* intelligently while falling back to the existing behavior for other data types. Unfortunately, Spark marks
* OnHeapColumnVector as final so we'd need to copy the underlying logic and maintain it ourselves or abuse
* reflection. Additionally, Spark doesn't really do anything clever with columnar dataframes in 2.4. Specifically
* for multi-valued string columns this means that under the hood Spark will immediately convert each sub-array
* (e.g. row) into an Object[] and so we won't gain anything by maintaining the value dictionary. Instead, we define
* a SingleValueDimensionDictionary to handle the single-valued case and reify multi-valued dimensions ourselves to
* reduce complexity.
* There are also a couple of open questions to investigate:
* First, how does Spark expect nulls to be flagged from dictionaries? If dictionaries can happily return null, then
* we can just drop the row vector in the dictionary creation and be on our way. If Spark expects nulls to be flagged
* explicitly, then we'll need to figure out how the different Druid null handling strategies change both what gets
* stored on disk and what we read here from the SingleValueDimensionSelector. In this case, based on
* PossiblyNullDimensionSelector we'll likely need to iterate over the row vector returned by the selector and call
* either putNull if the value at the index is 0 or putInt otherwise.
* Second, can Druid dictionaries change between parts of the segment file (i.e in different smooshes)? If they can,
* we need to add checks for that case and fall back to putting byte arrays into the column vector directly for
* single-valued dimensions.
* @param selectorFactory The Druid SelectorFactory backed by the data read in from segment files.
* @param columnVector The Spark ColumnVector to fill with the data from SELECTORFACTORY.
* @param column The Spark column schema we're filling.
* @param maybeHasMultipleValues Whether or not the Druid column we're reading from may contain multiple values.
private[reader] def fillStringVector(
selectorFactory: VectorColumnSelectorFactory,
columnVector: WritableColumnVector,
column: StructField,
maybeHasMultipleValues: Boolean
): Unit = {
if (maybeHasMultipleValues) {
// Multi-valued string dimension that may contain multiple values in this batch
val selector = selectorFactory.makeMultiValueDimensionSelector(DefaultDimensionSpec.of(
val vector = selector.getRowVector
val vectorLength = selector.getCurrentVectorSize
// This will store repeated strings multiple times. CPU should be more important than storage here, but
// if the repeated strings are a problem and reducing the batch size doesn't help, we could implement our
// own ColumnVector that tracks the row for each string in the lookup dict and then stores arrays of rowIds.
// We'd need two vectors (the main ColumnVector, which would store an array of ints for each actual row id
// and an arrayData column vector, which would store strings at each internal row id.) When we read in an
// array of IndexedInts, we'd check to see if we'd already stored the corresponding string in arrayData and
// if so just use the existing internal row. The ints in the main vector would point to the internal row ids
// and we'd override ColumnVector#getArray(rowId: Int) to follow the logic on read. This would preserve the
// space savings of the dictionary-encoding at the cost of possibly more CPU at read.
val arrayData = columnVector.arrayData()
// Note that offsets here are in rows, not bytes
var columnVectorOffset = 0
var arrayDataOffset = 0
// Iterating over the populated elements of vector twice is faster than reserving additional capacity as
// each new row is processed since reserving more capacity means copying arrays.
val numberOfValuesInBatch = (0 until vectorLength).map(vector(_).size()).sum
(0 until vectorLength).foreach{i =>
val arr = vector(i)
if (arr == null) {
// TODO: Is this possible? Need to test
} else {
val numberOfValuesInRow = arr.size() // Number of values in this row
(0 until numberOfValuesInRow).foreach { idx =>
val id = arr.get(idx)
val bytes = StringUtils.toUtf8(selector.lookupName(id))
arrayData.putByteArray(arrayDataOffset, bytes)
arrayDataOffset += 1
columnVector.putArray(i, columnVectorOffset, numberOfValuesInRow)
columnVectorOffset += numberOfValuesInRow
} else {
// Multi-valued string dimension that does not contain multiple values in this batch
val selector = selectorFactory.makeSingleValueDimensionSelector(DefaultDimensionSpec.of(
val vector = selector.getRowVector
val vectorLength = selector.getCurrentVectorSize
if (column.dataType.isInstanceOf[ArrayType]) {
// need to handle as if it were multi-dimensional so results are properly wrapped in arrays in spark
val arrayData = columnVector.arrayData()
// TODO: Work out null handling (see SingleValueDimensionDictionary as well)
(0 until vectorLength).foreach{i =>
val bytes = StringUtils.toUtf8(selector.lookupName(vector(i)))
arrayData.putByteArray(i, bytes)
columnVector.putArray(i, i,1)
} else {
// Single-valued string dimension
// TODO: There's got to be a better way to extract the lookups, but for now YOLO
val cardinality = selector.getValueCardinality
if (cardinality == -1) {
throw new ISE("Encountered dictionary with unknown cardinality, vectorized reading not supported!")
val lookupMap = (0 until cardinality).map { id =>
id -> selector.lookupName(id)
val colDict = new SingleValueDimensionDictionary(lookupMap)
val dictionaryIds = columnVector.reserveDictionaryIds(vectorLength)
dictionaryIds.appendInts(vectorLength, vector, 0)
private[reader] def fillComplexVector(
selectorFactory: VectorColumnSelectorFactory,
columnVector: WritableColumnVector,
column: StructField
): Unit = {
val selector = selectorFactory.makeObjectSelector(
val vector = selector.getObjectVector
val vectorLength = selector.getCurrentVectorSize
(0 until vectorLength).foreach{i =>
val obj = vector(i)
if (obj == null) { // scalastyle:ignore null
} else if (ComplexTypeRegistry.getRegisteredSerializedClasses.contains(obj.getClass)) {
val bytes = ComplexTypeRegistry.deserialize(obj)
columnVector.putByteArray(i, bytes)
} else {
obj match {
case arr: Array[Byte] =>
columnVector.putByteArray(i, arr)
case _ => throw new IllegalArgumentException(
s"Unable to parse ${column.getClass.toString} into a ByteArray! Try registering a Complex Type Plugin."
private[reader] def fillNullVector(columnVector: WritableColumnVector, column: StructField): Unit = {
val vectorLength = cursor.getCurrentVectorSize
if (useDefaultNullHandling) {
column.dataType match {
case FloatType =>
columnVector.putFloats(0, vectorLength, 0)
case LongType | TimestampType =>
columnVector.putLongs(0, vectorLength, 0)
case DoubleType =>
columnVector.putDoubles(0, vectorLength, 0)
case StringType =>
(0 until vectorLength).foreach{i =>
columnVector.putByteArray(i, Array.emptyByteArray)
case ArrayType(StringType, _) =>
val arrayData = columnVector.arrayData()
(0 until vectorLength).foreach{i =>
arrayData.putByteArray(i, Array.emptyByteArray)
columnVector.putArray(i, i,1)
case _ => // Complex Types use nulls regardless of null handling mode. Also nulling unknown types.
columnVector.putNulls(0, vectorLength)
} else {
columnVector.putNulls(0, cursor.getCurrentVectorSize)