| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| import json |
| import os |
| import sys |
| import unittest |
| from argparse import Namespace |
| from contextlib import contextmanager |
| from datetime import datetime |
| from unittest import mock |
| |
| import pytest |
| from parameterized import parameterized |
| |
| from airflow import settings |
| from airflow.exceptions import AirflowException |
| from airflow.utils import cli, cli_action_loggers |
| |
| |
| class TestCliUtil(unittest.TestCase): |
| def test_metrics_build(self): |
| func_name = 'test' |
| exec_date = datetime.utcnow() |
| namespace = Namespace(dag_id='foo', task_id='bar', subcommand='test', execution_date=exec_date) |
| metrics = cli._build_metrics(func_name, namespace) |
| |
| expected = { |
| 'user': os.environ.get('USER'), |
| 'sub_command': 'test', |
| 'dag_id': 'foo', |
| 'task_id': 'bar', |
| 'execution_date': exec_date, |
| } |
| for k, v in expected.items(): |
| assert v == metrics.get(k) |
| |
| assert metrics.get('start_datetime') <= datetime.utcnow() |
| assert metrics.get('full_command') |
| |
| log_dao = metrics.get('log') |
| assert log_dao |
| assert log_dao.dag_id == metrics.get('dag_id') |
| assert log_dao.task_id == metrics.get('task_id') |
| assert log_dao.execution_date == metrics.get('execution_date') |
| assert log_dao.owner == metrics.get('user') |
| |
| def test_fail_function(self): |
| """ |
| Actual function is failing and fail needs to be propagated. |
| :return: |
| """ |
| with pytest.raises(NotImplementedError): |
| fail_func(Namespace()) |
| |
| def test_success_function(self): |
| """ |
| Test success function but with failing callback. |
| In this case, failure should not propagate. |
| :return: |
| """ |
| with fail_action_logger_callback(): |
| success_func(Namespace()) |
| |
| def test_process_subdir_path_with_placeholder(self): |
| assert os.path.join(settings.DAGS_FOLDER, 'abc') == cli.process_subdir('DAGS_FOLDER/abc') |
| |
| def test_get_dags(self): |
| dags = cli.get_dags(None, "example_subdag_operator") |
| assert len(dags) == 1 |
| |
| dags = cli.get_dags(None, "subdag", True) |
| assert len(dags) > 1 |
| |
| with pytest.raises(AirflowException): |
| cli.get_dags(None, "foobar", True) |
| |
| @parameterized.expand( |
| [ |
| ( |
| "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password test", |
| "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password ********", |
| ), |
| ( |
| "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p test", |
| "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p ********", |
| ), |
| ( |
| "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password=test", |
| "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin --password=********", |
| ), |
| ( |
| "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p=test", |
| "airflow users create -u test2 -l doe -f jon -e jdoe@apache.org -r admin -p=********", |
| ), |
| ( |
| "airflow connections add dsfs --conn-login asd --conn-password test --conn-type google", |
| "airflow connections add dsfs --conn-login asd --conn-password ******** --conn-type google", |
| ), |
| ( |
| "airflow scheduler -p", |
| "airflow scheduler -p", |
| ), |
| ( |
| "airflow celery flower -p 8888", |
| "airflow celery flower -p 8888", |
| ), |
| ] |
| ) |
| def test_cli_create_user_supplied_password_is_masked(self, given_command, expected_masked_command): |
| # '-p' value which is not password, like 'airflow scheduler -p' |
| # or 'airflow celery flower -p 8888', should not be masked |
| args = given_command.split() |
| |
| expected_command = expected_masked_command.split() |
| |
| exec_date = datetime.utcnow() |
| namespace = Namespace(dag_id='foo', task_id='bar', subcommand='test', execution_date=exec_date) |
| with mock.patch.object(sys, "argv", args): |
| metrics = cli._build_metrics(args[1], namespace) |
| |
| assert metrics.get('start_datetime') <= datetime.utcnow() |
| |
| log = metrics.get('log') |
| command = json.loads(log.extra).get('full_command') # type: str |
| # Replace single quotes to double quotes to avoid json decode error |
| command = json.loads(command.replace("'", '"')) |
| assert command == expected_command |
| |
| def test_setup_locations_relative_pid_path(self): |
| relative_pid_path = "fake.pid" |
| pid_full_path = os.path.join(os.getcwd(), relative_pid_path) |
| pid, _, _, _ = cli.setup_locations(process="fake_process", pid=relative_pid_path) |
| assert pid == pid_full_path |
| |
| def test_setup_locations_absolute_pid_path(self): |
| abs_pid_path = os.path.join(os.getcwd(), "fake.pid") |
| pid, _, _, _ = cli.setup_locations(process="fake_process", pid=abs_pid_path) |
| assert pid == abs_pid_path |
| |
| def test_setup_locations_none_pid_path(self): |
| process_name = "fake_process" |
| default_pid_path = os.path.join(settings.AIRFLOW_HOME, f"airflow-{process_name}.pid") |
| pid, _, _, _ = cli.setup_locations(process=process_name) |
| assert pid == default_pid_path |
| |
| |
| @contextmanager |
| def fail_action_logger_callback(): |
| """ |
| Adding failing callback and revert it back when closed. |
| :return: |
| """ |
| tmp = cli_action_loggers.__pre_exec_callbacks[:] |
| |
| def fail_callback(**_): |
| raise NotImplementedError |
| |
| cli_action_loggers.register_pre_exec_callback(fail_callback) |
| yield |
| cli_action_loggers.__pre_exec_callbacks = tmp |
| |
| |
| @cli.action_logging |
| def fail_func(_): |
| raise NotImplementedError |
| |
| |
| @cli.action_logging |
| def success_func(_): |
| pass |