blob: 025d421db89ed890924eb41185158300cda1f021 [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import datetime
import itertools
import os
import unittest
from collections import OrderedDict, namedtuple
import pandas as pd
from hmsclient import HMSClient
from six import PY2
from airflow.exceptions import AirflowException
from airflow.hooks.hive_hooks import HiveMetastoreHook, HiveServer2Hook
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
from airflow.utils import timezone
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING
from tests.compat import patch, MagicMock, call, mock
from tests.test_utils.mock_hooks import MockHiveCliHook, MockHiveServer2Hook
from tests.test_utils.mock_process import MockSubProcess
DEFAULT_DATE = timezone.datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
class TestHiveEnvironment(unittest.TestCase):
def setUp(self):
self.next_day = (DEFAULT_DATE +
datetime.timedelta(days=1)).isoformat()[:10]
self.database = 'airflow'
self.partition_by = 'ds'
self.table = 'static_babynames_partitioned'
with patch('airflow.hooks.hive_hooks.HiveMetastoreHook.get_metastore_client') \
as get_metastore_mock:
get_metastore_mock.return_value = MagicMock()
self.hook = HiveMetastoreHook()
class TestHiveCliHook(unittest.TestCase):
@patch('tempfile.tempdir', '/tmp/')
@patch('tempfile._RandomNameSequence.next' if PY2 else 'tempfile._RandomNameSequence.__next__')
@patch('subprocess.Popen')
def test_run_cli(self, mock_popen, mock_temp_dir):
mock_subprocess = MockSubProcess()
mock_popen.return_value = mock_subprocess
mock_temp_dir.return_value = "test_run_cli"
with patch.dict('os.environ', {
'AIRFLOW_CTX_DAG_ID': 'test_dag_id',
'AIRFLOW_CTX_TASK_ID': 'test_task_id',
'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00',
'AIRFLOW_CTX_DAG_RUN_ID': '55',
'AIRFLOW_CTX_DAG_OWNER': 'airflow',
'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com',
}):
hook = MockHiveCliHook()
hook.run_cli("SHOW DATABASES")
hive_cmd = [u'hive', u'-hiveconf', u'airflow.ctx.dag_email=test@airflow.com', u'-hiveconf',
u'airflow.ctx.dag_id=test_dag_id', u'-hiveconf', u'airflow.ctx.dag_owner=airflow',
u'-hiveconf', u'airflow.ctx.dag_run_id=55', u'-hiveconf',
u'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', u'-hiveconf',
u'airflow.ctx.task_id=test_task_id', u'-hiveconf', u'mapreduce.job.queuename=airflow',
u'-hiveconf', u'mapred.job.queue.name=airflow', u'-hiveconf', u'tez.queue.name=airflow',
u'-f', u'/tmp/airflow_hiveop_test_run_cli/tmptest_run_cli']
mock_popen.assert_called_with(
hive_cmd,
stdout=mock_subprocess.PIPE,
stderr=mock_subprocess.STDOUT,
cwd="/tmp/airflow_hiveop_test_run_cli",
close_fds=True
)
@patch('subprocess.Popen')
def test_run_cli_with_hive_conf(self, mock_popen):
hql = "set key;\n" \
"set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" \
"set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n"
dag_id_ctx_var_name = \
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
task_id_ctx_var_name = \
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
execution_date_ctx_var_name = \
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
'env_var_format']
dag_run_id_ctx_var_name = \
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
'env_var_format']
mock_output = ['Connecting to jdbc:hive2://localhost:10000/default',
'log4j:WARN No appenders could be found for logger (org.apache.hive.jdbc.Utils).',
'log4j:WARN Please initialize the log4j system properly.',
'log4j:WARN See http://logging.apache.org/log4j/1.2/faq.html#noconfig for more info.',
'Connected to: Apache Hive (version 1.2.1.2.3.2.0-2950)',
'Driver: Hive JDBC (version 1.2.1.spark2)',
'Transaction isolation: TRANSACTION_REPEATABLE_READ',
'0: jdbc:hive2://localhost:10000/default> USE default;',
'No rows affected (0.37 seconds)',
'0: jdbc:hive2://localhost:10000/default> set key;',
'+------------+--+',
'| set |',
'+------------+--+',
'| key=value |',
'+------------+--+',
'1 row selected (0.133 seconds)',
'0: jdbc:hive2://localhost:10000/default> set airflow.ctx.dag_id;',
'+---------------------------------+--+',
'| set |',
'+---------------------------------+--+',
'| airflow.ctx.dag_id=test_dag_id |',
'+---------------------------------+--+',
'1 row selected (0.008 seconds)',
'0: jdbc:hive2://localhost:10000/default> set airflow.ctx.dag_run_id;',
'+-----------------------------------------+--+',
'| set |',
'+-----------------------------------------+--+',
'| airflow.ctx.dag_run_id=test_dag_run_id |',
'+-----------------------------------------+--+',
'1 row selected (0.007 seconds)',
'0: jdbc:hive2://localhost:10000/default> set airflow.ctx.task_id;',
'+-----------------------------------+--+',
'| set |',
'+-----------------------------------+--+',
'| airflow.ctx.task_id=test_task_id |',
'+-----------------------------------+--+',
'1 row selected (0.009 seconds)',
'0: jdbc:hive2://localhost:10000/default> set airflow.ctx.execution_date;',
'+-------------------------------------------------+--+',
'| set |',
'+-------------------------------------------------+--+',
'| airflow.ctx.execution_date=test_execution_date |',
'+-------------------------------------------------+--+',
'1 row selected (0.006 seconds)',
'0: jdbc:hive2://localhost:10000/default> ',
'0: jdbc:hive2://localhost:10000/default> ',
'Closing: 0: jdbc:hive2://localhost:10000/default',
'']
with patch.dict('os.environ', {
dag_id_ctx_var_name: 'test_dag_id',
task_id_ctx_var_name: 'test_task_id',
execution_date_ctx_var_name: 'test_execution_date',
dag_run_id_ctx_var_name: 'test_dag_run_id',
}):
hook = MockHiveCliHook()
mock_popen.return_value = MockSubProcess(output=mock_output)
output = hook.run_cli(hql=hql, hive_conf={'key': 'value'})
process_inputs = " ".join(mock_popen.call_args_list[0][0][0])
self.assertIn('value', process_inputs)
self.assertIn('test_dag_id', process_inputs)
self.assertIn('test_task_id', process_inputs)
self.assertIn('test_execution_date', process_inputs)
self.assertIn('test_dag_run_id', process_inputs)
self.assertIn('value', output)
self.assertIn('test_dag_id', output)
self.assertIn('test_task_id', output)
self.assertIn('test_execution_date', output)
self.assertIn('test_dag_run_id', output)
@patch('airflow.hooks.hive_hooks.HiveCliHook.run_cli')
def test_load_file_without_create_table(self, mock_run_cli):
filepath = "/path/to/input/file"
table = "output_table"
hook = MockHiveCliHook()
hook.load_file(filepath=filepath, table=table, create=False)
query = (
"LOAD DATA LOCAL INPATH '{filepath}' "
"OVERWRITE INTO TABLE {table} ;\n"
.format(filepath=filepath, table=table)
)
calls = [
call(query)
]
mock_run_cli.assert_has_calls(calls, any_order=True)
@patch('airflow.hooks.hive_hooks.HiveCliHook.run_cli')
def test_load_file_create_table(self, mock_run_cli):
filepath = "/path/to/input/file"
table = "output_table"
field_dict = OrderedDict([("name", "string"), ("gender", "string")])
fields = ",\n ".join(
['`{k}` {v}'.format(k=k.strip('`'), v=v) for k, v in field_dict.items()])
hook = MockHiveCliHook()
hook.load_file(filepath=filepath, table=table,
field_dict=field_dict, create=True, recreate=True)
create_table = (
"DROP TABLE IF EXISTS {table};\n"
"CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n"
"ROW FORMAT DELIMITED\n"
"FIELDS TERMINATED BY ','\n"
"STORED AS textfile\n;".format(table=table, fields=fields)
)
load_data = (
"LOAD DATA LOCAL INPATH '{filepath}' "
"OVERWRITE INTO TABLE {table} ;\n"
.format(filepath=filepath, table=table)
)
calls = [
call(create_table),
call(load_data)
]
mock_run_cli.assert_has_calls(calls, any_order=True)
@patch('airflow.hooks.hive_hooks.HiveCliHook.load_file')
@patch('pandas.DataFrame.to_csv')
def test_load_df(self, mock_to_csv, mock_load_file):
df = pd.DataFrame({"c": ["foo", "bar", "baz"]})
table = "t"
delimiter = ","
encoding = "utf-8"
hook = MockHiveCliHook()
hook.load_df(df=df,
table=table,
delimiter=delimiter,
encoding=encoding)
assert mock_to_csv.call_count == 1
kwargs = mock_to_csv.call_args[1]
self.assertEqual(kwargs["header"], False)
self.assertEqual(kwargs["index"], False)
self.assertEqual(kwargs["sep"], delimiter)
assert mock_load_file.call_count == 1
kwargs = mock_load_file.call_args[1]
self.assertEqual(kwargs["delimiter"], delimiter)
self.assertEqual(kwargs["field_dict"], {"c": "STRING"})
self.assertTrue(isinstance(kwargs["field_dict"], OrderedDict))
self.assertEqual(kwargs["table"], table)
@patch('airflow.hooks.hive_hooks.HiveCliHook.load_file')
@patch('pandas.DataFrame.to_csv')
def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file):
hook = MockHiveCliHook()
bools = (True, False)
for create, recreate in itertools.product(bools, bools):
mock_load_file.reset_mock()
hook.load_df(df=pd.DataFrame({"c": range(0, 10)}),
table="t",
create=create,
recreate=recreate)
assert mock_load_file.call_count == 1
kwargs = mock_load_file.call_args[1]
self.assertEqual(kwargs["create"], create)
self.assertEqual(kwargs["recreate"], recreate)
@patch('airflow.hooks.hive_hooks.HiveCliHook.run_cli')
def test_load_df_with_data_types(self, mock_run_cli):
ord_dict = OrderedDict()
ord_dict['b'] = [True]
ord_dict['i'] = [-1]
ord_dict['t'] = [1]
ord_dict['f'] = [0.0]
ord_dict['c'] = ['c']
ord_dict['M'] = [datetime.datetime(2018, 1, 1)]
ord_dict['O'] = [object()]
ord_dict['S'] = [b'STRING']
ord_dict['U'] = ['STRING']
ord_dict['V'] = [None]
df = pd.DataFrame(ord_dict)
hook = MockHiveCliHook()
hook.load_df(df, 't')
query = """
CREATE TABLE IF NOT EXISTS t (
`b` BOOLEAN,
`i` BIGINT,
`t` BIGINT,
`f` DOUBLE,
`c` STRING,
`M` TIMESTAMP,
`O` STRING,
`S` STRING,
`U` STRING,
`V` STRING)
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ','
STORED AS textfile
;
"""
call(
self, mock_run_cli.call_args_list[0][0][0], query)
class TestHiveMetastoreHook(TestHiveEnvironment):
VALID_FILTER_MAP = {'key2': 'value2'}
def test_get_max_partition_from_empty_part_specs(self):
max_partition = \
HiveMetastoreHook._get_max_partition_from_part_specs([],
'key1',
self.VALID_FILTER_MAP)
self.assertIsNone(max_partition)
# @patch('airflow.hooks.hive_hooks.HiveMetastoreHook', 'get_metastore_client')
def test_get_max_partition_from_valid_part_specs_and_invalid_filter_map(self):
with self.assertRaises(AirflowException):
HiveMetastoreHook._get_max_partition_from_part_specs(
[{'key1': 'value1', 'key2': 'value2'},
{'key1': 'value3', 'key2': 'value4'}],
'key1',
{'key3': 'value5'})
def test_get_max_partition_from_valid_part_specs_and_invalid_partition_key(self):
with self.assertRaises(AirflowException):
HiveMetastoreHook._get_max_partition_from_part_specs(
[{'key1': 'value1', 'key2': 'value2'},
{'key1': 'value3', 'key2': 'value4'}],
'key3',
self.VALID_FILTER_MAP)
def test_get_max_partition_from_valid_part_specs_and_none_partition_key(self):
with self.assertRaises(AirflowException):
HiveMetastoreHook._get_max_partition_from_part_specs(
[{'key1': 'value1', 'key2': 'value2'},
{'key1': 'value3', 'key2': 'value4'}],
None,
self.VALID_FILTER_MAP)
def test_get_max_partition_from_valid_part_specs_and_none_filter_map(self):
max_partition = \
HiveMetastoreHook._get_max_partition_from_part_specs(
[{'key1': 'value1', 'key2': 'value2'},
{'key1': 'value3', 'key2': 'value4'}],
'key1',
None)
# No partition will be filtered out.
self.assertEqual(max_partition, 'value3')
def test_get_max_partition_from_valid_part_specs(self):
max_partition = \
HiveMetastoreHook._get_max_partition_from_part_specs(
[{'key1': 'value1', 'key2': 'value2'},
{'key1': 'value3', 'key2': 'value4'}],
'key1',
self.VALID_FILTER_MAP)
self.assertEqual(max_partition, 'value1')
def test_get_max_partition_from_valid_part_specs_return_type(self):
max_partition = \
HiveMetastoreHook._get_max_partition_from_part_specs(
[{'key1': 'value1', 'key2': 'value2'},
{'key1': 'value3', 'key2': 'value4'}],
'key1',
self.VALID_FILTER_MAP)
self.assertIsInstance(max_partition, str)
@patch("airflow.hooks.hive_hooks.HiveMetastoreHook.get_connection",
return_value=[Connection(host="localhost", port="9802")])
@patch("airflow.hooks.hive_hooks.socket")
def test_error_metastore_client(self, socket_mock, _find_valid_server_mock):
socket_mock.socket.return_value.connect_ex.return_value = 0
self.hook.get_metastore_client()
def test_get_conn(self):
with patch('airflow.hooks.hive_hooks.HiveMetastoreHook._find_valid_server') \
as find_valid_server:
find_valid_server.return_value = MagicMock(return_value={})
metastore_hook = HiveMetastoreHook()
self.assertIsInstance(metastore_hook.get_conn(), HMSClient)
def test_check_for_partition(self):
# Check for existent partition.
FakePartition = namedtuple('FakePartition', ['values'])
fake_partition = FakePartition(['2015-01-01'])
metastore = self.hook.metastore.__enter__()
partition = "{p_by}='{date}'".format(date=DEFAULT_DATE_DS,
p_by=self.partition_by)
metastore.get_partitions_by_filter = MagicMock(
return_value=[fake_partition])
self.assertTrue(
self.hook.check_for_partition(self.database, self.table,
partition)
)
metastore.get_partitions_by_filter(
self.database, self.table, partition, 1)
# Check for non-existent partition.
missing_partition = "{p_by}='{date}'".format(date=self.next_day,
p_by=self.partition_by)
metastore.get_partitions_by_filter = MagicMock(return_value=[])
self.assertFalse(
self.hook.check_for_partition(self.database, self.table,
missing_partition)
)
metastore.get_partitions_by_filter.assert_called_with(
self.database, self.table, missing_partition, 1)
def test_check_for_named_partition(self):
# Check for existing partition.
partition = "{p_by}={date}".format(date=DEFAULT_DATE_DS,
p_by=self.partition_by)
self.hook.metastore.__enter__(
).check_for_named_partition = MagicMock(return_value=True)
self.assertTrue(
self.hook.check_for_named_partition(self.database,
self.table,
partition))
self.hook.metastore.__enter__().check_for_named_partition.assert_called_with(
self.database, self.table, partition)
# Check for non-existent partition
missing_partition = "{p_by}={date}".format(date=self.next_day,
p_by=self.partition_by)
self.hook.metastore.__enter__().check_for_named_partition = MagicMock(
return_value=False)
self.assertFalse(
self.hook.check_for_named_partition(self.database,
self.table,
missing_partition)
)
self.hook.metastore.__enter__().check_for_named_partition.assert_called_with(
self.database, self.table, missing_partition)
def test_get_table(self):
self.hook.metastore.__enter__().get_table = MagicMock()
self.hook.get_table(db=self.database, table_name=self.table)
self.hook.metastore.__enter__().get_table.assert_called_with(
dbname=self.database, tbl_name=self.table)
def test_get_tables(self): # static_babynames_partitioned
self.hook.metastore.__enter__().get_tables = MagicMock(
return_value=['static_babynames_partitioned'])
self.hook.get_tables(db=self.database, pattern=self.table + "*")
self.hook.metastore.__enter__().get_tables.assert_called_with(
db_name='airflow', pattern='static_babynames_partitioned*')
self.hook.metastore.__enter__().get_table_objects_by_name.assert_called_with(
'airflow', ['static_babynames_partitioned'])
def test_get_databases(self):
metastore = self.hook.metastore.__enter__()
metastore.get_databases = MagicMock()
self.hook.get_databases(pattern='*')
metastore.get_databases.assert_called_with('*')
def test_get_partitions(self):
FakeFieldSchema = namedtuple('FakeFieldSchema', ['name'])
fake_schema = FakeFieldSchema('ds')
FakeTable = namedtuple('FakeTable', ['partitionKeys'])
fake_table = FakeTable([fake_schema])
FakePartition = namedtuple('FakePartition', ['values'])
fake_partition = FakePartition(['2015-01-01'])
metastore = self.hook.metastore.__enter__()
metastore.get_table = MagicMock(return_value=fake_table)
metastore.get_partitions = MagicMock(
return_value=[fake_partition])
partitions = self.hook.get_partitions(schema=self.database,
table_name=self.table)
self.assertEqual(len(partitions), 1)
self.assertEqual(partitions, [{self.partition_by: DEFAULT_DATE_DS}])
metastore.get_table.assert_called_with(
dbname=self.database, tbl_name=self.table)
metastore.get_partitions.assert_called_with(
db_name=self.database, tbl_name=self.table, max_parts=HiveMetastoreHook.MAX_PART_COUNT)
def test_max_partition(self):
FakeFieldSchema = namedtuple('FakeFieldSchema', ['name'])
fake_schema = FakeFieldSchema('ds')
FakeTable = namedtuple('FakeTable', ['partitionKeys'])
fake_table = FakeTable([fake_schema])
metastore = self.hook.metastore.__enter__()
metastore.get_table = MagicMock(return_value=fake_table)
metastore.get_partition_names = MagicMock(
return_value=['ds=2015-01-01'])
metastore.partition_name_to_spec = MagicMock(
return_value={'ds': '2015-01-01'})
filter_map = {self.partition_by: DEFAULT_DATE_DS}
partition = self.hook.max_partition(schema=self.database,
table_name=self.table,
field=self.partition_by,
filter_map=filter_map)
self.assertEqual(partition, DEFAULT_DATE_DS)
metastore.get_table.assert_called_with(
dbname=self.database, tbl_name=self.table)
metastore.get_partition_names.assert_called_with(
self.database, self.table, max_parts=HiveMetastoreHook.MAX_PART_COUNT)
metastore.partition_name_to_spec.assert_called_with('ds=2015-01-01')
def test_table_exists(self):
# Test with existent table.
self.hook.metastore.__enter__().get_table = MagicMock(return_value=True)
self.assertTrue(self.hook.table_exists(self.table, db=self.database))
self.hook.metastore.__enter__().get_table.assert_called_with(
dbname='airflow', tbl_name='static_babynames_partitioned')
# Test with non-existent table.
self.hook.metastore.__enter__().get_table = MagicMock(side_effect=Exception())
self.assertFalse(
self.hook.table_exists("does-not-exist")
)
@mock.patch('airflow.hooks.hive_hooks.random.shuffle')
@mock.patch('airflow.hooks.hive_hooks.HiveMetastoreHook.get_connections')
@mock.patch('airflow.hooks.hive_hooks.HiveMetastoreHook.get_metastore_client')
def test_check_hms_clients_load_balance(self, mock_client, mock_get_conn, mock_shuffle):
HiveMetastoreHook()._find_valid_server()
mock_shuffle.assert_called_once_with(mock_get_conn.return_value)
class TestHiveServer2Hook(unittest.TestCase):
def _upload_dataframe(self):
df = pd.DataFrame({'a': [1, 2], 'b': [1, 2]})
self.local_path = '/tmp/TestHiveServer2Hook.csv'
df.to_csv(self.local_path, header=False, index=False)
def setUp(self):
self._upload_dataframe()
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG('test_dag_id', default_args=args)
self.database = 'airflow'
self.table = 'hive_server_hook'
self.hql = """
CREATE DATABASE IF NOT EXISTS {{ params.database }};
USE {{ params.database }};
DROP TABLE IF EXISTS {{ params.table }};
CREATE TABLE IF NOT EXISTS {{ params.table }} (
a int,
b int)
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ',';
LOAD DATA LOCAL INPATH '{{ params.csv_path }}'
OVERWRITE INTO TABLE {{ params.table }};
"""
self.columns = ['{}.a'.format(self.table),
'{}.b'.format(self.table)]
with patch('airflow.hooks.hive_hooks.HiveMetastoreHook.get_metastore_client') \
as get_metastore_mock:
get_metastore_mock.return_value = MagicMock()
self.hook = HiveMetastoreHook()
def test_get_conn(self):
hook = MockHiveServer2Hook()
hook.get_conn()
@patch('pyhive.hive.connect')
def test_get_conn_with_password(self, mock_connect):
conn_id = "conn_with_password"
conn_env = CONN_ENV_PREFIX + conn_id.upper()
with patch.dict(
'os.environ',
{conn_env: "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP"}
):
HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn()
mock_connect.assert_called_once_with(
host='localhost',
port=10000,
auth='LDAP',
kerberos_service_name=None,
username='conn_id',
password='conn_pass',
database='default')
def test_get_records(self):
hook = MockHiveServer2Hook()
query = "SELECT * FROM {}".format(self.table)
with patch.dict('os.environ', {
'AIRFLOW_CTX_DAG_ID': 'test_dag_id',
'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835',
'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00',
'AIRFLOW_CTX_DAG_RUN_ID': '55',
'AIRFLOW_CTX_DAG_OWNER': 'airflow',
'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com',
}):
results = hook.get_records(query, schema=self.database)
self.assertListEqual(results, [(1, 1), (2, 2)])
hook.get_conn.assert_called_with(self.database)
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_id=test_dag_id')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.task_id=HiveHook_3835')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_run_id=55')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_owner=airflow')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_email=test@airflow.com')
def test_get_pandas_df(self):
hook = MockHiveServer2Hook()
query = "SELECT * FROM {}".format(self.table)
with patch.dict('os.environ', {
'AIRFLOW_CTX_DAG_ID': 'test_dag_id',
'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835',
'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00',
'AIRFLOW_CTX_DAG_RUN_ID': '55',
'AIRFLOW_CTX_DAG_OWNER': 'airflow',
'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com',
}):
df = hook.get_pandas_df(query, schema=self.database)
self.assertEqual(len(df), 2)
self.assertListEqual(df["hive_server_hook.a"].values.tolist(), [1, 2])
hook.get_conn.assert_called_with(self.database)
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_id=test_dag_id')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.task_id=HiveHook_3835')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_run_id=55')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_owner=airflow')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_email=test@airflow.com')
def test_get_results_header(self):
hook = MockHiveServer2Hook()
query = "SELECT * FROM {}".format(self.table)
results = hook.get_results(query, schema=self.database)
self.assertListEqual([col[0] for col in results['header']],
self.columns)
def test_get_results_data(self):
hook = MockHiveServer2Hook()
query = "SELECT * FROM {}".format(self.table)
results = hook.get_results(query, schema=self.database)
self.assertListEqual(results['data'], [(1, 1), (2, 2)])
def test_to_csv(self):
hook = MockHiveServer2Hook()
hook._get_results = MagicMock(return_value=iter([
[
('hive_server_hook.a', 'INT_TYPE', None, None, None, None, True),
('hive_server_hook.b', 'INT_TYPE', None, None, None, None, True)
], (1, 1), (2, 2)
]))
query = "SELECT * FROM {}".format(self.table)
csv_filepath = 'query_results.csv'
hook.to_csv(query, csv_filepath, schema=self.database,
delimiter=',', lineterminator='\n', output_header=True, fetch_size=2)
df = pd.read_csv(csv_filepath, sep=',')
self.assertListEqual(df.columns.tolist(), self.columns)
self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
self.assertEqual(len(df), 2)
def test_multi_statements(self):
sqls = [
"CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)",
"SELECT * FROM {}".format(self.table),
"DROP TABLE test_multi_statements",
]
hook = MockHiveServer2Hook()
with patch.dict('os.environ', {
'AIRFLOW_CTX_DAG_ID': 'test_dag_id',
'AIRFLOW_CTX_TASK_ID': 'HiveHook_3835',
'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00',
'AIRFLOW_CTX_DAG_RUN_ID': '55',
'AIRFLOW_CTX_DAG_OWNER': 'airflow',
'AIRFLOW_CTX_DAG_EMAIL': 'test@airflow.com',
}):
# df = hook.get_pandas_df(query, schema=self.database)
results = hook.get_records(sqls, schema=self.database)
self.assertListEqual(results, [(1, 1), (2, 2)])
# self.assertEqual(len(df), 2)
# self.assertListEqual(df["hive_server_hook.a"].values.tolist(), [1, 2])
hook.get_conn.assert_called_with(self.database)
hook.mock_cursor.execute.assert_any_call(
'CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)')
hook.mock_cursor.execute.assert_any_call(
'SELECT * FROM {}'.format(self.table))
hook.mock_cursor.execute.assert_any_call(
'DROP TABLE test_multi_statements')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_id=test_dag_id')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.task_id=HiveHook_3835')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.execution_date=2015-01-01T00:00:00+00:00')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_run_id=55')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_owner=airflow')
hook.mock_cursor.execute.assert_any_call(
'set airflow.ctx.dag_email=test@airflow.com')
def test_get_results_with_hive_conf(self):
hql = ["set key",
"set airflow.ctx.dag_id",
"set airflow.ctx.dag_run_id",
"set airflow.ctx.task_id",
"set airflow.ctx.execution_date"]
dag_id_ctx_var_name = \
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
task_id_ctx_var_name = \
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
execution_date_ctx_var_name = \
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
'env_var_format']
dag_run_id_ctx_var_name = \
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
'env_var_format']
with patch.dict('os.environ', {
dag_id_ctx_var_name: 'test_dag_id',
task_id_ctx_var_name: 'test_task_id',
execution_date_ctx_var_name: 'test_execution_date',
dag_run_id_ctx_var_name: 'test_dag_run_id',
}):
hook = MockHiveServer2Hook()
hook._get_results = MagicMock(return_value=iter(
["header", ("value", "test"), ("test_dag_id", "test"), ("test_task_id", "test"),
("test_execution_date", "test"), ("test_dag_run_id", "test")]
))
output = '\n'.join(res_tuple[0] for res_tuple in hook.get_results(
hql=hql, hive_conf={'key': 'value'})['data'])
self.assertIn('value', output)
self.assertIn('test_dag_id', output)
self.assertIn('test_task_id', output)
self.assertIn('test_execution_date', output)
self.assertIn('test_dag_run_id', output)
class TestHiveCli(unittest.TestCase):
def setUp(self):
self.nondefault_schema = "nondefault"
os.environ["AIRFLOW__CORE__SECURITY"] = "kerberos"
def tearDown(self):
del os.environ["AIRFLOW__CORE__SECURITY"]
def test_get_proxy_user_value(self):
hook = MockHiveCliHook()
returner = MagicMock()
returner.extra_dejson = {'proxy_user': 'a_user_proxy'}
hook.use_beeline = True
hook.conn = returner
# Run
result = hook._prepare_cli_cmd()
# Verify
self.assertIn('hive.server2.proxy.user=a_user_proxy', result[2])