blob: fe8631b6eea2480bd673fb197d0fb148760bdbb2 [file] [view]
# Kafka Testing with PySpark
This document explains how to write PySpark unit tests that interact with a Kafka cluster using Docker containers.
## Overview
The `kafka_utils.py` module provides `KafkaUtils` class that launches a single-broker Kafka cluster via Docker using the `testcontainers-python` library. This enables end-to-end integration testing of PySpark Kafka streaming applications without requiring manual Kafka setup.
## Prerequisites
### 1. Docker
Docker must be installed and running on your system. The Kafka test container requires Docker to launch the Kafka broker.
- **Install Docker**: https://docs.docker.com/get-docker/
- **Verify Docker is running**: `docker ps`
### 2. Python Dependencies
Install the required Python packages:
```bash
pip install testcontainers[kafka] kafka-python-ng
```
Or install all dev dependencies:
```bash
cd $SPARK_HOME
pip install -r dev/requirements.txt
```
## Quick Start
Here's a minimal example of writing a Kafka test:
```python
import unittest
from pyspark.sql.tests.streaming.kafka_utils import KafkaUtils
from pyspark.testing.sqlutils import ReusedSQLTestCase
class MyKafkaTest(ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.kafka_utils = KafkaUtils()
cls.kafka_utils.setup()
@classmethod
def tearDownClass(cls):
cls.kafka_utils.teardown()
super().tearDownClass()
def test_kafka_read_write(self):
# Create a topic
topic = "test-topic"
self.kafka_utils.create_topics([topic])
# Send test data
messages = [("key1", "value1"), ("key2", "value2")]
self.kafka_utils.send_messages(topic, messages)
# Read with Spark
df = (
self.spark.read
.format("kafka")
.option("kafka.bootstrap.servers", self.kafka_utils.broker)
.option("subscribe", topic)
.option("startingOffsets", "earliest")
.load()
)
# Verify data
results = df.selectExpr(
"CAST(key AS STRING) as key",
"CAST(value AS STRING) as value"
).collect()
self.assertEqual(len(results), 2)
```
## KafkaUtils API Reference
### Initialization and Lifecycle
#### `__init__(kafka_version="7.4.0")`
Create a new KafkaUtils instance.
- **kafka_version**: Version of Confluent Kafka to use (default: "7.4.0" for stability)
#### `setup()`
Start the Kafka container and initialize clients. This must be called before using any other methods.
**Raises:**
- `ImportError`: If required dependencies are not installed
- `RuntimeError`: If Kafka container fails to start
**Note:** Container startup can take 10-30 seconds on first run while Docker pulls the image.
#### `teardown()`
Stop the Kafka container and clean up resources. Safe to call multiple times.
### Topic Management
#### `create_topics(topic_names, num_partitions=1, replication_factor=1)`
Create one or more Kafka topics.
**Parameters:**
- **topic_names** (List[str]): List of topic names to create
- **num_partitions** (int): Number of partitions per topic (default: 1)
- **replication_factor** (int): Replication factor (default: 1, max: 1 for single broker)
**Example:**
```python
# Create single partition topics
kafka_utils.create_topics(["topic1", "topic2"])
# Create multi-partition topic
kafka_utils.create_topics(["multi-partition-topic"], num_partitions=3)
```
#### `delete_topics(topic_names)`
Delete one or more Kafka topics.
**Parameters:**
- **topic_names** (List[str]): List of topic names to delete
### Producing Data
#### `send_messages(topic, messages)`
Send messages to a Kafka topic.
**Parameters:**
- **topic** (str): Topic name
- **messages** (List[tuple]): List of (key, value) tuples
**Example:**
```python
kafka_utils.send_messages("test-topic", [
("user1", "login"),
("user2", "logout"),
("user1", "purchase"),
])
```
### Reading Data
#### `get_all_records(spark, topic, key_deserializer="STRING", value_deserializer="STRING")`
Read all records from a Kafka topic using Spark batch read.
**Parameters:**
- **spark**: SparkSession instance
- **topic** (str): Topic name
- **key_deserializer** (str): How to deserialize keys (default: "STRING")
- **value_deserializer** (str): How to deserialize values (default: "STRING")
**Returns:** Sorted list of (key, value) tuples.
**Example:**
```python
records = kafka_utils.get_all_records(self.spark, "test-topic")
assert records == [("key1", "value1"), ("key2", "value2")]
```
### Testing Utilities
#### `assert_eventually(result_func, expected, timeout=60, interval=1.0)`
Assert that a condition becomes true within a timeout. Useful for testing streaming queries with eventually consistent results.
**Parameters:**
- **result_func** (Callable): Function that returns the current result
- **expected**: Expected result value
- **timeout** (int): Maximum time to wait in seconds (default: 60)
- **interval** (float): Time between checks in seconds (default: 1.0)
**Raises:** `AssertionError` if condition doesn't become true within timeout
**Example:**
```python
kafka_utils.assert_eventually(
lambda: kafka_utils.get_all_records(self.spark, "sink-topic"),
[("key1", "processed-value1")],
timeout=30
)
```
#### `wait_for_query_alive(query, timeout=60, interval=1.0)`
Wait for a streaming query to become active and ready to process data.
**Parameters:**
- **query**: StreamingQuery instance
- **timeout** (int): Maximum time to wait in seconds (default: 60)
- **interval** (float): Time between checks in seconds (default: 1.0)
**Raises:** `AssertionError` if query doesn't become active within timeout
**Example:**
```python
query = df.writeStream.format("memory").start()
kafka_utils.wait_for_query_alive(query, timeout=30)
```
### Properties
#### `broker`
Get the Kafka bootstrap server address (e.g., "localhost:9093").
#### `producer`
Get the underlying KafkaProducer instance for advanced usage.
#### `admin_client`
Get the underlying KafkaAdminClient instance for advanced usage.
## Common Testing Patterns
### Pattern 1: Batch Read/Write
Test basic Kafka read and write with Spark batch processing:
```python
def test_kafka_batch(self):
topic = "test-topic"
self.kafka_utils.create_topics([topic])
# Write with Spark DataFrame
df = self.spark.createDataFrame([("key1", "value1")], ["key", "value"])
(
df.selectExpr("CAST(key AS BINARY)", "CAST(value AS BINARY)")
.write
.format("kafka")
.option("kafka.bootstrap.servers", self.kafka_utils.broker)
.option("topic", topic)
.save()
)
# Read back
records = self.kafka_utils.get_all_records(self.spark, topic)
assert records == [("key1", "value1")]
```
### Pattern 2: Streaming Queries
Test streaming queries with checkpoint management:
```python
def test_kafka_streaming(self):
import tempfile
import os
# Setup topics
source_topic = "source"
sink_topic = "sink"
self.kafka_utils.create_topics([source_topic, sink_topic])
# Produce initial data
self.kafka_utils.send_messages(source_topic, [("k1", "v1")])
# Start streaming query
df = (
self.spark.readStream
.format("kafka")
.option("kafka.bootstrap.servers", self.kafka_utils.broker)
.option("subscribe", source_topic)
.option("startingOffsets", "earliest")
.load()
)
checkpoint_dir = os.path.join(tempfile.mkdtemp(), "checkpoint")
query = (
df.writeStream
.format("kafka")
.option("kafka.bootstrap.servers", self.kafka_utils.broker)
.option("topic", sink_topic)
.option("checkpointLocation", checkpoint_dir)
.start()
)
try:
self.kafka_utils.wait_for_query_alive(query)
self.kafka_utils.assert_eventually(
lambda: self.kafka_utils.get_all_records(self.spark, sink_topic),
[("k1", "v1")]
)
finally:
query.stop()
```
### Pattern 3: Stateful Aggregations
Test streaming aggregations:
```python
def test_kafka_aggregation(self):
# Send data for aggregation
self.kafka_utils.send_messages("source", [
("user1", "1"),
("user2", "1"),
("user1", "1"),
])
# Aggregate by key
df = (
self.spark.readStream
.format("kafka")
.option("kafka.bootstrap.servers", self.kafka_utils.broker)
.option("subscribe", "source")
.load()
.groupBy(col("key"))
.count()
.selectExpr("CAST(key AS BINARY)", "CAST(count AS STRING) AS value")
)
query = df.writeStream.format("kafka") # ... start query
# Verify aggregated results
self.kafka_utils.assert_eventually(
lambda: self.kafka_utils.get_all_records(self.spark, "sink"),
[("user1", "2"), ("user2", "1")]
)
```
### Pattern 4: Multiple Topics
Test writing to multiple topics based on data:
```python
def test_multiple_topics(self):
topic1, topic2 = "topic1", "topic2"
self.kafka_utils.create_topics([topic1, topic2])
# Write with topic column
df = self.spark.createDataFrame([
(topic1, "key1", "value1"),
(topic2, "key2", "value2"),
], ["topic", "key", "value"])
(
df.selectExpr("topic", "CAST(key AS BINARY)", "CAST(value AS BINARY)")
.write
.format("kafka")
.option("kafka.bootstrap.servers", self.kafka_utils.broker)
.save()
)
# Verify data in each topic
assert self.kafka_utils.get_all_records(self.spark, topic1) == [("key1", "value1")]
assert self.kafka_utils.get_all_records(self.spark, topic2) == [("key2", "value2")]
```
## Running Tests
### Run All Kafka Tests
```bash
cd $SPARK_HOME/python
python -m pytest pyspark/sql/tests/streaming/test_streaming_kafka_rtm.py -v
```
### Run Specific Test
```bash
python -m pytest pyspark/sql/tests/streaming/test_streaming_kafka_rtm.py::StreamingKafkaTests::test_streaming_stateless -v
```
### Run with unittest
```bash
cd $SPARK_HOME/python
python -m unittest pyspark.sql.tests.streaming.test_streaming_kafka_rtm
```
## Troubleshooting
### Docker Not Running
**Error:**
```
Cannot connect to the Docker daemon at unix:///var/run/docker.sock
```
**Solution:** Start Docker Desktop or Docker daemon
### Container Startup Timeout
**Error:**
```
Kafka container failed to start within timeout
```
**Solutions:**
1. Increase timeout in test code
2. Check Docker resource allocation (CPU/memory)
3. Check Docker logs: `docker logs <container-id>`
### Port Conflicts
**Error:**
```
Port 9093 already in use
```
**Solution:** testcontainers automatically allocates random ports. Ensure no manual port binding in tests.
### Missing Dependencies
**Error:**
```
ImportError: testcontainers is required for Kafka tests
```
**Solution:**
```bash
pip install testcontainers[kafka] kafka-python-ng
```