blob: 4b6dbf7b5020a3b72d1d33f1b88602d828fb115a [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import random
import tempfile
import unittest
try:
import xmlrunner
except ImportError:
xmlrunner = None
from pyspark.broadcast import Broadcast
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import ChunkedStream
class BroadcastTest(unittest.TestCase):
def tearDown(self):
if getattr(self, "sc", None) is not None:
self.sc.stop()
self.sc = None
def _test_encryption_helper(self, vs):
"""
Creates a broadcast variables for each value in vs, and runs a simple job to make sure the
value is the same when it's read in the executors. Also makes sure there are no task
failures.
"""
bs = [self.sc.broadcast(value=v) for v in vs]
exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect()
for ev in exec_values:
self.assertEqual(ev, vs)
# make sure there are no task failures
status = self.sc.statusTracker()
for jid in status.getJobIdsForGroup():
for sid in status.getJobInfo(jid).stageIds:
stage_info = status.getStageInfo(sid)
self.assertEqual(0, stage_info.numFailedTasks)
def _test_multiple_broadcasts(self, *extra_confs):
"""
Test broadcast variables make it OK to the executors. Tests multiple broadcast variables,
and also multiple jobs.
"""
conf = SparkConf()
for key, value in extra_confs:
conf.set(key, value)
conf.setMaster("local-cluster[2,1,1024]")
self.sc = SparkContext(conf=conf)
self._test_encryption_helper([5])
self._test_encryption_helper([5, 10, 20])
def test_broadcast_with_encryption(self):
self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true"))
def test_broadcast_no_encryption(self):
self._test_multiple_broadcasts()
def _test_broadcast_on_driver(self, *extra_confs):
conf = SparkConf()
for key, value in extra_confs:
conf.set(key, value)
conf.setMaster("local-cluster[2,1,1024]")
self.sc = SparkContext(conf=conf)
bs = self.sc.broadcast(value=5)
self.assertEqual(5, bs.value)
def test_broadcast_value_driver_no_encryption(self):
self._test_broadcast_on_driver()
def test_broadcast_value_driver_encryption(self):
self._test_broadcast_on_driver(("spark.io.encryption.enabled", "true"))
class BroadcastFrameProtocolTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
gateway = launch_gateway(SparkConf())
cls._jvm = gateway.jvm
cls.longMessage = True
random.seed(42)
def _test_chunked_stream(self, data, py_buf_size):
# write data using the chunked protocol from python.
chunked_file = tempfile.NamedTemporaryFile(delete=False)
dechunked_file = tempfile.NamedTemporaryFile(delete=False)
dechunked_file.close()
try:
out = ChunkedStream(chunked_file, py_buf_size)
out.write(data)
out.close()
# now try to read it in java
jin = self._jvm.java.io.FileInputStream(chunked_file.name)
jout = self._jvm.java.io.FileOutputStream(dechunked_file.name)
self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout)
# java should have decoded it back to the original data
self.assertEqual(len(data), os.stat(dechunked_file.name).st_size)
with open(dechunked_file.name, "rb") as f:
byte = f.read(1)
idx = 0
while byte:
self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx))
byte = f.read(1)
idx += 1
finally:
os.unlink(chunked_file.name)
os.unlink(dechunked_file.name)
def test_chunked_stream(self):
def random_bytes(n):
return bytearray(random.getrandbits(8) for _ in range(n))
for data_length in [1, 10, 100, 10000]:
for buffer_length in [1, 2, 5, 8192]:
self._test_chunked_stream(random_bytes(data_length), buffer_length)
if __name__ == '__main__':
from pyspark.test_broadcast import *
if xmlrunner:
unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'), verbosity=2)
else:
unittest.main(verbosity=2)