blob: 659ec3fdfcc1ac30b7c9066063238e9a4e3c2289 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union
from flask import current_app, g
from flask_appbuilder import Model
from marshmallow import post_load, pre_load, Schema, ValidationError
from sqlalchemy.orm.exc import NoResultFound
def validate_owner(value: int) -> None:
try:
(
current_app.appbuilder.get_session.query(
current_app.appbuilder.sm.user_model.id
)
.filter_by(id=value)
.one()
)
except NoResultFound:
raise ValidationError(f"User {value} does not exist")
class BaseSupersetSchema(Schema):
"""
Extends Marshmallow schema so that we can pass a Model to load
(following marshamallow-sqlalchemy pattern). This is useful
to perform partial model merges on HTTP PUT
"""
__class_model__: Model = None
def __init__(self, **kwargs: Any) -> None:
self.instance: Optional[Model] = None
super().__init__(**kwargs)
def load( # pylint: disable=arguments-differ
self,
data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]],
many: Optional[bool] = None,
partial: Union[bool, Sequence[str], Set[str], None] = None,
instance: Optional[Model] = None,
**kwargs: Any,
) -> Any:
self.instance = instance
if many is None:
many = False
if partial is None:
partial = False
return super().load(data, many=many, partial=partial, **kwargs)
@post_load
def make_object(
self, data: Dict[Any, Any], discard: Optional[List[str]] = None
) -> Model:
"""
Creates a Model object from POST or PUT requests. PUT will use self.instance
previously fetched from the endpoint handler
:param data: Schema data payload
:param discard: List of fields to not set on the model
"""
discard = discard or []
if not self.instance:
self.instance = self.__class_model__() # pylint: disable=not-callable
for field in data:
if field not in discard:
setattr(self.instance, field, data.get(field))
return self.instance
class BaseOwnedSchema(BaseSupersetSchema):
"""
Implements owners validation,pre load and post_load
(to populate the owners field) on Marshmallow schemas
"""
owners_field_name = "owners"
@post_load
def make_object(
self, data: Dict[str, Any], discard: Optional[List[str]] = None
) -> Model:
discard = discard or []
discard.append(self.owners_field_name)
instance = super().make_object(data, discard)
if "owners" not in data and g.user not in instance.owners:
instance.owners.append(g.user)
if self.owners_field_name in data:
self.set_owners(instance, data[self.owners_field_name])
return instance
@pre_load
def pre_load(self, data: Dict[Any, Any]) -> None:
# if PUT request don't set owners to empty list
if not self.instance:
data[self.owners_field_name] = data.get(self.owners_field_name, [])
@staticmethod
def set_owners(instance: Model, owners: List[int]) -> None:
owner_objs = list()
if g.user.id not in owners:
owners.append(g.user.id)
for owner_id in owners:
user = current_app.appbuilder.get_session.query(
current_app.appbuilder.sm.user_model
).get(owner_id)
owner_objs.append(user)
instance.owners = owner_objs