blob: 818d8361537d8aa93b754f944e19e7e7dfed4668 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from pyspark.sql.dataframe import DataFrame
from pyspark.testing.sqlutils import ReusedSQLTestCase
def my_test_function_1():
return 1
class StreamingTestsForeachBatchMixin:
def test_streaming_foreach_batch(self):
q = None
def collectBatch(batch_df, batch_id):
batch_df.write.format("parquet").saveAsTable("test_table1")
try:
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
q = df.writeStream.foreachBatch(collectBatch).start()
q.processAllAvailable()
collected = self.spark.sql("select * from test_table1").collect()
self.assertTrue(len(collected), 2)
finally:
if q:
q.stop()
self.spark.sql("DROP TABLE IF EXISTS test_table1")
def test_streaming_foreach_batch_tempview(self):
q = None
def collectBatch(batch_df, batch_id):
batch_df.createOrReplaceTempView("updates")
# it should use the spark session within given DataFrame, as microbatch execution will
# clone the session which is no longer same with the session used to start the
# streaming query
assert len(batch_df.sparkSession.sql("SELECT * FROM updates").collect()) == 2
# Write a table to verify on the repl/client side.
batch_df.write.format("parquet").saveAsTable("test_table2")
try:
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
q = df.writeStream.foreachBatch(collectBatch).start()
q.processAllAvailable()
collected = self.spark.sql("SELECT * FROM test_table2").collect()
self.assertTrue(len(collected[0]), 2)
finally:
if q:
q.stop()
self.spark.sql("DROP TABLE IF EXISTS test_table2")
def test_streaming_foreach_batch_propagates_python_errors(self):
from pyspark.errors import StreamingQueryException
q = None
def collectBatch(df, id):
raise RuntimeError("this should fail the query")
try:
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
q = df.writeStream.foreachBatch(collectBatch).start()
q.processAllAvailable()
self.fail("Expected a failure")
except StreamingQueryException as e:
err_msg = str(e)
self.assertTrue("this should fail" in err_msg)
finally:
if q:
q.stop()
def test_streaming_foreach_batch_graceful_stop(self):
# SPARK-39218: Make foreachBatch streaming query stop gracefully
def func(batch_df, _):
batch_df.sparkSession._jvm.java.lang.Thread.sleep(10000)
q = self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
time.sleep(3) # 'rowsPerSecond' defaults to 1. Waits 3 secs out for the input.
q.stop()
self.assertIsNone(q.exception(), "No exception has to be propagated.")
def test_streaming_foreach_batch_spark_session(self):
table_name = "testTable_foreach_batch"
with self.table(table_name):
def func(df: DataFrame, batch_id: int):
if batch_id > 0: # only process once
return
spark = df.sparkSession
df1 = spark.createDataFrame([("structured",), ("streaming",)])
df1.union(df).write.mode("append").saveAsTable(table_name)
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
q = df.writeStream.foreachBatch(func).start()
q.processAllAvailable()
q.stop()
actual = self.spark.read.table(table_name)
df = (
self.spark.read.format("text")
.load(path="python/test_support/sql/streaming/")
.union(self.spark.createDataFrame([("structured",), ("streaming",)]))
)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
def test_streaming_foreach_batch_path_access(self):
table_name = "testTable_foreach_batch_path"
with self.table(table_name):
def func(df: DataFrame, batch_id: int):
if batch_id > 0: # only process once
return
spark = df.sparkSession
df1 = spark.read.format("text").load("python/test_support/sql/streaming")
df1.union(df).write.mode("append").saveAsTable(table_name)
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
q = df.writeStream.foreachBatch(func).start()
q.processAllAvailable()
q.stop()
actual = self.spark.read.table(table_name)
df = self.spark.read.format("text").load(path="python/test_support/sql/streaming/")
df = df.union(df)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
@staticmethod
def my_test_function_2():
return 2
def test_streaming_foreach_batch_function_calling(self):
def my_test_function_3():
return 3
table_name = "testTable_foreach_batch_function"
with self.table(table_name):
def func(df: DataFrame, batch_id: int):
if batch_id > 0: # only process once
return
spark = df.sparkSession
df1 = spark.createDataFrame(
[
(my_test_function_1(),),
(StreamingTestsForeachBatchMixin.my_test_function_2(),),
(my_test_function_3(),),
]
)
df1.write.mode("append").saveAsTable(table_name)
df = self.spark.readStream.format("rate").load()
q = df.writeStream.foreachBatch(func).start()
q.processAllAvailable()
q.stop()
actual = self.spark.read.table(table_name)
df = self.spark.createDataFrame(
[
(my_test_function_1(),),
(StreamingTestsForeachBatchMixin.my_test_function_2(),),
(my_test_function_3(),),
]
)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
def test_streaming_foreach_batch_import(self):
import time # not imported in foreach_batch_worker
table_name = "testTable_foreach_batch_import"
with self.table(table_name):
def func(df: DataFrame, batch_id: int):
if batch_id > 0: # only process once
return
time.sleep(1)
spark = df.sparkSession
df1 = spark.read.format("text").load("python/test_support/sql/streaming")
df1.write.mode("append").saveAsTable(table_name)
df = self.spark.readStream.format("rate").load()
q = df.writeStream.foreachBatch(func).start()
q.processAllAvailable()
q.stop()
actual = self.spark.read.table(table_name)
df = self.spark.read.format("text").load("python/test_support/sql/streaming")
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
def test_streaming_foreach_batch_external_column(self):
from pyspark.sql import functions as sf
table_name = "testTable_foreach_batch_external_column"
with self.table(table_name):
# Define 'col' outside the `func` below, so it'd have to be serialized.
col = sf.col("value")
def func(df: DataFrame, batch_id: int):
result_df = df.select(col.alias("result"))
result_df.write.mode("append").saveAsTable(table_name)
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")
q = df.writeStream.foreachBatch(func).start()
q.processAllAvailable()
q.stop()
collected = self.spark.sql("select * from " + table_name).collect()
results = [row["result"] for row in collected]
self.assertEqual(sorted(results), ["hello", "this"])
class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.testing import main
main()