| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import os |
| import tempfile |
| import time |
| import unittest |
| import json |
| |
| from pyspark.sql.datasource import ( |
| DataSource, |
| DataSourceStreamReader, |
| InputPartition, |
| DataSourceStreamWriter, |
| DataSourceStreamArrowWriter, |
| SimpleDataSourceStreamReader, |
| WriterCommitMessage, |
| ) |
| from pyspark.sql.streaming.datasource import ( |
| ReadAllAvailable, |
| ReadLimit, |
| ReadMaxRows, |
| SupportsTriggerAvailableNow, |
| ) |
| from pyspark.sql.streaming import StreamingQueryException |
| from pyspark.sql.types import Row |
| from pyspark.testing.sqlutils import ( |
| have_pyarrow, |
| pyarrow_requirement_message, |
| ) |
| from pyspark.testing import assertDataFrameEqual |
| from pyspark.testing.utils import eventually |
| from pyspark.testing.sqlutils import ReusedSQLTestCase |
| |
| |
| def wait_for_condition(query, condition_fn, timeout_sec=30): |
| """ |
| Wait for a condition on a streaming query to be met, with timeout and error context. |
| |
| :param query: StreamingQuery object |
| :param condition_fn: Function that takes query and returns True when condition is met |
| :param timeout_sec: Timeout in seconds (default 30) |
| :raises TimeoutError: If condition is not met within timeout, with query context |
| """ |
| start_time = time.time() |
| sleep_interval = 0.2 |
| |
| while not condition_fn(query): |
| elapsed = time.time() - start_time |
| if elapsed >= timeout_sec: |
| # Collect context for debugging |
| exception_info = query.exception() |
| recent_progresses = query.recentProgress |
| |
| error_msg = ( |
| f"Timeout after {timeout_sec} seconds waiting for condition. " |
| f"Query exception: {exception_info}. " |
| f"Recent progress count: {len(recent_progresses)}. " |
| ) |
| |
| if recent_progresses: |
| error_msg += f"Last progress: {recent_progresses[-1]}. " |
| error_msg += f"All recent progresses: {recent_progresses}" |
| else: |
| error_msg += "No progress recorded." |
| |
| raise TimeoutError(error_msg) |
| |
| time.sleep(sleep_interval) |
| |
| |
| @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) |
| class BasePythonStreamingDataSourceTestsMixin: |
| def test_basic_streaming_data_source_class(self): |
| class MyDataSource(DataSource): |
| ... |
| |
| options = dict(a=1, b=2) |
| ds = MyDataSource(options=options) |
| self.assertEqual(ds.options, options) |
| self.assertEqual(ds.name(), "MyDataSource") |
| with self.assertRaises(NotImplementedError): |
| ds.schema() |
| with self.assertRaises(NotImplementedError): |
| ds.streamReader(None) |
| with self.assertRaises(NotImplementedError): |
| ds.streamWriter(None, None) |
| |
| def test_basic_data_source_stream_reader_class(self): |
| class MyDataSourceStreamReader(DataSourceStreamReader): |
| def read(self, partition): |
| yield (1, "abc") |
| |
| stream_reader = MyDataSourceStreamReader() |
| self.assertEqual(list(stream_reader.read(None)), [(1, "abc")]) |
| |
| def _get_test_data_source(self): |
| class RangePartition(InputPartition): |
| def __init__(self, start, end): |
| self.start = start |
| self.end = end |
| |
| class TestStreamReader(DataSourceStreamReader): |
| current = 0 |
| |
| def initialOffset(self): |
| return {"offset": 0} |
| |
| def latestOffset(self, start, limit): |
| return {"offset": start["offset"] + 2} |
| |
| def partitions(self, start, end): |
| return [RangePartition(start["offset"], end["offset"])] |
| |
| def commit(self, end): |
| pass |
| |
| def read(self, partition): |
| start, end = partition.start, partition.end |
| for i in range(start, end): |
| yield (i,) |
| |
| import json |
| import os |
| from dataclasses import dataclass |
| |
| @dataclass |
| class SimpleCommitMessage(WriterCommitMessage): |
| partition_id: int |
| count: int |
| |
| class TestStreamWriter(DataSourceStreamWriter): |
| def __init__(self, options): |
| self.options = options |
| self.path = self.options.get("path") |
| assert self.path is not None |
| |
| def write(self, iterator): |
| from pyspark import TaskContext |
| |
| context = TaskContext.get() |
| partition_id = context.partitionId() |
| cnt = 0 |
| for row in iterator: |
| if row.id > 50: |
| raise Exception("invalid value") |
| cnt += 1 |
| return SimpleCommitMessage(partition_id=partition_id, count=cnt) |
| |
| def commit(self, messages, batchId) -> None: |
| status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages)) |
| |
| with open(os.path.join(self.path, f"{batchId}.json"), "a") as file: |
| file.write(json.dumps(status) + "\\n") |
| |
| def abort(self, messages, batchId) -> None: |
| with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file: |
| file.write(f"failed in batch {batchId}") |
| |
| class TestDataSource(DataSource): |
| def schema(self): |
| return "id INT" |
| |
| def streamReader(self, schema): |
| return TestStreamReader() |
| |
| def streamWriter(self, schema, overwrite): |
| return TestStreamWriter(self.options) |
| |
| return TestDataSource |
| |
| def _get_test_data_source_old_latest_offset_signature(self): |
| class RangePartition(InputPartition): |
| def __init__(self, start, end): |
| self.start = start |
| self.end = end |
| |
| class TestStreamReader(DataSourceStreamReader): |
| current = 0 |
| |
| def initialOffset(self): |
| return {"offset": 0} |
| |
| def latestOffset(self): |
| self.current += 2 |
| return {"offset": self.current} |
| |
| def partitions(self, start, end): |
| return [RangePartition(start["offset"], end["offset"])] |
| |
| def commit(self, end): |
| pass |
| |
| def read(self, partition): |
| start, end = partition.start, partition.end |
| for i in range(start, end): |
| yield (i,) |
| |
| class TestDataSource(DataSource): |
| def schema(self): |
| return "id INT" |
| |
| def streamReader(self, schema): |
| return TestStreamReader() |
| |
| return TestDataSource |
| |
| def _get_test_data_source_for_admission_control(self): |
| class TestDataStreamReader(DataSourceStreamReader): |
| def initialOffset(self): |
| return {"partition-1": 0} |
| |
| def getDefaultReadLimit(self): |
| return ReadMaxRows(2) |
| |
| def latestOffset(self, start: dict, limit: ReadLimit): |
| start_idx = start["partition-1"] |
| if isinstance(limit, ReadAllAvailable): |
| end_offset = start_idx + 10 |
| else: |
| assert isinstance( |
| limit, ReadMaxRows |
| ), "Expected ReadMaxRows read limit but got " + str(type(limit)) |
| end_offset = start_idx + limit.max_rows |
| return {"partition-1": end_offset} |
| |
| def reportLatestOffset(self): |
| return {"partition-1": 1000000} |
| |
| def partitions(self, start: dict, end: dict): |
| start_index = start["partition-1"] |
| end_index = end["partition-1"] |
| return [InputPartition(i) for i in range(start_index, end_index)] |
| |
| def read(self, partition): |
| yield (partition.value,) |
| |
| class TestDataSource(DataSource): |
| def schema(self) -> str: |
| return "id INT" |
| |
| def streamReader(self, schema): |
| return TestDataStreamReader() |
| |
| return TestDataSource |
| |
| def _get_test_data_source_for_trigger_available_now(self): |
| class TestDataStreamReader(DataSourceStreamReader, SupportsTriggerAvailableNow): |
| def initialOffset(self): |
| return {"partition-1": 0} |
| |
| def getDefaultReadLimit(self): |
| return ReadMaxRows(2) |
| |
| def latestOffset(self, start: dict, limit: ReadLimit): |
| start_idx = start["partition-1"] |
| if isinstance(limit, ReadAllAvailable): |
| end_offset = start_idx + 10 |
| else: |
| assert isinstance( |
| limit, ReadMaxRows |
| ), "Expected ReadMaxRows read limit but got " + str(type(limit)) |
| end_offset = min( |
| start_idx + limit.max_rows, self.desired_end_offset["partition-1"] |
| ) |
| return {"partition-1": end_offset} |
| |
| def reportLatestOffset(self): |
| return {"partition-1": 1000000} |
| |
| def prepareForTriggerAvailableNow(self) -> None: |
| self.desired_end_offset = {"partition-1": 10} |
| |
| def partitions(self, start: dict, end: dict): |
| start_index = start["partition-1"] |
| end_index = end["partition-1"] |
| return [InputPartition(i) for i in range(start_index, end_index)] |
| |
| def read(self, partition): |
| yield (partition.value,) |
| |
| class TestDataSource(DataSource): |
| def schema(self) -> str: |
| return "id INT" |
| |
| def streamReader(self, schema): |
| return TestDataStreamReader() |
| |
| return TestDataSource |
| |
| def _test_stream_reader(self, test_data_source): |
| self.spark.dataSource.register(test_data_source) |
| df = self.spark.readStream.format("TestDataSource").load() |
| |
| def check_batch(df, batch_id): |
| assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) |
| |
| q = df.writeStream.foreachBatch(check_batch).start() |
| wait_for_condition(q, lambda query: len(query.recentProgress) >= 10) |
| q.stop() |
| q.awaitTermination() |
| self.assertIsNone(q.exception(), "No exception has to be propagated.") |
| |
| def test_stream_reader(self): |
| self._test_stream_reader(self._get_test_data_source()) |
| |
| def test_stream_reader_old_latest_offset_signature(self): |
| self._test_stream_reader(self._get_test_data_source_old_latest_offset_signature()) |
| |
| def test_stream_reader_pyarrow(self): |
| import pyarrow as pa |
| |
| class TestStreamReader(DataSourceStreamReader): |
| def initialOffset(self): |
| return {"offset": 0} |
| |
| def latestOffset(self): |
| return {"offset": 2} |
| |
| def partitions(self, start, end): |
| # hardcoded number of partitions |
| num_part = 1 |
| return [InputPartition(i) for i in range(num_part)] |
| |
| def read(self, partition): |
| keys = pa.array([1, 2, 3, 4, 5], type=pa.int32()) |
| values = pa.array(["one", "two", "three", "four", "five"], type=pa.string()) |
| schema = pa.schema([("key", pa.int32()), ("value", pa.string())]) |
| record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema) |
| yield record_batch |
| |
| class TestDataSourcePyarrow(DataSource): |
| @classmethod |
| def name(cls): |
| return "testdatasourcepyarrow" |
| |
| def schema(self): |
| return "key int, value string" |
| |
| def streamReader(self, schema): |
| return TestStreamReader() |
| |
| self.spark.dataSource.register(TestDataSourcePyarrow) |
| df = self.spark.readStream.format("testdatasourcepyarrow").load() |
| |
| output_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_output") |
| checkpoint_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_checkpoint") |
| |
| q = ( |
| df.writeStream.format("json") |
| .option("checkpointLocation", checkpoint_dir.name) |
| .start(output_dir.name) |
| ) |
| wait_for_condition(q, lambda query: len(query.recentProgress) > 0) |
| q.stop() |
| q.awaitTermination() |
| |
| expected_data = [ |
| Row(key=1, value="one"), |
| Row(key=2, value="two"), |
| Row(key=3, value="three"), |
| Row(key=4, value="four"), |
| Row(key=5, value="five"), |
| ] |
| df = self.spark.read.json(output_dir.name) |
| |
| assertDataFrameEqual(df, expected_data) |
| |
| def test_stream_reader_admission_control_trigger_once(self): |
| self.spark.dataSource.register(self._get_test_data_source_for_admission_control()) |
| df = self.spark.readStream.format("TestDataSource").load() |
| |
| def check_batch(df, batch_id): |
| assertDataFrameEqual(df, [Row(x) for x in range(10)]) |
| |
| q = df.writeStream.trigger(once=True).foreachBatch(check_batch).start() |
| q.awaitTermination() |
| self.assertIsNone(q.exception(), "No exception has to be propagated.") |
| self.assertEqual(len(q.recentProgress), 1) |
| self.assertEqual(q.lastProgress.numInputRows, 10) |
| self.assertEqual(q.lastProgress.sources[0].numInputRows, 10) |
| self.assertEqual( |
| json.loads(q.lastProgress.sources[0].latestOffset), {"partition-1": 1000000} |
| ) |
| |
| def test_stream_reader_admission_control_processing_time_trigger(self): |
| self.spark.dataSource.register(self._get_test_data_source_for_admission_control()) |
| df = self.spark.readStream.format("TestDataSource").load() |
| |
| def check_batch(df, batch_id): |
| assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) |
| |
| q = df.writeStream.foreachBatch(check_batch).start() |
| wait_for_condition(q, lambda query: len(query.recentProgress) >= 10) |
| q.stop() |
| q.awaitTermination() |
| self.assertIsNone(q.exception(), "No exception has to be propagated.") |
| |
| def test_stream_reader_trigger_available_now(self): |
| self.spark.dataSource.register(self._get_test_data_source_for_trigger_available_now()) |
| df = self.spark.readStream.format("TestDataSource").load() |
| |
| def check_batch(df, batch_id): |
| assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) |
| |
| q = df.writeStream.foreachBatch(check_batch).trigger(availableNow=True).start() |
| q.awaitTermination(timeout=30) |
| self.assertIsNone(q.exception(), "No exception has to be propagated.") |
| # 2 rows * 5 batches = 10 rows |
| self.assertEqual(len(q.recentProgress), 5) |
| for progress in q.recentProgress: |
| self.assertEqual(progress.numInputRows, 2) |
| self.assertEqual(q.lastProgress.sources[0].numInputRows, 2) |
| self.assertEqual( |
| json.loads(q.lastProgress.sources[0].latestOffset), {"partition-1": 1000000} |
| ) |
| |
| def test_simple_stream_reader(self): |
| class SimpleStreamReader(SimpleDataSourceStreamReader): |
| def initialOffset(self): |
| return {"offset": 0} |
| |
| def read(self, start: dict): |
| start_idx = start["offset"] |
| it = iter([(i,) for i in range(start_idx, start_idx + 2)]) |
| return (it, {"offset": start_idx + 2}) |
| |
| def commit(self, end): |
| pass |
| |
| def readBetweenOffsets(self, start: dict, end: dict): |
| start_idx = start["offset"] |
| end_idx = end["offset"] |
| return iter([(i,) for i in range(start_idx, end_idx)]) |
| |
| class SimpleDataSource(DataSource): |
| def schema(self): |
| return "id INT" |
| |
| def simpleStreamReader(self, schema): |
| return SimpleStreamReader() |
| |
| self.spark.dataSource.register(SimpleDataSource) |
| df = self.spark.readStream.format("SimpleDataSource").load() |
| |
| def check_batch(df, batch_id): |
| assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) |
| |
| q = df.writeStream.foreachBatch(check_batch).start() |
| wait_for_condition(q, lambda query: len(query.recentProgress) >= 10) |
| q.stop() |
| q.awaitTermination() |
| self.assertIsNone(q.exception(), "No exception has to be propagated.") |
| |
| def test_simple_stream_reader_trigger_available_now(self): |
| class SimpleStreamReader(SimpleDataSourceStreamReader, SupportsTriggerAvailableNow): |
| def initialOffset(self): |
| return {"offset": 0} |
| |
| def read(self, start: dict): |
| start_idx = start["offset"] |
| end_offset = min(start_idx + 2, self.desired_end_offset["offset"]) |
| it = iter([(i,) for i in range(start_idx, end_offset)]) |
| return (it, {"offset": end_offset}) |
| |
| def commit(self, end): |
| pass |
| |
| def readBetweenOffsets(self, start: dict, end: dict): |
| start_idx = start["offset"] |
| end_idx = end["offset"] |
| return iter([(i,) for i in range(start_idx, end_idx)]) |
| |
| def prepareForTriggerAvailableNow(self) -> None: |
| self.desired_end_offset = {"offset": 10} |
| |
| class SimpleDataSource(DataSource): |
| def schema(self): |
| return "id INT" |
| |
| def simpleStreamReader(self, schema): |
| return SimpleStreamReader() |
| |
| self.spark.dataSource.register(SimpleDataSource) |
| df = self.spark.readStream.format("SimpleDataSource").load() |
| |
| def check_batch(df, batch_id): |
| # the last offset for the data is 9 since the desired end offset is 10 |
| # the batch isn't triggered with no data, so either we have one data or two data in each batch |
| if batch_id * 2 + 1 > 9: |
| assertDataFrameEqual(df, [Row(batch_id * 2)]) |
| else: |
| assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)]) |
| |
| q = df.writeStream.foreachBatch(check_batch).trigger(availableNow=True).start() |
| q.awaitTermination(timeout=30) |
| self.assertIsNone(q.exception(), "No exception has to be propagated.") |
| |
| def test_stream_writer(self): |
| input_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_input") |
| output_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_output") |
| checkpoint_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_checkpoint") |
| |
| try: |
| self.spark.range(0, 30).repartition(2).write.format("json").mode("append").save( |
| input_dir.name |
| ) |
| self.spark.dataSource.register(self._get_test_data_source()) |
| df = self.spark.readStream.schema("id int").json(input_dir.name) |
| q = ( |
| df.writeStream.format("TestDataSource") |
| .option("checkpointLocation", checkpoint_dir.name) |
| .start(output_dir.name) |
| ) |
| wait_for_condition(q, lambda query: len(query.recentProgress) > 0) |
| |
| # Test stream writer write and commit. |
| # The first microbatch contain 30 rows and 2 partitions. |
| # Number of rows and partitions is writen by StreamWriter.commit(). |
| assertDataFrameEqual(self.spark.read.json(output_dir.name), [Row(2, 30)]) |
| |
| self.spark.range(50, 80).repartition(2).write.format("json").mode("append").save( |
| input_dir.name |
| ) |
| |
| # Test StreamWriter write and abort. |
| # When row id > 50, write tasks throw exception and fail. |
| # 1.txt is written by StreamWriter.abort() to record the failure. |
| wait_for_condition(q, lambda query: query.exception() is not None) |
| assertDataFrameEqual( |
| self.spark.read.text(os.path.join(output_dir.name, "1.txt")), |
| [Row("failed in batch 1")], |
| ) |
| q.awaitTermination() |
| except StreamingQueryException as e: |
| self.assertIn("invalid value", str(e)) |
| finally: |
| input_dir.cleanup() |
| output_dir.cleanup() |
| checkpoint_dir.cleanup() |
| |
| def test_stream_arrow_writer(self): |
| """Test DataSourceStreamArrowWriter with Arrow RecordBatch format.""" |
| import tempfile |
| import shutil |
| import json |
| import os |
| import pyarrow as pa |
| from dataclasses import dataclass |
| |
| @dataclass |
| class ArrowCommitMessage(WriterCommitMessage): |
| partition_id: int |
| batch_count: int |
| total_rows: int |
| |
| class TestStreamArrowWriter(DataSourceStreamArrowWriter): |
| def __init__(self, options): |
| self.options = options |
| self.path = self.options.get("path") |
| assert self.path is not None |
| |
| def write(self, iterator): |
| from pyspark import TaskContext |
| |
| context = TaskContext.get() |
| partition_id = context.partitionId() |
| batch_count = 0 |
| total_rows = 0 |
| |
| for batch in iterator: |
| assert isinstance(batch, pa.RecordBatch) |
| batch_count += 1 |
| total_rows += batch.num_rows |
| |
| # Convert to pandas and write to temp JSON file |
| df = batch.to_pandas() |
| |
| filename = f"partition_{partition_id}_batch_{batch_count}.json" |
| filepath = os.path.join(self.path, filename) |
| |
| # Actually write the JSON file |
| df.to_json(filepath, orient="records") |
| |
| commit_msg = ArrowCommitMessage( |
| partition_id=partition_id, batch_count=batch_count, total_rows=total_rows |
| ) |
| return commit_msg |
| |
| def commit(self, messages, batchId): |
| """Write commit metadata for successful batch.""" |
| total_batches = sum(m.batch_count for m in messages if m) |
| total_rows = sum(m.total_rows for m in messages if m) |
| |
| status = { |
| "batch_id": batchId, |
| "num_partitions": len([m for m in messages if m]), |
| "total_batches": total_batches, |
| "total_rows": total_rows, |
| } |
| |
| with open(os.path.join(self.path, f"commit_{batchId}.json"), "w") as f: |
| json.dump(status, f) |
| |
| def abort(self, messages, batchId): |
| """Handle batch failure.""" |
| with open(os.path.join(self.path, f"abort_{batchId}.txt"), "w") as f: |
| f.write(f"Batch {batchId} aborted") |
| |
| class TestDataSource(DataSource): |
| @classmethod |
| def name(cls): |
| return "TestArrowStreamWriter" |
| |
| def schema(self): |
| return "id INT, name STRING, value DOUBLE" |
| |
| def streamWriter(self, schema, overwrite): |
| return TestStreamArrowWriter(self.options) |
| |
| # Create temporary directory for test |
| temp_dir = tempfile.mkdtemp() |
| try: |
| # Register the data source |
| self.spark.dataSource.register(TestDataSource) |
| |
| # Create test data |
| df = ( |
| self.spark.readStream.format("rate") |
| .option("rowsPerSecond", 10) |
| .option("numPartitions", 3) |
| .load() |
| .selectExpr("value as id", "concat('name_', value) as name", "value * 2.5 as value") |
| ) |
| |
| # Write using streaming with Arrow writer |
| query = ( |
| df.writeStream.format("TestArrowStreamWriter") |
| .option("path", temp_dir) |
| .option("checkpointLocation", os.path.join(temp_dir, "checkpoint")) |
| .trigger(processingTime="1 seconds") |
| .start() |
| ) |
| |
| @eventually( |
| timeout=20, |
| interval=2.0, |
| catch_assertions=True, |
| expected_exceptions=(json.JSONDecodeError,), |
| ) |
| def check(): |
| # Since we're writing actual JSON files, verify commit metadata and written files |
| commit_files = [f for f in os.listdir(temp_dir) if f.startswith("commit_")] |
| self.assertTrue(len(commit_files) > 0, "No commit files were created") |
| |
| # Read and verify commit metadata - check all commits for any with data |
| total_committed_rows = 0 |
| total_committed_batches = 0 |
| |
| for commit_file in commit_files: |
| with open(os.path.join(temp_dir, commit_file), "r") as f: |
| commit_data = json.load(f) |
| total_committed_rows += commit_data.get("total_rows", 0) |
| total_committed_batches += commit_data.get("total_batches", 0) |
| |
| self.assertTrue( |
| total_committed_rows > 0, |
| f"Expected committed data but got {total_committed_rows} rows", |
| ) |
| |
| check() |
| |
| query.stop() |
| query.awaitTermination() |
| |
| json_files = [ |
| f |
| for f in os.listdir(temp_dir) |
| if f.startswith("partition_") and f.endswith(".json") |
| ] |
| |
| self.assertTrue( |
| len(json_files) > 0, f"Expected JSON files but found {len(json_files)} files" |
| ) |
| |
| # Verify JSON files contain valid data |
| for json_file in json_files: |
| with open(os.path.join(temp_dir, json_file), "r") as f: |
| data = json.load(f) |
| self.assertTrue(len(data) > 0, f"JSON file {json_file} is empty") |
| |
| finally: |
| # Clean up |
| shutil.rmtree(temp_dir, ignore_errors=True) |
| |
| |
| class PythonStreamingDataSourceTests(BasePythonStreamingDataSourceTestsMixin, ReusedSQLTestCase): |
| pass |
| |
| |
| if __name__ == "__main__": |
| from pyspark.testing import main |
| |
| main() |