blob: d574615a243cf07a3d80c8e79ded278b28a3ad71 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import unittest
from unittest.mock import ANY, patch
from freezegun import freeze_time
from airflow.api.client.local_client import Client
from airflow.example_dags import example_bash_operator
from airflow.exceptions import AirflowException
from airflow.models import DAG, DagBag, DagModel, DagRun, Pool
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
from tests.test_utils.db import clear_db_pools
EXECDATE = timezone.utcnow()
EXECDATE_NOFRACTIONS = EXECDATE.replace(microsecond=0)
EXECDATE_ISO = EXECDATE_NOFRACTIONS.isoformat()
class TestLocalClient(unittest.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
DagBag(example_bash_operator.__file__).get_dag("example_bash_operator").sync_to_db()
def setUp(self):
super().setUp()
clear_db_pools()
self.client = Client(api_base_url=None, auth=None)
def tearDown(self):
clear_db_pools()
super().tearDown()
@patch.object(DAG, 'create_dagrun')
def test_trigger_dag(self, mock):
test_dag_id = "example_bash_operator"
run_id = DagRun.generate_run_id(DagRunType.MANUAL, EXECDATE_NOFRACTIONS)
DagBag(include_examples=True)
# non existent
with self.assertRaises(AirflowException):
self.client.trigger_dag(dag_id="blablabla")
with freeze_time(EXECDATE):
# no execution date, execution date should be set automatically
self.client.trigger_dag(dag_id=test_dag_id)
mock.assert_called_once_with(
run_id=run_id,
execution_date=EXECDATE_NOFRACTIONS,
state=State.RUNNING,
conf=None,
external_trigger=True,
dag_hash=ANY,
)
mock.reset_mock()
# execution date with microseconds cutoff
self.client.trigger_dag(dag_id=test_dag_id, execution_date=EXECDATE)
mock.assert_called_once_with(
run_id=run_id,
execution_date=EXECDATE_NOFRACTIONS,
state=State.RUNNING,
conf=None,
external_trigger=True,
dag_hash=ANY,
)
mock.reset_mock()
# run id
custom_run_id = "my_run_id"
self.client.trigger_dag(dag_id=test_dag_id, run_id=custom_run_id)
mock.assert_called_once_with(
run_id=custom_run_id,
execution_date=EXECDATE_NOFRACTIONS,
state=State.RUNNING,
conf=None,
external_trigger=True,
dag_hash=ANY,
)
mock.reset_mock()
# test conf
conf = '{"name": "John"}'
self.client.trigger_dag(dag_id=test_dag_id, conf=conf)
mock.assert_called_once_with(
run_id=run_id,
execution_date=EXECDATE_NOFRACTIONS,
state=State.RUNNING,
conf=json.loads(conf),
external_trigger=True,
dag_hash=ANY,
)
mock.reset_mock()
def test_delete_dag(self):
key = "my_dag_id"
with create_session() as session:
self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 0)
session.add(DagModel(dag_id=key))
with create_session() as session:
self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 1)
self.client.delete_dag(dag_id=key)
self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 0)
def test_get_pool(self):
self.client.create_pool(name='foo', slots=1, description='')
pool = self.client.get_pool(name='foo')
self.assertEqual(pool, ('foo', 1, ''))
def test_get_pools(self):
self.client.create_pool(name='foo1', slots=1, description='')
self.client.create_pool(name='foo2', slots=2, description='')
pools = sorted(self.client.get_pools(), key=lambda p: p[0])
self.assertEqual(pools, [('default_pool', 128, 'Default pool'), ('foo1', 1, ''), ('foo2', 2, '')])
def test_create_pool(self):
pool = self.client.create_pool(name='foo', slots=1, description='')
self.assertEqual(pool, ('foo', 1, ''))
with create_session() as session:
self.assertEqual(session.query(Pool).count(), 2)
def test_delete_pool(self):
self.client.create_pool(name='foo', slots=1, description='')
with create_session() as session:
self.assertEqual(session.query(Pool).count(), 2)
self.client.delete_pool(name='foo')
with create_session() as session:
self.assertEqual(session.query(Pool).count(), 1)