blob: a0e85c73aedf6f80f90452f27a9552ab9014e529 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import shutil
import tempfile
import time
from pyspark.sql import Row
from pyspark.sql.functions import lit
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, TimestampType
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.errors import PySparkValueError
class StreamingTestsMixin:
def test_streaming_query_functions_basic(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = (
df.writeStream.format("memory")
.queryName("test_streaming_query_functions_basic")
.start()
)
try:
self.assertEqual(query.name, "test_streaming_query_functions_basic")
self.assertTrue(isinstance(query.id, str))
self.assertTrue(isinstance(query.runId, str))
self.assertTrue(query.isActive)
self.assertEqual(query.exception(), None)
self.assertFalse(query.awaitTermination(1))
query.processAllAvailable()
lastProgress = query.lastProgress
recentProgress = query.recentProgress
self.assertEqual(lastProgress["name"], query.name)
self.assertEqual(lastProgress["id"], query.id)
self.assertTrue(any(p == lastProgress for p in recentProgress))
query.explain()
except Exception as e:
self.fail(
"Streaming query functions sanity check shouldn't throw any error. "
"Error message: " + str(e)
)
finally:
query.stop()
def test_streaming_progress(self):
"""
Should be able to access fields using attributes in lastProgress / recentProgress
e.g. q.lastProgress.id
"""
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
query = df.writeStream.format("noop").start()
try:
query.processAllAvailable()
lastProgress = query.lastProgress
recentProgress = query.recentProgress
self.assertEqual(lastProgress["name"], query.name)
# Return str when accessed using dict get.
self.assertEqual(lastProgress["id"], query.id)
# SPARK-48567 Use attribute to access fields in q.lastProgress
self.assertEqual(lastProgress.name, query.name)
# Return uuid when accessed using attribute.
self.assertEqual(str(lastProgress.id), query.id)
self.assertTrue(any(p == lastProgress for p in recentProgress))
self.assertTrue(lastProgress.numInputRows > 0)
# Also access source / sink progress with attributes
self.assertTrue(len(lastProgress.sources) > 0)
self.assertTrue(lastProgress.sources[0].numInputRows > 0)
self.assertTrue(lastProgress["sources"][0]["numInputRows"] > 0)
self.assertTrue(lastProgress.sink.numOutputRows > 0)
self.assertTrue(lastProgress["sink"]["numOutputRows"] > 0)
# In Python, for historical reasons, changing field value
# in StreamingQueryProgress is allowed.
new_name = "myNewQuery"
lastProgress["name"] = new_name
self.assertEqual(lastProgress.name, new_name)
except Exception as e:
self.fail(
"Streaming query functions sanity check shouldn't throw any error. "
"Error message: " + str(e)
)
finally:
query.stop()
def test_streaming_query_name_edge_case(self):
# Query name should be None when not specified
q1 = self.spark.readStream.format("rate").load().writeStream.format("noop").start()
self.assertEqual(q1.name, None)
# Cannot set query name to be an empty string
error_thrown = False
try:
(
self.spark.readStream.format("rate")
.load()
.writeStream.format("noop")
.queryName("")
.start()
)
except PySparkValueError:
error_thrown = True
self.assertTrue(error_thrown)
def test_stream_trigger(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
# Should take at least one arg
try:
df.writeStream.trigger()
except ValueError:
pass
# Should not take multiple args
try:
df.writeStream.trigger(once=True, processingTime="5 seconds")
except ValueError:
pass
# Should not take multiple args
try:
df.writeStream.trigger(processingTime="5 seconds", continuous="1 second")
except ValueError:
pass
# Should take only keyword args
try:
df.writeStream.trigger("5 seconds")
self.fail("Should have thrown an exception")
except TypeError:
pass
def test_stream_read_options(self):
schema = StructType([StructField("data", StringType(), False)])
df = (
self.spark.readStream.format("text")
.option("path", "python/test_support/sql/streaming")
.schema(schema)
.load()
)
self.assertTrue(df.isStreaming)
self.assertEqual(df.schema.simpleString(), "struct<data:string>")
def test_stream_read_options_overwrite(self):
bad_schema = StructType([StructField("test", IntegerType(), False)])
schema = StructType([StructField("data", StringType(), False)])
# SPARK-32516 disables the overwrite behavior by default.
with self.sql_conf({"spark.sql.legacy.pathOptionBehavior.enabled": True}):
df = (
self.spark.readStream.format("csv")
.option("path", "python/test_support/sql/fake")
.schema(bad_schema)
.load(path="python/test_support/sql/streaming", schema=schema, format="text")
)
self.assertTrue(df.isStreaming)
self.assertEqual(df.schema.simpleString(), "struct<data:string>")
def test_stream_save_options(self):
df = (
self.spark.readStream.format("text")
.load("python/test_support/sql/streaming")
.withColumn("id", lit(1))
)
for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
self.assertTrue(df.isStreaming)
out = os.path.join(tmpPath, "out")
chk = os.path.join(tmpPath, "chk")
q = (
df.writeStream.option("checkpointLocation", chk)
.queryName("this_query")
.format("parquet")
.partitionBy("id")
.outputMode("append")
.option("path", out)
.start()
)
try:
self.assertEqual(q.name, "this_query")
self.assertTrue(q.isActive)
q.processAllAvailable()
output_files = []
for _, _, files in os.walk(out):
output_files.extend([f for f in files if not f.startswith(".")])
self.assertTrue(len(output_files) > 0)
self.assertTrue(len(os.listdir(chk)) > 0)
finally:
q.stop()
shutil.rmtree(tmpPath)
def test_stream_save_options_overwrite(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
self.assertTrue(df.isStreaming)
out = os.path.join(tmpPath, "out")
chk = os.path.join(tmpPath, "chk")
fake1 = os.path.join(tmpPath, "fake1")
fake2 = os.path.join(tmpPath, "fake2")
# SPARK-32516 disables the overwrite behavior by default.
with self.sql_conf({"spark.sql.legacy.pathOptionBehavior.enabled": True}):
q = (
df.writeStream.option("checkpointLocation", fake1)
.format("memory")
.option("path", fake2)
.queryName("fake_query")
.outputMode("append")
.start(path=out, format="parquet", queryName="this_query", checkpointLocation=chk)
)
try:
self.assertEqual(q.name, "this_query")
self.assertTrue(q.isActive)
q.processAllAvailable()
output_files = []
for _, _, files in os.walk(out):
output_files.extend([f for f in files if not f.startswith(".")])
self.assertTrue(len(output_files) > 0)
self.assertTrue(len(os.listdir(chk)) > 0)
self.assertFalse(os.path.isdir(fake1)) # should not have been created
self.assertFalse(os.path.isdir(fake2)) # should not have been created
finally:
q.stop()
shutil.rmtree(tmpPath)
def test_stream_status_and_progress(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
self.assertTrue(df.isStreaming)
out = os.path.join(tmpPath, "out")
chk = os.path.join(tmpPath, "chk")
def func(x):
time.sleep(1)
return x
from pyspark.sql.functions import col, udf
sleep_udf = udf(func)
# Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there
# were no updates.
q = df.select(sleep_udf(col("value")).alias("value")).writeStream.start(
path=out, format="parquet", queryName="this_query", checkpointLocation=chk
)
try:
# "lastProgress" will return None in most cases. However, as it may be flaky when
# Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress"
# may throw error with a high chance and make this test flaky, so we should still be
# able to detect broken codes.
q.lastProgress
q.processAllAvailable()
lastProgress = q.lastProgress
recentProgress = q.recentProgress
status = q.status
self.assertEqual(lastProgress["name"], q.name)
self.assertEqual(lastProgress["id"], q.id)
self.assertTrue(any(p == lastProgress for p in recentProgress))
self.assertTrue(
"message" in status and "isDataAvailable" in status and "isTriggerActive" in status
)
finally:
q.stop()
shutil.rmtree(tmpPath)
def test_stream_await_termination(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
self.assertTrue(df.isStreaming)
out = os.path.join(tmpPath, "out")
chk = os.path.join(tmpPath, "chk")
q = df.writeStream.start(
path=out, format="parquet", queryName="this_query", checkpointLocation=chk
)
try:
self.assertTrue(q.isActive)
try:
q.awaitTermination("hello")
self.fail("Expected a value exception")
except ValueError:
pass
now = time.time()
# test should take at least 2 seconds
res = q.awaitTermination(2.6)
duration = time.time() - now
self.assertTrue(duration >= 2)
self.assertFalse(res)
q.processAllAvailable()
q.stop()
# Sanity check when no parameter is set
q.awaitTermination()
self.assertFalse(q.isActive)
finally:
q.stop()
shutil.rmtree(tmpPath)
def test_stream_exception(self):
with self.sql_conf({"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled": True}):
sdf = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
sq = sdf.writeStream.format("memory").queryName("query_explain").start()
try:
sq.processAllAvailable()
self.assertEqual(sq.exception(), None)
finally:
sq.stop()
from pyspark.sql.functions import col, udf
from pyspark.errors import StreamingQueryException
bad_udf = udf(lambda x: 1 / 0)
sq = (
sdf.select(bad_udf(col("value")))
.writeStream.format("memory")
.queryName("this_query")
.start()
)
try:
# Process some data to fail the query
sq.processAllAvailable()
self.fail("bad udf should fail the query")
except StreamingQueryException as e:
# This is expected
self._assert_exception_tree_contains_msg(e, "ZeroDivisionError")
finally:
exception = sq.exception()
sq.stop()
self.assertIsInstance(exception, StreamingQueryException)
self._assert_exception_tree_contains_msg(exception, "ZeroDivisionError")
def test_query_manager_no_recreation(self):
# SPARK-46873: There should not be a new StreamingQueryManager created every time
# spark.streams is called.
for i in range(5):
self.assertTrue(self.spark.streams == self.spark.streams)
def test_query_manager_get(self):
df = self.spark.readStream.format("rate").load()
for q in self.spark.streams.active:
q.stop()
q = df.writeStream.format("noop").start()
self.assertTrue(q.isActive)
self.assertTrue(q.id == self.spark.streams.get(q.id).id)
q.stop()
self.assertIsNone(self.spark.streams.get(q.id))
def test_query_manager_await_termination(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
for q in self.spark.streams.active:
q.stop()
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
self.assertTrue(df.isStreaming)
out = os.path.join(tmpPath, "out")
chk = os.path.join(tmpPath, "chk")
q = df.writeStream.start(
path=out, format="parquet", queryName="this_query", checkpointLocation=chk
)
try:
self.assertTrue(q.isActive)
try:
self.spark.streams.awaitAnyTermination("hello")
self.fail("Expected a value exception")
except ValueError:
pass
now = time.time()
# test should take at least 2 seconds
res = self.spark.streams.awaitAnyTermination(2.6)
duration = time.time() - now
self.assertTrue(duration >= 2)
self.assertFalse(res)
finally:
q.processAllAvailable()
q.stop()
shutil.rmtree(tmpPath)
def test_streaming_read_from_table(self):
with self.table("input_table", "this_query"):
self.spark.sql("CREATE TABLE input_table (value string) USING parquet")
self.spark.sql("INSERT INTO input_table VALUES ('aaa'), ('bbb'), ('ccc')")
df = self.spark.readStream.table("input_table")
self.assertTrue(df.isStreaming)
q = df.writeStream.format("memory").queryName("this_query").start()
q.processAllAvailable()
q.stop()
result = self.spark.sql("SELECT * FROM this_query ORDER BY value").collect()
self.assertEqual(
set([Row(value="aaa"), Row(value="bbb"), Row(value="ccc")]), set(result)
)
def test_streaming_write_to_table(self):
with self.table("output_table"), tempfile.TemporaryDirectory(prefix="to_table") as tmpdir:
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
q = df.writeStream.toTable("output_table", format="parquet", checkpointLocation=tmpdir)
self.assertTrue(q.isActive)
time.sleep(10)
q.stop()
result = self.spark.sql("SELECT value FROM output_table").collect()
self.assertTrue(len(result) > 0)
def test_streaming_write_to_table_cluster_by(self):
with self.table("output_table"), tempfile.TemporaryDirectory(prefix="to_table") as tmpdir:
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
q = df.writeStream.clusterBy("value").toTable(
"output_table", format="parquet", checkpointLocation=tmpdir
)
self.assertTrue(q.isActive)
time.sleep(10)
q.stop()
result = self.spark.sql("DESCRIBE output_table").collect()
self.assertEqual(
set(
[
Row(col_name="timestamp", data_type="timestamp", comment=None),
Row(col_name="value", data_type="bigint", comment=None),
Row(col_name="# Clustering Information", data_type="", comment=""),
Row(col_name="# col_name", data_type="data_type", comment="comment"),
Row(col_name="value", data_type="bigint", comment=None),
]
),
set(result),
)
result = self.spark.sql("SELECT value FROM output_table").collect()
self.assertTrue(len(result) > 0)
def test_streaming_with_temporary_view(self):
"""
This verifies createOrReplaceTempView() works with a streaming dataframe. An SQL
SELECT query on such a table results in a streaming dataframe and the streaming query works
as expected.
"""
with self.table("input_table", "this_query"):
self.spark.sql("CREATE TABLE input_table (value string) USING parquet")
self.spark.sql("INSERT INTO input_table VALUES ('a'), ('b'), ('c')")
df = self.spark.readStream.table("input_table")
self.assertTrue(df.isStreaming)
# Create a temp view
df.createOrReplaceTempView("test_view")
# Create a select query
view_df = self.spark.sql("SELECT CONCAT('view_', value) as vv from test_view")
self.assertTrue(view_df.isStreaming)
q = view_df.writeStream.format("memory").queryName("this_query").start()
q.processAllAvailable()
q.stop()
result = self.spark.sql("SELECT * FROM this_query ORDER BY vv").collect()
self.assertEqual(
set([Row(value="view_a"), Row(value="view_b"), Row(value="view_c")]), set(result)
)
def test_streaming_drop_duplicate_within_watermark(self):
"""
This verifies dropDuplicatesWithinWatermark works with a streaming dataframe.
"""
user_schema = StructType().add("time", TimestampType()).add("id", "integer")
df = (
self.spark.readStream.option("sep", ";")
.schema(user_schema)
.csv("python/test_support/sql/streaming/time")
)
q1 = (
df.withWatermark("time", "2 seconds")
.dropDuplicatesWithinWatermark(["id"])
.writeStream.outputMode("update")
.format("memory")
.queryName("test_streaming_drop_duplicates_within_wm")
.start()
)
self.assertTrue(q1.isActive)
q1.processAllAvailable()
q1.stop()
result = self.spark.sql("SELECT * FROM test_streaming_drop_duplicates_within_wm").collect()
self.assertTrue(len(result) >= 6 and len(result) <= 9)
class StreamingTests(StreamingTestsMixin, ReusedSQLTestCase):
def _assert_exception_tree_contains_msg(self, exception, msg):
e = exception
contains = msg in e._desc
while e._cause is not None and not contains:
e = e._cause
contains = msg in e._desc
self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg)
if __name__ == "__main__":
import unittest
from pyspark.sql.tests.streaming.test_streaming import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)