blob: 16d196a39b8e1a412c31ab0497a3b159887652d1 [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from hdfs import HdfsError
from mock import patch, call
from airflow.hooks.webhdfs_hook import WebHDFSHook, AirflowWebHDFSHookException
from airflow.models.connection import Connection
class TestWebHDFSHook(unittest.TestCase):
def setUp(self):
self.webhdfs_hook = WebHDFSHook()
@patch('airflow.hooks.webhdfs_hook.InsecureClient')
@patch('airflow.hooks.webhdfs_hook.WebHDFSHook.get_connections', return_value=[
Connection(host='host_1', port=123),
Connection(host='host_2', port=321, login='user')
])
def test_get_conn(self, mock_get_connections, mock_insecure_client):
mock_insecure_client.side_effect = [HdfsError('Error'), mock_insecure_client.return_value]
conn = self.webhdfs_hook.get_conn()
mock_insecure_client.assert_has_calls([
call('http://{host}:{port}'.format(host=connection.host, port=connection.port),
user=connection.login)
for connection in mock_get_connections.return_value
])
mock_insecure_client.return_value.status.assert_called_once_with('/')
self.assertEqual(conn, mock_insecure_client.return_value)
@patch('airflow.hooks.webhdfs_hook.KerberosClient', create=True)
@patch('airflow.hooks.webhdfs_hook.WebHDFSHook.get_connections', return_value=[
Connection(host='host_1', port=123)
])
@patch('airflow.hooks.webhdfs_hook._kerberos_security_mode', return_value=True)
def test_get_conn_kerberos_security_mode(self,
mock_kerberos_security_mode,
mock_get_connections,
mock_kerberos_client):
conn = self.webhdfs_hook.get_conn()
connection = mock_get_connections.return_value[0]
mock_kerberos_client.assert_called_once_with(
'http://{host}:{port}'.format(host=connection.host, port=connection.port))
self.assertEqual(conn, mock_kerberos_client.return_value)
@patch('airflow.hooks.webhdfs_hook.WebHDFSHook.get_connections', return_value=[])
def test_get_conn_no_connection_found(self, mock_get_connection):
with self.assertRaises(AirflowWebHDFSHookException):
self.webhdfs_hook.get_conn()
@patch('airflow.hooks.webhdfs_hook.WebHDFSHook.get_conn')
def test_check_for_path(self, mock_get_conn):
hdfs_path = 'path'
exists_path = self.webhdfs_hook.check_for_path(hdfs_path)
mock_get_conn.assert_called_once_with()
mock_status = mock_get_conn.return_value.status
mock_status.assert_called_once_with(hdfs_path, strict=False)
self.assertEqual(exists_path, bool(mock_status.return_value))
@patch('airflow.hooks.webhdfs_hook.WebHDFSHook.get_conn')
def test_load_file(self, mock_get_conn):
source = 'source'
destination = 'destination'
self.webhdfs_hook.load_file(source, destination)
mock_get_conn.assert_called_once_with()
mock_upload = mock_get_conn.return_value.upload
mock_upload.assert_called_once_with(
hdfs_path=destination,
local_path=source,
overwrite=True,
n_threads=1
)