blob: 975ada7028ad9baf36586879b7c77e5a47cfc7c0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SQLAlchemy implementation of the storage model API ("MAPI").
"""
import os
import platform
from sqlalchemy import (
create_engine,
orm,
)
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.exc import StaleDataError
from aria.utils.collections import OrderedDict
from . import (
api,
exceptions,
collection_instrumentation
)
_predicates = {'ge': '__ge__',
'gt': '__gt__',
'lt': '__lt__',
'le': '__le__',
'eq': '__eq__',
'ne': '__ne__'}
class SQLAlchemyModelAPI(api.ModelAPI):
"""
SQLAlchemy implementation of the storage model API ("MAPI").
"""
def __init__(self,
engine,
session,
**kwargs):
super(SQLAlchemyModelAPI, self).__init__(**kwargs)
self._engine = engine
self._session = session
def get(self, entry_id, include=None, **kwargs):
"""
Returns a single result based on the model class and element ID
"""
query = self._get_query(include, {'id': entry_id})
result = query.first()
if not result:
raise exceptions.NotFoundError(
'Requested `{0}` with ID `{1}` was not found'
.format(self.model_cls.__name__, entry_id)
)
return self._instrument(result)
def get_by_name(self, entry_name, include=None, **kwargs):
assert hasattr(self.model_cls, 'name')
result = self.list(include=include, filters={'name': entry_name})
if not result:
raise exceptions.NotFoundError(
'Requested {0} with name `{1}` was not found'
.format(self.model_cls.__name__, entry_name)
)
elif len(result) > 1:
raise exceptions.StorageError(
'Requested {0} with name `{1}` returned more than 1 value'
.format(self.model_cls.__name__, entry_name)
)
else:
return result[0]
def list(self,
include=None,
filters=None,
pagination=None,
sort=None,
**kwargs):
query = self._get_query(include, filters, sort)
results, total, size, offset = self._paginate(query, pagination)
return ListResult(
dict(total=total, size=size, offset=offset),
[self._instrument(result) for result in results]
)
def iter(self,
include=None,
filters=None,
sort=None,
**kwargs):
"""
Returns a (possibly empty) list of ``model_class`` results.
"""
for result in self._get_query(include, filters, sort):
yield self._instrument(result)
def put(self, entry, **kwargs):
"""
Creatse a ``model_class`` instance from a serializable ``model`` object.
:param entry: dict with relevant kwargs, or an instance of a class that has a ``to_dict``
method, and whose attributes match the columns of ``model_class`` (might also be just an
instance of ``model_class``)
:return: an instance of ``model_class``
"""
self._session.add(entry)
self._safe_commit()
return entry
def delete(self, entry, **kwargs):
"""
Deletes a single result based on the model class and element ID.
"""
self._load_relationships(entry)
self._session.delete(entry)
self._safe_commit()
return entry
def update(self, entry, **kwargs):
"""
Adds ``instance`` to the database session, and attempts to commit.
:return: updated instance
"""
return self.put(entry)
def refresh(self, entry):
"""
Reloads the instance with fresh information from the database.
:param entry: instance to be re-loaded from the database
:return: refreshed instance
"""
self._session.refresh(entry)
self._load_relationships(entry)
return entry
def _destroy_connection(self):
pass
def _establish_connection(self):
pass
def create(self, checkfirst=True, create_all=True, **kwargs):
self.model_cls.__table__.create(self._engine, checkfirst=checkfirst)
if create_all:
# In order to create any models created dynamically (e.g. many-to-many helper tables are
# created at runtime).
self.model_cls.metadata.create_all(bind=self._engine, checkfirst=checkfirst)
def drop(self):
"""
Drops the table.
"""
self.model_cls.__table__.drop(self._engine)
def _safe_commit(self):
"""
Try to commit changes in the session. Roll back if exception raised SQLAlchemy errors and
rolls back if they're caught.
"""
try:
self._session.commit()
except StaleDataError as e:
self._session.rollback()
raise exceptions.StorageError('Version conflict: {0}'.format(str(e)))
except (SQLAlchemyError, ValueError) as e:
self._session.rollback()
raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
def _get_base_query(self, include, joins):
"""
Create the initial query from the model class and included columns.
:param include: (possibly empty) list of columns to include in the query
:return: SQLAlchemy AppenderQuery object
"""
# If only some columns are included, query through the session object
if include:
# Make sure that attributes come before association proxies
include.sort(key=lambda x: x.is_clause_element)
query = self._session.query(*include)
else:
# If all columns should be returned, query directly from the model
query = self._session.query(self.model_cls)
query = query.join(*joins)
return query
@staticmethod
def _get_joins(model_class, columns):
"""
Gets a list of all the tables on which we need to join.
:param columns: set of all attributes involved in the query
"""
# Using a list instead of a set because order is important
joins = OrderedDict()
for column_name in columns:
column = getattr(model_class, column_name)
while not column.is_attribute:
join_attr = column.local_attr
# This is a hack, to deal with the fact that SQLA doesn't
# fully support doing something like: `if join_attr in joins`,
# because some SQLA elements have their own comparators
join_attr_name = str(join_attr)
if join_attr_name not in joins:
joins[join_attr_name] = join_attr
column = column.remote_attr
return joins.values()
@staticmethod
def _sort_query(query, sort=None):
"""
Adds sorting clauses to the query.
:param query: base SQL query
:param sort: optional dictionary where keys are column names to sort by, and values are
the order (asc/desc)
:return: SQLAlchemy AppenderQuery object
"""
if sort:
for column, order in sort.items():
if order == 'desc':
column = column.desc()
query = query.order_by(column)
return query
def _filter_query(self, query, filters):
"""
Adds filter clauses to the query.
:param query: base SQL query
:param filters: optional dictionary where keys are column names to filter by, and values
are values applicable for those columns (or lists of such values)
:return: SQLAlchemy AppenderQuery object
"""
return self._add_value_filter(query, filters)
@staticmethod
def _add_value_filter(query, filters):
for column, value in filters.items():
if isinstance(value, dict):
for predicate, operand in value.items():
query = query.filter(getattr(column, predicate)(operand))
elif isinstance(value, (list, tuple)):
query = query.filter(column.in_(value))
else:
query = query.filter(column == value)
return query
def _get_query(self,
include=None,
filters=None,
sort=None):
"""
Gets a SQL query object based on the params passed.
:param model_class: SQL database table class
:param include: optional list of columns to include in the query
:param filters: optional dictionary where keys are column names to filter by, and values
are values applicable for those columns (or lists of such values)
:param sort: optional dictionary where keys are column names to sort by, and values are the
order (asc/desc)
:return: sorted and filtered query with only the relevant columns
"""
include, filters, sort, joins = self._get_joins_and_converted_columns(
include, filters, sort
)
filters = self._convert_operands(filters)
query = self._get_base_query(include, joins)
query = self._filter_query(query, filters)
query = self._sort_query(query, sort)
return query
@staticmethod
def _convert_operands(filters):
for column, conditions in filters.items():
if isinstance(conditions, dict):
for predicate, operand in conditions.items():
if predicate not in _predicates:
raise exceptions.StorageError(
"{0} is not a valid predicate for filtering. Valid predicates are {1}"
.format(predicate, ', '.join(_predicates.keys())))
del filters[column][predicate]
filters[column][_predicates[predicate]] = operand
return filters
def _get_joins_and_converted_columns(self,
include,
filters,
sort):
"""
Gets a list of tables on which we need to join and the converted ``include``, ``filters``
and ```sort`` arguments (converted to actual SQLAlchemy column/label objects instead of
column names).
"""
include = include or []
filters = filters or dict()
sort = sort or OrderedDict()
all_columns = set(include) | set(filters.keys()) | set(sort.keys())
joins = self._get_joins(self.model_cls, all_columns)
include, filters, sort = self._get_columns_from_field_names(
include, filters, sort
)
return include, filters, sort, joins
def _get_columns_from_field_names(self,
include,
filters,
sort):
"""
Gooes over the optional parameters (include, filters, sort), and replace column names with
actual SQLAlechmy column objects.
"""
include = [self._get_column(c) for c in include]
filters = dict((self._get_column(c), filters[c]) for c in filters)
sort = OrderedDict((self._get_column(c), sort[c]) for c in sort)
return include, filters, sort
def _get_column(self, column_name):
"""
Returns the column on which an action (filtering, sorting, etc.) would need to be performed.
Can be either an attribute of the class, or an association proxy linked to a relationship
in the class.
"""
column = getattr(self.model_cls, column_name)
if column.is_attribute:
return column
else:
# We need to get to the underlying attribute, so we move on to the
# next remote_attr until we reach one
while not column.remote_attr.is_attribute:
column = column.remote_attr
# Put a label on the remote attribute with the name of the column
return column.remote_attr.label(column_name)
@staticmethod
def _paginate(query, pagination):
"""
Paginates the query by size and offset.
:param query: current SQLAlchemy query object
:param pagination: optional dict with size and offset keys
:return: tuple with four elements:
* results: ``size`` items starting from ``offset``
* the total count of items
* ``size`` [default: 0]
* ``offset`` [default: 0]
"""
if pagination:
size = pagination.get('size', 0)
offset = pagination.get('offset', 0)
total = query.order_by(None).count() # Fastest way to count
results = query.limit(size).offset(offset).all()
return results, total, size, offset
else:
results = query.all()
return results, len(results), 0, 0
@staticmethod
def _load_relationships(instance):
"""
Helper method used to overcome a problem where the relationships that rely on joins aren't
being loaded automatically.
"""
for rel in instance.__mapper__.relationships:
getattr(instance, rel.key)
def _instrument(self, model):
if self._instrumentation:
return collection_instrumentation.instrument(self._instrumentation, model, self)
else:
return model
def init_storage(base_dir, filename='db.sqlite'):
"""
Built-in ModelStorage initiator.
Creates a SQLAlchemy engine and a session to be passed to the MAPI.
``initiator_kwargs`` must be passed to the ModelStorage which must hold the ``base_dir`` for the
location of the database file, and an option filename. This would create an SQLite database.
:param base_dir: directory of the database
:param filename: database file name.
:return:
"""
uri = 'sqlite:///{platform_char}{path}'.format(
# Handles the windows behavior where there is not root, but drivers.
# Thus behaving as relative path.
platform_char='' if 'Windows' in platform.system() else '/',
path=os.path.join(base_dir, filename))
engine = create_engine(uri, connect_args=dict(timeout=15))
session_factory = orm.sessionmaker(bind=engine)
session = orm.scoped_session(session_factory=session_factory)
return dict(engine=engine, session=session)
class ListResult(list):
"""
Contains results about the requested items.
"""
def __init__(self, metadata, *args, **qwargs):
super(ListResult, self).__init__(*args, **qwargs)
self.metadata = metadata
self.items = self