Kafka Testing with PySpark

This document explains how to write PySpark unit tests that interact with a Kafka cluster using Docker containers.

Overview

The kafka_utils.py module provides KafkaUtils class that launches a single-broker Kafka cluster via Docker using the testcontainers-python library. This enables end-to-end integration testing of PySpark Kafka streaming applications without requiring manual Kafka setup.

Prerequisites

1. Docker

Docker must be installed and running on your system. The Kafka test container requires Docker to launch the Kafka broker.

2. Python Dependencies

Install the required Python packages:

pip install testcontainers[kafka] kafka-python-ng

Or install all dev dependencies:

cd $SPARK_HOME
pip install -r dev/requirements.txt

Quick Start

Here's a minimal example of writing a Kafka test:

import unittest
from pyspark.sql.tests.streaming.kafka_utils import KafkaUtils
from pyspark.testing.sqlutils import ReusedSQLTestCase

class MyKafkaTest(ReusedSQLTestCase):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.kafka_utils = KafkaUtils()
        cls.kafka_utils.setup()

    @classmethod
    def tearDownClass(cls):
        cls.kafka_utils.teardown()
        super().tearDownClass()

    def test_kafka_read_write(self):
        # Create a topic
        topic = "test-topic"
        self.kafka_utils.create_topics([topic])

        # Send test data
        messages = [("key1", "value1"), ("key2", "value2")]
        self.kafka_utils.send_messages(topic, messages)

        # Read with Spark
        df = (
            self.spark.read
            .format("kafka")
            .option("kafka.bootstrap.servers", self.kafka_utils.broker)
            .option("subscribe", topic)
            .option("startingOffsets", "earliest")
            .load()
        )

        # Verify data
        results = df.selectExpr(
            "CAST(key AS STRING) as key",
            "CAST(value AS STRING) as value"
        ).collect()

        self.assertEqual(len(results), 2)

KafkaUtils API Reference

Initialization and Lifecycle

__init__(kafka_version="7.4.0")

Create a new KafkaUtils instance.

  • kafka_version: Version of Confluent Kafka to use (default: “7.4.0” for stability)

setup()

Start the Kafka container and initialize clients. This must be called before using any other methods.

Raises:

  • ImportError: If required dependencies are not installed
  • RuntimeError: If Kafka container fails to start

Note: Container startup can take 10-30 seconds on first run while Docker pulls the image.

teardown()

Stop the Kafka container and clean up resources. Safe to call multiple times.

Topic Management

create_topics(topic_names, num_partitions=1, replication_factor=1)

Create one or more Kafka topics.

Parameters:

  • topic_names (List[str]): List of topic names to create
  • num_partitions (int): Number of partitions per topic (default: 1)
  • replication_factor (int): Replication factor (default: 1, max: 1 for single broker)

Example:

# Create single partition topics
kafka_utils.create_topics(["topic1", "topic2"])

# Create multi-partition topic
kafka_utils.create_topics(["multi-partition-topic"], num_partitions=3)

delete_topics(topic_names)

Delete one or more Kafka topics.

Parameters:

  • topic_names (List[str]): List of topic names to delete

Producing Data

send_messages(topic, messages)

Send messages to a Kafka topic.

Parameters:

  • topic (str): Topic name
  • messages (List[tuple]): List of (key, value) tuples

Example:

kafka_utils.send_messages("test-topic", [
    ("user1", "login"),
    ("user2", "logout"),
    ("user1", "purchase"),
])

Reading Data

get_all_records(spark, topic, key_deserializer="STRING", value_deserializer="STRING")

Read all records from a Kafka topic using Spark batch read.

Parameters:

  • spark: SparkSession instance
  • topic (str): Topic name
  • key_deserializer (str): How to deserialize keys (default: “STRING”)
  • value_deserializer (str): How to deserialize values (default: “STRING”)

Returns: Sorted list of (key, value) tuples.

Example:

records = kafka_utils.get_all_records(self.spark, "test-topic")
assert records == [("key1", "value1"), ("key2", "value2")]

Testing Utilities

assert_eventually(result_func, expected, timeout=60, interval=1.0)

Assert that a condition becomes true within a timeout. Useful for testing streaming queries with eventually consistent results.

Parameters:

  • result_func (Callable): Function that returns the current result
  • expected: Expected result value
  • timeout (int): Maximum time to wait in seconds (default: 60)
  • interval (float): Time between checks in seconds (default: 1.0)

Raises: AssertionError if condition doesn't become true within timeout

Example:

kafka_utils.assert_eventually(
    lambda: kafka_utils.get_all_records(self.spark, "sink-topic"),
    [("key1", "processed-value1")],
    timeout=30
)

wait_for_query_alive(query, timeout=60, interval=1.0)

Wait for a streaming query to become active and ready to process data.

Parameters:

  • query: StreamingQuery instance
  • timeout (int): Maximum time to wait in seconds (default: 60)
  • interval (float): Time between checks in seconds (default: 1.0)

Raises: AssertionError if query doesn't become active within timeout

Example:

query = df.writeStream.format("memory").start()
kafka_utils.wait_for_query_alive(query, timeout=30)

