blob: 3ffe4adbee7b039e2e2d3295d0448c4ef3bae377 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import datetime
from unittest.mock import ANY, Mock, patch
from pytest import raises
from sqlalchemy.exc import OperationalError
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.jobs.base_job import BaseJob
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from tests.test_utils.config import conf_vars
class MockJob(BaseJob):
__mapper_args__ = {'polymorphic_identity': 'MockJob'}
def __init__(self, func, **kwargs):
self.func = func
super().__init__(**kwargs)
def _execute(self):
return self.func()
class TestBaseJob:
def test_state_success(self):
job = MockJob(lambda: True)
job.run()
assert job.state == State.SUCCESS
assert job.end_date is not None
def test_state_sysexit(self):
import sys
job = MockJob(lambda: sys.exit(0))
job.run()
assert job.state == State.SUCCESS
assert job.end_date is not None
def test_state_failed(self):
def abort():
raise RuntimeError("fail")
job = MockJob(abort)
with raises(RuntimeError):
job.run()
assert job.state == State.FAILED
assert job.end_date is not None
def test_most_recent_job(self):
with create_session() as session:
old_job = MockJob(None, heartrate=10)
old_job.latest_heartbeat = old_job.latest_heartbeat - datetime.timedelta(seconds=20)
job = MockJob(None, heartrate=10)
session.add(job)
session.add(old_job)
session.flush()
assert MockJob.most_recent_job(session=session) == job
session.rollback()
def test_is_alive(self):
job = MockJob(None, heartrate=10, state=State.RUNNING)
assert job.is_alive() is True
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=20)
assert job.is_alive() is True
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=21)
assert job.is_alive() is False
# test because .seconds was used before instead of total_seconds
# internal repr of datetime is (days, seconds)
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(days=1)
assert job.is_alive() is False
job.state = State.SUCCESS
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10)
assert job.is_alive() is False, "Completed jobs even with recent heartbeat should not be alive"
@patch('airflow.jobs.base_job.create_session')
def test_heartbeat_failed(self, mock_create_session):
when = timezone.utcnow() - datetime.timedelta(seconds=60)
with create_session() as session:
mock_session = Mock(spec_set=session, name="MockSession")
mock_create_session.return_value.__enter__.return_value = mock_session
job = MockJob(None, heartrate=10, state=State.RUNNING)
job.latest_heartbeat = when
mock_session.commit.side_effect = OperationalError("Force fail", {}, None)
job.heartbeat()
assert job.latest_heartbeat == when, "attribute not updated when heartbeat fails"
@conf_vars({('scheduler', 'max_tis_per_query'): '100'})
@patch('airflow.jobs.base_job.ExecutorLoader.get_default_executor')
@patch('airflow.jobs.base_job.get_hostname')
@patch('airflow.jobs.base_job.getpass.getuser')
def test_essential_attr(self, mock_getuser, mock_hostname, mock_default_executor):
mock_sequential_executor = SequentialExecutor()
mock_hostname.return_value = "test_hostname"
mock_getuser.return_value = "testuser"
mock_default_executor.return_value = mock_sequential_executor
test_job = MockJob(None, heartrate=10, dag_id="example_dag", state=State.RUNNING)
assert test_job.executor_class == "SequentialExecutor"
assert test_job.heartrate == 10
assert test_job.dag_id == "example_dag"
assert test_job.hostname == "test_hostname"
assert test_job.max_tis_per_query == 100
assert test_job.unixname == "testuser"
assert test_job.state == "running"
assert test_job.executor == mock_sequential_executor
def test_heartbeat(self, frozen_sleep, monkeypatch):
monkeypatch.setattr('airflow.jobs.base_job.sleep', frozen_sleep)
with create_session() as session:
job = MockJob(None, heartrate=10)
job.latest_heartbeat = timezone.utcnow()
session.add(job)
session.commit()
hb_callback = Mock()
job.heartbeat_callback = hb_callback
job.heartbeat()
hb_callback.assert_called_once_with(session=ANY)
hb_callback.reset_mock()
job.heartbeat(only_if_necessary=True)
assert hb_callback.called is False