| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import copy |
| import json |
| import os |
| |
| import sqlalchemy.event |
| |
| from ..modeling import models as _models |
| from ..storage.exceptions import StorageError |
| |
| |
| _VERSION_ID_COL = 'version' |
| _STUB = object() |
| _INSTRUMENTED = { |
| 'modified': { |
| _models.Node.runtime_properties: dict, |
| _models.Node.state: str, |
| _models.Task.status: str, |
| }, |
| 'new': (_models.Log, ) |
| |
| } |
| |
| _NEW_INSTANCE = 'NEW_INSTANCE' |
| |
| |
| def track_changes(model=None, instrumented=None): |
| """Track changes in the specified model columns |
| |
| This call will register event listeners using sqlalchemy's event mechanism. The listeners |
| instrument all returned objects such that the attributes specified in ``instrumented``, will |
| be replaced with a value that is stored in the returned instrumentation context |
| ``tracked_changes`` property. |
| |
| Why should this be implemented when sqlalchemy already does a fantastic job at tracking changes |
| you ask? Well, when sqlalchemy is used with sqlite, due to how sqlite works, only one process |
| can hold a write lock to the database. This does not work well when ARIA runs tasks in |
| subprocesses (by the process executor) and these tasks wish to change some state as well. These |
| tasks certainly deserve a chance to do so! |
| |
| To enable this, the subprocess calls ``track_changes()`` before any state changes are made. |
| At the end of the subprocess execution, it should return the ``tracked_changes`` attribute of |
| the instrumentation context returned from this call, to the parent process. The parent process |
| will then call ``apply_tracked_changes()`` that resides in this module as well. |
| At that point, the changes will actually be written back to the database. |
| |
| :param model: the model storage. it should hold a mapi for each model. the session of each mapi |
| is needed to setup events |
| :param instrumented: A dict from model columns to their python native type |
| :return: The instrumentation context |
| """ |
| return _Instrumentation(model, instrumented or _INSTRUMENTED) |
| |
| |
| class _Instrumentation(object): |
| |
| def __init__(self, model, instrumented): |
| self.tracked_changes = {} |
| self.new_instances = {} |
| self.listeners = [] |
| self._instances_to_expunge = [] |
| self._model = model |
| self._track_changes(instrumented) |
| |
| @property |
| def _new_instance_id(self): |
| return '{prefix}_{index}'.format(prefix=_NEW_INSTANCE, |
| index=len(self._instances_to_expunge)) |
| |
| def expunge_session(self): |
| for new_instance in self._instances_to_expunge: |
| self._get_session_from_model(new_instance.__tablename__).expunge(new_instance) |
| |
| def _get_session_from_model(self, tablename): |
| mapi = getattr(self._model, tablename, None) |
| if mapi: |
| return mapi._session |
| raise StorageError("Could not retrieve session for {0}".format(tablename)) |
| |
| def _track_changes(self, instrumented): |
| instrumented_attribute_classes = {} |
| # Track any newly-set attributes. |
| for instrumented_attribute, attribute_type in instrumented.get('modified', {}).items(): |
| self._register_set_attribute_listener( |
| instrumented_attribute=instrumented_attribute, |
| attribute_type=attribute_type) |
| instrumented_class = instrumented_attribute.parent.entity |
| instrumented_class_attributes = instrumented_attribute_classes.setdefault( |
| instrumented_class, {}) |
| instrumented_class_attributes[instrumented_attribute.key] = attribute_type |
| |
| # Track any global instance update such as 'refresh' or 'load' |
| for instrumented_class, instrumented_attributes in instrumented_attribute_classes.items(): |
| self._register_instance_listeners(instrumented_class=instrumented_class, |
| instrumented_attributes=instrumented_attributes) |
| |
| # Track any newly created instances. |
| for instrumented_class in instrumented.get('new', {}): |
| self._register_new_instance_listener(instrumented_class) |
| |
| def _register_new_instance_listener(self, instrumented_class): |
| if self._model is None: |
| raise StorageError("In order to keep track of new instances, a ctx is needed") |
| |
| def listener(_, instance): |
| if not isinstance(instance, instrumented_class): |
| return |
| self._instances_to_expunge.append(instance) |
| tracked_instances = self.new_instances.setdefault(instance.__modelname__, {}) |
| tracked_attributes = tracked_instances.setdefault(self._new_instance_id, {}) |
| instance_as_dict = instance.to_dict() |
| instance_as_dict.update((k, getattr(instance, k)) |
| for k in getattr(instance, '__private_fields__', [])) |
| tracked_attributes.update(instance_as_dict) |
| session = self._get_session_from_model(instrumented_class.__tablename__) |
| listener_args = (session, 'after_attach', listener) |
| sqlalchemy.event.listen(*listener_args) |
| self.listeners.append(listener_args) |
| |
| def _register_set_attribute_listener(self, instrumented_attribute, attribute_type): |
| def listener(target, value, *_): |
| mapi_name = target.__modelname__ |
| tracked_instances = self.tracked_changes.setdefault(mapi_name, {}) |
| tracked_attributes = tracked_instances.setdefault(target.id, {}) |
| if value is None: |
| current = None |
| else: |
| current = copy.deepcopy(attribute_type(value)) |
| tracked_attributes[instrumented_attribute.key] = _Value(_STUB, current) |
| return current |
| listener_args = (instrumented_attribute, 'set', listener) |
| sqlalchemy.event.listen(*listener_args, retval=True) |
| self.listeners.append(listener_args) |
| |
| def _register_instance_listeners(self, instrumented_class, instrumented_attributes): |
| def listener(target, *_): |
| mapi_name = instrumented_class.__modelname__ |
| tracked_instances = self.tracked_changes.setdefault(mapi_name, {}) |
| tracked_attributes = tracked_instances.setdefault(target.id, {}) |
| if hasattr(target, _VERSION_ID_COL): |
| # We want to keep track of the initial version id so it can be compared |
| # with the committed version id when the tracked changes are applied |
| tracked_attributes.setdefault(_VERSION_ID_COL, |
| _Value(_STUB, getattr(target, _VERSION_ID_COL))) |
| for attribute_name, attribute_type in instrumented_attributes.items(): |
| if attribute_name not in tracked_attributes: |
| initial = getattr(target, attribute_name) |
| if initial is None: |
| current = None |
| else: |
| current = copy.deepcopy(attribute_type(initial)) |
| tracked_attributes[attribute_name] = _Value(initial, current) |
| target.__dict__[attribute_name] = tracked_attributes[attribute_name].current |
| for listener_args in ((instrumented_class, 'load', listener), |
| (instrumented_class, 'refresh', listener), |
| (instrumented_class, 'refresh_flush', listener)): |
| sqlalchemy.event.listen(*listener_args) |
| self.listeners.append(listener_args) |
| |
| def clear(self, target=None): |
| if target: |
| mapi_name = target.__modelname__ |
| tracked_instances = self.tracked_changes.setdefault(mapi_name, {}) |
| tracked_instances.pop(target.id, None) |
| else: |
| self.tracked_changes.clear() |
| |
| self.new_instances.clear() |
| self._instances_to_expunge = [] |
| |
| def restore(self): |
| """Remove all listeners registered by this instrumentation""" |
| for listener_args in self.listeners: |
| if sqlalchemy.event.contains(*listener_args): |
| sqlalchemy.event.remove(*listener_args) |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.restore() |
| |
| |
| class _Value(object): |
| # You may wonder why is this a full blown class and not a named tuple. The reason is that |
| # jsonpickle that is used to serialize the tracked_changes, does not handle named tuples very |
| # well. At the very least, I could not get it to behave. |
| |
| def __init__(self, initial, current): |
| self.initial = initial |
| self.current = current |
| |
| def __eq__(self, other): |
| if not isinstance(other, _Value): |
| return False |
| return self.initial == other.initial and self.current == other.current |
| |
| def __hash__(self): |
| return hash((self.initial, self.current)) |
| |
| @property |
| def dict(self): |
| return {'initial': self.initial, 'current': self.current}.copy() |
| |
| |
| def apply_tracked_changes(tracked_changes, new_instances, model): |
| """Write tracked changes back to the database using provided model storage |
| |
| :param tracked_changes: The ``tracked_changes`` attribute of the instrumentation context |
| returned by calling ``track_changes()`` |
| :param model: The model storage used to actually apply the changes |
| """ |
| successfully_updated_changes = dict() |
| try: |
| # handle instance updates |
| for mapi_name, tracked_instances in tracked_changes.items(): |
| successfully_updated_changes[mapi_name] = dict() |
| mapi = getattr(model, mapi_name) |
| for instance_id, tracked_attributes in tracked_instances.items(): |
| successfully_updated_changes[mapi_name][instance_id] = dict() |
| instance = None |
| for attribute_name, value in tracked_attributes.items(): |
| if value.initial != value.current: |
| instance = instance or mapi.get(instance_id) |
| setattr(instance, attribute_name, value.current) |
| if instance: |
| _validate_version_id(instance, mapi) |
| mapi.update(instance) |
| successfully_updated_changes[mapi_name][instance_id] = [ |
| v.dict for v in tracked_attributes.values()] |
| |
| # Handle new instances |
| for mapi_name, new_instance in new_instances.items(): |
| successfully_updated_changes[mapi_name] = dict() |
| mapi = getattr(model, mapi_name) |
| for new_instance_kwargs in new_instance.values(): |
| instance = mapi.model_cls(**new_instance_kwargs) |
| mapi.put(instance) |
| successfully_updated_changes[mapi_name][instance.id] = new_instance_kwargs |
| except BaseException: |
| for key, value in successfully_updated_changes.items(): |
| if not value: |
| del successfully_updated_changes[key] |
| # TODO: if the successful has _STUB, the logging fails because it can't serialize the object |
| model.logger.error( |
| 'Registering all the changes to the storage has failed. {0}' |
| 'The successful updates were: {0} ' |
| '{1}'.format(os.linesep, json.dumps(successfully_updated_changes, indent=4))) |
| |
| raise |
| |
| |
| def _validate_version_id(instance, mapi): |
| version_id = sqlalchemy.inspect(instance).committed_state.get(_VERSION_ID_COL) |
| # There are two version conflict code paths: |
| # 1. The instance committed state loaded already holds a newer version, |
| # in this case, we manually raise the error |
| # 2. The UPDATE statement is executed with version validation and sqlalchemy |
| # will raise a StateDataError if there is a version mismatch. |
| if version_id and getattr(instance, _VERSION_ID_COL) != version_id: |
| object_version_id = getattr(instance, _VERSION_ID_COL) |
| mapi._session.rollback() |
| raise StorageError( |
| 'Version conflict: committed and object {0} differ ' |
| '[committed {0}={1}, object {0}={2}]' |
| .format(_VERSION_ID_COL, |
| version_id, |
| object_version_id)) |