blob: ff5bce7b998daa9f4a6c421653b7c5466e33ad97 [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
# Copyright [2017] [B2W Digital]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import findspark
findspark.init()
# is important to import these classes after findspark.init call
from pyspark.tests import ReusedPySparkTestCase
from marvin_python_toolbox.common.data_source_provider import get_spark_session
try:
import mock
except ImportError:
import unittest.mock as mock
class TestDataSourceProvider:
@mock.patch("pyspark.sql.SparkSession")
def test_get_spark_session(self, mocked_session):
spark = get_spark_session()
assert spark
mocked_session.assert_has_calls([
mock.call.builder.appName('marvin-engine'),
mock.call.builder.appName().getOrCreate()]
)
spark = get_spark_session(app_name='TestEngine')
assert spark
mocked_session.assert_has_calls([
mock.call.builder.appName('TestEngine'),
mock.call.builder.appName().getOrCreate()]
)
spark = get_spark_session(configs=[("spark.xxx", "true")])
assert spark
mocked_session.assert_has_calls([
mock.call.builder.appName('TestEngine'),
mock.call.builder.appName().getOrCreate()]
)
@mock.patch("pyspark.sql.SparkSession")
def test_get_spark_session_with_hive(self, mocked_session):
spark = get_spark_session(enable_hive=True)
assert spark
mocked_session.assert_has_calls([
mock.call.builder.appName('marvin-engine'),
mock.call.builder.appName().enableHiveSupport(),
mock.call.builder.appName().enableHiveSupport().getOrCreate()]
)
class TestSparkDataSource(ReusedPySparkTestCase):
def test_spark_initialization(self):
rdd = self.sc.parallelize(['Hi there', 'Hi'])
counted = rdd.flatMap(lambda word: word.split(' ')).map(lambda word: (word, 1)).reduceByKey(lambda acc, n: acc + n)
assert counted.collectAsMap() == {'Hi': 2, 'there': 1}