fix: benchmark migration script (#15032)
diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py
index 0faa92a..d226efb 100644
--- a/scripts/benchmark_migration.py
+++ b/scripts/benchmark_migration.py
@@ -25,10 +25,12 @@
from typing import Dict, List, Set, Type
import click
+from flask import current_app
from flask_appbuilder import Model
from flask_migrate import downgrade, upgrade
from graphlib import TopologicalSorter # pylint: disable=wrong-import-order
-from sqlalchemy import inspect
+from sqlalchemy import create_engine, inspect, Table
+from sqlalchemy.ext.automap import automap_base
from superset import db
from superset.utils.mock_data import add_sample_rows
@@ -83,11 +85,18 @@
elif isinstance(obj, dict):
queue.extend(obj.values())
- # add implicit models
- # pylint: disable=no-member, protected-access
- for obj in Model._decl_class_registry.values():
- if hasattr(obj, "__table__") and obj.__table__.fullname in tables:
- models.append(obj)
+ # build models by automapping the existing tables, instead of using current
+ # code; this is needed for migrations that modify schemas (eg, add a column),
+ # where the current model is out-of-sync with the existing table after a
+ # downgrade
+ sqlalchemy_uri = current_app.config["SQLALCHEMY_DATABASE_URI"]
+ engine = create_engine(sqlalchemy_uri)
+ Base = automap_base()
+ Base.prepare(engine, reflect=True)
+ for table in tables:
+ model = getattr(Base.classes, table)
+ model.__tablename__ = table
+ models.append(model)
# sort topologically so we can create entities in order and
# maintain relationships (eg, create a database before creating
@@ -133,15 +142,6 @@
).scalar()
print(f"Current version of the DB is {current_revision}")
- print("\nIdentifying models used in the migration:")
- models = find_models(module)
- model_rows: Dict[Type[Model], int] = {}
- for model in models:
- rows = session.query(model).count()
- print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
- model_rows[model] = rows
- session.close()
-
if current_revision != down_revision:
if not force:
click.confirm(
@@ -152,6 +152,15 @@
)
downgrade(revision=down_revision)
+ print("\nIdentifying models used in the migration:")
+ models = find_models(module)
+ model_rows: Dict[Type[Model], int] = {}
+ for model in models:
+ rows = session.query(model).count()
+ print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
+ model_rows[model] = rows
+ session.close()
+
print("Benchmarking migration")
results: Dict[str, float] = {}
start = time.time()
diff --git a/superset/migrations/versions/27ae655e4247_make_creator_owners.py b/superset/migrations/versions/27ae655e4247_make_creator_owners.py
index 561a8ca..c373c0f 100644
--- a/superset/migrations/versions/27ae655e4247_make_creator_owners.py
+++ b/superset/migrations/versions/27ae655e4247_make_creator_owners.py
@@ -27,10 +27,10 @@
down_revision = "d8bc074f7aad"
from alembic import op
+from flask import g
from flask_appbuilder import Model
-from flask_appbuilder.models.mixins import AuditMixin
from sqlalchemy import Column, ForeignKey, Integer, Table
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import relationship
from superset import db
@@ -62,6 +62,29 @@
)
+class AuditMixin:
+ @classmethod
+ def get_user_id(cls):
+ try:
+ return g.user.id
+ except Exception:
+ return None
+
+ @declared_attr
+ def created_by_fk(cls):
+ return Column(
+ Integer, ForeignKey("ab_user.id"), default=cls.get_user_id, nullable=False
+ )
+
+ @declared_attr
+ def created_by(cls):
+ return relationship(
+ "User",
+ primaryjoin="%s.created_by_fk == User.id" % cls.__name__,
+ enable_typechecks=False,
+ )
+
+
class Slice(Base, AuditMixin):
"""Declarative class to do query in upgrade"""
diff --git a/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py b/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py
index 42fdb7e..3bab3f6 100644
--- a/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py
+++ b/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py
@@ -26,16 +26,46 @@
revision = "c82ee8a39623"
down_revision = "c617da68de7d"
-from alembic import op
-from sqlalchemy import Column, Enum, ForeignKey, Integer, String
-from sqlalchemy.ext.declarative import declarative_base
+from datetime import datetime
-from superset.models.helpers import AuditMixinNullable
+from alembic import op
+from flask_appbuilder.models.mixins import AuditMixin
+from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String
+from sqlalchemy.ext.declarative import declarative_base, declared_attr
+
from superset.models.tags import ObjectTypes, TagTypes
Base = declarative_base()
+class AuditMixinNullable(AuditMixin):
+ """Altering the AuditMixin to use nullable fields
+
+ Allows creating objects programmatically outside of CRUD
+ """
+
+ created_on = Column(DateTime, default=datetime.now, nullable=True)
+ changed_on = Column(
+ DateTime, default=datetime.now, onupdate=datetime.now, nullable=True
+ )
+
+ @declared_attr
+ def created_by_fk(self) -> Column:
+ return Column(
+ Integer, ForeignKey("ab_user.id"), default=self.get_user_id, nullable=True,
+ )
+
+ @declared_attr
+ def changed_by_fk(self) -> Column:
+ return Column(
+ Integer,
+ ForeignKey("ab_user.id"),
+ default=self.get_user_id,
+ onupdate=self.get_user_id,
+ nullable=True,
+ )
+
+
class Tag(Base, AuditMixinNullable):
"""A tag attached to an object (query, chart or dashboard)."""
diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py
index 06327ef..84981ca 100644
--- a/superset/utils/mock_data.py
+++ b/superset/utils/mock_data.py
@@ -29,6 +29,7 @@
import sqlalchemy_utils
from flask_appbuilder import Model
from sqlalchemy import Column, inspect, MetaData, Table
+from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from sqlalchemy.sql.visitors import VisitableType
@@ -146,6 +147,9 @@
if isinstance(sqltype, sqlalchemy_utils.types.uuid.UUIDType):
return uuid4
+ if isinstance(sqltype, postgresql.base.UUID):
+ return lambda: str(uuid4())
+
if isinstance(sqltype, sqlalchemy.sql.sqltypes.BLOB):
length = random.randrange(sqltype.length or 255)
return lambda: os.urandom(length)
@@ -153,7 +157,7 @@
logger.warning(
"Unknown type %s. Please add it to `get_type_generator`.", type(sqltype)
)
- return lambda: "UNKNOWN TYPE"
+ return lambda: b"UNKNOWN TYPE"
def add_data(