blob: ba859f519bf47bb4eae2be9ca1a6ed2f1aebcfc7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import enum
from typing import TYPE_CHECKING
from flask_appbuilder import Model
from markupsafe import escape
from sqlalchemy import (
Column,
Enum,
exists,
ForeignKey,
Integer,
orm,
String,
Table,
Text,
)
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from superset import security_manager
from superset.models.helpers import AuditMixinNullable
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import FavStar
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query
Session = sessionmaker()
user_favorite_tag_table = Table(
"user_favorite_tag",
Model.metadata, # pylint: disable=no-member
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("tag_id", Integer, ForeignKey("tag.id")),
)
class TagType(enum.Enum):
"""
Types for tags.
Objects (queries, charts, dashboards, and datasets) will have with implicit tags based
on metadata: types, owners and who favorited them. This way, user "alice"
can find all their objects by querying for the tag `owner:alice`.
"""
# pylint: disable=invalid-name
# explicit tags, added manually by the owner
custom = 1
# implicit tags, generated automatically
type = 2
owner = 3
favorited_by = 4
class ObjectType(enum.Enum):
"""Object types."""
# pylint: disable=invalid-name
query = 1
chart = 2
dashboard = 3
dataset = 4
class Tag(Model, AuditMixinNullable):
"""A tag attached to an object (query, chart, dashboard, or dataset)."""
__tablename__ = "tag"
id = Column(Integer, primary_key=True)
name = Column(String(250), unique=True)
type = Column(Enum(TagType))
description = Column(Text)
objects = relationship(
"TaggedObject", back_populates="tag", overlaps="objects,tags"
)
users_favorited = relationship(
security_manager.user_model, secondary=user_favorite_tag_table
)
class TaggedObject(Model, AuditMixinNullable):
"""An association between an object and a tag."""
__tablename__ = "tagged_object"
id = Column(Integer, primary_key=True)
tag_id = Column(Integer, ForeignKey("tag.id"))
object_id = Column(
Integer,
ForeignKey("dashboards.id"),
ForeignKey("slices.id"),
ForeignKey("saved_query.id"),
)
object_type = Column(Enum(ObjectType))
tag = relationship("Tag", back_populates="objects", overlaps="tags")
__table_args__ = (
UniqueConstraint(
"tag_id", "object_id", "object_type", name="uix_tagged_object"
),
)
def __str__(self) -> str:
return f"<TaggedObject: {self.object_type}:{self.object_id} TAG:{self.tag_id}>"
def get_tag(
name: str,
session: orm.Session, # pylint: disable=disallowed-name
type_: TagType,
) -> Tag:
tag_name = name.strip()
tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none()
if tag is None:
tag = Tag(name=escape(tag_name), type=type_)
session.add(tag)
session.commit()
return tag
def get_object_type(class_name: str) -> ObjectType:
mapping = {
"slice": ObjectType.chart,
"dashboard": ObjectType.dashboard,
"query": ObjectType.query,
"dataset": ObjectType.dataset,
}
try:
return mapping[class_name.lower()]
except KeyError as ex:
raise Exception( # pylint: disable=broad-exception-raised
f"No mapping found for {class_name}"
) from ex
class ObjectUpdater:
object_type: str = "default"
@classmethod
def get_owners_ids(
cls, target: Dashboard | FavStar | Slice | Query | SqlaTable
) -> list[int]:
raise NotImplementedError("Subclass should implement `get_owners_ids`")
@classmethod
def get_owner_tag_ids(
cls,
session: orm.Session, # pylint: disable=disallowed-name
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> set[int]:
tag_ids = set()
for owner_id in cls.get_owners_ids(target):
name = f"owner:{owner_id}"
tag = get_tag(name, session, TagType.owner)
tag_ids.add(tag.id)
return tag_ids
@classmethod
def _add_owners(
cls,
session: orm.Session, # pylint: disable=disallowed-name
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
for owner_id in cls.get_owners_ids(target):
name: str = f"owner:{owner_id}"
tag = get_tag(name, session, TagType.owner)
cls.add_tag_object_if_not_tagged(
session, tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
@classmethod
def add_tag_object_if_not_tagged(
cls,
session: orm.Session, # pylint: disable=disallowed-name
tag_id: int,
object_id: int,
object_type: str,
) -> None:
# Check if the object is already tagged
exists_query = exists().where(
TaggedObject.tag_id == tag_id,
TaggedObject.object_id == object_id,
TaggedObject.object_type == object_type,
)
already_tagged = session.query(exists_query).scalar()
# Add TaggedObject to the session if it isn't already tagged
if not already_tagged:
tagged_object = TaggedObject(
tag_id=tag_id, object_id=object_id, object_type=object_type
)
session.add(tagged_object)
@classmethod
def after_insert(
cls,
_mapper: Mapper,
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
with Session(bind=connection) as session: # pylint: disable=disallowed-name
# add `owner:` tags
cls._add_owners(session, target)
# add `type:` tags
tag = get_tag(f"type:{cls.object_type}", session, TagType.type)
cls.add_tag_object_if_not_tagged(
session, tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.commit()
@classmethod
def after_update(
cls,
_mapper: Mapper,
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
with Session(bind=connection) as session: # pylint: disable=disallowed-name
# Fetch current owner tags
existing_tags = (
session.query(TaggedObject)
.join(Tag)
.filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
Tag.type == TagType.owner,
)
.all()
)
existing_owner_tag_ids = {tag.tag_id for tag in existing_tags}
# Determine new owner IDs
new_owner_tag_ids = cls.get_owner_tag_ids(session, target)
# Add missing tags
for owner_tag_id in new_owner_tag_ids - existing_owner_tag_ids:
tagged_object = TaggedObject(
tag_id=owner_tag_id,
object_id=target.id,
object_type=cls.object_type,
)
session.add(tagged_object)
# Remove unnecessary tags
for tag in existing_tags:
if tag.tag_id not in new_owner_tag_ids:
session.delete(tag)
session.commit()
@classmethod
def after_delete(
cls,
_mapper: Mapper,
connection: Connection,
target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
with Session(bind=connection) as session: # pylint: disable=disallowed-name
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()
session.commit()
class ChartUpdater(ObjectUpdater):
object_type = "chart"
@classmethod
def get_owners_ids(cls, target: Slice) -> list[int]:
return [owner.id for owner in target.owners]
class DashboardUpdater(ObjectUpdater):
object_type = "dashboard"
@classmethod
def get_owners_ids(cls, target: Dashboard) -> list[int]:
return [owner.id for owner in target.owners]
class QueryUpdater(ObjectUpdater):
object_type = "query"
@classmethod
def get_owners_ids(cls, target: Query) -> list[int]:
return [target.user_id]
class DatasetUpdater(ObjectUpdater):
object_type = "dataset"
@classmethod
def get_owners_ids(cls, target: SqlaTable) -> list[int]:
return [owner.id for owner in target.owners]
class FavStarUpdater:
@classmethod
def after_insert(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
with Session(bind=connection) as session: # pylint: disable=disallowed-name
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagType.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
object_id=target.obj_id,
object_type=get_object_type(target.class_name),
)
session.add(tagged_object)
session.commit()
@classmethod
def after_delete(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
with Session(bind=connection) as session: # pylint: disable=disallowed-name
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_id == target.obj_id,
Tag.type == TagType.favorited_by,
Tag.name == name,
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
session.commit()