| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """ |
| PySpark tests for Kafka streaming integration using Docker test containers. |
| |
| These tests demonstrate how to use KafkaUtils to test Spark streaming with Kafka. |
| Tests require Docker to be running and the following Python packages: |
| - testcontainers[kafka] |
| - kafka-python |
| """ |
| |
| import os |
| import shutil |
| import tempfile |
| import unittest |
| import uuid |
| |
| from pyspark.sql.tests.streaming.kafka_utils import KafkaUtils |
| from pyspark.testing.sqlutils import ReusedSQLTestCase, search_jar, read_classpath |
| from pyspark.testing.utils import ( |
| have_kafka, |
| have_testcontainers, |
| kafka_requirement_message, |
| testcontainers_requirement_message, |
| ) |
| |
| |
| class StreamingKafkaTestsMixin: |
| """ |
| Base mixin for Kafka streaming tests that provides KafkaUtils setup/teardown |
| and topic management. |
| """ |
| |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| |
| # Setup Kafka JAR on classpath before SparkSession is created |
| # This follows the same pattern as streamingutils.py for Kinesis |
| kafka_sql_jar = search_jar( |
| "connector/kafka-0-10-sql", |
| "spark-sql-kafka-0-10_", |
| "spark-sql-kafka-0-10_", |
| return_first=True, |
| ) |
| |
| if kafka_sql_jar is None: |
| raise RuntimeError( |
| "Kafka SQL connector JAR was not found. " |
| "To run these tests, you need to build Spark with " |
| "'build/mvn package' or 'build/sbt Test/package' " |
| "before running this test." |
| ) |
| |
| # Read the full classpath including all dependencies |
| # This works for both Maven builds (reads classpath.txt) and SBT builds (queries SBT) |
| # Define the project name mapping for SBT builds |
| kafka_project_name_map = { |
| "connector/kafka-0-10-sql": "sql-kafka-0-10", |
| } |
| kafka_classpath = read_classpath("connector/kafka-0-10-sql", kafka_project_name_map) |
| all_jars = f"{kafka_sql_jar},{kafka_classpath}" |
| |
| # Add Kafka JAR to PYSPARK_SUBMIT_ARGS before SparkSession is created |
| cls.original_pyspark_submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS") |
| |
| if cls.original_pyspark_submit_args is None: |
| pyspark_submit_args = "pyspark-shell" |
| else: |
| pyspark_submit_args = cls.original_pyspark_submit_args |
| jars_args = "--jars %s" % all_jars |
| |
| os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, pyspark_submit_args]) |
| |
| # Start Kafka container - this may take 10-30 seconds on first run |
| cls.kafka_utils = KafkaUtils() |
| cls.kafka_utils.setup() |
| |
| @classmethod |
| def tearDownClass(cls): |
| os.environ["PYSPARK_SUBMIT_ARGS"] = cls.original_pyspark_submit_args |
| |
| # Stop Kafka container and clean up resources |
| if hasattr(cls, "kafka_utils"): |
| cls.kafka_utils.teardown() |
| super().tearDownClass() |
| |
| def setUp(self): |
| super().setUp() |
| # Create unique topics for each test to avoid interference |
| self.source_topic = f"source-{uuid.uuid4().hex}" |
| self.sink_topic = f"sink-{uuid.uuid4().hex}" |
| self.kafka_utils.create_topics([self.source_topic, self.sink_topic]) |
| |
| def tearDown(self): |
| # Clean up topics after each test |
| self.kafka_utils.delete_topics([self.source_topic, self.sink_topic]) |
| super().tearDown() |
| |
| |
| def _is_docker_available(): |
| """Check if Docker daemon is running and accessible.""" |
| try: |
| import subprocess |
| |
| result = subprocess.run(["docker", "info"], capture_output=True, timeout=10) |
| return result.returncode == 0 |
| except (FileNotFoundError, subprocess.TimeoutExpired, OSError): |
| return False |
| |
| |
| @unittest.skipIf(not have_kafka, kafka_requirement_message) |
| @unittest.skipIf(not have_testcontainers, testcontainers_requirement_message) |
| @unittest.skipIf(not _is_docker_available(), "Docker is not available") |
| class StreamingKafkaTests(StreamingKafkaTestsMixin, ReusedSQLTestCase): |
| """ |
| Tests for Kafka streaming integration with PySpark. |
| """ |
| |
| def test_streaming_stateless(self): |
| """ |
| Test stateless rtm query with earliest offset. |
| """ |
| |
| # produce test data to source_topic |
| self.kafka_utils.send_messages(self.source_topic, [(i, i) for i in range(10)]) |
| |
| # Build streaming query for Kafka to Kafka. |
| kafka_source = ( |
| self.spark.readStream.format("kafka") |
| .option("kafka.bootstrap.servers", self.kafka_utils.broker) |
| .option("subscribe", self.source_topic) |
| .option("startingOffsets", "earliest") |
| .load() |
| ) |
| |
| tmpdir = tempfile.mkdtemp() |
| self.addCleanup(shutil.rmtree, tmpdir, True) |
| checkpoint_dir = os.path.join(tmpdir, "checkpoint") |
| |
| query = ( |
| kafka_source.writeStream.format("kafka") |
| .option("kafka.bootstrap.servers", self.kafka_utils.broker) |
| .option("topic", self.sink_topic) |
| .option("checkpointLocation", checkpoint_dir) |
| .outputMode("update") |
| .trigger(realTime="30 seconds") |
| .start() |
| ) |
| |
| expected = sorted((str(i), str(i)) for i in range(10)) |
| try: |
| # Wait for the streaming to process data |
| self.kafka_utils.wait_for_query_alive(query) |
| self.kafka_utils.assert_eventually( |
| result_func=lambda: self.kafka_utils.get_all_records(self.spark, self.sink_topic), |
| expected=expected, |
| ) |
| finally: |
| query.stop() |
| |
| |
| if __name__ == "__main__": |
| from pyspark.testing import main |
| |
| main() |