Merge pull request #3 from rafaelnovello/bugfix-reload-action
Ensure artifact reload is working. Resolving #MARVIN-11 issue
diff --git a/python-toolbox/marvin_python_toolbox/engine_base/engine_base_action.py b/python-toolbox/marvin_python_toolbox/engine_base/engine_base_action.py
index daead75..fcabdb1 100644
--- a/python-toolbox/marvin_python_toolbox/engine_base/engine_base_action.py
+++ b/python-toolbox/marvin_python_toolbox/engine_base/engine_base_action.py
@@ -95,6 +95,7 @@
self._local_saved_objects[object_reference] = object_file_path
def _load_obj(self, object_reference, force=False):
+ object_reference = object_reference if object_reference.startswith('_') else '_%s' % object_reference
if (getattr(self, object_reference, None) is None and self._persistence_mode == 'local') or force:
object_file_path = self._get_object_file_path(object_reference)
logger.info("Loading object from {}".format(object_file_path))
diff --git a/python-toolbox/tests/engine_base/test_engine_base_prediction.py b/python-toolbox/tests/engine_base/test_engine_base_prediction.py
index bba43a3..4e3c3fe 100644
--- a/python-toolbox/tests/engine_base/test_engine_base_prediction.py
+++ b/python-toolbox/tests/engine_base/test_engine_base_prediction.py
@@ -16,6 +16,10 @@
# limitations under the License.
import pytest
+try:
+ import mock
+except ImportError:
+ import unittest.mock as mock
from marvin_python_toolbox.engine_base import EngineBasePrediction
@@ -26,7 +30,10 @@
def execute(self, **kwargs):
return 1
- return EngineAction(default_root_path="/tmp/.marvin")
+ return EngineAction(
+ default_root_path="/tmp/.marvin",
+ persistence_mode="local"
+ )
class TestEngineBasePrediction:
@@ -38,3 +45,32 @@
def test_metrics(self, engine_action):
engine_action.marvin_metrics = [3]
assert engine_action.marvin_metrics == engine_action._metrics == [3]
+
+
+class TestEnsureReloadActionReplaceObjectAttr:
+
+ @mock.patch('marvin_python_toolbox.engine_base.engine_base_action.EngineBaseAction._serializer_load')
+ def test_first_load_from_artifact_works(self, mock_serializer, engine_action):
+ mock_serializer.return_value = "MOCKED"
+
+ assert engine_action._model == None
+
+ engine_action._load_obj(object_reference="model")
+
+ assert engine_action._model == engine_action.marvin_model == "MOCKED"
+
+ @mock.patch('marvin_python_toolbox.engine_base.engine_base_action.EngineBaseAction._serializer_load')
+ def test_reload_works_before_first_load(self, mock_serializer, engine_action):
+ mock_serializer.return_value = "MOCKED"
+
+ assert engine_action._model == None
+
+ engine_action._load_obj(object_reference="model")
+
+ assert engine_action._model == engine_action.marvin_model == "MOCKED"
+
+ mock_serializer.return_value = "NEW MOCKED"
+
+ engine_action._load_obj(object_reference="model", force=True)
+
+ assert engine_action._model == engine_action.marvin_model == "NEW MOCKED"
\ No newline at end of file