| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import enum |
| from typing import TYPE_CHECKING |
| |
| from flask_appbuilder import Model |
| from markupsafe import escape |
| from sqlalchemy import ( |
| Column, |
| Enum, |
| exists, |
| ForeignKey, |
| Integer, |
| orm, |
| String, |
| Table, |
| Text, |
| ) |
| from sqlalchemy.engine.base import Connection |
| from sqlalchemy.orm import relationship, sessionmaker |
| from sqlalchemy.orm.mapper import Mapper |
| from sqlalchemy.schema import UniqueConstraint |
| |
| from superset import security_manager |
| from superset.models.helpers import AuditMixinNullable |
| |
| if TYPE_CHECKING: |
| from superset.connectors.sqla.models import SqlaTable |
| from superset.models.core import FavStar |
| from superset.models.dashboard import Dashboard |
| from superset.models.slice import Slice |
| from superset.models.sql_lab import Query |
| |
| Session = sessionmaker() |
| |
| user_favorite_tag_table = Table( |
| "user_favorite_tag", |
| Model.metadata, # pylint: disable=no-member |
| Column("user_id", Integer, ForeignKey("ab_user.id")), |
| Column("tag_id", Integer, ForeignKey("tag.id")), |
| ) |
| |
| |
| class TagType(enum.Enum): |
| """ |
| Types for tags. |
| |
| Objects (queries, charts, dashboards, and datasets) will have with implicit tags based |
| on metadata: types, owners and who favorited them. This way, user "alice" |
| can find all their objects by querying for the tag `owner:alice`. |
| """ |
| |
| # pylint: disable=invalid-name |
| # explicit tags, added manually by the owner |
| custom = 1 |
| |
| # implicit tags, generated automatically |
| type = 2 |
| owner = 3 |
| favorited_by = 4 |
| |
| |
| class ObjectType(enum.Enum): |
| """Object types.""" |
| |
| # pylint: disable=invalid-name |
| query = 1 |
| chart = 2 |
| dashboard = 3 |
| dataset = 4 |
| |
| |
| class Tag(Model, AuditMixinNullable): |
| """A tag attached to an object (query, chart, dashboard, or dataset).""" |
| |
| __tablename__ = "tag" |
| id = Column(Integer, primary_key=True) |
| name = Column(String(250), unique=True) |
| type = Column(Enum(TagType)) |
| description = Column(Text) |
| |
| objects = relationship( |
| "TaggedObject", back_populates="tag", overlaps="objects,tags" |
| ) |
| |
| users_favorited = relationship( |
| security_manager.user_model, secondary=user_favorite_tag_table |
| ) |
| |
| |
| class TaggedObject(Model, AuditMixinNullable): |
| """An association between an object and a tag.""" |
| |
| __tablename__ = "tagged_object" |
| id = Column(Integer, primary_key=True) |
| tag_id = Column(Integer, ForeignKey("tag.id")) |
| object_id = Column( |
| Integer, |
| ForeignKey("dashboards.id"), |
| ForeignKey("slices.id"), |
| ForeignKey("saved_query.id"), |
| ) |
| object_type = Column(Enum(ObjectType)) |
| |
| tag = relationship("Tag", back_populates="objects", overlaps="tags") |
| __table_args__ = ( |
| UniqueConstraint( |
| "tag_id", "object_id", "object_type", name="uix_tagged_object" |
| ), |
| ) |
| |
| def __str__(self) -> str: |
| return f"<TaggedObject: {self.object_type}:{self.object_id} TAG:{self.tag_id}>" |
| |
| |
| def get_tag( |
| name: str, |
| session: orm.Session, # pylint: disable=disallowed-name |
| type_: TagType, |
| ) -> Tag: |
| tag_name = name.strip() |
| tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none() |
| if tag is None: |
| tag = Tag(name=escape(tag_name), type=type_) |
| session.add(tag) |
| session.commit() |
| return tag |
| |
| |
| def get_object_type(class_name: str) -> ObjectType: |
| mapping = { |
| "slice": ObjectType.chart, |
| "dashboard": ObjectType.dashboard, |
| "query": ObjectType.query, |
| "dataset": ObjectType.dataset, |
| } |
| try: |
| return mapping[class_name.lower()] |
| except KeyError as ex: |
| raise Exception( # pylint: disable=broad-exception-raised |
| f"No mapping found for {class_name}" |
| ) from ex |
| |
| |
| class ObjectUpdater: |
| object_type: str = "default" |
| |
| @classmethod |
| def get_owners_ids( |
| cls, target: Dashboard | FavStar | Slice | Query | SqlaTable |
| ) -> list[int]: |
| raise NotImplementedError("Subclass should implement `get_owners_ids`") |
| |
| @classmethod |
| def get_owner_tag_ids( |
| cls, |
| session: orm.Session, # pylint: disable=disallowed-name |
| target: Dashboard | FavStar | Slice | Query | SqlaTable, |
| ) -> set[int]: |
| tag_ids = set() |
| for owner_id in cls.get_owners_ids(target): |
| name = f"owner:{owner_id}" |
| tag = get_tag(name, session, TagType.owner) |
| tag_ids.add(tag.id) |
| return tag_ids |
| |
| @classmethod |
| def _add_owners( |
| cls, |
| session: orm.Session, # pylint: disable=disallowed-name |
| target: Dashboard | FavStar | Slice | Query | SqlaTable, |
| ) -> None: |
| for owner_id in cls.get_owners_ids(target): |
| name: str = f"owner:{owner_id}" |
| tag = get_tag(name, session, TagType.owner) |
| cls.add_tag_object_if_not_tagged( |
| session, tag_id=tag.id, object_id=target.id, object_type=cls.object_type |
| ) |
| |
| @classmethod |
| def add_tag_object_if_not_tagged( |
| cls, |
| session: orm.Session, # pylint: disable=disallowed-name |
| tag_id: int, |
| object_id: int, |
| object_type: str, |
| ) -> None: |
| # Check if the object is already tagged |
| exists_query = exists().where( |
| TaggedObject.tag_id == tag_id, |
| TaggedObject.object_id == object_id, |
| TaggedObject.object_type == object_type, |
| ) |
| already_tagged = session.query(exists_query).scalar() |
| |
| # Add TaggedObject to the session if it isn't already tagged |
| if not already_tagged: |
| tagged_object = TaggedObject( |
| tag_id=tag_id, object_id=object_id, object_type=object_type |
| ) |
| session.add(tagged_object) |
| |
| @classmethod |
| def after_insert( |
| cls, |
| _mapper: Mapper, |
| connection: Connection, |
| target: Dashboard | FavStar | Slice | Query | SqlaTable, |
| ) -> None: |
| with Session(bind=connection) as session: # pylint: disable=disallowed-name |
| # add `owner:` tags |
| cls._add_owners(session, target) |
| |
| # add `type:` tags |
| tag = get_tag(f"type:{cls.object_type}", session, TagType.type) |
| cls.add_tag_object_if_not_tagged( |
| session, tag_id=tag.id, object_id=target.id, object_type=cls.object_type |
| ) |
| session.commit() |
| |
| @classmethod |
| def after_update( |
| cls, |
| _mapper: Mapper, |
| connection: Connection, |
| target: Dashboard | FavStar | Slice | Query | SqlaTable, |
| ) -> None: |
| with Session(bind=connection) as session: # pylint: disable=disallowed-name |
| # Fetch current owner tags |
| existing_tags = ( |
| session.query(TaggedObject) |
| .join(Tag) |
| .filter( |
| TaggedObject.object_type == cls.object_type, |
| TaggedObject.object_id == target.id, |
| Tag.type == TagType.owner, |
| ) |
| .all() |
| ) |
| existing_owner_tag_ids = {tag.tag_id for tag in existing_tags} |
| |
| # Determine new owner IDs |
| new_owner_tag_ids = cls.get_owner_tag_ids(session, target) |
| |
| # Add missing tags |
| for owner_tag_id in new_owner_tag_ids - existing_owner_tag_ids: |
| tagged_object = TaggedObject( |
| tag_id=owner_tag_id, |
| object_id=target.id, |
| object_type=cls.object_type, |
| ) |
| session.add(tagged_object) |
| |
| # Remove unnecessary tags |
| for tag in existing_tags: |
| if tag.tag_id not in new_owner_tag_ids: |
| session.delete(tag) |
| session.commit() |
| |
| @classmethod |
| def after_delete( |
| cls, |
| _mapper: Mapper, |
| connection: Connection, |
| target: Dashboard | FavStar | Slice | Query | SqlaTable, |
| ) -> None: |
| with Session(bind=connection) as session: # pylint: disable=disallowed-name |
| # delete row from `tagged_objects` |
| session.query(TaggedObject).filter( |
| TaggedObject.object_type == cls.object_type, |
| TaggedObject.object_id == target.id, |
| ).delete() |
| |
| session.commit() |
| |
| |
| class ChartUpdater(ObjectUpdater): |
| object_type = "chart" |
| |
| @classmethod |
| def get_owners_ids(cls, target: Slice) -> list[int]: |
| return [owner.id for owner in target.owners] |
| |
| |
| class DashboardUpdater(ObjectUpdater): |
| object_type = "dashboard" |
| |
| @classmethod |
| def get_owners_ids(cls, target: Dashboard) -> list[int]: |
| return [owner.id for owner in target.owners] |
| |
| |
| class QueryUpdater(ObjectUpdater): |
| object_type = "query" |
| |
| @classmethod |
| def get_owners_ids(cls, target: Query) -> list[int]: |
| return [target.user_id] |
| |
| |
| class DatasetUpdater(ObjectUpdater): |
| object_type = "dataset" |
| |
| @classmethod |
| def get_owners_ids(cls, target: SqlaTable) -> list[int]: |
| return [owner.id for owner in target.owners] |
| |
| |
| class FavStarUpdater: |
| @classmethod |
| def after_insert( |
| cls, _mapper: Mapper, connection: Connection, target: FavStar |
| ) -> None: |
| with Session(bind=connection) as session: # pylint: disable=disallowed-name |
| name = f"favorited_by:{target.user_id}" |
| tag = get_tag(name, session, TagType.favorited_by) |
| tagged_object = TaggedObject( |
| tag_id=tag.id, |
| object_id=target.obj_id, |
| object_type=get_object_type(target.class_name), |
| ) |
| session.add(tagged_object) |
| session.commit() |
| |
| @classmethod |
| def after_delete( |
| cls, _mapper: Mapper, connection: Connection, target: FavStar |
| ) -> None: |
| with Session(bind=connection) as session: # pylint: disable=disallowed-name |
| name = f"favorited_by:{target.user_id}" |
| query = ( |
| session.query(TaggedObject.id) |
| .join(Tag) |
| .filter( |
| TaggedObject.object_id == target.obj_id, |
| Tag.type == TagType.favorited_by, |
| Tag.name == name, |
| ) |
| ) |
| ids = [row[0] for row in query] |
| session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( |
| synchronize_session=False |
| ) |
| |
| session.commit() |