blob: 5ea06cb9634a5b0fa4943986efbc2be711294a60 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming
import java.io.{File, FileWriter}
import java.nio.file.Paths
import java.util.concurrent.TimeUnit
import scala.jdk.CollectionConverters._
import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.timeout
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
import org.apache.spark.sql.functions.{col, lit, udf, window}
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.sql.test.{IntegrationTestUtils, QueryTest, SQLHelper}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.SparkFileUtils
class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging {
private val testDataPath = Paths
.get(
IntegrationTestUtils.sparkHome,
"connector",
"connect",
"common",
"src",
"test",
"resources",
"query-tests",
"test-data",
"streaming")
test("Streaming API with windowed aggregate query") {
// This verifies standard streaming API by starting a streaming query with windowed count.
withSQLConf(
"spark.sql.shuffle.partitions" -> "1" // Avoid too many reducers.
) {
val readDF = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
// Verify schema (results in sending an RPC)
assert(readDF.schema.toDDL == "timestamp TIMESTAMP,value BIGINT")
val countsDF = readDF
.withWatermark("timestamp", "10 seconds")
.groupBy(window(col("timestamp"), "5 seconds"))
.count()
.selectExpr("window.start as timestamp", "count as num_events")
assert(countsDF.schema.toDDL == "timestamp TIMESTAMP,num_events BIGINT NOT NULL")
// Start the query
val queryName = "sparkConnectStreamingQuery"
val query = countsDF.writeStream
.format("memory")
.queryName(queryName)
.trigger(Trigger.ProcessingTime("1 second"))
.start()
try {
// Verify some of the API.
assert(query.isActive)
eventually(timeout(30.seconds)) {
assert(query.status.isDataAvailable)
assert(query.recentProgress.nonEmpty) // Query made progress.
}
val lastProgress = query.lastProgress
assert(lastProgress != null)
assert(lastProgress.name == queryName)
assert(!lastProgress.durationMs.isEmpty)
assert(!lastProgress.eventTime.isEmpty)
assert(lastProgress.stateOperators.nonEmpty)
assert(
lastProgress.stateOperators.head.customMetrics.keySet().asScala == Set(
"loadedMapCacheHitCount",
"loadedMapCacheMissCount",
"stateOnCurrentVersionSizeBytes"))
assert(lastProgress.sources.nonEmpty)
assert(lastProgress.sink.description == "MemorySink")
assert(lastProgress.observedMetrics.isEmpty)
query.recentProgress.foreach { p =>
assert(p.id == lastProgress.id)
assert(p.runId == lastProgress.runId)
assert(p.name == lastProgress.name)
}
query.explain() // Prints the plan to console.
// Consider verifying explain output by capturing stdout similar to
// test("Dataset explain") in ClientE2ETestSuite.
} finally {
// Don't wait for any processed data. Otherwise the test could take multiple seconds.
query.stop()
// The query should still be accessible after stopped.
assert(!query.isActive)
assert(query.recentProgress.nonEmpty)
}
}
}
test("Streaming table API") {
withSQLConf(
"spark.sql.shuffle.partitions" -> "1" // Avoid too many reducers.
) {
spark.sql("DROP TABLE IF EXISTS my_table").collect()
withTempPath { ckpt =>
val q1 = spark.readStream
.format("rate")
.load()
.writeStream
.option("checkpointLocation", ckpt.getCanonicalPath)
.toTable("my_table")
val q2 = spark.readStream
.table("my_table")
.writeStream
.format("memory")
.queryName("my_sink")
.start()
try {
q1.processAllAvailable()
q2.processAllAvailable()
eventually(timeout(30.seconds)) {
assert(spark.table("my_sink").count() > 0)
}
} finally {
q1.stop()
q2.stop()
spark.sql("DROP TABLE my_table")
}
}
}
}
test("stream read options with csv source and Trigger.AvailableNow") {
withTempPath { ckpt =>
val q = spark.readStream
.format("csv")
.option("sep", ";")
.option("header", "true")
.option("path", testDataPath.resolve("csv").toString)
.schema(StructType(Array(
StructField("name", StringType),
StructField("age", IntegerType),
StructField("job", StringType))))
.load()
.writeStream
.option("checkpointLocation", ckpt.getCanonicalPath)
.format("memory")
.queryName("my_sink_csv")
.trigger(Trigger.AvailableNow())
.start()
try {
q.processAllAvailable()
eventually(timeout(30.seconds)) {
assert(spark.table("my_sink_csv").count() == 2)
}
} finally {
q.stop()
}
}
}
test("stream save options with txt source") {
withTempPath { path =>
val checkpointPath = s"${path.getCanonicalPath}/_checkpoint"
val outputPath = s"${path.getCanonicalPath}/out"
val q = spark.readStream
.format("text")
.load(testDataPath.resolve("txt").toString)
.withColumn("age", lit(1))
.writeStream
.option("checkpointLocation", checkpointPath)
.format("parquet")
.partitionBy("age")
.outputMode("append")
.option("path", outputPath)
.start()
try {
q.processAllAvailable()
eventually(timeout(30.seconds)) {
val file = new File(outputPath)
assert(file.listFiles().exists(!_.getName.startsWith("_")))
}
} finally {
q.stop()
}
}
}
test("streaming with temp view") {
spark.sql("CREATE TABLE input_table (value string) USING parquet")
spark.sql("INSERT INTO input_table VALUES ('a'), ('b'), ('c')")
val df = spark.readStream.table("input_table")
assert(df.isStreaming)
df.createOrReplaceTempView("test_view")
val viewDf = spark.sql("SELECT * FROM test_view")
assert(viewDf.isStreaming)
val q = viewDf.writeStream.format("memory").queryName("test_view_sink").start()
try {
q.processAllAvailable()
eventually(timeout(30.seconds)) {
assert(spark.table("test_view_sink").count() == 3)
}
} finally {
q.stop()
spark.sql("DROP VIEW IF EXISTS test_view")
spark.sql("DROP TABLE IF EXISTS input_table")
}
}
test("awaitTermination") {
withSQLConf(
"spark.sql.shuffle.partitions" -> "1" // Avoid too many reducers.
) {
val q = spark.readStream
.format("rate")
.load()
.writeStream
.format("memory")
.queryName("test")
.start()
val start = System.nanoTime
val terminated = q.awaitTermination(500)
val end = System.nanoTime
assert((end - start) / 1e6 >= 500)
assert(!terminated)
q.stop()
eventually(timeout(1.minute)) {
q.awaitTermination()
}
}
}
test("throw exception in streaming") {
try {
val session = spark
import session.implicits._
val checkForTwo = udf((value: Int) => {
if (value == 2) {
throw new RuntimeException("Number 2 encountered!")
}
value
})
val query = spark.readStream
.format("rate")
.option("rowsPerSecond", "1")
.load()
.select(checkForTwo($"value").as("checkedValue"))
.writeStream
.outputMode("append")
.format("console")
.start()
val exception = intercept[StreamingQueryException] {
query.awaitTermination()
}
assert(exception.getErrorClass != null)
assert(exception.getMessageParameters().get("id") == query.id.toString)
assert(exception.getMessageParameters().get("runId") == query.runId.toString)
assert(!exception.getMessageParameters().get("startOffset").isEmpty)
assert(!exception.getMessageParameters().get("endOffset").isEmpty)
assert(exception.getCause.isInstanceOf[SparkException])
assert(exception.getCause.getCause.isInstanceOf[SparkException])
assert(
exception.getCause.getCause.getMessage
.contains("java.lang.RuntimeException: Number 2 encountered!"))
} finally {
spark.streams.resetTerminated()
}
}
test("throw exception in streaming, check with StreamingQueryManager") {
val session = spark
import session.implicits._
val checkForTwo = udf((value: Int) => {
if (value == 2) {
throw new RuntimeException("Number 2 encountered!")
}
value
})
val query = spark.readStream
.format("rate")
.option("rowsPerSecond", "1")
.load()
.select(checkForTwo($"value").as("checkedValue"))
.writeStream
.outputMode("append")
.format("console")
.start()
val exception = intercept[StreamingQueryException] {
spark.streams.awaitAnyTermination()
}
assert(exception.getErrorClass != null)
assert(exception.getMessageParameters().get("id") == query.id.toString)
assert(exception.getMessageParameters().get("runId") == query.runId.toString)
assert(!exception.getMessageParameters().get("startOffset").isEmpty)
assert(!exception.getMessageParameters().get("endOffset").isEmpty)
assert(exception.getCause.isInstanceOf[SparkException])
assert(exception.getCause.getCause.isInstanceOf[SparkException])
assert(
exception.getCause.getCause.getMessage
.contains("java.lang.RuntimeException: Number 2 encountered!"))
}
test("foreach Row") {
val writer = new TestForeachWriter[Row]
val df = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.load()
val query = df.writeStream
.foreach(writer)
.outputMode("update")
.start()
assert(query.isActive)
assert(query.exception.isEmpty)
query.stop()
}
test("foreach Int") {
val session: SparkSession = spark
import session.implicits._
val writer = new TestForeachWriter[Int]
val df = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.load()
val query = df
.selectExpr("CAST(value AS INT)")
.as[Int]
.writeStream
.foreach(writer)
.outputMode("update")
.start()
assert(query.isActive)
assert(query.exception.isEmpty)
query.stop()
}
test("foreach Custom class") {
val session: SparkSession = spark
import session.implicits._
val writer = new TestForeachWriter[TestClass]
val df = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.load()
val query = df
.selectExpr("CAST(value AS INT)")
.as[TestClass]
.writeStream
.foreach(writer)
.outputMode("update")
.start()
assert(query.isActive)
assert(query.exception.isEmpty)
query.stop()
}
test("streaming query manager") {
assert(spark.streams.active.isEmpty)
val q = spark.readStream
.format("rate")
.load()
.writeStream
.format("console")
.start()
assert(q.name == null)
val q1 = spark.streams.get(q.id)
val q2 = spark.streams.active(0)
assert(q.id == q1.id && q.id == q2.id)
assert(q.runId == q1.runId && q.runId == q2.runId)
assert(q1.name == null && q2.name == null)
spark.streams.resetTerminated()
val start = System.nanoTime
// Same setting as in test_query_manager_await_termination in test_streaming.py
val terminated = spark.streams.awaitAnyTermination(2600)
val end = System.nanoTime
assert((end - start) >= TimeUnit.MILLISECONDS.toNanos(2000))
assert(!terminated)
q.stop()
assert(!q1.isActive)
assert(spark.streams.get(q.id) == null)
}
test("streaming query listener") {
testStreamingQueryListener(new EventCollectorV1, "_v1")
testStreamingQueryListener(new EventCollectorV2, "_v2")
}
private def testStreamingQueryListener(
listener: StreamingQueryListener,
tablePostfix: String): Unit = {
assert(spark.streams.listListeners().length == 0)
spark.streams.addListener(listener)
val q = spark.readStream
.format("rate")
.load()
.writeStream
.format("console")
.start()
try {
q.processAllAvailable()
eventually(timeout(30.seconds)) {
assert(q.isActive)
assert(!spark.table(s"listener_start_events$tablePostfix").toDF().isEmpty)
assert(!spark.table(s"listener_progress_events$tablePostfix").toDF().isEmpty)
}
} finally {
q.stop()
eventually(timeout(30.seconds)) {
assert(!q.isActive)
assert(!spark.table(s"listener_terminated_events$tablePostfix").toDF().isEmpty)
}
spark.sql(s"DROP TABLE IF EXISTS listener_start_events$tablePostfix")
spark.sql(s"DROP TABLE IF EXISTS listener_progress_events$tablePostfix")
spark.sql(s"DROP TABLE IF EXISTS listener_terminated_events$tablePostfix")
}
// List listeners after adding a new listener, length should be 1.
val listeners = spark.streams.listListeners()
assert(listeners.length == 1)
// Add listener1 as another instance of EventCollector and validate
val listener1 = new EventCollectorV2
spark.streams.addListener(listener1)
assert(spark.streams.listListeners().length == 2)
spark.streams.removeListener(listener1)
assert(spark.streams.listListeners().length == 1)
// Add the same listener again and validate, this aims to verify the listener cache
// is correctly stored and cleaned.
spark.streams.addListener(listener)
assert(spark.streams.listListeners().length == 2)
spark.streams.removeListener(listener)
assert(spark.streams.listListeners().length == 1)
// Remove the listener, length should be 1.
spark.streams.removeListener(listener)
assert(spark.streams.listListeners().length == 0)
}
test("listener events") {
val listener = new MyListener()
spark.streams.addListener(listener)
val q = spark.readStream
.format("rate")
.load()
.writeStream
.format("console")
.start()
try {
q.processAllAvailable()
eventually(timeout(30.seconds)) {
assert(q.isActive)
assert(listener.start.length == 1)
assert(listener.progress.nonEmpty)
}
} finally {
q.stop()
eventually(timeout(30.seconds)) {
assert(!q.isActive)
assert(listener.terminate.nonEmpty)
}
}
}
test("foreachBatch") {
// Starts a streaming query with a foreachBatch function, which writes batchId and row count
// to a temp view. The test verifies that the view is populated with data.
val viewName = "test_view"
val tableName = s"global_temp.$viewName"
withTable(tableName) {
val q = spark.readStream
.format("rate")
.option("rowsPerSecond", "10")
.option("numPartitions", "1")
.load()
.writeStream
.foreachBatch(new ForeachBatchFn(viewName))
.start()
eventually(timeout(30.seconds)) { // Wait for first progress.
assert(q.lastProgress != null, "Failed to make progress")
assert(q.lastProgress.numInputRows > 0)
}
eventually(timeout(30.seconds)) {
// There should be row(s) in temporary view created by foreachBatch.
val rows = spark
.sql(s"select * from $tableName")
.collect()
.toSeq
assert(rows.size > 0)
logInfo(s"Rows in $tableName: $rows")
}
q.stop()
}
}
abstract class EventCollector extends StreamingQueryListener {
protected def tablePostfix: String
protected def handleOnQueryStarted(event: QueryStartedEvent): Unit = {
val df = spark.createDataFrame(Seq((event.json, 0)))
df.write.mode("append").saveAsTable(s"listener_start_events$tablePostfix")
}
protected def handleOnQueryProgress(event: QueryProgressEvent): Unit = {
val df = spark.createDataFrame(Seq((event.json, 0)))
df.write.mode("append").saveAsTable(s"listener_progress_events$tablePostfix")
}
protected def handleOnQueryTerminated(event: QueryTerminatedEvent): Unit = {
val df = spark.createDataFrame(Seq((event.json, 0)))
df.write.mode("append").saveAsTable(s"listener_terminated_events$tablePostfix")
}
}
/**
* V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`,
* `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
*/
class EventCollectorV1 extends EventCollector {
override protected def tablePostfix: String = "_v1"
override def onQueryStarted(event: QueryStartedEvent): Unit = handleOnQueryStarted(event)
override def onQueryProgress(event: QueryProgressEvent): Unit = handleOnQueryProgress(event)
override def onQueryTerminated(event: QueryTerminatedEvent): Unit =
handleOnQueryTerminated(event)
}
/**
* V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+.
*/
class EventCollectorV2 extends EventCollector {
override protected def tablePostfix: String = "_v2"
override def onQueryStarted(event: QueryStartedEvent): Unit = handleOnQueryStarted(event)
override def onQueryProgress(event: QueryProgressEvent): Unit = handleOnQueryProgress(event)
override def onQueryIdle(event: QueryIdleEvent): Unit = {}
override def onQueryTerminated(event: QueryTerminatedEvent): Unit =
handleOnQueryTerminated(event)
}
class MyListener extends StreamingQueryListener {
var start: Seq[String] = Seq.empty
var progress: Seq[String] = Seq.empty
var terminate: Seq[String] = Seq.empty
override def onQueryStarted(event: QueryStartedEvent): Unit = {
start = start :+ event.json
}
override def onQueryProgress(event: QueryProgressEvent): Unit = {
progress = progress :+ event.json
}
override def onQueryIdle(event: QueryIdleEvent): Unit = {
// Do nothing
}
override def onQueryTerminated(event: QueryTerminatedEvent): Unit = {
terminate = terminate :+ event.json
}
}
}
class TestForeachWriter[T] extends ForeachWriter[T] {
var fileWriter: FileWriter = _
var path: File = _
def open(partitionId: Long, version: Long): Boolean = {
path = SparkFileUtils.createTempDir()
fileWriter = new FileWriter(path, true)
true
}
def process(row: T): Unit = {
fileWriter.write(row.toString)
fileWriter.write("\n")
}
def close(errorOrNull: Throwable): Unit = {
fileWriter.close()
SparkFileUtils.deleteRecursively(path)
}
}
case class TestClass(value: Int) {
override def toString: String = value.toString
}
class ForeachBatchFn(val viewName: String)
extends VoidFunction2[DataFrame, java.lang.Long]
with Serializable {
override def call(df: DataFrame, batchId: java.lang.Long): Unit = {
val count = df.count()
df.sparkSession
.createDataFrame(Seq((batchId.toLong, count)))
.createOrReplaceGlobalTempView(viewName)
}
}