blob: 3a10996983bf9d18a560c86668998a3fd974d83c [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import unittest
from airflow.contrib.hooks.spark_jdbc_hook import SparkJDBCHook
from airflow.models import Connection
from airflow.utils import db
class TestSparkJDBCHook(unittest.TestCase):
_config = {
'cmd_type': 'spark_to_jdbc',
'jdbc_table': 'tableMcTableFace',
'jdbc_driver': 'org.postgresql.Driver',
'metastore_table': 'hiveMcHiveFace',
'jdbc_truncate': False,
'save_mode': 'append',
'save_format': 'parquet',
'batch_size': 100,
'fetch_size': 200,
'num_partitions': 10,
'partition_column': 'columnMcColumnFace',
'lower_bound': '10',
'upper_bound': '20',
'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),'
'comments VARCHAR(1024)'
}
# this config is invalid because if one of [partitionColumn, lowerBound, upperBound]
# is set, all of the options must be enabled (enforced by Spark)
_invalid_config = {
'cmd_type': 'spark_to_jdbc',
'jdbc_table': 'tableMcTableFace',
'jdbc_driver': 'org.postgresql.Driver',
'metastore_table': 'hiveMcHiveFace',
'jdbc_truncate': False,
'save_mode': 'append',
'save_format': 'parquet',
'batch_size': 100,
'fetch_size': 200,
'num_partitions': 10,
'partition_column': 'columnMcColumnFace',
'upper_bound': '20',
'create_table_column_types': 'columnMcColumnFace INTEGER(100), name CHAR(64),'
'comments VARCHAR(1024)'
}
def setUp(self):
db.merge_conn(
Connection(
conn_id='spark-default', conn_type='spark',
host='yarn://yarn-master',
extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
)
db.merge_conn(
Connection(
conn_id='jdbc-default', conn_type='postgres',
host='localhost', schema='default', port=5432,
login='user', password='supersecret',
extra='{"conn_prefix":"jdbc:postgresql://"}'
)
)
def test_resolve_jdbc_connection(self):
# Given
hook = SparkJDBCHook(jdbc_conn_id='jdbc-default')
expected_connection = {
'url': 'localhost:5432',
'schema': 'default',
'conn_prefix': 'jdbc:postgresql://',
'user': 'user',
'password': 'supersecret'
}
# When
connection = hook._resolve_jdbc_connection()
# Then
self.assertEqual(connection, expected_connection)
def test_build_jdbc_arguments(self):
# Given
hook = SparkJDBCHook(**self._config)
# When
cmd = hook._build_jdbc_application_arguments(hook._resolve_jdbc_connection())
# Then
expected_jdbc_arguments = [
'-cmdType', 'spark_to_jdbc',
'-url', 'jdbc:postgresql://localhost:5432/default',
'-user', 'user',
'-password', 'supersecret',
'-metastoreTable', 'hiveMcHiveFace',
'-jdbcTable', 'tableMcTableFace',
'-jdbcDriver', 'org.postgresql.Driver',
'-batchsize', '100',
'-fetchsize', '200',
'-numPartitions', '10',
'-partitionColumn', 'columnMcColumnFace',
'-lowerBound', '10',
'-upperBound', '20',
'-saveMode', 'append',
'-saveFormat', 'parquet',
'-createTableColumnTypes', 'columnMcColumnFace INTEGER(100), name CHAR(64),'
'comments VARCHAR(1024)'
]
self.assertEqual(expected_jdbc_arguments, cmd)
def test_build_jdbc_arguments_invalid(self):
# Given
hook = SparkJDBCHook(**self._invalid_config)
# Expect Exception
hook._build_jdbc_application_arguments(hook._resolve_jdbc_connection())
if __name__ == '__main__':
unittest.main()