| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Test Task Sql.""" |
| from pathlib import Path |
| from unittest.mock import patch |
| |
| import pytest |
| |
| from pydolphinscheduler.models.datasource import TaskUsage |
| from pydolphinscheduler.resources_plugin import Local |
| from pydolphinscheduler.tasks.sql import Sql, SqlType |
| from pydolphinscheduler.utils import file |
| from tests.testing.file import delete_file |
| |
| file_name = "local_res.sql" |
| file_content = "select 1" |
| res_plugin_prefix = Path(__file__).parent |
| file_path = res_plugin_prefix.joinpath(file_name) |
| |
| |
| @pytest.fixture |
| def setup_crt_first(): |
| """Set up and teardown about create file first and then delete it.""" |
| file.write(content=file_content, to_path=file_path) |
| yield |
| delete_file(file_path) |
| |
| |
| @pytest.mark.parametrize( |
| "stm, expected", |
| [ |
| ("select * from t_ds_version", ["select * from t_ds_version"]), |
| (None, []), |
| ( |
| ["select * from table1", "select * from table2"], |
| ["select * from table1", "select * from table2"], |
| ), |
| ( |
| ("select * from table1", "select * from table2"), |
| ["select * from table1", "select * from table2"], |
| ), |
| ( |
| {"select * from table1", "select * from table2"}, |
| ["select * from table1", "select * from table2"], |
| ), |
| ], |
| ) |
| def test_get_stm_list(stm, expected) -> None: |
| """Test static function get_stm_list.""" |
| assert sorted(Sql.get_stm_list(stm)) == sorted(expected) |
| |
| |
| @pytest.mark.parametrize( |
| "sql, param_sql_type, sql_type", |
| [ |
| ("select 1", None, SqlType.SELECT), |
| (" select 1", None, SqlType.SELECT), |
| (" select 1 ", None, SqlType.SELECT), |
| (" select 'insert' ", None, SqlType.SELECT), |
| (" select 'insert ' ", None, SqlType.SELECT), |
| ("with tmp as (select 1) select * from tmp ", None, SqlType.SELECT), |
| ( |
| "insert into table_name(col1, col2) value (val1, val2)", |
| None, |
| SqlType.NOT_SELECT, |
| ), |
| ( |
| "insert into table_name(select, col2) value ('select', val2)", |
| None, |
| SqlType.NOT_SELECT, |
| ), |
| ("update table_name SET col1=val1 where col1=val2", None, SqlType.NOT_SELECT), |
| ( |
| "update table_name SET col1='select' where col1=val2", |
| None, |
| SqlType.NOT_SELECT, |
| ), |
| ("delete from table_name where id < 10", None, SqlType.NOT_SELECT), |
| ("delete from table_name where id < 10", None, SqlType.NOT_SELECT), |
| ("alter table table_name add column col1 int", None, SqlType.NOT_SELECT), |
| ("create table table_name2 (col1 int)", None, SqlType.NOT_SELECT), |
| ("truncate table table_name", None, SqlType.NOT_SELECT), |
| ("create table table_name2 (col1 int)", SqlType.SELECT, SqlType.SELECT), |
| ("select 1", SqlType.NOT_SELECT, SqlType.NOT_SELECT), |
| ("create table table_name2 (col1 int)", SqlType.NOT_SELECT, SqlType.NOT_SELECT), |
| ("select 1", SqlType.SELECT, SqlType.SELECT), |
| ], |
| ) |
| @patch( |
| "pydolphinscheduler.core.task.Task.gen_code_and_version", |
| return_value=(123, 1), |
| ) |
| @patch( |
| "pydolphinscheduler.models.datasource.Datasource.get_task_usage_4j", |
| return_value=TaskUsage(id=1, type="mock_type"), |
| ) |
| def test_get_sql_type( |
| mock_datasource, mock_code_version, sql, param_sql_type, sql_type |
| ): |
| """Test property sql_type could return correct type.""" |
| name = "test_get_sql_type" |
| datasource_name = "test_datasource" |
| task = Sql(name, datasource_name, sql, sql_type=param_sql_type) |
| assert ( |
| sql_type == task.sql_type |
| ), f"Sql {sql} expect sql type is {sql_type} but got {task.sql_type}" |
| |
| |
| @pytest.mark.parametrize( |
| "attr, expect", |
| [ |
| ( |
| {"datasource_name": "datasource_name", "sql": "select 1"}, |
| { |
| "sql": "select 1", |
| "type": "MYSQL", |
| "datasource": 1, |
| "sqlType": "0", |
| "preStatements": [], |
| "postStatements": [], |
| "displayRows": 10, |
| "localParams": [], |
| "resourceList": [], |
| "dependence": {}, |
| "waitStartTimeout": {}, |
| "conditionResult": {"successNode": [""], "failedNode": [""]}, |
| }, |
| ) |
| ], |
| ) |
| @patch( |
| "pydolphinscheduler.core.task.Task.gen_code_and_version", |
| return_value=(123, 1), |
| ) |
| @patch( |
| "pydolphinscheduler.models.datasource.Datasource.get_task_usage_4j", |
| return_value=TaskUsage(id=1, type="MYSQL"), |
| ) |
| def test_property_task_params(mock_datasource, mock_code_version, attr, expect): |
| """Test task sql task property.""" |
| task = Sql("test-sql-task-params", **attr) |
| assert expect == task.task_params |
| |
| |
| @patch( |
| "pydolphinscheduler.models.datasource.Datasource.get_task_usage_4j", |
| return_value=TaskUsage(id=1, type="MYSQL"), |
| ) |
| def test_sql_get_define(mock_datasource): |
| """Test task sql function get_define.""" |
| code = 123 |
| version = 1 |
| name = "test_sql_get_define" |
| command = "select 1" |
| datasource_name = "test_datasource" |
| expect_task_params = { |
| "type": "MYSQL", |
| "datasource": 1, |
| "sql": command, |
| "sqlType": "0", |
| "displayRows": 10, |
| "preStatements": [], |
| "postStatements": [], |
| "localParams": [], |
| "resourceList": [], |
| "dependence": {}, |
| "conditionResult": {"successNode": [""], "failedNode": [""]}, |
| "waitStartTimeout": {}, |
| } |
| with patch( |
| "pydolphinscheduler.core.task.Task.gen_code_and_version", |
| return_value=(code, version), |
| ): |
| task = Sql(name, datasource_name, command) |
| assert task.task_params == expect_task_params |
| |
| |
| @pytest.mark.parametrize( |
| "attr, expect", |
| [ |
| ( |
| { |
| "name": "test-sql-local-res", |
| "sql": file_name, |
| "datasource_name": "test_datasource", |
| "resource_plugin": Local(str(res_plugin_prefix)), |
| }, |
| file_content, |
| ) |
| ], |
| ) |
| @patch( |
| "pydolphinscheduler.core.task.Task.gen_code_and_version", |
| return_value=(123, 1), |
| ) |
| def test_resources_local_sql_command_content( |
| mock_code_version, attr, expect, setup_crt_first |
| ): |
| """Test sql content through the local resource plug-in.""" |
| sql = Sql(**attr) |
| assert expect == getattr(sql, "sql") |