| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import logging |
| from unittest.mock import MagicMock, patch |
| |
| import pytest |
| |
| from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning |
| from airflow.models import Connection |
| from airflow.models.dag import DAG |
| from airflow.providers.apache.livy.hooks.livy import BatchState |
| from airflow.providers.apache.livy.operators.livy import LivyOperator |
| from airflow.utils import db, timezone |
| |
| pytestmark = pytest.mark.db_test |
| |
| |
| DEFAULT_DATE = timezone.datetime(2017, 1, 1) |
| BATCH_ID = 100 |
| APP_ID = "application_1433865536131_34483" |
| GET_BATCH = {"appId": APP_ID} |
| LOG_RESPONSE = {"total": 3, "log": ["first_line", "second_line", "third_line"]} |
| |
| |
| class TestLivyOperator: |
| def setup_method(self): |
| args = {"owner": "airflow", "start_date": DEFAULT_DATE} |
| self.dag = DAG("test_dag_id", default_args=args) |
| db.merge_conn( |
| Connection( |
| conn_id="livyunittest", conn_type="livy", host="localhost:8998", port="8998", schema="http" |
| ) |
| ) |
| self.mock_context = dict(ti=MagicMock()) |
| |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", |
| return_value=None, |
| ) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") |
| def test_poll_for_termination(self, mock_livy, mock_dump_logs): |
| state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS] |
| |
| def side_effect(_, retry_args): |
| if state_list: |
| return state_list.pop(0) |
| # fail if does not stop right before |
| raise AssertionError() |
| |
| mock_livy.side_effect = side_effect |
| |
| task = LivyOperator(file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example") |
| task.poll_for_termination(BATCH_ID) |
| |
| mock_livy.assert_called_with(BATCH_ID, retry_args=None) |
| mock_dump_logs.assert_called_with(BATCH_ID) |
| assert mock_livy.call_count == 3 |
| |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", |
| return_value=None, |
| ) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") |
| def test_poll_for_termination_fail(self, mock_livy, mock_dump_logs): |
| state_list = 2 * [BatchState.RUNNING] + [BatchState.ERROR] |
| |
| def side_effect(_, retry_args): |
| if state_list: |
| return state_list.pop(0) |
| # fail if does not stop right before |
| raise AssertionError() |
| |
| mock_livy.side_effect = side_effect |
| |
| task = LivyOperator(file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example") |
| |
| with pytest.raises(AirflowException): |
| task.poll_for_termination(BATCH_ID) |
| |
| mock_livy.assert_called_with(BATCH_ID, retry_args=None) |
| mock_dump_logs.assert_called_with(BATCH_ID) |
| assert mock_livy.call_count == 3 |
| |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", |
| return_value=None, |
| ) |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", |
| return_value=BatchState.SUCCESS, |
| ) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) |
| def test_execution(self, mock_get_batch, mock_post, mock_get, mock_dump_logs): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", |
| file="sparkapp", |
| polling_interval=1, |
| dag=self.dag, |
| task_id="livy_example", |
| ) |
| task.execute(context=self.mock_context) |
| |
| call_args = {k: v for k, v in mock_post.call_args.kwargs.items() if v} |
| assert call_args == {"file": "sparkapp"} |
| mock_get.assert_called_once_with(BATCH_ID, retry_args=None) |
| mock_dump_logs.assert_called_once_with(BATCH_ID) |
| mock_get_batch.assert_called_once_with(BATCH_ID) |
| self.mock_context["ti"].xcom_push.assert_called_once_with(key="app_id", value=APP_ID) |
| |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch") |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) |
| def test_execution_with_extra_options(self, mock_get_batch, mock_post): |
| extra_options = {"check_response": True} |
| task = LivyOperator( |
| file="sparkapp", dag=self.dag, task_id="livy_example", extra_options=extra_options |
| ) |
| |
| task.execute(context=self.mock_context) |
| |
| assert task.hook.extra_options == extra_options |
| |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch") |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) |
| def test_deletion(self, mock_get_batch, mock_post, mock_delete): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", file="sparkapp", dag=self.dag, task_id="livy_example" |
| ) |
| task.execute(context=self.mock_context) |
| task.kill() |
| |
| mock_delete.assert_called_once_with(BATCH_ID) |
| |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", |
| return_value=BatchState.SUCCESS, |
| ) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_logs", return_value=LOG_RESPONSE) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) |
| def test_log_dump(self, mock_get_batch, mock_post, mock_get_logs, mock_get, caplog): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", |
| file="sparkapp", |
| dag=self.dag, |
| task_id="livy_example", |
| polling_interval=1, |
| ) |
| caplog.clear() |
| with caplog.at_level(level=logging.INFO, logger=task.hook.log.name): |
| task.execute(context=self.mock_context) |
| |
| assert "first_line" in caplog.messages |
| assert "second_line" in caplog.messages |
| assert "third_line" in caplog.messages |
| |
| mock_get.assert_called_once_with(BATCH_ID, retry_args=None) |
| mock_get_logs.assert_called_once_with(BATCH_ID, 0, 100) |
| |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", |
| return_value=None, |
| ) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") |
| def test_poll_for_termination_deferrable(self, mock_livy, mock_dump_logs): |
| state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS] |
| |
| def side_effect(_, retry_args): |
| if state_list: |
| return state_list.pop(0) |
| # fail if does not stop right before |
| raise AssertionError() |
| |
| mock_livy.side_effect = side_effect |
| |
| task = LivyOperator( |
| file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example", deferrable=True |
| ) |
| task.poll_for_termination(BATCH_ID) |
| |
| mock_livy.assert_called_with(BATCH_ID, retry_args=None) |
| mock_dump_logs.assert_called_with(BATCH_ID) |
| assert mock_livy.call_count == 3 |
| |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", |
| return_value=None, |
| ) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") |
| def test_poll_for_termination_fail_deferrable(self, mock_livy, mock_dump_logs): |
| state_list = 2 * [BatchState.RUNNING] + [BatchState.ERROR] |
| |
| def side_effect(_, retry_args): |
| if state_list: |
| return state_list.pop(0) |
| # fail if does not stop right before |
| raise AssertionError() |
| |
| mock_livy.side_effect = side_effect |
| |
| task = LivyOperator( |
| file="sparkapp", polling_interval=1, dag=self.dag, task_id="livy_example", deferrable=True |
| ) |
| |
| with pytest.raises(AirflowException): |
| task.poll_for_termination(BATCH_ID) |
| |
| mock_livy.assert_called_with(BATCH_ID, retry_args=None) |
| mock_dump_logs.assert_called_with(BATCH_ID) |
| assert mock_livy.call_count == 3 |
| |
| @patch("airflow.providers.apache.livy.operators.livy.LivyOperator.defer") |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", |
| return_value=None, |
| ) |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", |
| return_value=BatchState.SUCCESS, |
| ) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) |
| def test_execution_deferrable(self, mock_get_batch, mock_post, mock_get, mock_dump_logs, mock_defer): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", |
| file="sparkapp", |
| polling_interval=1, |
| dag=self.dag, |
| task_id="livy_example", |
| deferrable=True, |
| ) |
| task.execute(context=self.mock_context) |
| assert not mock_defer.called |
| call_args = {k: v for k, v in mock_post.call_args[1].items() if v} |
| assert call_args == {"file": "sparkapp"} |
| mock_get.assert_called_once_with(BATCH_ID, retry_args=None) |
| mock_dump_logs.assert_called_once_with(BATCH_ID) |
| mock_get_batch.assert_called_once_with(BATCH_ID) |
| self.mock_context["ti"].xcom_push.assert_called_once_with(key="app_id", value=APP_ID) |
| |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", |
| return_value=None, |
| ) |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", |
| return_value=BatchState.SUCCESS, |
| ) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) |
| def test_execution_with_extra_options_deferrable( |
| self, mock_get_batch, mock_post, mock_get_batch_state, mock_dump_logs |
| ): |
| extra_options = {"check_response": True} |
| task = LivyOperator( |
| file="sparkapp", |
| dag=self.dag, |
| task_id="livy_example", |
| extra_options=extra_options, |
| deferrable=True, |
| ) |
| |
| task.execute(context=self.mock_context) |
| assert task.hook.extra_options == extra_options |
| |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch") |
| def test_when_kill_is_called_right_after_construction_it_should_not_raise_attribute_error( |
| self, mock_delete_batch |
| ): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", |
| file="sparkapp", |
| dag=self.dag, |
| task_id="livy_example", |
| ) |
| task.kill() |
| mock_delete_batch.assert_not_called() |
| |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch") |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", |
| return_value=BatchState.SUCCESS, |
| ) |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", |
| return_value=None, |
| ) |
| def test_deletion_deferrable( |
| self, mock_dump_logs, mock_get_batch_state, mock_get_batch, mock_post, mock_delete |
| ): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", |
| file="sparkapp", |
| dag=self.dag, |
| task_id="livy_example", |
| deferrable=True, |
| ) |
| task.execute(context=self.mock_context) |
| task.kill() |
| |
| mock_delete.assert_called_once_with(BATCH_ID) |
| |
| @patch( |
| "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", |
| return_value=BatchState.SUCCESS, |
| ) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_logs", return_value=LOG_RESPONSE) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH) |
| def test_log_dump_deferrable(self, mock_get_batch, mock_post, mock_get_logs, mock_get, caplog): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", |
| file="sparkapp", |
| dag=self.dag, |
| task_id="livy_example", |
| polling_interval=1, |
| deferrable=True, |
| ) |
| caplog.clear() |
| |
| with caplog.at_level(level=logging.INFO, logger=task.hook.log.name): |
| task.execute(context=self.mock_context) |
| |
| assert "first_line" in caplog.messages |
| assert "second_line" in caplog.messages |
| assert "third_line" in caplog.messages |
| |
| mock_get.assert_called_once_with(BATCH_ID, retry_args=None) |
| mock_get_logs.assert_called_once_with(BATCH_ID, 0, 100) |
| |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value={"appId": APP_ID}) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| def test_execute_complete_success(self, mock_post, mock_get): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", |
| file="sparkapp", |
| dag=self.dag, |
| task_id="livy_example", |
| polling_interval=1, |
| deferrable=True, |
| ) |
| result = task.execute_complete( |
| context=self.mock_context, |
| event={ |
| "status": "success", |
| "log_lines": None, |
| "batch_id": BATCH_ID, |
| "response": "mock success", |
| }, |
| ) |
| |
| assert result == BATCH_ID |
| self.mock_context["ti"].xcom_push.assert_called_once_with(key="app_id", value=APP_ID) |
| |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| def test_execute_complete_error(self, mock_post): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", |
| file="sparkapp", |
| dag=self.dag, |
| task_id="livy_example", |
| polling_interval=1, |
| deferrable=True, |
| ) |
| with pytest.raises(AirflowException): |
| task.execute_complete( |
| context=self.mock_context, |
| event={ |
| "status": "error", |
| "log_lines": ["mock log"], |
| "batch_id": BATCH_ID, |
| "response": "mock error", |
| }, |
| ) |
| self.mock_context["ti"].xcom_push.assert_not_called() |
| |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) |
| @patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch") |
| def test_execute_complete_timeout(self, mock_delete, mock_post): |
| task = LivyOperator( |
| livy_conn_id="livyunittest", |
| file="sparkapp", |
| dag=self.dag, |
| task_id="livy_example", |
| polling_interval=1, |
| deferrable=True, |
| ) |
| with pytest.raises(AirflowException): |
| task.execute_complete( |
| context=self.mock_context, |
| event={ |
| "status": "timeout", |
| "log_lines": ["mock log"], |
| "batch_id": BATCH_ID, |
| "response": "mock timeout", |
| }, |
| ) |
| mock_delete.assert_called_once_with(BATCH_ID) |
| self.mock_context["ti"].xcom_push.assert_not_called() |
| |
| def test_deprecated_get_hook(self): |
| op = LivyOperator(task_id="livy_example", file="sparkapp") |
| with pytest.warns(AirflowProviderDeprecationWarning, match="use `hook` property instead"): |
| hook = op.get_hook() |
| assert hook is op.hook |
| |
| |
| @pytest.mark.db_test |
| def test_spark_params_templating(create_task_instance_of_operator): |
| ti = create_task_instance_of_operator( |
| LivyOperator, |
| # Templated fields |
| file="{{ 'literal-file' }}", |
| class_name="{{ 'literal-class-name' }}", |
| args="{{ 'literal-args' }}", |
| jars="{{ 'literal-jars' }}", |
| py_files="{{ 'literal-py-files' }}", |
| files="{{ 'literal-files' }}", |
| driver_memory="{{ 'literal-driver-memory' }}", |
| driver_cores="{{ 'literal-driver-cores' }}", |
| executor_memory="{{ 'literal-executor-memory' }}", |
| executor_cores="{{ 'literal-executor-cores' }}", |
| num_executors="{{ 'literal-num-executors' }}", |
| archives="{{ 'literal-archives' }}", |
| queue="{{ 'literal-queue' }}", |
| name="{{ 'literal-name' }}", |
| conf="{{ 'literal-conf' }}", |
| proxy_user="{{ 'literal-proxy-user' }}", |
| # Other parameters |
| dag_id="test_template_body_templating_dag", |
| task_id="test_template_body_templating_task", |
| execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), |
| ) |
| ti.render_templates() |
| task: LivyOperator = ti.task |
| assert task.spark_params == { |
| "archives": "literal-archives", |
| "args": "literal-args", |
| "class_name": "literal-class-name", |
| "conf": "literal-conf", |
| "driver_cores": "literal-driver-cores", |
| "driver_memory": "literal-driver-memory", |
| "executor_cores": "literal-executor-cores", |
| "executor_memory": "literal-executor-memory", |
| "file": "literal-file", |
| "files": "literal-files", |
| "jars": "literal-jars", |
| "name": "literal-name", |
| "num_executors": "literal-num-executors", |
| "proxy_user": "literal-proxy-user", |
| "py_files": "literal-py-files", |
| "queue": "literal-queue", |
| } |