blob: 978de77872d9ee38a8e2124216c226483effc9f7 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import time
import random
if sys.version_info[:2] <= (2, 6):
try:
import unittest2 as unittest
except ImportError:
sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
sys.exit(1)
else:
import unittest
from pyspark.context import SparkConf, SparkContext, RDD
from pyspark.streaming.context import StreamingContext
from pyspark.streaming.tests import PySparkStreamingTestCase
from mqtt import MQTTUtils
class MQTTStreamTests(PySparkStreamingTestCase):
timeout = 20 # seconds
duration = 1
def setUp(self):
super(MQTTStreamTests, self).setUp()
MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils")
self._MQTTTestUtils = MQTTTestUtilsClz.newInstance()
self._MQTTTestUtils.setup()
def tearDown(self):
if self._MQTTTestUtils is not None:
self._MQTTTestUtils.teardown()
self._MQTTTestUtils = None
super(MQTTStreamTests, self).tearDown()
def _randomTopic(self):
return "topic-%d" % random.randint(0, 10000)
def _startContext(self, topic):
# Start the StreamingContext and also collect the result
stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic)
result = []
def getOutput(_, rdd):
for data in rdd.collect():
result.append(data)
stream.foreachRDD(getOutput)
self.ssc.start()
return result
def test_mqtt_stream(self):
"""Test the Python MQTT stream API."""
sendData = "MQTT demo for spark streaming"
topic = self._randomTopic()
result = self._startContext(topic)
def retry():
self._MQTTTestUtils.publishData(topic, sendData)
# Because "publishData" sends duplicate messages, here we should use > 0
self.assertTrue(len(result) > 0)
self.assertEqual(sendData, result[0])
# Retry it because we don't know when the receiver will start.
self._retry_or_timeout(retry)
def _start_context_with_paired_stream(self, topics):
stream = MQTTUtils.createPairedStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topics)
# Keep a set because records can potentially be repeated.
result = set()
def getOutput(_, rdd):
for data in rdd.collect():
result.add(data)
stream.foreachRDD(getOutput)
self.ssc.start()
return result
def test_mqtt_pair_stream(self):
"""Test the Python MQTT stream API with multiple topics."""
data_records = ["random string 1", "random string 2", "random string 3"]
topics = [self._randomTopic(), self._randomTopic(), self._randomTopic()]
topics_and_records = zip(topics, data_records)
result = self._start_context_with_paired_stream(topics)
def retry():
for topic, data_record in topics_and_records:
self._MQTTTestUtils.publishData(topic, data_record)
# Sort the received records as they might be out of order.
self.assertEqual(topics_and_records, sorted(result, key=lambda x: x[1]))
# Retry it because we don't know when the receiver will start.
self._retry_or_timeout(retry)
def _retry_or_timeout(self, test_func):
start_time = time.time()
while True:
try:
test_func()
break
except:
if time.time() - start_time > self.timeout:
raise
time.sleep(0.01)
if __name__ == "__main__":
unittest.main()