blob: a746a829723dc0d39b7142dd994dfa8e9e680e27 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.spark.v2
import com.google.common.base.{Supplier, Suppliers}
import org.apache.commons.dbcp2.BasicDataSource
import org.apache.druid.java.util.common.granularity.GranularityType
import org.apache.druid.java.util.common.{FileUtils, Intervals, StringUtils}
import org.apache.druid.metadata.{MetadataStorageConnectorConfig, MetadataStorageTablesConfig,
SQLMetadataConnector}
import org.apache.druid.spark.MAPPER
import org.apache.druid.spark.configuration.DruidConfigurationKeys
import org.apache.druid.spark.registries.SQLConnectorRegistry
import org.apache.druid.spark.utils.SchemaUtils
import org.apache.druid.timeline.DataSegment
import org.apache.druid.timeline.partition.NumberedShardSpec
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader
import org.apache.spark.sql.types.{ArrayType, BinaryType, DoubleType, FloatType, LongType,
StringType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.unsafe.types.UTF8String
import org.joda.time.Interval
import org.skife.jdbi.v2.{DBI, Handle}
import org.skife.jdbi.v2.exceptions.UnableToObtainConnectionException
import java.io.File
import java.util.{Properties, UUID, List => JList, Map => JMap}
import scala.collection.JavaConverters.{asScalaIteratorConverter, mapAsJavaMapConverter,
seqAsJavaListConverter}
import scala.collection.mutable.ArrayBuffer
trait DruidDataSourceV2TestUtils {
val dataSource: String = "spark_druid_test"
val interval: Interval = Intervals.of("2020-01-01T00:00:00.000Z/2020-01-02T00:00:00.000Z")
val secondInterval: Interval = Intervals.of("2020-01-02T00:00:00.000Z/2020-01-03T00:00:00.000Z")
val version: String = "0"
val segmentsDir: File =
new File(makePath("src", "test", "resources", "segments")).getCanonicalFile
val firstSegmentPath: String =
makePath("spark_druid_test", "2020-01-01T00:00:00.000Z_2020-01-02T00:00:00.000Z", "0", "0", "index.zip")
val secondSegmentPath: String =
makePath("spark_druid_test", "2020-01-01T00:00:00.000Z_2020-01-02T00:00:00.000Z", "0", "1", "index.zip")
val thirdSegmentPath: String =
makePath("spark_druid_test", "2020-01-02T00:00:00.000Z_2020-01-03T00:00:00.000Z", "0", "0", "index.zip")
val loadSpec: String => JMap[String, AnyRef] = (path: String) =>
Map[String, AnyRef]("type" -> "local", "path" -> path).asJava
val dimensions: JList[String] = List("dim1", "dim2", "id1", "id2").asJava
val metrics: JList[String] = List(
"count", "sum_metric1","sum_metric2","sum_metric3","sum_metric4","uniq_id1").asJava
val metricsSpec: String =
"""[
| { "type": "count", "name": "count" },
| { "type": "longSum", "name": "sum_metric1", "fieldName": "sum_metric1" },
| { "type": "longSum", "name": "sum_metric2", "fieldName": "sum_metric2" },
| { "type": "doubleSum", "name": "sum_metric3", "fieldName": "sum_metric3" },
| { "type": "floatSum", "name": "sum_metric4", "fieldName": "sum_metric4" },
| { "type": "thetaSketch", "name": "uniq_id1", "fieldName": "uniq_id1", "isInputThetaSketch": true }
|]""".stripMargin
val binaryVersion: Integer = 9
val timestampColumn: String = "__time"
val timestampFormat: String = "auto"
val segmentGranularity: String = GranularityType.DAY.name
val firstSegment: DataSegment = new DataSegment(
dataSource,
interval,
version,
loadSpec(makePath(segmentsDir.getCanonicalPath, firstSegmentPath)),
dimensions,
metrics,
new NumberedShardSpec(0, 0),
binaryVersion,
3278L
)
val secondSegment: DataSegment = new DataSegment(
dataSource,
interval,
version,
loadSpec(makePath(segmentsDir.getCanonicalPath, secondSegmentPath)),
dimensions,
metrics,
new NumberedShardSpec(1, 0),
binaryVersion,
3299L
)
val thirdSegment: DataSegment = new DataSegment(
dataSource,
secondInterval,
version,
loadSpec(makePath(segmentsDir.getCanonicalPath, thirdSegmentPath)),
dimensions,
metrics,
new NumberedShardSpec(0, 0),
binaryVersion,
3409L
)
val firstSegmentString: String = MAPPER.writeValueAsString(firstSegment)
val secondSegmentString: String = MAPPER.writeValueAsString(secondSegment)
val thirdSegmentString: String = MAPPER.writeValueAsString(thirdSegment)
val idOneSketch: Array[Byte] = StringUtils.decodeBase64String("AQMDAAA6zJNV0wc7TCHDCQ==")
val idTwoSketch: Array[Byte] = StringUtils.decodeBase64String("AQMDAAA6zJNHlmybd5/laQ==")
val idThreeSketch: Array[Byte] = StringUtils.decodeBase64String("AQMDAAA6zJOppPrHQT61Dw==")
val firstTimeBucket: Long = 1577836800000L
val secondTimeBucket: Long = 1577923200000L
val schema: StructType = StructType(Seq[StructField](
StructField("__time", LongType),
StructField("dim1", ArrayType(StringType, false)),
StructField("dim2", StringType),
StructField("id1", StringType),
StructField("id2", StringType),
StructField("count", LongType),
StructField("sum_metric1", LongType),
StructField("sum_metric2", LongType),
StructField("sum_metric3", DoubleType),
StructField("sum_metric4", FloatType),
StructField("uniq_id1", BinaryType)
))
val columnTypes: Option[Set[String]] =
Option(Set("LONG", "STRING", "FLOAT", "DOUBLE", "thetaSketch"))
private val tempDirs: ArrayBuffer[String] = new ArrayBuffer[String]()
def testWorkingStorageDirectory: String = {
val tempDir = FileUtils.createTempDir("druid-spark-tests").getCanonicalPath
tempDirs += tempDir
tempDir
}
private val testDbUri = "jdbc:derby:memory:TestDatabase"
def generateUniqueTestUri(): String = testDbUri + dbSafeUUID
val metadataClientProps: String => Map[String, String] = (uri: String) => Map[String, String](
s"${DruidConfigurationKeys.metadataPrefix}.${DruidConfigurationKeys.metadataDbTypeKey}" -> "embedded_derby",
s"${DruidConfigurationKeys.metadataPrefix}.${DruidConfigurationKeys.metadataConnectUriKey}" -> uri
)
def createTestDb(uri: String): Unit = new DBI(s"$uri;create=true").open().close()
def openDbiToTestDb(uri: String): Handle = new DBI(uri).open()
def tearDownTestDb(uri: String): Unit = {
try {
new DBI(s"$uri;shutdown=true").open().close()
} catch {
// Closing an in-memory Derby database throws an expected exception. It bubbles up as an
// UnableToObtainConnectionException from skiffie.
// TODO: Just open the connection directly and check the exception there
case _: UnableToObtainConnectionException =>
}}
def registerEmbeddedDerbySQLConnector(): Unit = {
SQLConnectorRegistry.register("embedded_derby",
(connectorConfigSupplier: Supplier[MetadataStorageConnectorConfig],
metadataTableConfigSupplier: Supplier[MetadataStorageTablesConfig]) => {
val connectorConfig = connectorConfigSupplier.get()
val amendedConnectorConfigSupplier =
new MetadataStorageConnectorConfig
{
override def isCreateTables: Boolean = true
override def getHost: String = connectorConfig.getHost
override def getPort: Int = connectorConfig.getPort
override def getConnectURI: String = connectorConfig.getConnectURI
override def getUser: String = connectorConfig.getUser
override def getPassword: String = connectorConfig.getPassword
override def getDbcpProperties: Properties = connectorConfig.getDbcpProperties
}
val res: SQLMetadataConnector =
new SQLMetadataConnector(Suppliers.ofInstance(amendedConnectorConfigSupplier), metadataTableConfigSupplier) {
val datasource: BasicDataSource = getDatasource
datasource.setDriverClassLoader(getClass.getClassLoader)
datasource.setDriverClassName("org.apache.derby.jdbc.EmbeddedDriver")
private val dbi = new DBI(connectorConfigSupplier.get().getConnectURI)
private val SERIAL_TYPE = "BIGINT GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)"
override def getSerialType: String = SERIAL_TYPE
override def getStreamingFetchSize: Int = 1
override def getQuoteString: String = "\\\""
override def tableExists(handle: Handle, tableName: String): Boolean =
!handle.createQuery("select * from SYS.SYSTABLES where tablename = :tableName")
.bind("tableName", StringUtils.toUpperCase(tableName)).list.isEmpty;
override def getDBI: DBI = dbi
}
res.createSegmentTable()
res
})
}
def cleanUpWorkingDirectory(): Unit = {
tempDirs.foreach(dir => FileUtils.deleteDirectory(new File(dir).getCanonicalFile))
}
def partitionReaderToSeq(reader: InputPartitionReader[InternalRow]): Seq[InternalRow] = {
val res = new ArrayBuffer[InternalRow]()
while (reader.next()) {
res += reader.get()
}
reader.close()
res
}
def columnarPartitionReaderToSeq(reader: InputPartitionReader[ColumnarBatch]): Seq[InternalRow] = {
val res = new ArrayBuffer[InternalRow]()
// ColumnarBatches return MutableColumnarRows, so we need to copy them before we close
while (reader.next()) {
val batch = reader.get()
batch.rowIterator().asScala.foreach { row =>
// MutableColumnarRows don't support copying ArrayTypes, we can't use row.copy()
val finalizedRow = new GenericInternalRow(batch.numCols())
(0 until batch.numCols()).foreach{ col =>
if (row.isNullAt(col)) {
finalizedRow.setNullAt(col)
} else {
val dataType = batch.column(col).dataType()
dataType match {
case _: ArrayType =>
// Druid only supports multiple values for Strings, hard-code that assumption here for now
val finalizedArr = row.getArray(col).array.map(el => el.asInstanceOf[UTF8String].copy())
finalizedRow.update(col, ArrayData.toArrayData(finalizedArr))
case _ =>
finalizedRow.update(col, row.get(col, dataType))
}
}
}
res += finalizedRow
}
}
reader.close()
res
}
def wrapSeqToInternalRow(seq: Seq[Any], schema: StructType): InternalRow = {
InternalRow.fromSeq(seq.zipWithIndex.map{case (elem, i) =>
if (elem == null) { // scalastyle:ignore null
null // scalastyle:ignore null
} else {
schema(i).dataType match {
case _: ArrayType =>
val baseType = schema(i).dataType.asInstanceOf[ArrayType].elementType
elem match {
case collection: Traversable[_] =>
ArrayData.toArrayData(collection.map { elem =>
SchemaUtils.parseToScala(elem, baseType)
})
case _ =>
// Single-element arrays
ArrayData.toArrayData(List(SchemaUtils.parseToScala(elem, baseType)))
}
case _ => SchemaUtils.parseToScala(elem, schema(i).dataType)
}
}
})
}
/**
* Given a DataFrame DF, return a collection of arrays of Rows where each array contains all rows for a
* partition in DF.
*
* @param df The dataframe to extract partitions from.
* @return A Seq[Array[Row]], where each Array[Row] contains all rows for a corresponding partition in DF.
*/
def getDataFramePartitions(df: DataFrame): Seq[Array[Row]] = {
df
.rdd
.map(row => TaskContext.getPartitionId() -> row)
.collect()
.groupBy(_._1)
.values
.map(_.map(_._2))
.toSeq
}
def makePath(components: String*): String = {
components.mkString(File.separator)
}
def dbSafeUUID: String = StringUtils.removeChar(UUID.randomUUID.toString, '-')
}