blob: cfa5ea739bab2d2570fad93c962636151581c490 [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Utilities for running PySpark tests against a Kafka cluster using Docker containers.
This module provides KafkaUtils class that launches a single-broker Kafka cluster
via Docker using testcontainers-python library. It's designed to be used with
Python unittest-based PySpark tests.
Example usage:
class MyKafkaTest(ReusedSQLTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.kafka_utils = KafkaUtils()
cls.kafka_utils.setup()
@classmethod
def tearDownClass(cls):
cls.kafka_utils.teardown()
super().tearDownClass()
def test_kafka_streaming(self):
topic = "test-topic"
self.kafka_utils.create_topics([topic])
# ... use self.kafka_utils.broker for bootstrap.servers
"""
import logging
import time
from typing import List, Tuple, Callable, Any
from pyspark.sql import SparkSession
logger = logging.getLogger(__name__)
class KafkaUtils:
"""
Utility class for managing a Kafka test cluster using Docker containers.
This class provides methods to:
- Start/stop a single-broker Kafka cluster in a Docker container
- Create and delete topics
- Send messages to topics
- Query topic data using Spark
- Helper methods for testing (assert_eventually, wait_for_query_alive)
Attributes:
broker (str): The bootstrap server address (e.g., "localhost:9093")
initialized (bool): Whether the Kafka cluster has been started
"""
def __init__(self, kafka_version: str = "7.4.0"):
"""
Initialize KafkaUtils.
Args:
kafka_version: Version of Confluent Kafka to use (default: 7.4.0 for stability)
"""
self.kafka_version = kafka_version
self.initialized = False
self._kafka_container = None
self._admin_client = None
self._producer = None
self.broker = None
def setup(self) -> None:
"""
Start the Kafka container and initialize admin client and producer.
This method:
1. Starts a Kafka container using testcontainers
2. Creates an admin client for topic management
3. Creates a producer for sending test messages
Raises:
ImportError: If required dependencies (testcontainers, kafka-python) are not installed
RuntimeError: If Kafka container fails to start
"""
if self.initialized:
return
try:
from testcontainers.kafka import KafkaContainer
except ImportError as e:
raise ImportError(
"testcontainers is required for Kafka tests. "
"Install it with: pip install testcontainers[kafka]"
) from e
try:
from kafka import KafkaProducer
from kafka.admin import KafkaAdminClient
except ImportError as e:
raise ImportError(
"kafka-python is required for Kafka tests. "
"Install it with: pip install kafka-python"
) from e
# Start Kafka container with specific version for test stability
try:
self._kafka_container = KafkaContainer(f"confluentinc/cp-kafka:{self.kafka_version}")
self._kafka_container.start()
# Get bootstrap server address
self.broker = self._kafka_container.get_bootstrap_server()
# Initialize admin client for topic management
self._admin_client = KafkaAdminClient(
bootstrap_servers=self.broker,
request_timeout_ms=10000,
api_version_auto_timeout_ms=10000,
)
# Initialize producer for sending test messages
self._producer = KafkaProducer(
bootstrap_servers=self.broker,
key_serializer=lambda k: str(k).encode("utf-8") if k is not None else None,
value_serializer=lambda v: str(v).encode("utf-8") if v is not None else None,
request_timeout_ms=10000,
max_block_ms=10000,
)
self.initialized = True
except Exception:
self.teardown()
raise
def teardown(self) -> None:
"""
Stop the Kafka container and clean up resources.
This method closes the admin client, producer, and stops the Kafka container.
It's safe to call multiple times.
"""
# Close admin client
if self._admin_client is not None:
try:
self._admin_client.close()
except Exception:
logger.warning("Failed to close admin client", exc_info=True)
self._admin_client = None
# Close producer
if self._producer is not None:
try:
self._producer.close(timeout=5)
except Exception:
logger.warning("Failed to close producer", exc_info=True)
self._producer = None
# Stop Kafka container
if self._kafka_container is not None:
try:
self._kafka_container.stop()
except Exception:
logger.warning("Failed to stop Kafka container", exc_info=True)
self._kafka_container = None
self.broker = None
self.initialized = False
def _assert_initialized(self) -> None:
"""Check if KafkaUtils has been initialized, raise error if not."""
if not self.initialized:
raise RuntimeError("KafkaUtils has not been initialized. Call setup() first.")
def create_topics(
self, topic_names: List[str], num_partitions: int = 1, replication_factor: int = 1
) -> None:
"""
Create Kafka topics.
Args:
topic_names: List of topic names to create
num_partitions: Number of partitions per topic (default: 1)
replication_factor: Replication factor (default: 1, max: 1 for single broker)
Note:
If a topic already exists, it will be silently ignored.
"""
self._assert_initialized()
from kafka.admin import NewTopic
from kafka.errors import TopicAlreadyExistsError
topics = [
NewTopic(
name=name, num_partitions=num_partitions, replication_factor=replication_factor
)
for name in topic_names
]
try:
self._admin_client.create_topics(new_topics=topics, validate_only=False)
except TopicAlreadyExistsError:
logger.debug("One or more topics already exist, skipping creation: %s", topic_names)
def delete_topics(self, topic_names: List[str]) -> None:
"""
Delete Kafka topics.
Args:
topic_names: List of topic names to delete
Note:
If a topic doesn't exist, it will be silently ignored.
"""
self._assert_initialized()
from kafka.errors import UnknownTopicOrPartitionError
try:
self._admin_client.delete_topics(topics=topic_names)
except UnknownTopicOrPartitionError:
# Topic doesn't exist, ignore
pass
def send_messages(self, topic: str, messages: List[tuple]) -> None:
"""
Send messages to a Kafka topic.
Args:
topic: Topic name to send messages to
messages: List of (key, value) tuples to send
Example:
kafka_utils.send_messages("test-topic", [
("key1", "value1"),
("key2", "value2"),
])
"""
self._assert_initialized()
futures = [self._producer.send(topic, key=key, value=value) for key, value in messages]
for future in futures:
future.get(timeout=10)
self._producer.flush()
def get_all_records(
self,
spark: SparkSession,
topic: str,
key_deserializer: str = "STRING",
value_deserializer: str = "STRING",
) -> List[Tuple[str, str]]:
"""
Read all records from a Kafka topic using Spark.
Args:
spark: SparkSession instance
topic: Topic name to read from
key_deserializer: How to deserialize keys (default: "STRING")
value_deserializer: How to deserialize values (default: "STRING")
Returns:
Sorted list of (key, value) tuples.
Example:
records = kafka_utils.get_all_records(self.spark, "test-topic")
assert records == [("key1", "value1"), ("key2", "value2")]
"""
self._assert_initialized()
df = (
spark.read.format("kafka")
.option("kafka.bootstrap.servers", self.broker)
.option("subscribe", topic)
.option("startingOffsets", "earliest")
.option("endingOffsets", "latest")
.load()
)
df = df.selectExpr(
f"CAST(key AS {key_deserializer}) AS key_str",
f"CAST(value AS {value_deserializer}) AS value_str",
)
rows = df.collect()
return sorted((row.key_str, row.value_str) for row in rows)
def assert_eventually(
self,
result_func: Callable[[], Any],
expected: Any,
timeout: int = 60,
interval: float = 1.0,
) -> None:
"""
Assert that a condition becomes true within a timeout.
This is useful for testing streaming queries where results are eventually consistent.
Args:
result_func: Function that returns the current result
expected: Expected result value
timeout: Maximum time to wait in seconds (default: 60)
interval: Time between checks in seconds (default: 1.0)
Raises:
AssertionError: If the condition doesn't become true within timeout
Example:
kafka_utils.assert_eventually(
lambda: kafka_utils.get_all_records(spark, topic),
[("key1", "value1"), ("key2", "value2")]
)
"""
deadline = time.time() + timeout
last_result = None
while time.time() < deadline:
last_result = result_func()
if last_result == expected:
return # Success!
time.sleep(interval)
# Timeout reached, raise assertion error
raise AssertionError(
f"Condition not met within {timeout}s. Expected: {expected}, Got: {last_result}"
)
def wait_for_query_alive(self, query, timeout: int = 60, interval: float = 1.0) -> None:
"""
Wait for a streaming query to become active and ready to process data.
Args:
query: StreamingQuery instance
timeout: Maximum time to wait in seconds (default: 60)
interval: Time between checks in seconds (default: 1.0)
Raises:
AssertionError: If the query doesn't become active within timeout
Example:
query = df.writeStream.format("memory").start()
kafka_utils.wait_for_query_alive(query)
"""
deadline = time.time() + timeout
while time.time() < deadline:
exc = query.exception()
if exc is not None:
raise exc
status = query.status
if status["isDataAvailable"] or status["isTriggerActive"]:
return # Query is alive!
time.sleep(interval)
raise AssertionError(f"Query did not become active within {timeout}s.")
@property
def producer(self):
"""Get the Kafka producer instance for advanced usage."""
self._assert_initialized()
return self._producer
@property
def admin_client(self):
"""Get the Kafka admin client instance for advanced usage."""
self._assert_initialized()
return self._admin_client