blob: 9ddaef4710c400b00e21e6c63ed23f37652d086b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import retrying
try:
import celery as _celery
app = _celery.Celery()
app.conf.update(CELERY_RESULT_BACKEND='amqp://')
except ImportError:
_celery = None
app = None
import aria
from aria.modeling import models
from aria.orchestrator import events
from aria.orchestrator.workflows.executor import (
thread,
process,
# celery
)
import tests
from . import MockTask
def _get_function(func):
return '{module}.{func.__name__}'.format(module=__name__, func=func)
def execute_and_assert(executor, storage=None):
expected_value = 'value'
successful_task = MockTask(_get_function(mock_successful_task), storage=storage)
failing_task = MockTask(_get_function(mock_failing_task), storage=storage)
task_with_inputs = MockTask(_get_function(mock_task_with_input),
arguments={'input': models.Parameter.wrap('input', 'value')},
storage=storage)
for task in [successful_task, failing_task, task_with_inputs]:
executor.execute(task)
@retrying.retry(stop_max_delay=10000, wait_fixed=100)
def assertion():
assert successful_task.states == ['start', 'success']
assert failing_task.states == ['start', 'failure']
assert task_with_inputs.states == ['start', 'failure']
assert isinstance(failing_task.exception, MockException)
assert isinstance(task_with_inputs.exception, MockException)
assert task_with_inputs.exception.message == expected_value
assertion()
def test_thread_execute(thread_executor):
execute_and_assert(thread_executor)
def test_process_execute(process_executor, storage):
execute_and_assert(process_executor, storage)
def mock_successful_task(**_):
pass
def mock_failing_task(**_):
raise MockException
def mock_task_with_input(input, **_):
raise MockException(input)
if app:
mock_successful_task = app.task(mock_successful_task)
mock_failing_task = app.task(mock_failing_task)
mock_task_with_input = app.task(mock_task_with_input)
class MockException(Exception):
pass
@pytest.fixture
def storage(tmpdir):
return aria.application_model_storage(
aria.storage.sql_mapi.SQLAlchemyModelAPI,
initiator_kwargs=dict(base_dir=str(tmpdir))
)
@pytest.fixture(params=[
(thread.ThreadExecutor, {'pool_size': 1}),
(thread.ThreadExecutor, {'pool_size': 2}),
# subprocess needs to load a tests module so we explicitly add the root directory as if
# the project has been installed in editable mode
# (celery.CeleryExecutor, {'app': app})
])
def thread_executor(request):
executor_cls, executor_kwargs = request.param
result = executor_cls(**executor_kwargs)
yield result
result.close()
@pytest.fixture
def process_executor():
result = process.ProcessExecutor(python_path=tests.ROOT_DIR)
yield result
result.close()
@pytest.fixture(autouse=True)
def register_signals():
def start_handler(task, *args, **kwargs):
task.states.append('start')
def success_handler(task, *args, **kwargs):
task.states.append('success')
def failure_handler(task, exception, *args, **kwargs):
task.states.append('failure')
task.exception = exception
events.start_task_signal.connect(start_handler)
events.on_success_task_signal.connect(success_handler)
events.on_failure_task_signal.connect(failure_handler)
yield
events.start_task_signal.disconnect(start_handler)
events.on_success_task_signal.disconnect(success_handler)
events.on_failure_task_signal.disconnect(failure_handler)