blob: 9879231540f1dfeb9bde9308e17aca96b0e547e4 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import tempfile
import time
import unittest
from pyspark.sql.datasource import (
DataSource,
DataSourceStreamReader,
InputPartition,
DataSourceStreamWriter,
DataSourceStreamArrowWriter,
SimpleDataSourceStreamReader,
WriterCommitMessage,
)
from pyspark.sql.streaming import StreamingQueryException
from pyspark.sql.types import Row
from pyspark.testing.sqlutils import (
have_pyarrow,
pyarrow_requirement_message,
)
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class BasePythonStreamingDataSourceTestsMixin:
def test_basic_streaming_data_source_class(self):
class MyDataSource(DataSource):
...
options = dict(a=1, b=2)
ds = MyDataSource(options=options)
self.assertEqual(ds.options, options)
self.assertEqual(ds.name(), "MyDataSource")
with self.assertRaises(NotImplementedError):
ds.schema()
with self.assertRaises(NotImplementedError):
ds.streamReader(None)
with self.assertRaises(NotImplementedError):
ds.streamWriter(None, None)
def test_basic_data_source_stream_reader_class(self):
class MyDataSourceStreamReader(DataSourceStreamReader):
def read(self, partition):
yield (1, "abc")
stream_reader = MyDataSourceStreamReader()
self.assertEqual(list(stream_reader.read(None)), [(1, "abc")])
def _get_test_data_source(self):
class RangePartition(InputPartition):
def __init__(self, start, end):
self.start = start
self.end = end
class TestStreamReader(DataSourceStreamReader):
current = 0
def initialOffset(self):
return {"offset": 0}
def latestOffset(self):
self.current += 2
return {"offset": self.current}
def partitions(self, start, end):
return [RangePartition(start["offset"], end["offset"])]
def commit(self, end):
pass
def read(self, partition):
start, end = partition.start, partition.end
for i in range(start, end):
yield (i,)
import json
import os
from dataclasses import dataclass
@dataclass
class SimpleCommitMessage(WriterCommitMessage):
partition_id: int
count: int
class TestStreamWriter(DataSourceStreamWriter):
def __init__(self, options):
self.options = options
self.path = self.options.get("path")
assert self.path is not None
def write(self, iterator):
from pyspark import TaskContext
context = TaskContext.get()
partition_id = context.partitionId()
cnt = 0
for row in iterator:
if row.id > 50:
raise Exception("invalid value")
cnt += 1
return SimpleCommitMessage(partition_id=partition_id, count=cnt)
def commit(self, messages, batchId) -> None:
status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
file.write(json.dumps(status) + "\\n")
def abort(self, messages, batchId) -> None:
with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
file.write(f"failed in batch {batchId}")
class TestDataSource(DataSource):
def schema(self):
return "id INT"
def streamReader(self, schema):
return TestStreamReader()
def streamWriter(self, schema, overwrite):
return TestStreamWriter(self.options)
return TestDataSource
def test_stream_reader(self):
self.spark.dataSource.register(self._get_test_data_source())
df = self.spark.readStream.format("TestDataSource").load()
def check_batch(df, batch_id):
assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)])
q = df.writeStream.foreachBatch(check_batch).start()
while len(q.recentProgress) < 10:
time.sleep(0.2)
q.stop()
q.awaitTermination()
self.assertIsNone(q.exception(), "No exception has to be propagated.")
def test_stream_reader_pyarrow(self):
import pyarrow as pa
class TestStreamReader(DataSourceStreamReader):
def initialOffset(self):
return {"offset": 0}
def latestOffset(self):
return {"offset": 2}
def partitions(self, start, end):
# hardcoded number of partitions
num_part = 1
return [InputPartition(i) for i in range(num_part)]
def read(self, partition):
keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
yield record_batch
class TestDataSourcePyarrow(DataSource):
@classmethod
def name(cls):
return "testdatasourcepyarrow"
def schema(self):
return "key int, value string"
def streamReader(self, schema):
return TestStreamReader()
self.spark.dataSource.register(TestDataSourcePyarrow)
df = self.spark.readStream.format("testdatasourcepyarrow").load()
output_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_output")
checkpoint_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_checkpoint")
q = (
df.writeStream.format("json")
.option("checkpointLocation", checkpoint_dir.name)
.start(output_dir.name)
)
while not q.recentProgress:
time.sleep(0.2)
q.stop()
q.awaitTermination()
expected_data = [
Row(key=1, value="one"),
Row(key=2, value="two"),
Row(key=3, value="three"),
Row(key=4, value="four"),
Row(key=5, value="five"),
]
df = self.spark.read.json(output_dir.name)
assertDataFrameEqual(df, expected_data)
def test_simple_stream_reader(self):
class SimpleStreamReader(SimpleDataSourceStreamReader):
def initialOffset(self):
return {"offset": 0}
def read(self, start: dict):
start_idx = start["offset"]
it = iter([(i,) for i in range(start_idx, start_idx + 2)])
return (it, {"offset": start_idx + 2})
def commit(self, end):
pass
def readBetweenOffsets(self, start: dict, end: dict):
start_idx = start["offset"]
end_idx = end["offset"]
return iter([(i,) for i in range(start_idx, end_idx)])
class SimpleDataSource(DataSource):
def schema(self):
return "id INT"
def simpleStreamReader(self, schema):
return SimpleStreamReader()
self.spark.dataSource.register(SimpleDataSource)
df = self.spark.readStream.format("SimpleDataSource").load()
def check_batch(df, batch_id):
assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 1)])
q = df.writeStream.foreachBatch(check_batch).start()
while len(q.recentProgress) < 10:
time.sleep(0.2)
q.stop()
q.awaitTermination()
self.assertIsNone(q.exception(), "No exception has to be propagated.")
def test_stream_writer(self):
input_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_input")
output_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_output")
checkpoint_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_checkpoint")
try:
self.spark.range(0, 30).repartition(2).write.format("json").mode("append").save(
input_dir.name
)
self.spark.dataSource.register(self._get_test_data_source())
df = self.spark.readStream.schema("id int").json(input_dir.name)
q = (
df.writeStream.format("TestDataSource")
.option("checkpointLocation", checkpoint_dir.name)
.start(output_dir.name)
)
while not q.recentProgress:
time.sleep(0.2)
# Test stream writer write and commit.
# The first microbatch contain 30 rows and 2 partitions.
# Number of rows and partitions is writen by StreamWriter.commit().
assertDataFrameEqual(self.spark.read.json(output_dir.name), [Row(2, 30)])
self.spark.range(50, 80).repartition(2).write.format("json").mode("append").save(
input_dir.name
)
# Test StreamWriter write and abort.
# When row id > 50, write tasks throw exception and fail.
# 1.txt is written by StreamWriter.abort() to record the failure.
while q.exception() is None:
time.sleep(0.2)
assertDataFrameEqual(
self.spark.read.text(os.path.join(output_dir.name, "1.txt")),
[Row("failed in batch 1")],
)
q.awaitTermination()
except StreamingQueryException as e:
self.assertIn("invalid value", str(e))
finally:
input_dir.cleanup()
output_dir.cleanup()
checkpoint_dir.cleanup()
def test_stream_arrow_writer(self):
"""Test DataSourceStreamArrowWriter with Arrow RecordBatch format."""
import tempfile
import shutil
import json
import os
import pyarrow as pa
from dataclasses import dataclass
@dataclass
class ArrowCommitMessage(WriterCommitMessage):
partition_id: int
batch_count: int
total_rows: int
class TestStreamArrowWriter(DataSourceStreamArrowWriter):
def __init__(self, options):
self.options = options
self.path = self.options.get("path")
assert self.path is not None
def write(self, iterator):
from pyspark import TaskContext
context = TaskContext.get()
partition_id = context.partitionId()
batch_count = 0
total_rows = 0
for batch in iterator:
assert isinstance(batch, pa.RecordBatch)
batch_count += 1
total_rows += batch.num_rows
# Convert to pandas and write to temp JSON file
df = batch.to_pandas()
filename = f"partition_{partition_id}_batch_{batch_count}.json"
filepath = os.path.join(self.path, filename)
# Actually write the JSON file
df.to_json(filepath, orient="records")
commit_msg = ArrowCommitMessage(
partition_id=partition_id, batch_count=batch_count, total_rows=total_rows
)
return commit_msg
def commit(self, messages, batchId):
"""Write commit metadata for successful batch."""
total_batches = sum(m.batch_count for m in messages if m)
total_rows = sum(m.total_rows for m in messages if m)
status = {
"batch_id": batchId,
"num_partitions": len([m for m in messages if m]),
"total_batches": total_batches,
"total_rows": total_rows,
}
with open(os.path.join(self.path, f"commit_{batchId}.json"), "w") as f:
json.dump(status, f)
def abort(self, messages, batchId):
"""Handle batch failure."""
with open(os.path.join(self.path, f"abort_{batchId}.txt"), "w") as f:
f.write(f"Batch {batchId} aborted")
class TestDataSource(DataSource):
@classmethod
def name(cls):
return "TestArrowStreamWriter"
def schema(self):
return "id INT, name STRING, value DOUBLE"
def streamWriter(self, schema, overwrite):
return TestStreamArrowWriter(self.options)
# Create temporary directory for test
temp_dir = tempfile.mkdtemp()
try:
# Register the data source
self.spark.dataSource.register(TestDataSource)
# Create test data
df = (
self.spark.readStream.format("rate")
.option("rowsPerSecond", 10)
.option("numPartitions", 3)
.load()
.selectExpr("value as id", "concat('name_', value) as name", "value * 2.5 as value")
)
# Write using streaming with Arrow writer
query = (
df.writeStream.format("TestArrowStreamWriter")
.option("path", temp_dir)
.option("checkpointLocation", os.path.join(temp_dir, "checkpoint"))
.trigger(processingTime="1 seconds")
.start()
)
# Wait a bit for data to be processed, then stop
time.sleep(6) # Allow a few batches to run
query.stop()
query.awaitTermination()
# Since we're writing actual JSON files, verify commit metadata and written files
commit_files = [f for f in os.listdir(temp_dir) if f.startswith("commit_")]
self.assertTrue(len(commit_files) > 0, "No commit files were created")
# Read and verify commit metadata - check all commits for any with data
total_committed_rows = 0
total_committed_batches = 0
for commit_file in commit_files:
with open(os.path.join(temp_dir, commit_file), "r") as f:
commit_data = json.load(f)
total_committed_rows += commit_data.get("total_rows", 0)
total_committed_batches += commit_data.get("total_batches", 0)
# We should have both committed data AND JSON files written
json_files = [
f
for f in os.listdir(temp_dir)
if f.startswith("partition_") and f.endswith(".json")
]
# Verify that we have both committed data AND JSON files
has_committed_data = total_committed_rows > 0
has_json_files = len(json_files) > 0
self.assertTrue(
has_committed_data, f"Expected committed data but got {total_committed_rows} rows"
)
self.assertTrue(
has_json_files, f"Expected JSON files but found {len(json_files)} files"
)
# Verify JSON files contain valid data
for json_file in json_files:
with open(os.path.join(temp_dir, json_file), "r") as f:
data = json.load(f)
self.assertTrue(len(data) > 0, f"JSON file {json_file} is empty")
finally:
# Clean up
shutil.rmtree(temp_dir, ignore_errors=True)
class PythonStreamingDataSourceTests(BasePythonStreamingDataSourceTestsMixin, ReusedSQLTestCase):
pass
if __name__ == "__main__":
from pyspark.sql.tests.test_python_streaming_datasource import * # noqa: F401
try:
import xmlrunner # type: ignore
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)