blob: b29338e7f59e7d9ecf86545430573f7dd705b4a2 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import tempfile
from pyspark.testing.sqlutils import ReusedSQLTestCase
class StreamingTestsForeachMixin:
class ForeachWriterTester:
def __init__(self, spark):
self.spark = spark
def write_open_event(self, partitionId, epochId):
self._write_event(self.open_events_dir, {"partition": partitionId, "epoch": epochId})
def write_process_event(self, row):
self._write_event(self.process_events_dir, {"value": "text"})
def write_close_event(self, error):
self._write_event(self.close_events_dir, {"error": str(error)})
def write_input_file(self):
self._write_event(self.input_dir, "text")
def open_events(self):
return self._read_events(self.open_events_dir, "partition INT, epoch INT")
def process_events(self):
return self._read_events(self.process_events_dir, "value STRING")
def close_events(self):
return self._read_events(self.close_events_dir, "error STRING")
def run_streaming_query_on_writer(self, writer, num_files):
self._reset()
try:
sdf = self.spark.readStream.format("text").load(self.input_dir)
sq = sdf.writeStream.foreach(writer).start()
for i in range(num_files):
self.write_input_file()
sq.processAllAvailable()
finally:
self.stop_all()
def assert_invalid_writer(self, writer, msg=None):
self._reset()
try:
sdf = self.spark.readStream.format("text").load(self.input_dir)
sq = sdf.writeStream.foreach(writer).start()
self.write_input_file()
sq.processAllAvailable()
self.fail("invalid writer %s did not fail the query" % str(writer)) # not expected
except Exception as e:
if msg:
assert msg in str(e), "%s not in %s" % (msg, str(e))
finally:
self.stop_all()
def stop_all(self):
for q in self.spark.streams.active:
q.stop()
def _reset(self):
self.input_dir = tempfile.mkdtemp()
self.open_events_dir = tempfile.mkdtemp()
self.process_events_dir = tempfile.mkdtemp()
self.close_events_dir = tempfile.mkdtemp()
def _read_events(self, dir, json):
rows = self.spark.read.schema(json).json(dir).collect()
dicts = [row.asDict() for row in rows]
return dicts
def _write_event(self, dir, event):
import uuid
with open(os.path.join(dir, str(uuid.uuid4())), "w") as f:
f.write("%s\n" % str(event))
def __getstate__(self):
return (self.open_events_dir, self.process_events_dir, self.close_events_dir)
def __setstate__(self, state):
self.open_events_dir, self.process_events_dir, self.close_events_dir = state
# Those foreach tests are failed in macOS High Sierra by defined rules
# at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html
# To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES.
def test_streaming_foreach_with_simple_function(self):
tester = self.ForeachWriterTester(self.spark)
def foreach_func(row):
tester.write_process_event(row)
tester.run_streaming_query_on_writer(foreach_func, 2)
self.assertEqual(len(tester.process_events()), 2)
def test_streaming_foreach_with_basic_open_process_close(self):
tester = self.ForeachWriterTester(self.spark)
class ForeachWriter:
def open(self, partitionId, epochId):
tester.write_open_event(partitionId, epochId)
return True
def process(self, row):
tester.write_process_event(row)
def close(self, error):
tester.write_close_event(error)
tester.run_streaming_query_on_writer(ForeachWriter(), 2)
open_events = tester.open_events()
self.assertEqual(len(open_events), 2)
self.assertSetEqual(set([e["epoch"] for e in open_events]), {0, 1})
self.assertEqual(len(tester.process_events()), 2)
close_events = tester.close_events()
self.assertEqual(len(close_events), 2)
self.assertSetEqual(set([e["error"] for e in close_events]), {"None"})
def test_streaming_foreach_with_open_returning_false(self):
tester = self.ForeachWriterTester(self.spark)
class ForeachWriter:
def open(self, partition_id, epoch_id):
tester.write_open_event(partition_id, epoch_id)
return False
def process(self, row):
tester.write_process_event(row)
def close(self, error):
tester.write_close_event(error)
tester.run_streaming_query_on_writer(ForeachWriter(), 2)
self.assertEqual(len(tester.open_events()), 2)
self.assertEqual(len(tester.process_events()), 0) # no row was processed
close_events = tester.close_events()
self.assertEqual(len(close_events), 2)
self.assertSetEqual(set([e["error"] for e in close_events]), {"None"})
def test_streaming_foreach_without_open_method(self):
tester = self.ForeachWriterTester(self.spark)
class ForeachWriter:
def process(self, row):
tester.write_process_event(row)
def close(self, error):
tester.write_close_event(error)
tester.run_streaming_query_on_writer(ForeachWriter(), 2)
self.assertEqual(len(tester.open_events()), 0) # no open events
self.assertEqual(len(tester.process_events()), 2)
self.assertEqual(len(tester.close_events()), 2)
def test_streaming_foreach_without_close_method(self):
tester = self.ForeachWriterTester(self.spark)
class ForeachWriter:
def open(self, partition_id, epoch_id):
tester.write_open_event(partition_id, epoch_id)
return True
def process(self, row):
tester.write_process_event(row)
tester.run_streaming_query_on_writer(ForeachWriter(), 2)
self.assertEqual(len(tester.open_events()), 2) # no open events
self.assertEqual(len(tester.process_events()), 2)
self.assertEqual(len(tester.close_events()), 0)
def test_streaming_foreach_without_open_and_close_methods(self):
tester = self.ForeachWriterTester(self.spark)
class ForeachWriter:
def process(self, row):
tester.write_process_event(row)
tester.run_streaming_query_on_writer(ForeachWriter(), 2)
self.assertEqual(len(tester.open_events()), 0) # no open events
self.assertEqual(len(tester.process_events()), 2)
self.assertEqual(len(tester.close_events()), 0)
def test_streaming_foreach_with_process_throwing_error(self):
from pyspark.errors import StreamingQueryException
tester = self.ForeachWriterTester(self.spark)
class ForeachWriter:
def process(self, row):
raise RuntimeError("test error")
def close(self, error):
tester.write_close_event(error)
try:
tester.run_streaming_query_on_writer(ForeachWriter(), 1)
self.fail("bad writer did not fail the query") # this is not expected
except StreamingQueryException as e:
err_msg = str(e)
self.assertTrue("test error" in err_msg)
self.assertTrue("FOREACH_USER_FUNCTION_ERROR" in err_msg)
self.assertEqual(len(tester.process_events()), 0) # no row was processed
close_events = tester.close_events()
self.assertEqual(len(close_events), 1)
# TODO: Verify whether original error message is inside the exception
def test_streaming_foreach_with_invalid_writers(self):
tester = self.ForeachWriterTester(self.spark)
def func_with_iterator_input(iter):
for x in iter:
print(x)
tester.assert_invalid_writer(func_with_iterator_input)
class WriterWithoutProcess:
def open(self, partition):
pass
tester.assert_invalid_writer(WriterWithoutProcess(), "ATTRIBUTE_NOT_CALLABLE")
class WriterWithNonCallableProcess:
process = True
tester.assert_invalid_writer(WriterWithNonCallableProcess(), "ATTRIBUTE_NOT_CALLABLE")
class WriterWithNoParamProcess:
def process(self):
pass
tester.assert_invalid_writer(WriterWithNoParamProcess())
# Abstract class for tests below
class WithProcess:
def process(self, row):
pass
class WriterWithNonCallableOpen(WithProcess):
open = True
tester.assert_invalid_writer(WriterWithNonCallableOpen(), "ATTRIBUTE_NOT_CALLABLE")
class WriterWithNoParamOpen(WithProcess):
def open(self):
pass
tester.assert_invalid_writer(WriterWithNoParamOpen())
class WriterWithNonCallableClose(WithProcess):
close = True
tester.assert_invalid_writer(WriterWithNonCallableClose(), "ATTRIBUTE_NOT_CALLABLE")
class StreamingTestsForeach(StreamingTestsForeachMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
import unittest
from pyspark.sql.tests.streaming.test_streaming_foreach import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)