blob: e53725849fd62f5cdb8c5b1001c947f9b704b843 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import retrying
try:
import celery as _celery
app = _celery.Celery()
app.conf.update(CELERY_RESULT_BACKEND='amqp://')
except ImportError:
_celery = None
app = None
import aria
from aria.modeling import models
from aria.orchestrator import events
from aria.orchestrator.workflows.executor import (
thread,
process,
# celery
)
import tests
from .. import helpers
from . import MockContext
def _get_function(func):
return '{module}.{func.__name__}'.format(module=__name__, func=func)
def execute_and_assert(executor, storage=None):
expected_value = 'value'
successful_ctx = MockContext(
storage, task_kwargs=dict(function=_get_function(mock_successful_task))
)
failing_ctx = MockContext(
storage, task_kwargs=dict(function=_get_function(mock_failing_task))
)
task_with_inputs_ctx = MockContext(
storage,
task_kwargs=dict(function=_get_function(mock_task_with_input),
arguments={'input': models.Argument.wrap('input', 'value')})
)
for ctx in [successful_ctx, failing_ctx, task_with_inputs_ctx]:
executor.execute(ctx)
@retrying.retry(stop_max_delay=10000, wait_fixed=100)
def assertion():
assert successful_ctx.states == ['start', 'success']
assert failing_ctx.states == ['start', 'failure']
assert task_with_inputs_ctx.states == ['start', 'failure']
assert isinstance(failing_ctx.exception, MockException)
assert isinstance(task_with_inputs_ctx.exception, MockException)
assert task_with_inputs_ctx.exception.message == expected_value
assertion()
def test_thread_execute(thread_executor):
execute_and_assert(thread_executor)
def test_process_execute(process_executor, storage):
execute_and_assert(process_executor, storage)
def mock_successful_task(**_):
pass
def mock_failing_task(**_):
raise MockException
def mock_task_with_input(input, **_):
raise MockException(input)
if app:
mock_successful_task = app.task(mock_successful_task)
mock_failing_task = app.task(mock_failing_task)
mock_task_with_input = app.task(mock_task_with_input)
class MockException(Exception):
pass
@pytest.fixture
def storage(tmpdir):
_storage = aria.application_model_storage(aria.storage.sql_mapi.SQLAlchemyModelAPI,
initiator_kwargs=dict(base_dir=str(tmpdir)))
yield _storage
tests.storage.release_sqlite_storage(_storage)
@pytest.fixture(params=[
(thread.ThreadExecutor, {'pool_size': 1}),
(thread.ThreadExecutor, {'pool_size': 2}),
# subprocess needs to load a tests module so we explicitly add the root directory as if
# the project has been installed in editable mode
# (celery.CeleryExecutor, {'app': app})
])
def thread_executor(request):
executor_cls, executor_kwargs = request.param
result = executor_cls(**executor_kwargs)
yield result
result.close()
@pytest.fixture
def process_executor():
result = process.ProcessExecutor(python_path=tests.ROOT_DIR)
yield result
result.close()
@pytest.fixture(autouse=True)
def register_signals():
def start_handler(ctx, *args, **kwargs):
ctx.states.append('start')
def success_handler(ctx, *args, **kwargs):
ctx.states.append('success')
def failure_handler(ctx, exception, *args, **kwargs):
ctx.states.append('failure')
ctx.exception = exception
with helpers.disconnect_event_handlers():
events.start_task_signal.connect(start_handler)
events.on_success_task_signal.connect(success_handler)
events.on_failure_task_signal.connect(failure_handler)
yield
events.start_task_signal.disconnect(start_handler)
events.on_success_task_signal.disconnect(success_handler)
events.on_failure_task_signal.disconnect(failure_handler)