blob: bb6223a11ddcd57b0bc5873eeadf84d5ff863bb1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SQLAlchemy based MAPI
"""
import os
import platform
from sqlalchemy import (
create_engine,
orm,
)
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.exc import StaleDataError
from aria.utils.collections import OrderedDict
from . import (
api,
exceptions,
collection_instrumentation
)
_predicates = {'ge': '__ge__',
'gt': '__gt__',
'lt': '__lt__',
'le': '__le__',
'eq': '__eq__',
'ne': '__ne__'}
class SQLAlchemyModelAPI(api.ModelAPI):
"""
SQL based MAPI.
"""
def __init__(self,
engine,
session,
**kwargs):
super(SQLAlchemyModelAPI, self).__init__(**kwargs)
self._engine = engine
self._session = session
def get(self, entry_id, include=None, **kwargs):
"""Return a single result based on the model class and element ID
"""
query = self._get_query(include, {'id': entry_id})
result = query.first()
if not result:
raise exceptions.NotFoundError(
'Requested `{0}` with ID `{1}` was not found'
.format(self.model_cls.__name__, entry_id)
)
return self._instrument(result)
def get_by_name(self, entry_name, include=None, **kwargs):
assert hasattr(self.model_cls, 'name')
result = self.list(include=include, filters={'name': entry_name})
if not result:
raise exceptions.NotFoundError(
'Requested {0} with name `{1}` was not found'
.format(self.model_cls.__name__, entry_name)
)
elif len(result) > 1:
raise exceptions.StorageError(
'Requested {0} with name `{1}` returned more than 1 value'
.format(self.model_cls.__name__, entry_name)
)
else:
return result[0]
def list(self,
include=None,
filters=None,
pagination=None,
sort=None,
**kwargs):
query = self._get_query(include, filters, sort)
results, total, size, offset = self._paginate(query, pagination)
return ListResult(
dict(total=total, size=size, offset=offset),
[self._instrument(result) for result in results]
)
def iter(self,
include=None,
filters=None,
sort=None,
**kwargs):
"""Return a (possibly empty) list of `model_class` results
"""
for result in self._get_query(include, filters, sort):
yield self._instrument(result)
def put(self, entry, **kwargs):
"""Create a `model_class` instance from a serializable `model` object
:param entry: A dict with relevant kwargs, or an instance of a class
that has a `to_dict` method, and whose attributes match the columns
of `model_class` (might also my just an instance of `model_class`)
:return: An instance of `model_class`
"""
self._session.add(entry)
self._safe_commit()
return entry
def delete(self, entry, **kwargs):
"""Delete a single result based on the model class and element ID
"""
self._load_relationships(entry)
self._session.delete(entry)
self._safe_commit()
return entry
def update(self, entry, **kwargs):
"""Add `instance` to the DB session, and attempt to commit
:return: The updated instance
"""
return self.put(entry)
def refresh(self, entry):
"""Reload the instance with fresh information from the DB
:param entry: Instance to be re-loaded from the DB
:return: The refreshed instance
"""
self._session.refresh(entry)
self._load_relationships(entry)
return entry
def _destroy_connection(self):
pass
def _establish_connection(self):
pass
def create(self, checkfirst=True, create_all=True, **kwargs):
self.model_cls.__table__.create(self._engine, checkfirst=checkfirst)
if create_all:
# In order to create any models created dynamically (e.g. many-to-many helper tables are
# created at runtime).
self.model_cls.metadata.create_all(bind=self._engine, checkfirst=checkfirst)
def drop(self):
"""
Drop the table from the storage.
:return:
"""
self.model_cls.__table__.drop(self._engine)
def _safe_commit(self):
"""Try to commit changes in the session. Roll back if exception raised
Excepts SQLAlchemy errors and rollbacks if they're caught
"""
try:
self._session.commit()
except StaleDataError as e:
self._session.rollback()
raise exceptions.StorageError('Version conflict: {0}'.format(str(e)))
except (SQLAlchemyError, ValueError) as e:
self._session.rollback()
raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
def _get_base_query(self, include, joins):
"""Create the initial query from the model class and included columns
:param include: A (possibly empty) list of columns to include in
the query
:return: An SQLAlchemy AppenderQuery object
"""
# If only some columns are included, query through the session object
if include:
# Make sure that attributes come before association proxies
include.sort(key=lambda x: x.is_clause_element)
query = self._session.query(*include)
else:
# If all columns should be returned, query directly from the model
query = self._session.query(self.model_cls)
query = query.join(*joins)
return query
@staticmethod
def _get_joins(model_class, columns):
"""Get a list of all the tables on which we need to join
:param columns: A set of all attributes involved in the query
"""
# Using a list instead of a set because order is important
joins = OrderedDict()
for column_name in columns:
column = getattr(model_class, column_name)
while not column.is_attribute:
join_attr = column.local_attr
# This is a hack, to deal with the fact that SQLA doesn't
# fully support doing something like: `if join_attr in joins`,
# because some SQLA elements have their own comparators
join_attr_name = str(join_attr)
if join_attr_name not in joins:
joins[join_attr_name] = join_attr
column = column.remote_attr
return joins.values()
@staticmethod
def _sort_query(query, sort=None):
"""Add sorting clauses to the query
:param query: Base SQL query
:param sort: An optional dictionary where keys are column names to
sort by, and values are the order (asc/desc)
:return: An SQLAlchemy AppenderQuery object
"""
if sort:
for column, order in sort.items():
if order == 'desc':
column = column.desc()
query = query.order_by(column)
return query
def _filter_query(self, query, filters):
"""Add filter clauses to the query
:param query: Base SQL query
:param filters: An optional dictionary where keys are column names to
filter by, and values are values applicable for those columns (or lists
of such values)
:return: An SQLAlchemy AppenderQuery object
"""
return self._add_value_filter(query, filters)
@staticmethod
def _add_value_filter(query, filters):
for column, value in filters.items():
if isinstance(value, dict):
for predicate, operand in value.items():
query = query.filter(getattr(column, predicate)(operand))
elif isinstance(value, (list, tuple)):
query = query.filter(column.in_(value))
else:
query = query.filter(column == value)
return query
def _get_query(self,
include=None,
filters=None,
sort=None):
"""Get an SQL query object based on the params passed
:param model_class: SQL DB table class
:param include: An optional list of columns to include in the query
:param filters: An optional dictionary where keys are column names to
filter by, and values are values applicable for those columns (or lists
of such values)
:param sort: An optional dictionary where keys are column names to
sort by, and values are the order (asc/desc)
:return: A sorted and filtered query with only the relevant
columns
"""
include, filters, sort, joins = self._get_joins_and_converted_columns(
include, filters, sort
)
filters = self._convert_operands(filters)
query = self._get_base_query(include, joins)
query = self._filter_query(query, filters)
query = self._sort_query(query, sort)
return query
@staticmethod
def _convert_operands(filters):
for column, conditions in filters.items():
if isinstance(conditions, dict):
for predicate, operand in conditions.items():
if predicate not in _predicates:
raise exceptions.StorageError(
"{0} is not a valid predicate for filtering. Valid predicates are {1}"
.format(predicate, ', '.join(_predicates.keys())))
del filters[column][predicate]
filters[column][_predicates[predicate]] = operand
return filters
def _get_joins_and_converted_columns(self,
include,
filters,
sort):
"""Get a list of tables on which we need to join and the converted
`include`, `filters` and `sort` arguments (converted to actual SQLA
column/label objects instead of column names)
"""
include = include or []
filters = filters or dict()
sort = sort or OrderedDict()
all_columns = set(include) | set(filters.keys()) | set(sort.keys())
joins = self._get_joins(self.model_cls, all_columns)
include, filters, sort = self._get_columns_from_field_names(
include, filters, sort
)
return include, filters, sort, joins
def _get_columns_from_field_names(self,
include,
filters,
sort):
"""Go over the optional parameters (include, filters, sort), and
replace column names with actual SQLA column objects
"""
include = [self._get_column(c) for c in include]
filters = dict((self._get_column(c), filters[c]) for c in filters)
sort = OrderedDict((self._get_column(c), sort[c]) for c in sort)
return include, filters, sort
def _get_column(self, column_name):
"""Return the column on which an action (filtering, sorting, etc.)
would need to be performed. Can be either an attribute of the class,
or an association proxy linked to a relationship the class has
"""
column = getattr(self.model_cls, column_name)
if column.is_attribute:
return column
else:
# We need to get to the underlying attribute, so we move on to the
# next remote_attr until we reach one
while not column.remote_attr.is_attribute:
column = column.remote_attr
# Put a label on the remote attribute with the name of the column
return column.remote_attr.label(column_name)
@staticmethod
def _paginate(query, pagination):
"""Paginate the query by size and offset
:param query: Current SQLAlchemy query object
:param pagination: An optional dict with size and offset keys
:return: A tuple with four elements:
- res ults: `size` items starting from `offset`
- the total count of items
- `size` [default: 0]
- `offset` [default: 0]
"""
if pagination:
size = pagination.get('size', 0)
offset = pagination.get('offset', 0)
total = query.order_by(None).count() # Fastest way to count
results = query.limit(size).offset(offset).all()
return results, total, size, offset
else:
results = query.all()
return results, len(results), 0, 0
@staticmethod
def _load_relationships(instance):
"""A helper method used to overcome a problem where the relationships
that rely on joins aren't being loaded automatically
"""
for rel in instance.__mapper__.relationships:
getattr(instance, rel.key)
def _instrument(self, model):
if self._instrumentation:
return collection_instrumentation.instrument(self._instrumentation, model, self)
else:
return model
def init_storage(base_dir, filename='db.sqlite'):
"""
A builtin ModelStorage initiator.
Creates a sqlalchemy engine and a session to be passed to the mapi.
Initiator_kwargs must be passed to the ModelStorage which must hold the base_dir for the
location of the db file, and an option filename. This would create an sqlite db.
:param base_dir: the dir of the db
:param filename: the db file name.
:return:
"""
uri = 'sqlite:///{platform_char}{path}'.format(
# Handles the windows behavior where there is not root, but drivers.
# Thus behaving as relative path.
platform_char='' if 'Windows' in platform.system() else '/',
path=os.path.join(base_dir, filename))
engine = create_engine(uri, connect_args=dict(timeout=15))
session_factory = orm.sessionmaker(bind=engine)
session = orm.scoped_session(session_factory=session_factory)
return dict(engine=engine, session=session)
class ListResult(list):
"""
a ListResult contains results about the requested items.
"""
def __init__(self, metadata, *args, **qwargs):
super(ListResult, self).__init__(*args, **qwargs)
self.metadata = metadata
self.items = self