blob: 5256cf8e93886675a186f6d22bb97b3c570bc621 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import json
import logging
import os
import pytest
import fabric.api
from fabric.contrib import files
from fabric import context_managers
from aria.modeling import models
from aria.orchestrator import events
from aria.orchestrator import workflow
from aria.orchestrator.workflows import api
from aria.orchestrator.workflows.executor import process
from aria.orchestrator.workflows.core import engine, graph_compiler
from aria.orchestrator.workflows.exceptions import ExecutorException
from aria.orchestrator.exceptions import TaskAbortException, TaskRetryException
from aria.orchestrator.execution_plugin import operations
from aria.orchestrator.execution_plugin import constants
from aria.orchestrator.execution_plugin.exceptions import ProcessException, TaskException
from aria.orchestrator.execution_plugin.ssh import operations as ssh_operations
from tests import mock, storage, resources
from tests.orchestrator.workflows.helpers import events_collector
_CUSTOM_BASE_DIR = '/tmp/new-aria-ctx'
import tests
KEY_FILENAME = os.path.join(tests.ROOT_DIR, 'tests/resources/keys/test')
_FABRIC_ENV = {
'disable_known_hosts': True,
'user': 'test',
'key_filename': KEY_FILENAME
}
import mockssh
@pytest.fixture(scope='session')
def server():
with mockssh.Server({'test': KEY_FILENAME}) as s:
yield s
#@pytest.mark.skipif(not os.environ.get('TRAVIS'), reason='actual ssh server required')
class TestWithActualSSHServer(object):
def test_run_script_basic(self):
expected_attribute_value = 'some_value'
props = self._execute(env={'test_value': expected_attribute_value})
assert props['test_value'].value == expected_attribute_value
@pytest.mark.skip(reason='sudo privileges are required')
def test_run_script_as_sudo(self):
self._execute(use_sudo=True)
with self._ssh_env():
assert files.exists('/opt/test_dir')
fabric.api.sudo('rm -rf /opt/test_dir')
def test_run_script_default_base_dir(self):
props = self._execute()
assert props['work_dir'].value == '{0}/work'.format(constants.DEFAULT_BASE_DIR)
@pytest.mark.skip(reason='Re-enable once output from process executor can be captured')
@pytest.mark.parametrize('hide_groups', [[], ['everything']])
def test_run_script_with_hide(self, hide_groups):
self._execute(hide_output=hide_groups)
output = 'TODO'
expected_log_message = ('[localhost] run: source {0}/scripts/'
.format(constants.DEFAULT_BASE_DIR))
if hide_groups:
assert expected_log_message not in output
else:
assert expected_log_message in output
def test_run_script_process_config(self):
expected_env_value = 'test_value_env'
expected_arg1_value = 'test_value_arg1'
expected_arg2_value = 'test_value_arg2'
expected_cwd = '/tmp'
expected_base_dir = _CUSTOM_BASE_DIR
props = self._execute(
env={'test_value_env': expected_env_value},
process={
'args': [expected_arg1_value, expected_arg2_value],
'cwd': expected_cwd,
'base_dir': expected_base_dir
})
assert props['env_value'].value == expected_env_value
assert len(props['bash_version'].value) > 0
assert props['arg1_value'].value == expected_arg1_value
assert props['arg2_value'].value == expected_arg2_value
assert props['cwd'].value == expected_cwd
assert props['ctx_path'].value == '{0}/ctx'.format(expected_base_dir)
def test_run_script_command_prefix(self):
props = self._execute(process={'command_prefix': 'bash -i'})
assert 'i' in props['dollar_dash'].value
def test_run_script_reuse_existing_ctx(self):
expected_test_value_1 = 'test_value_1'
expected_test_value_2 = 'test_value_2'
props = self._execute(
test_operations=['{0}_1'.format(self.test_name),
'{0}_2'.format(self.test_name)],
env={'test_value1': expected_test_value_1,
'test_value2': expected_test_value_2})
assert props['test_value1'].value == expected_test_value_1
assert props['test_value2'].value == expected_test_value_2
def test_run_script_download_resource_plain(self, tmpdir):
resource = tmpdir.join('resource')
resource.write('content')
self._upload(str(resource), 'test_resource')
props = self._execute()
assert props['test_value'].value == 'content'
def test_run_script_download_resource_and_render(self, tmpdir):
resource = tmpdir.join('resource')
resource.write('{{ctx.service.name}}')
self._upload(str(resource), 'test_resource')
props = self._execute()
assert props['test_value'].value == self._workflow_context.service.name
@pytest.mark.parametrize('value', ['string-value', [1, 2, 3], {'key': 'value'}])
def test_run_script_inputs_as_env_variables_no_override(self, value):
props = self._execute(custom_input=value)
return_value = props['test_value'].value
expected = return_value if isinstance(value, basestring) else json.loads(return_value)
assert value == expected
@pytest.mark.parametrize('value', ['string-value', [1, 2, 3], {'key': 'value'}])
def test_run_script_inputs_as_env_variables_process_env_override(self, value):
props = self._execute(custom_input='custom-input-value',
env={'custom_env_var': value})
return_value = props['test_value'].value
expected = return_value if isinstance(value, basestring) else json.loads(return_value)
assert value == expected
def test_run_script_error_in_script(self):
exception = self._execute_and_get_task_exception()
assert isinstance(exception, TaskException)
def test_run_script_abort_immediate(self):
exception = self._execute_and_get_task_exception()
assert isinstance(exception, TaskAbortException)
assert exception.message == 'abort-message'
def test_run_script_retry(self):
exception = self._execute_and_get_task_exception()
assert isinstance(exception, TaskRetryException)
assert exception.message == 'retry-message'
def test_run_script_abort_error_ignored_by_script(self):
exception = self._execute_and_get_task_exception()
assert isinstance(exception, TaskAbortException)
assert exception.message == 'abort-message'
def test_run_commands(self):
temp_file_path = '/tmp/very_temporary_file'
with self._ssh_env():
if files.exists(temp_file_path):
fabric.api.run('rm {0}'.format(temp_file_path))
self._execute(commands=['touch {0}'.format(temp_file_path)])
with self._ssh_env():
assert files.exists(temp_file_path)
fabric.api.run('rm {0}'.format(temp_file_path))
@pytest.fixture(autouse=True)
def _setup(self, request, workflow_context, executor, capfd, server):
print 'HI!!!!!!!!!!', server.port
self._workflow_context = workflow_context
self._executor = executor
self._capfd = capfd
self.test_name = request.node.originalname or request.node.name
with self._ssh_env(server):
for directory in [constants.DEFAULT_BASE_DIR, _CUSTOM_BASE_DIR]:
if files.exists(directory):
fabric.api.run('rm -rf {0}'.format(directory))
@contextlib.contextmanager
def _ssh_env(self, server):
with self._capfd.disabled():
with context_managers.settings(fabric.api.hide('everything'),
host_string='localhost:{0}'.format(server.port),
**_FABRIC_ENV):
yield
def _execute(self,
env=None,
use_sudo=False,
hide_output=None,
process=None,
custom_input='',
test_operations=None,
commands=None):
process = process or {}
if env:
process.setdefault('env', {}).update(env)
test_operations = test_operations or [self.test_name]
local_script_path = os.path.join(resources.DIR, 'scripts', 'test_ssh.sh')
script_path = os.path.basename(local_script_path)
self._upload(local_script_path, script_path)
if commands:
operation = operations.run_commands_with_ssh
else:
operation = operations.run_script_with_ssh
node = self._workflow_context.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME)
arguments = {
'script_path': script_path,
'fabric_env': _FABRIC_ENV,
'process': process,
'use_sudo': use_sudo,
'custom_env_var': custom_input,
'test_operation': '',
}
if hide_output:
arguments['hide_output'] = hide_output
if commands:
arguments['commands'] = commands
interface = mock.models.create_interface(
node.service,
'test',
'op',
operation_kwargs=dict(
function='{0}.{1}'.format(
operations.__name__,
operation.__name__),
arguments=arguments)
)
node.interfaces[interface.name] = interface
@workflow
def mock_workflow(ctx, graph):
ops = []
for test_operation in test_operations:
op_arguments = arguments.copy()
op_arguments['test_operation'] = test_operation
ops.append(api.task.OperationTask(
node,
interface_name='test',
operation_name='op',
arguments=op_arguments))
graph.sequence(*ops)
return graph
tasks_graph = mock_workflow(ctx=self._workflow_context) # pylint: disable=no-value-for-parameter
graph_compiler.GraphCompiler(
self._workflow_context, self._executor.__class__).compile(tasks_graph)
eng = engine.Engine({self._executor.__class__: self._executor})
eng.execute(self._workflow_context)
return self._workflow_context.model.node.get_by_name(
mock.models.DEPENDENCY_NODE_NAME).attributes
def _execute_and_get_task_exception(self, *args, **kwargs):
signal = events.on_failure_task_signal
with events_collector(signal) as collected:
with pytest.raises(ExecutorException):
self._execute(*args, **kwargs)
return collected[signal][0]['kwargs']['exception']
def _upload(self, source, path):
self._workflow_context.resource.service.upload(
entry_id=str(self._workflow_context.service.id),
source=source,
path=path)
@pytest.fixture
def executor(self):
result = process.ProcessExecutor()
try:
yield result
finally:
result.close()
@pytest.fixture
def workflow_context(self, tmpdir):
workflow_context = mock.context.simple(str(tmpdir))
workflow_context.states = []
workflow_context.exception = None
yield workflow_context
storage.release_sqlite_storage(workflow_context.model)
class TestFabricEnvHideGroupsAndRunCommands(object):
def test_fabric_env_default_override(self):
# first sanity for no override
self._run()
assert self.mock.settings_merged['timeout'] == constants.FABRIC_ENV_DEFAULTS['timeout']
# now override
invocation_fabric_env = self.default_fabric_env.copy()
timeout = 1000000
invocation_fabric_env['timeout'] = timeout
self._run(fabric_env=invocation_fabric_env)
assert self.mock.settings_merged['timeout'] == timeout
def test_implicit_host_string(self, mocker):
expected_host_address = '1.1.1.1'
mocker.patch.object(self._Ctx.task.actor, 'host')
mocker.patch.object(self._Ctx.task.actor.host, 'host_address', expected_host_address)
fabric_env = self.default_fabric_env.copy()
del fabric_env['host_string']
self._run(fabric_env=fabric_env)
assert self.mock.settings_merged['host_string'] == expected_host_address
def test_explicit_host_string(self):
fabric_env = self.default_fabric_env.copy()
host_string = 'explicit_host_string'
fabric_env['host_string'] = host_string
self._run(fabric_env=fabric_env)
assert self.mock.settings_merged['host_string'] == host_string
def test_override_warn_only(self):
fabric_env = self.default_fabric_env.copy()
self._run(fabric_env=fabric_env)
assert self.mock.settings_merged['warn_only'] is True
fabric_env = self.default_fabric_env.copy()
fabric_env['warn_only'] = False
self._run(fabric_env=fabric_env)
assert self.mock.settings_merged['warn_only'] is False
def test_missing_host_string(self):
with pytest.raises(TaskAbortException) as exc_ctx:
fabric_env = self.default_fabric_env.copy()
del fabric_env['host_string']
self._run(fabric_env=fabric_env)
assert '`host_string` not supplied' in str(exc_ctx.value)
def test_missing_user(self):
with pytest.raises(TaskAbortException) as exc_ctx:
fabric_env = self.default_fabric_env.copy()
del fabric_env['user']
self._run(fabric_env=fabric_env)
assert '`user` not supplied' in str(exc_ctx.value)
def test_missing_key_or_password(self):
with pytest.raises(TaskAbortException) as exc_ctx:
fabric_env = self.default_fabric_env.copy()
del fabric_env['key_filename']
self._run(fabric_env=fabric_env)
assert 'Access credentials not supplied' in str(exc_ctx.value)
def test_hide_in_settings_and_non_viable_groups(self):
groups = ('running', 'stdout')
self._run(hide_output=groups)
assert set(self.mock.settings_merged['hide_output']) == set(groups)
with pytest.raises(TaskAbortException) as exc_ctx:
self._run(hide_output=('running', 'bla'))
assert '`hide_output` must be a subset of' in str(exc_ctx.value)
def test_run_commands(self):
def test(use_sudo):
commands = ['command1', 'command2']
self._run(
commands=commands,
use_sudo=use_sudo)
assert all(item in self.mock.settings_merged.items() for
item in self.default_fabric_env.items())
assert self.mock.settings_merged['warn_only'] is True
assert self.mock.settings_merged['use_sudo'] == use_sudo
assert self.mock.commands == commands
self.mock.settings_merged = {}
self.mock.commands = []
test(use_sudo=False)
test(use_sudo=True)
def test_failed_command(self):
with pytest.raises(ProcessException) as exc_ctx:
self._run(commands=['fail'])
exception = exc_ctx.value
assert exception.stdout == self.MockCommandResult.stdout
assert exception.stderr == self.MockCommandResult.stderr
assert exception.command == self.MockCommandResult.command
assert exception.exit_code == self.MockCommandResult.return_code
class MockCommandResult(object):
stdout = 'mock_stdout'
stderr = 'mock_stderr'
command = 'mock_command'
return_code = 1
def __init__(self, failed):
self.failed = failed
class MockFabricApi(object):
def __init__(self):
self.commands = []
self.settings_merged = {}
@contextlib.contextmanager
def settings(self, *args, **kwargs):
self.settings_merged.update(kwargs)
if args:
groups = args[0]
self.settings_merged.update({'hide_output': groups})
yield
def run(self, command):
self.commands.append(command)
self.settings_merged['use_sudo'] = False
return TestFabricEnvHideGroupsAndRunCommands.MockCommandResult(command == 'fail')
def sudo(self, command):
self.commands.append(command)
self.settings_merged['use_sudo'] = True
return TestFabricEnvHideGroupsAndRunCommands.MockCommandResult(command == 'fail')
def hide(self, *groups):
return groups
def exists(self, *args, **kwargs):
raise RuntimeError
class _Ctx(object):
INSTRUMENTATION_FIELDS = ()
class Task(object):
@staticmethod
def abort(message=None):
models.Task.abort(message)
actor = None
class Actor(object):
host = None
class Model(object):
@contextlib.contextmanager
def instrument(self, *args, **kwargs):
yield
task = Task
task.actor = Actor
model = Model()
logger = logging.getLogger()
@staticmethod
@contextlib.contextmanager
def _mock_self_logging(*args, **kwargs):
yield
_Ctx.logging_handlers = _mock_self_logging
@pytest.fixture(autouse=True)
def _setup(self, mocker):
self.default_fabric_env = {
'host_string': 'test',
'user': 'test',
'key_filename': 'test',
}
self.mock = self.MockFabricApi()
mocker.patch('fabric.api', self.mock)
def _run(self,
commands=(),
fabric_env=None,
process=None,
use_sudo=False,
hide_output=None):
operations.run_commands_with_ssh(
ctx=self._Ctx,
commands=commands,
process=process,
fabric_env=fabric_env or self.default_fabric_env,
use_sudo=use_sudo,
hide_output=hide_output)
class TestUtilityFunctions(object):
def test_paths(self):
base_dir = '/path'
local_script_path = '/local/script/path.py'
paths = ssh_operations._Paths(base_dir=base_dir,
local_script_path=local_script_path)
assert paths.local_script_path == local_script_path
assert paths.remote_ctx_dir == base_dir
assert paths.base_script_path == 'path.py'
assert paths.remote_ctx_path == '/path/ctx'
assert paths.remote_scripts_dir == '/path/scripts'
assert paths.remote_work_dir == '/path/work'
assert paths.remote_env_script_path.startswith('/path/scripts/env-path.py-')
assert paths.remote_script_path.startswith('/path/scripts/path.py-')
def test_write_environment_script_file(self):
base_dir = '/path'
local_script_path = '/local/script/path.py'
paths = ssh_operations._Paths(base_dir=base_dir,
local_script_path=local_script_path)
env = {'one': "'1'"}
local_socket_url = 'local_socket_url'
remote_socket_url = 'remote_socket_url'
env_script_lines = set([l for l in ssh_operations._write_environment_script_file(
process={'env': env},
paths=paths,
local_socket_url=local_socket_url,
remote_socket_url=remote_socket_url
).getvalue().split('\n') if l])
expected_env_script_lines = set([
'export PATH=/path:$PATH',
'export PYTHONPATH=/path:$PYTHONPATH',
'chmod +x /path/ctx',
'chmod +x {0}'.format(paths.remote_script_path),
'export CTX_SOCKET_URL={0}'.format(remote_socket_url),
'export LOCAL_CTX_SOCKET_URL={0}'.format(local_socket_url),
'export one=\'1\''
])
assert env_script_lines == expected_env_script_lines