| #!/usr/bin/env python |
| # coding=utf-8 |
| |
| # Copyright [2017] [B2W Digital] |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| from builtins import str |
| import joblib as serializer |
| import pytest |
| import os |
| import shutil |
| import copy |
| from mock import ANY |
| try: |
| import mock |
| except ImportError: |
| import unittest.mock as mock |
| |
| from marvin_python_toolbox.engine_base import EngineBaseBatchAction |
| from marvin_python_toolbox.engine_base import EngineBaseAction, EngineBaseOnlineAction |
| from marvin_python_toolbox.engine_base.stubs.actions_pb2 import HealthCheckResponse, HealthCheckRequest |
| from marvin_python_toolbox.engine_base.stubs.actions_pb2 import OnlineActionRequest, ReloadRequest, BatchActionRequest |
| |
| |
| @pytest.fixture |
| def engine_action(): |
| class EngineAction(EngineBaseAction): |
| def execute(self, params, **kwargs): |
| return 1 |
| |
| return EngineAction(default_root_path="/tmp/.marvin") |
| |
| |
| @pytest.fixture |
| def batch_engine_action(): |
| class BatchEngineAction(EngineBaseBatchAction): |
| def execute(self, params, **kwargs): |
| return 1 |
| |
| return BatchEngineAction(default_root_path="/tmp/.marvin") |
| |
| |
| class TestEngineBaseAction: |
| def setup(self): |
| shutil.rmtree("/tmp/.marvin", ignore_errors=True) |
| |
| def test_retrieve_obj(self): |
| path = '/tmp/test-obj' |
| obj = [1, 2] |
| serializer.dump(obj, open(path, 'wb')) |
| assert obj == EngineBaseAction.retrieve_obj(path) |
| |
| def test_constructor(self): |
| class EngineAction(EngineBaseAction): |
| def execute(self, params, **kwargs): |
| return 1 |
| |
| engine = EngineAction(params={"x", 1}, persistence_mode='x') |
| |
| assert engine._params == {"x", 1} |
| assert engine._persistence_mode == 'x' |
| |
| def test_get_object_file_path(self, engine_action): |
| assert engine_action._get_object_file_path(object_reference="xpath") == "/tmp/.marvin/test_base_action/xpath" |
| |
| def test_save_obj_memory_persistence(self, engine_action): |
| obj = [6, 5, 4] |
| object_reference = '_params' |
| engine_action._save_obj(object_reference, obj) |
| |
| assert obj == engine_action._params |
| assert not os.path.exists("/tmp/.marvin/test_base_action/params") |
| |
| def test_save_obj_local_persistence(self, engine_action): |
| obj = [6, 5, 4] |
| object_reference = '_params' |
| engine_action._persistence_mode = 'local' |
| engine_action._save_obj(object_reference, obj) |
| |
| assert obj == engine_action._params |
| assert os.path.exists("/tmp/.marvin/test_base_action/params") |
| assert list(engine_action._local_saved_objects.keys()) == [object_reference] |
| |
| def test_release_saved_objects(self, engine_action): |
| obj = [6, 5, 4] |
| object_reference = '_params' |
| engine_action._persistence_mode = 'local' |
| engine_action._save_obj(object_reference, obj) |
| |
| assert list(engine_action._local_saved_objects.keys()) == [object_reference] |
| engine_action._release_local_saved_objects() |
| assert engine_action._params is None |
| |
| def test_save_two_times(self, engine_action): |
| obj = [6, 5, 4] |
| object_reference = '_params' |
| engine_action._persistence_mode = 'local' |
| |
| engine_action._save_obj(object_reference, obj) |
| try: |
| engine_action._save_obj(object_reference, obj) |
| assert False |
| |
| except Exception as e: |
| assert str(e.args[0]) == 'MultipleAssignException' |
| assert str(e.args[1]) == '_params' |
| |
| def test_load_obj_local_persistence(self, engine_action): |
| engine_action2 = copy.copy(engine_action) |
| |
| obj = [6, 5, 4] |
| object_reference = '_params' |
| engine_action._persistence_mode = 'local' |
| engine_action._save_obj(object_reference, obj) |
| |
| engine_action2._persistence_mode = 'local' |
| engine_action2._load_obj(object_reference) |
| |
| assert obj == engine_action2._params |
| |
| engine_action2._persistence_mode = 'memory' |
| engine_action2._params = [1] |
| engine_action2._load_obj(object_reference) |
| |
| assert [1] == engine_action2._params |
| |
| def test_health_check_ok(self, engine_action): |
| obj1_key = "obj1" |
| engine_action._save_obj(obj1_key, "check") |
| request = HealthCheckRequest(artifacts=obj1_key) |
| expected_response = HealthCheckResponse(status=HealthCheckResponse.OK) |
| response = engine_action._health_check(request=request, context=None) |
| |
| assert expected_response.status == response.status |
| |
| def test_health_check_ok_multiple(self, engine_action): |
| obj1_key = "obj1" |
| engine_action._save_obj(obj1_key, "check") |
| obj2_key = "obj2" |
| engine_action._save_obj(obj2_key, "check") |
| |
| request = HealthCheckRequest(artifacts=obj1_key + "," + obj2_key) |
| expected_response = HealthCheckResponse(status=HealthCheckResponse.OK) |
| response = engine_action._health_check(request=request, context=None) |
| |
| assert expected_response.status == response.status |
| |
| def test_health_check_nok(self, engine_action): |
| obj1_key = "obj1" |
| request = HealthCheckRequest(artifacts=obj1_key) |
| expected_response = HealthCheckResponse(status=HealthCheckResponse.NOK) |
| response = engine_action._health_check(request=request, context=None) |
| |
| assert expected_response.status == response.status |
| |
| engine_action._save_obj(obj1_key, "check") |
| request = HealthCheckRequest(artifacts=obj1_key + ", obj2") |
| response = engine_action._health_check(request=request, context=None) |
| |
| assert expected_response.status == response.status |
| |
| def test_health_check_exception(self): |
| class BadEngineAction(EngineBaseAction): |
| def execute(self, params, **kwargs): |
| return 1 |
| |
| def __getattribute__(self, name): |
| if name == 'obj1': |
| raise Exception('I am Bad!') |
| else: |
| return EngineBaseAction.__getattribute__(self, name) |
| |
| engine_action = BadEngineAction() |
| |
| obj1_key = "obj1" |
| request = HealthCheckRequest(artifacts=obj1_key) |
| expected_response = HealthCheckResponse(status=HealthCheckResponse.NOK) |
| response = engine_action._health_check(request=request, context=None) |
| |
| assert expected_response.status == response.status |
| |
| def test_remote_execute_with_string_response(self): |
| class StringReturnedAction(EngineBaseOnlineAction): |
| def execute(self, input_message, params, **kwargs): |
| return "message 1" |
| |
| request = OnlineActionRequest(message="{\"k\": 1}", params="{\"k\": 1}") |
| engine_action = StringReturnedAction() |
| response = engine_action._remote_execute(request=request, context=None) |
| |
| assert response.message == "message 1" |
| |
| def test_remote_execute_with_int_response(self): |
| class StringReturnedAction(EngineBaseOnlineAction): |
| def execute(self, input_message, params, **kwargs): |
| return 1 |
| |
| request = OnlineActionRequest(message="{\"k\": 1}", params="{\"k\": 1}") |
| engine_action = StringReturnedAction() |
| response = engine_action._remote_execute(request=request, context=None) |
| |
| assert response.message == "1" |
| |
| def test_remote_execute_with_object_response(self): |
| class StringReturnedAction(EngineBaseOnlineAction): |
| def execute(self, input_message, params, **kwargs): |
| return {"r": 1} |
| |
| request = OnlineActionRequest(message="{\"k\": 1}", params="{\"k\": 1}") |
| engine_action = StringReturnedAction() |
| response = engine_action._remote_execute(request=request, context=None) |
| |
| assert response.message == "{\"r\": 1}" |
| |
| def test_remote_execute_with_list_response(self): |
| class StringReturnedAction(EngineBaseOnlineAction): |
| def execute(self, input_message, params, **kwargs): |
| return [1, 2] |
| |
| request = OnlineActionRequest(message="{\"k\": 1}", params="{\"k\": 1}") |
| engine_action = StringReturnedAction() |
| response = engine_action._remote_execute(request=request, context=None) |
| |
| assert response.message == "[1, 2]" |
| |
| @mock.patch('marvin_python_toolbox.engine_base.engine_base_action.EngineBaseAction._load_obj') |
| def test_remote_reload_with_artifacts(self, load_obj_mocked, engine_action): |
| objs_key = "obj1" |
| engine_action._save_obj(objs_key, "check") |
| request = ReloadRequest(artifacts=objs_key, protocol='xyz') |
| |
| response = engine_action._remote_reload(request, None) |
| load_obj_mocked.assert_called_once_with(force=True, object_reference=u'obj1') |
| assert response.message == "Reloaded" |
| |
| @mock.patch('marvin_python_toolbox.engine_base.engine_base_action.EngineBaseAction._load_obj') |
| def test_remote_reload_without_artifacts(self, load_obj_mocked, engine_action): |
| request = ReloadRequest(artifacts=None, protocol='xyz') |
| |
| response = engine_action._remote_reload(request, None) |
| load_obj_mocked.assert_not_called() |
| assert response.message == "Nothing to reload" |
| |
| def test_load_obj_dont_reload_without_force(self, engine_action): |
| obj = [6, 5, 4] |
| object_reference = '_params' |
| engine_action._persistence_mode = 'local' |
| engine_action._save_obj(object_reference, obj) |
| |
| assert obj == engine_action._params |
| |
| new_obj = [1, 2, 3] |
| path = engine_action._get_object_file_path(object_reference) |
| engine_action._serializer_dump(new_obj, path) |
| engine_action._load_obj(object_reference) |
| |
| assert obj == engine_action._params |
| |
| engine_action._load_obj(object_reference, force=True) |
| |
| assert new_obj == engine_action._params |
| |
| |
| class TestEngineBaseBatchAction: |
| def setup(self): |
| shutil.rmtree("/tmp/.marvin", ignore_errors=True) |
| |
| def test_pipeline_execute_without_previous_steps(self, batch_engine_action): |
| batch_engine_action.execute = mock.MagicMock() |
| batch_engine_action._pipeline_execute(params=123) |
| |
| batch_engine_action.execute.assert_called_once_with(123) |
| |
| def test_pipeline_execute_with_previous_steps(self, batch_engine_action): |
| previous = copy.copy(batch_engine_action) |
| previous._pipeline_execute = mock.MagicMock() |
| batch_engine_action._previous_step = previous |
| batch_engine_action.execute = mock.MagicMock() |
| |
| batch_engine_action._pipeline_execute(params=123) |
| |
| previous._pipeline_execute.assert_called_once_with(123) |
| batch_engine_action.execute.assert_called_once_with(123) |
| |
| def test_remote_execute_without_request_params(self, batch_engine_action): |
| batch_engine_action._params = 123 |
| batch_engine_action._pipeline_execute = mock.MagicMock() |
| |
| request = BatchActionRequest() |
| batch_engine_action._remote_execute(request, None) |
| |
| batch_engine_action._pipeline_execute.assert_called_once_with(params=123) |
| |
| def test_remote_execute_with_request_params(self, batch_engine_action): |
| batch_engine_action._params = 123 |
| batch_engine_action._pipeline_execute = mock.MagicMock() |
| |
| request = BatchActionRequest(params='{"test": 123}') |
| batch_engine_action._remote_execute(request, None) |
| |
| batch_engine_action._pipeline_execute.assert_called_once_with(params={u"test": 123}) |
| |
| @mock.patch("json.load") |
| def test__serializer_load_metrics(self, mocked_load): |
| obj = {"key", 1} |
| mocked_load.return_value = obj |
| object_reference = "_metrics" |
| |
| class _EAction(EngineBaseAction): |
| _metrics = None |
| |
| def execute(self, params, **kwargs): |
| pass |
| |
| mocked_open = mock.mock_open() |
| with mock.patch('marvin_python_toolbox.engine_base.engine_base_action.open', mocked_open, create=False): |
| _metrics = _EAction(default_root_path="/tmp/.marvin", persistence_mode="local")._load_obj(object_reference) |
| |
| mocked_load.assert_called_once_with(ANY) |
| mocked_open.assert_called_once() |
| assert obj == _metrics |
| |
| @mock.patch("json.dump") |
| def test__serializer_dump_metrics(self, mocked_dump): |
| obj = {"key", 1} |
| object_reference = "_metrics" |
| |
| class _EAction(EngineBaseAction): |
| _metrics = None |
| |
| def execute(self, params, **kwargs): |
| pass |
| |
| mocked_open = mock.mock_open() |
| with mock.patch('marvin_python_toolbox.engine_base.engine_base_action.open', mocked_open, create=False): |
| _EAction(default_root_path="/tmp/.marvin", persistence_mode="local")._save_obj(object_reference, obj) |
| |
| mocked_dump.assert_called_once_with(obj, ANY, indent=4, separators=(u',', u': '), sort_keys=True) |
| mocked_open.assert_called_once() |
| |
| |