blob: 44aad34e53ddd9b93cbee0ed85cf32b4dfa62e77 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import unittest
import mock
from airflow.exceptions import AirflowException
try:
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.providers.docker.hooks.docker import DockerHook
from docker import APIClient
except ImportError:
pass
class TestDockerOperator(unittest.TestCase):
@mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute(self, client_class_mock, tempdir_mock):
host_config = mock.Mock()
tempdir_mock.return_value.__enter__.return_value = '/mkdtemp'
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = host_config
client_mock.images.return_value = []
client_mock.attach.return_value = ['container log']
client_mock.logs.return_value = ['container log']
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.wait.return_value = {"StatusCode": 0}
client_class_mock.return_value = client_mock
operator = DockerOperator(api_version='1.19', command='env', environment={'UNIT': 'TEST'},
private_environment={'PRIVATE': 'MESSAGE'}, image='ubuntu:latest',
network_mode='bridge', owner='unittest', task_id='unittest',
volumes=['/host/path:/container/path'],
working_dir='/container/path', shm_size=1000,
host_tmp_dir='/host/airflow', container_name='test_container',
tty=True)
operator.execute(None)
client_class_mock.assert_called_once_with(base_url='unix://var/run/docker.sock', tls=None,
version='1.19')
client_mock.create_container.assert_called_once_with(command='env',
name='test_container',
environment={
'AIRFLOW_TMP_DIR': '/tmp/airflow',
'UNIT': 'TEST',
'PRIVATE': 'MESSAGE'
},
host_config=host_config,
image='ubuntu:latest',
user=None,
working_dir='/container/path',
tty=True
)
client_mock.create_host_config.assert_called_once_with(binds=['/host/path:/container/path',
'/mkdtemp:/tmp/airflow'],
network_mode='bridge',
shm_size=1000,
cpu_shares=1024,
mem_limit=None,
auto_remove=False,
dns=None,
dns_search=None)
tempdir_mock.assert_called_once_with(dir='/host/airflow', prefix='airflowtmp')
client_mock.images.assert_called_once_with(name='ubuntu:latest')
client_mock.attach.assert_called_once_with(container='some_id', stdout=True,
stderr=True, stream=True)
client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True,
decode=True)
client_mock.wait.assert_called_once_with('some_id')
def test_private_environment_is_private(self):
operator = DockerOperator(private_environment={'PRIVATE': 'MESSAGE'},
image='ubuntu:latest',
task_id='unittest')
self.assertEqual(
operator._private_environment, {'PRIVATE': 'MESSAGE'},
"To keep this private, it must be an underscored attribute."
)
@mock.patch('airflow.providers.docker.operators.docker.tls.TLSConfig')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_tls(self, client_class_mock, tls_class_mock):
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = mock.Mock()
client_mock.images.return_value = []
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}
client_class_mock.return_value = client_mock
tls_mock = mock.Mock()
tls_class_mock.return_value = tls_mock
operator = DockerOperator(docker_url='tcp://127.0.0.1:2376', image='ubuntu',
owner='unittest', task_id='unittest', tls_client_cert='cert.pem',
tls_ca_cert='ca.pem', tls_client_key='key.pem')
operator.execute(None)
tls_class_mock.assert_called_once_with(assert_hostname=None, ca_cert='ca.pem',
client_cert=('cert.pem', 'key.pem'),
ssl_version=None, verify=True)
client_class_mock.assert_called_once_with(base_url='https://127.0.0.1:2376',
tls=tls_mock, version=None)
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_unicode_logs(self, client_class_mock):
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = mock.Mock()
client_mock.images.return_value = []
client_mock.attach.return_value = ['unicode container log 😁']
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}
client_class_mock.return_value = client_mock
originalRaiseExceptions = logging.raiseExceptions # pylint: disable=invalid-name
logging.raiseExceptions = True
operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest')
with mock.patch('traceback.print_exception') as print_exception_mock:
operator.execute(None)
logging.raiseExceptions = originalRaiseExceptions
print_exception_mock.assert_not_called()
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_container_fails(self, client_class_mock):
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = mock.Mock()
client_mock.images.return_value = []
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 1}
client_class_mock.return_value = client_mock
operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest')
with self.assertRaises(AirflowException):
operator.execute(None)
@staticmethod
def test_on_kill():
client_mock = mock.Mock(spec=APIClient)
operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest')
operator.cli = client_mock
operator.container = {'Id': 'some_id'}
operator.on_kill()
client_mock.stop.assert_called_once_with('some_id')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock):
# Mock out a Docker client, so operations don't raise errors
client_mock = mock.Mock(name='DockerOperator.APIClient mock', spec=APIClient)
client_mock.images.return_value = []
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}
operator_client_mock.return_value = client_mock
# Create the DockerOperator
operator = DockerOperator(
image='publicregistry/someimage',
owner='unittest',
task_id='unittest'
)
# Mock out the DockerHook
hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook)
hook_mock.get_conn.return_value = client_mock
operator.get_hook = mock.Mock(
name='DockerOperator.get_hook mock',
spec=DockerOperator.get_hook,
return_value=hook_mock
)
operator.execute(None)
self.assertEqual(
operator.get_hook.call_count, 0,
'Hook called though no docker_conn_id configured'
)
@mock.patch('airflow.providers.docker.operators.docker.DockerHook')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock,
operator_docker_hook):
# Mock out a Docker client, so operations don't raise errors
client_mock = mock.Mock(name='DockerOperator.APIClient mock', spec=APIClient)
client_mock.images.return_value = []
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}
operator_client_mock.return_value = client_mock
# Create the DockerOperator
operator = DockerOperator(
image='publicregistry/someimage',
owner='unittest',
task_id='unittest',
docker_conn_id='some_conn_id'
)
# Mock out the DockerHook
hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook)
hook_mock.get_conn.return_value = client_mock
operator_docker_hook.return_value = hook_mock
operator.execute(None)
self.assertEqual(
operator_client_mock.call_count, 0,
'Client was called on the operator instead of the hook'
)
self.assertEqual(
operator_docker_hook.call_count, 1,
'Hook was not called although docker_conn_id configured'
)
self.assertEqual(
client_mock.pull.call_count, 1,
'Image was not pulled using operator client'
)
@mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_xcom_behavior(self, client_class_mock, tempdir_mock):
tempdir_mock.return_value.__enter__.return_value = '/mkdtemp'
client_mock = mock.Mock(spec=APIClient)
client_mock.images.return_value = []
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.attach.return_value = ['container log']
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.wait.return_value = {"StatusCode": 0}
client_class_mock.return_value = client_mock
kwargs = {
'api_version': '1.19',
'command': 'env',
'environment': {'UNIT': 'TEST'},
'private_environment': {'PRIVATE': 'MESSAGE'},
'image': 'ubuntu:latest',
'network_mode': 'bridge',
'owner': 'unittest',
'task_id': 'unittest',
'volumes': ['/host/path:/container/path'],
'working_dir': '/container/path',
'shm_size': 1000,
'host_tmp_dir': '/host/airflow',
'container_name': 'test_container',
'tty': True,
}
xcom_push_operator = DockerOperator(**kwargs, do_xcom_push=True)
no_xcom_push_operator = DockerOperator(**kwargs, do_xcom_push=False)
xcom_push_result = xcom_push_operator.execute(None)
no_xcom_push_result = no_xcom_push_operator.execute(None)
self.assertEqual(xcom_push_result, b'container log')
self.assertIs(no_xcom_push_result, None)
if __name__ == "__main__":
unittest.main()