blob: b98f2668cdc89f43eb89ca0a74fe5545c6dafc8e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.testing.utils import ReusedPySparkTestCase
from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
class RDDSamplerTests(ReusedPySparkTestCase):
def test_rdd_sampler_func(self):
# SPARK-38879: Test case to improve test coverage for RDDSampler
# RDDSampler.func
rdd = self.sc.parallelize(range(20), 2)
sample_count = rdd.mapPartitionsWithIndex(RDDSampler(False, 0.4, 10).func).count()
self.assertGreater(sample_count, 3)
self.assertLess(sample_count, 10)
sample_data = rdd.mapPartitionsWithIndex(RDDSampler(True, 1, 10).func).collect()
sample_data.sort()
# check if at least one element is repeated.
self.assertTrue(
any(sample_data[i] == sample_data[i - 1] for i in range(1, len(sample_data)))
)
def test_rdd_stratified_sampler_func(self):
# SPARK-38879: Test case to improve test coverage for RDDSampler
# RDDStratifiedSampler.func
fractions = {"a": 0.8, "b": 0.2}
rdd = self.sc.parallelize(fractions.keys()).cartesian(self.sc.parallelize(range(0, 100)))
sample_data = dict(
rdd.mapPartitionsWithIndex(
RDDStratifiedSampler(False, fractions, 10).func, True
).countByKey()
)
# Since a have higher sampling rate (0.8),
# it will occur more number of times than b.
self.assertGreater(sample_data["a"], sample_data["b"])
self.assertGreater(sample_data["a"], 60)
self.assertLess(sample_data["a"], 90)
self.assertGreater(sample_data["b"], 15)
self.assertLess(sample_data["b"], 30)
if __name__ == "__main__":
import unittest
from pyspark.tests.test_rddsampler import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)