| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| import logging |
| import uuid |
| from contextlib import contextmanager |
| |
| import aria |
| from aria.modeling import models |
| from aria.orchestrator.context.common import BaseContext |
| |
| |
| class MockContext(object): |
| |
| INSTRUMENTATION_FIELDS = BaseContext.INSTRUMENTATION_FIELDS |
| |
| def __init__(self, storage=None, task_kwargs=None): |
| self.logger = logging.getLogger('mock_logger') |
| self._task_kwargs = task_kwargs or {} |
| import mock |
| self._storage = storage or mock.MagicMock() |
| self._storage_kwargs = self._storage.serialization_dict if storage else None |
| self.task = MockTask(storage, **task_kwargs) |
| self.states = [] |
| self.exception = None |
| |
| @property |
| def serialization_dict(self): |
| return { |
| 'context_cls': self.__class__, |
| 'context': { |
| 'storage_kwargs': self._storage_kwargs, |
| 'task_kwargs': self._task_kwargs |
| } |
| } |
| |
| def __getattr__(self, item): |
| return None |
| |
| def close(self): |
| pass |
| |
| @property |
| def model(self): |
| return self._storage |
| |
| @classmethod |
| def instantiate_from_dict(cls, storage_kwargs=None, task_kwargs=None): |
| if storage_kwargs: |
| return cls(storage=aria.application_model_storage(**(storage_kwargs or {})), |
| task_kwargs=task_kwargs) |
| else: |
| return cls(task_kwargs=task_kwargs) |
| |
| |
| class MockActor(object): |
| def __init__(self): |
| self.name = 'actor_name' |
| |
| |
| class MockTask(object): |
| |
| INFINITE_RETRIES = models.Task.INFINITE_RETRIES |
| |
| def __init__(self, model, function, arguments=None, plugin_fk=None): |
| self.function = self.name = function |
| self.plugin_fk = plugin_fk |
| self.arguments = arguments or {} |
| self.states = [] |
| self.exception = None |
| self.id = str(uuid.uuid4()) |
| self.logger = logging.getLogger() |
| self.attempts_count = 1 |
| self.max_attempts = 1 |
| self.ignore_failure = False |
| self.interface_name = 'interface_name' |
| self.operation_name = 'operation_name' |
| self.actor = MockActor() |
| self.node = self.actor |
| self.model = model |
| |
| for state in models.Task.STATES: |
| setattr(self, state.upper(), state) |
| |
| @property |
| def plugin(self): |
| return self.model.plugin.get(self.plugin_fk) if self.plugin_fk else None |