| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """ |
| This module contains utilities to auto-generate an |
| Entity-Relationship Diagram (ERD) from SQLAlchemy |
| and onto a plantuml file. |
| """ |
| |
| import json |
| import os |
| from collections import defaultdict |
| from collections.abc import Iterable |
| from typing import Any, Optional |
| |
| import click |
| import jinja2 |
| |
| from superset import db |
| |
| GROUPINGS: dict[str, Iterable[str]] = { |
| "Core": [ |
| "css_templates", |
| "dynamic_plugin", |
| "favstar", |
| "dashboards", |
| "slices", |
| "user_attribute", |
| "embedded_dashboards", |
| "annotation", |
| "annotation_layer", |
| "tag", |
| "tagged_object", |
| ], |
| "System": ["ssh_tunnels", "keyvalue", "cache_keys", "key_value", "logs"], |
| "Alerts & Reports": ["report_recipient", "report_execution_log", "report_schedule"], |
| "Inherited from Flask App Builder (FAB)": [ |
| "ab_user", |
| "ab_permission", |
| "ab_permission_view", |
| "ab_view_menu", |
| "ab_role", |
| "ab_register_user", |
| ], |
| "SQL Lab": ["query", "saved_query", "tab_state", "table_schema"], |
| "Data Assets": [ |
| "dbs", |
| "table_columns", |
| "sql_metrics", |
| "tables", |
| "row_level_security_filters", |
| "sl_tables", |
| "sl_datasets", |
| "sl_columns", |
| "database_user_oauth2_tokens", |
| ], |
| } |
| # Table name to group name mapping (reversing the above one for easy lookup) |
| TABLE_TO_GROUP_MAP: dict[str, str] = {} |
| for group, tables in GROUPINGS.items(): |
| for table in tables: |
| TABLE_TO_GROUP_MAP[table] = group |
| |
| |
| def sort_data_structure(data): # type: ignore |
| sorted_json = json.dumps(data, sort_keys=True) |
| sorted_data = json.loads(sorted_json) |
| return sorted_data |
| |
| |
| def introspect_sqla_model(mapper: Any, seen: set[str]) -> dict[str, Any]: |
| """ |
| Introspects a SQLAlchemy model and returns a data structure that |
| can be pass to a jinja2 template for instance |
| |
| Parameters: |
| ----------- |
| mapper: SQLAlchemy model mapper |
| seen: set of model identifiers to avoid duplicates |
| |
| Returns: |
| -------- |
| Dict[str, Any]: data structure for jinja2 template |
| """ |
| table_name = mapper.persist_selectable.name |
| model_info: dict[str, Any] = { |
| "class_name": mapper.class_.__name__, |
| "table_name": table_name, |
| "fields": [], |
| "relationships": [], |
| } |
| # Collect fields (columns) and their types |
| for column in mapper.columns: |
| field_info: dict[str, str] = { |
| "field_name": column.key, |
| "type": str(column.type), |
| } |
| model_info["fields"].append(field_info) |
| |
| # Collect relationships and identify types |
| for attr, relationship in mapper.relationships.items(): |
| related_table = relationship.mapper.persist_selectable.name |
| # Create a unique identifier for the relationship to avoid duplicates |
| relationship_id = "-".join(sorted([table_name, related_table])) |
| |
| if relationship_id not in seen: |
| seen.add(relationship_id) |
| squiggle = "||--|{" |
| if relationship.direction.name == "MANYTOONE": |
| squiggle = "}|--||" |
| |
| relationship_info: dict[str, str] = { |
| "relationship_name": attr, |
| "related_model": relationship.mapper.class_.__name__, |
| "type": relationship.direction.name, |
| "related_table": related_table, |
| } |
| # Identify many-to-many by checking for secondary table |
| if relationship.secondary is not None: |
| squiggle = "}|--|{" |
| relationship_info["type"] = "many-to-many" |
| relationship_info["secondary_table"] = relationship.secondary.name |
| |
| relationship_info["squiggle"] = squiggle |
| model_info["relationships"].append(relationship_info) |
| return sort_data_structure(model_info) # type: ignore |
| |
| |
| def introspect_models() -> dict[str, list[dict[str, Any]]]: |
| """ |
| Introspects SQLAlchemy models and returns a data structure that |
| can be pass to a jinja2 template for rendering an ERD. |
| |
| Returns: |
| -------- |
| Dict[str, List[Dict[str, Any]]]: data structure for jinja2 template |
| """ |
| data: dict[str, list[dict[str, Any]]] = defaultdict(list) |
| seen_models: set[str] = set() |
| for model in db.Model.registry.mappers: |
| group_name = ( |
| TABLE_TO_GROUP_MAP.get(model.mapper.persist_selectable.name) |
| or "Uncategorized Models" |
| ) |
| model_data = introspect_sqla_model(model, seen_models) |
| data[group_name].append(model_data) |
| return data |
| |
| |
| def generate_erd(file_path: str) -> None: |
| """ |
| Generates a PlantUML ERD of the models/database |
| |
| Parameters: |
| ----------- |
| file_path: str |
| File path to write the ERD to |
| """ |
| data = introspect_models() |
| templates_path = os.path.dirname(__file__) |
| env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_path)) |
| |
| # Load the template |
| template = env.get_template("erd.template.puml") |
| rendered = template.render(data=data) |
| with open(file_path, "w") as f: |
| click.secho(f"Writing to {file_path}...", fg="green") |
| f.write(rendered) |
| |
| |
| @click.command() |
| @click.option( |
| "--output", |
| "-o", |
| type=click.Path(dir_okay=False, writable=True), |
| help="File to write the ERD to", |
| ) |
| def erd(output: Optional[str] = None) -> None: |
| """ |
| Generates a PlantUML ERD of the models/database |
| |
| Parameters: |
| ----------- |
| output: str, optional |
| File to write the ERD to, defaults to erd.plantuml if not provided |
| """ |
| path = os.path.dirname(__file__) |
| output = output or os.path.join(path, "erd.puml") |
| |
| from superset.app import create_app |
| |
| app = create_app() |
| with app.app_context(): |
| generate_erd(output) |
| |
| |
| if __name__ == "__main__": |
| erd() |