blob: 3765f7d46d1a48cf0b0f411308ee8daba1a39039 [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import io
import os
import subprocess
import unittest
from unittest import mock
from airflow.contrib.hooks.pinot_hook import PinotAdminHook, PinotDbApiHook
from airflow.exceptions import AirflowException
class TestPinotAdminHook(unittest.TestCase):
def setUp(self):
super().setUp()
self.conn = conn = mock.MagicMock()
self.conn.host = 'host'
self.conn.port = '1000'
self.conn.extra_dejson = {'cmd_path': './pinot-admin.sh'}
class PinotAdminHookTest(PinotAdminHook):
def get_connection(self, conn_id):
return conn
self.db_hook = PinotAdminHookTest()
@mock.patch('airflow.contrib.hooks.pinot_hook.PinotAdminHook.run_cli')
def test_add_schema(self, mock_run_cli):
params = ["schema_file", False]
self.db_hook.add_schema(*params)
mock_run_cli.assert_called_once_with(['AddSchema',
'-controllerHost', self.conn.host,
'-controllerPort', self.conn.port,
'-schemaFile', params[0]])
@mock.patch('airflow.contrib.hooks.pinot_hook.PinotAdminHook.run_cli')
def test_add_table(self, mock_run_cli):
params = ["config_file", False]
self.db_hook.add_table(*params)
mock_run_cli.assert_called_once_with(['AddTable',
'-controllerHost', self.conn.host,
'-controllerPort', self.conn.port,
'-filePath', params[0]])
@mock.patch('airflow.contrib.hooks.pinot_hook.PinotAdminHook.run_cli')
def test_create_segment(self, mock_run_cli):
params = {
"generator_config_file": "a",
"data_dir": "b",
"format": "c",
"out_dir": "d",
"overwrite": True,
"table_name": "e",
"segment_name": "f",
"time_column_name": "g",
"schema_file": "h",
"reader_config_file": "i",
"enable_star_tree_index": False,
"star_tree_index_spec_file": "j",
"hll_size": 9,
"hll_columns": "k",
"hll_suffix": "l",
"num_threads": 8,
"post_creation_verification": True,
"retry": 7,
}
self.db_hook.create_segment(**params)
mock_run_cli.assert_called_once_with([
'CreateSegment',
'-generatorConfigFile', params["generator_config_file"],
'-dataDir', params["data_dir"],
'-format', params["format"],
'-outDir', params["out_dir"],
'-overwrite', params["overwrite"],
'-tableName', params["table_name"],
'-segmentName', params["segment_name"],
'-timeColumnName', params["time_column_name"],
'-schemaFile', params["schema_file"],
'-readerConfigFile', params["reader_config_file"],
'-starTreeIndexSpecFile', params["star_tree_index_spec_file"],
'-hllSize', params["hll_size"],
'-hllColumns', params["hll_columns"],
'-hllSuffix', params["hll_suffix"],
'-numThreads', params["num_threads"],
'-postCreationVerification', params["post_creation_verification"],
'-retry', params["retry"]])
@mock.patch('airflow.contrib.hooks.pinot_hook.PinotAdminHook.run_cli')
def test_upload_segment(self, mock_run_cli):
params = ["segment_dir", False]
self.db_hook.upload_segment(*params)
mock_run_cli.assert_called_once_with(['UploadSegment',
'-controllerHost', self.conn.host,
'-controllerPort', self.conn.port,
'-segmentDir', params[0]])
@mock.patch('subprocess.Popen')
def test_run_cli_success(self, mock_popen):
mock_proc = mock.MagicMock()
mock_proc.returncode = 0
mock_proc.stdout = io.BytesIO(b'')
mock_popen.return_value = mock_proc
params = ["foo", "bar", "baz"]
self.db_hook.run_cli(params)
params.insert(0, self.conn.extra_dejson.get('cmd_path'))
mock_popen.assert_called_once_with(params,
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
close_fds=True,
env=None)
@mock.patch('subprocess.Popen')
def test_run_cli_failure_error_message(self, mock_popen):
msg = b"Exception caught"
mock_proc = mock.MagicMock()
mock_proc.returncode = 0
mock_proc.stdout = io.BytesIO(msg)
mock_popen.return_value = mock_proc
params = ["foo", "bar", "baz"]
with self.assertRaises(AirflowException, msg=msg):
self.db_hook.run_cli(params)
params.insert(0, self.conn.extra_dejson.get('cmd_path'))
mock_popen.assert_called_once_with(params,
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
close_fds=True,
env=None)
@mock.patch('subprocess.Popen')
def test_run_cli_failure_status_code(self, mock_popen):
mock_proc = mock.MagicMock()
mock_proc.returncode = 1
mock_proc.stdout = io.BytesIO(b'')
mock_popen.return_value = mock_proc
self.db_hook.pinot_admin_system_exit = True
params = ["foo", "bar", "baz"]
with self.assertRaises(AirflowException):
self.db_hook.run_cli(params)
params.insert(0, self.conn.extra_dejson.get('cmd_path'))
env = os.environ.copy()
env.update({"JAVA_OPTS": "-Dpinot.admin.system.exit=true "})
mock_popen.assert_called_once_with(params,
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
close_fds=True,
env=env)
class TestPinotDbApiHook(unittest.TestCase):
def setUp(self):
super().setUp()
self.conn = conn = mock.MagicMock()
self.conn.host = 'host'
self.conn.port = '1000'
self.conn.conn_type = 'http'
self.conn.extra_dejson = {'endpoint': 'pql'}
self.cur = mock.MagicMock()
self.conn.cursor.return_value = self.cur
self.conn.__enter__.return_value = self.cur
self.conn.__exit__.return_value = None
class TestPinotDBApiHook(PinotDbApiHook):
def get_conn(self):
return conn
def get_connection(self, conn_id):
return conn
self.db_hook = TestPinotDBApiHook
def test_get_uri(self):
"""
Test on getting a pinot connection uri
"""
db_hook = self.db_hook()
self.assertEqual(db_hook.get_uri(), 'http://host:1000/pql')
def test_get_conn(self):
"""
Test on getting a pinot connection
"""
conn = self.db_hook().get_conn()
self.assertEqual(conn.host, 'host')
self.assertEqual(conn.port, '1000')
self.assertEqual(conn.conn_type, 'http')
self.assertEqual(conn.extra_dejson.get('endpoint'), 'pql')
def test_get_records(self):
statement = 'SQL'
result_sets = [('row1',), ('row2',)]
self.cur.fetchall.return_value = result_sets
self.assertEqual(result_sets, self.db_hook().get_records(statement))
def test_get_first(self):
statement = 'SQL'
result_sets = [('row1',), ('row2',)]
self.cur.fetchone.return_value = result_sets[0]
self.assertEqual(result_sets[0], self.db_hook().get_first(statement))
def test_get_pandas_df(self):
statement = 'SQL'
column = 'col'
result_sets = [('row1',), ('row2',)]
self.cur.description = [(column,)]
self.cur.fetchall.return_value = result_sets
df = self.db_hook().get_pandas_df(statement)
self.assertEqual(column, df.columns[0])
for i in range(len(result_sets)): # pylint: disable=consider-using-enumerate
self.assertEqual(result_sets[i][0], df.values.tolist()[i][0])