Properties

broker

Get the Kafka bootstrap server address (e.g., “localhost:9093”).

producer

Get the underlying KafkaProducer instance for advanced usage.

admin_client

Get the underlying KafkaAdminClient instance for advanced usage.

Common Testing Patterns

Pattern 1: Batch Read/Write

Test basic Kafka read and write with Spark batch processing:

def test_kafka_batch(self):
    topic = "test-topic"
    self.kafka_utils.create_topics([topic])

    # Write with Spark DataFrame
    df = self.spark.createDataFrame([("key1", "value1")], ["key", "value"])
    (
        df.selectExpr("CAST(key AS BINARY)", "CAST(value AS BINARY)")
        .write
        .format("kafka")
        .option("kafka.bootstrap.servers", self.kafka_utils.broker)
        .option("topic", topic)
        .save()
    )

    # Read back
    records = self.kafka_utils.get_all_records(self.spark, topic)
    assert records == [("key1", "value1")]

Pattern 2: Streaming Queries

Test streaming queries with checkpoint management:

def test_kafka_streaming(self):
    import tempfile
    import os

    # Setup topics
    source_topic = "source"
    sink_topic = "sink"
    self.kafka_utils.create_topics([source_topic, sink_topic])

    # Produce initial data
    self.kafka_utils.send_messages(source_topic, [("k1", "v1")])

    # Start streaming query
    df = (
        self.spark.readStream
        .format("kafka")
        .option("kafka.bootstrap.servers", self.kafka_utils.broker)
        .option("subscribe", source_topic)
        .option("startingOffsets", "earliest")
        .load()
    )

    checkpoint_dir = os.path.join(tempfile.mkdtemp(), "checkpoint")
    query = (
        df.writeStream
        .format("kafka")
        .option("kafka.bootstrap.servers", self.kafka_utils.broker)
        .option("topic", sink_topic)
        .option("checkpointLocation", checkpoint_dir)
        .start()
    )

    try:
        self.kafka_utils.wait_for_query_alive(query)
        self.kafka_utils.assert_eventually(
            lambda: self.kafka_utils.get_all_records(self.spark, sink_topic),
            [("k1", "v1")]
        )
    finally:
        query.stop()

Pattern 3: Stateful Aggregations

Test streaming aggregations:

def test_kafka_aggregation(self):
    # Send data for aggregation
    self.kafka_utils.send_messages("source", [
        ("user1", "1"),
        ("user2", "1"),
        ("user1", "1"),
    ])

    # Aggregate by key
    df = (
        self.spark.readStream
        .format("kafka")
        .option("kafka.bootstrap.servers", self.kafka_utils.broker)
        .option("subscribe", "source")
        .load()
        .groupBy(col("key"))
        .count()
        .selectExpr("CAST(key AS BINARY)", "CAST(count AS STRING) AS value")
    )

    query = df.writeStream.format("kafka") # ... start query

    # Verify aggregated results
    self.kafka_utils.assert_eventually(
        lambda: self.kafka_utils.get_all_records(self.spark, "sink"),
        [("user1", "2"), ("user2", "1")]
    )

Pattern 4: Multiple Topics

Test writing to multiple topics based on data:

def test_multiple_topics(self):
    topic1, topic2 = "topic1", "topic2"
    self.kafka_utils.create_topics([topic1, topic2])

    # Write with topic column
    df = self.spark.createDataFrame([
        (topic1, "key1", "value1"),
        (topic2, "key2", "value2"),
    ], ["topic", "key", "value"])

    (
        df.selectExpr("topic", "CAST(key AS BINARY)", "CAST(value AS BINARY)")
        .write
        .format("kafka")
        .option("kafka.bootstrap.servers", self.kafka_utils.broker)
        .save()
    )

    # Verify data in each topic
    assert self.kafka_utils.get_all_records(self.spark, topic1) == [("key1", "value1")]
    assert self.kafka_utils.get_all_records(self.spark, topic2) == [("key2", "value2")]

Running Tests

Run All Kafka Tests

cd $SPARK_HOME/python
python -m pytest pyspark/sql/tests/streaming/test_streaming_kafka_rtm.py -v

Run Specific Test

python -m pytest pyspark/sql/tests/streaming/test_streaming_kafka_rtm.py::StreamingKafkaTests::test_streaming_stateless -v

Run with unittest

cd $SPARK_HOME/python
python -m unittest pyspark.sql.tests.streaming.test_streaming_kafka_rtm

Troubleshooting

Docker Not Running

Error:

Cannot connect to the Docker daemon at unix:///var/run/docker.sock

Solution: Start Docker Desktop or Docker daemon

Container Startup Timeout

Error:

Kafka container failed to start within timeout

Solutions:

  1. Increase timeout in test code
  2. Check Docker resource allocation (CPU/memory)
  3. Check Docker logs: docker logs <container-id>

Port Conflicts

Error:

Port 9093 already in use

Solution: testcontainers automatically allocates random ports. Ensure no manual port binding in tests.

Missing Dependencies

Error:

ImportError: testcontainers is required for Kafka tests

Solution:

pip install testcontainers[kafka] kafka-python-ng