| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from __future__ import annotations |
| |
| import logging |
| import re |
| import urllib |
| from datetime import datetime |
| from re import Pattern |
| from typing import Any, TYPE_CHECKING, TypedDict |
| |
| import pandas as pd |
| from apispec import APISpec |
| from apispec.ext.marshmallow import MarshmallowPlugin |
| from flask_babel import gettext as __ |
| from marshmallow import fields, Schema |
| from marshmallow.exceptions import ValidationError |
| from sqlalchemy import column, func, types |
| from sqlalchemy.engine.base import Engine |
| from sqlalchemy.engine.interfaces import Dialect |
| from sqlalchemy.engine.reflection import Inspector |
| from sqlalchemy.engine.url import URL |
| from sqlalchemy.sql import column as sql_column, select, sqltypes |
| from sqlalchemy.sql.expression import table as sql_table |
| |
| from superset.constants import TimeGrain |
| from superset.databases.schemas import encrypted_field_properties, EncryptedString |
| from superset.databases.utils import make_url_safe |
| from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType |
| from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError |
| from superset.errors import SupersetError, SupersetErrorType |
| from superset.exceptions import SupersetException |
| from superset.sql.parse import SQLScript, Table |
| from superset.superset_typing import ResultSetColumnType |
| from superset.utils import core as utils, json |
| from superset.utils.hashing import md5_sha_from_str |
| |
| if TYPE_CHECKING: |
| from sqlalchemy.sql.expression import Select |
| |
| logger = logging.getLogger(__name__) |
| |
| try: |
| import google.auth |
| from google.cloud import bigquery |
| from google.oauth2 import service_account |
| |
| dependencies_installed = True |
| except ImportError: |
| dependencies_installed = False |
| |
| try: |
| import pandas_gbq |
| |
| can_upload = True |
| except ModuleNotFoundError: |
| can_upload = False |
| |
| if TYPE_CHECKING: |
| from superset.models.core import Database # pragma: no cover |
| |
| |
| logger = logging.getLogger() |
| |
| CONNECTION_DATABASE_PERMISSIONS_REGEX = re.compile( |
| "Access Denied: Project (?P<project_name>.+?): User does not have " |
| + "bigquery.jobs.create permission in project (?P<project>.+?)" |
| ) |
| |
| TABLE_DOES_NOT_EXIST_REGEX = re.compile( |
| 'Table name "(?P<table>.*?)" missing dataset while no default ' |
| "dataset is set in the request" |
| ) |
| |
| COLUMN_DOES_NOT_EXIST_REGEX = re.compile( |
| r"Unrecognized name: (?P<column>.*?) at \[(?P<location>.+?)\]" |
| ) |
| |
| SCHEMA_DOES_NOT_EXIST_REGEX = re.compile( |
| r"bigquery error: 404 Not found: Dataset (?P<dataset>.*?):" |
| r"(?P<schema>.*?) was not found in location" |
| ) |
| |
| SYNTAX_ERROR_REGEX = re.compile( |
| 'Syntax error: Expected end of input but got identifier "(?P<syntax_error>.+?)"' |
| ) |
| |
| ma_plugin = MarshmallowPlugin() |
| |
| |
| class BigQueryParametersSchema(Schema): |
| credentials_info = EncryptedString( |
| required=False, |
| metadata={"description": "Contents of BigQuery JSON credentials."}, |
| ) |
| query = fields.Dict(required=False) |
| |
| |
| class BigQueryParametersType(TypedDict): |
| credentials_info: dict[str, Any] |
| query: dict[str, Any] |
| |
| |
| class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods |
| """Engine spec for Google's BigQuery |
| |
| As contributed by @mxmzdlv on issue #945""" |
| |
| engine = "bigquery" |
| engine_name = "Google BigQuery" |
| max_column_name_length = 128 |
| disable_ssh_tunneling = True |
| |
| parameters_schema = BigQueryParametersSchema() |
| default_driver = "bigquery" |
| sqlalchemy_uri_placeholder = "bigquery://{project_id}" |
| |
| # BigQuery doesn't maintain context when running multiple statements in the |
| # same cursor, so we need to run all statements at once |
| run_multiple_statements_as_one = True |
| |
| allows_hidden_cc_in_orderby = True |
| |
| supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True |
| |
| # when editing the database, mask this field in `encrypted_extra` |
| # pylint: disable=invalid-name |
| encrypted_extra_sensitive_fields = {"$.credentials_info.private_key"} |
| |
| """ |
| https://www.python.org/dev/peps/pep-0249/#arraysize |
| raw_connections bypass the sqlalchemy-bigquery query execution context and deal with |
| raw dbapi connection directly. |
| If this value is not set, the default value is set to 1, as described here, |
| https://googlecloudplatform.github.io/google-cloud-python/latest/_modules/google/cloud/bigquery/dbapi/cursor.html#Cursor |
| |
| The default value of 5000 is derived from the sqlalchemy-bigquery. |
| https://github.com/googleapis/python-bigquery-sqlalchemy/blob/4e17259088f89eac155adc19e0985278a29ecf9c/sqlalchemy_bigquery/base.py#L762 |
| """ |
| arraysize = 5000 |
| |
| _date_trunc_functions = { |
| "DATE": "DATE_TRUNC", |
| "DATETIME": "DATETIME_TRUNC", |
| "TIME": "TIME_TRUNC", |
| "TIMESTAMP": "TIMESTAMP_TRUNC", |
| } |
| |
| _time_grain_expressions = { |
| None: "{col}", |
| TimeGrain.SECOND: "CAST(TIMESTAMP_SECONDS(" |
| "UNIX_SECONDS(CAST({col} AS TIMESTAMP))" |
| ") AS {type})", |
| TimeGrain.MINUTE: "CAST(TIMESTAMP_SECONDS(" |
| "60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 60)" |
| ") AS {type})", |
| TimeGrain.FIVE_MINUTES: "CAST(TIMESTAMP_SECONDS(" |
| "5*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 5*60)" |
| ") AS {type})", |
| TimeGrain.TEN_MINUTES: "CAST(TIMESTAMP_SECONDS(" |
| "10*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 10*60)" |
| ") AS {type})", |
| TimeGrain.FIFTEEN_MINUTES: "CAST(TIMESTAMP_SECONDS(" |
| "15*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 15*60)" |
| ") AS {type})", |
| TimeGrain.THIRTY_MINUTES: "CAST(TIMESTAMP_SECONDS(" |
| "30*60 * DIV(UNIX_SECONDS(CAST({col} AS TIMESTAMP)), 30*60)" |
| ") AS {type})", |
| TimeGrain.HOUR: "{func}({col}, HOUR)", |
| TimeGrain.DAY: "{func}({col}, DAY)", |
| TimeGrain.WEEK: "{func}({col}, WEEK)", |
| TimeGrain.WEEK_STARTING_MONDAY: "{func}({col}, ISOWEEK)", |
| TimeGrain.MONTH: "{func}({col}, MONTH)", |
| TimeGrain.QUARTER: "{func}({col}, QUARTER)", |
| TimeGrain.YEAR: "{func}({col}, YEAR)", |
| } |
| |
| custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { |
| CONNECTION_DATABASE_PERMISSIONS_REGEX: ( |
| __( |
| "Unable to connect. Verify that the following roles are set " |
| 'on the service account: "BigQuery Data Viewer", ' |
| '"BigQuery Metadata Viewer", "BigQuery Job User" ' |
| "and the following permissions are set " |
| '"bigquery.readsessions.create", ' |
| '"bigquery.readsessions.getData"' |
| ), |
| SupersetErrorType.CONNECTION_DATABASE_PERMISSIONS_ERROR, |
| {}, |
| ), |
| TABLE_DOES_NOT_EXIST_REGEX: ( |
| __( |
| 'The table "%(table)s" does not exist. ' |
| "A valid table must be used to run this query.", |
| ), |
| SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, |
| {}, |
| ), |
| COLUMN_DOES_NOT_EXIST_REGEX: ( |
| __('We can\'t seem to resolve column "%(column)s" at line %(location)s.'), |
| SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR, |
| {}, |
| ), |
| SCHEMA_DOES_NOT_EXIST_REGEX: ( |
| __( |
| 'The schema "%(schema)s" does not exist. ' |
| "A valid schema must be used to run this query." |
| ), |
| SupersetErrorType.SCHEMA_DOES_NOT_EXIST_ERROR, |
| {}, |
| ), |
| SYNTAX_ERROR_REGEX: ( |
| __( |
| "Please check your query for syntax errors at or near " |
| '"%(syntax_error)s". Then, try running your query again.' |
| ), |
| SupersetErrorType.SYNTAX_ERROR, |
| {}, |
| ), |
| } |
| |
| @classmethod |
| def convert_dttm( |
| cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None |
| ) -> str | None: |
| sqla_type = cls.get_sqla_column_type(target_type) |
| if isinstance(sqla_type, types.Date): |
| return f"CAST('{dttm.date().isoformat()}' AS DATE)" |
| if isinstance(sqla_type, types.TIMESTAMP): |
| return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)""" |
| if isinstance(sqla_type, types.DateTime): |
| return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS DATETIME)""" |
| if isinstance(sqla_type, types.Time): |
| return f"""CAST('{dttm.strftime("%H:%M:%S.%f")}' AS TIME)""" |
| return None |
| |
| @classmethod |
| def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]: |
| data = super().fetch_data(cursor, limit) |
| # Support type BigQuery Row, introduced here PR #4071 |
| # google.cloud.bigquery.table.Row |
| if data and type(data[0]).__name__ == "Row": |
| data = [r.values() for r in data] # type: ignore |
| return data |
| |
| @staticmethod |
| def _mutate_label(label: str) -> str: |
| """ |
| BigQuery field_name should start with a letter or underscore and contain only |
| alphanumeric characters. Labels that start with a number are prefixed with an |
| underscore. Any unsupported characters are replaced with underscores and an |
| md5 hash is added to the end of the label to avoid possible collisions. |
| |
| :param label: Expected expression label |
| :return: Conditionally mutated label |
| """ |
| label_hashed = "_" + md5_sha_from_str(label) |
| |
| # if label starts with number, add underscore as first character |
| label_mutated = "_" + label if re.match(r"^\d", label) else label |
| |
| # replace non-alphanumeric characters with underscores |
| label_mutated = re.sub(r"[^\w]+", "_", label_mutated) |
| if label_mutated != label: |
| # add first 5 chars from md5 hash to label to avoid possible collisions |
| label_mutated += label_hashed[:6] |
| |
| return label_mutated |
| |
| @classmethod |
| def _truncate_label(cls, label: str) -> str: |
| """BigQuery requires column names start with either a letter or |
| underscore. To make sure this is always the case, an underscore is prefixed |
| to the md5 hash of the original label. |
| |
| :param label: expected expression label |
| :return: truncated label |
| """ |
| return "_" + md5_sha_from_str(label) |
| |
| @classmethod |
| def where_latest_partition( |
| cls, |
| database: Database, |
| table: Table, |
| query: Select, |
| columns: list[ResultSetColumnType] | None = None, |
| ) -> Select | None: |
| if partition_column := cls.get_time_partition_column(database, table): |
| max_partition_id = cls.get_max_partition_id(database, table) |
| query = query.where( |
| column(partition_column) == func.PARSE_DATE("%Y%m%d", max_partition_id) |
| ) |
| |
| return query |
| |
| @classmethod |
| def get_max_partition_id( |
| cls, |
| database: Database, |
| table: Table, |
| ) -> Select | None: |
| # Compose schema from catalog and schema |
| schema_parts = [] |
| if table.catalog: |
| schema_parts.append(table.catalog) |
| if table.schema: |
| schema_parts.append(table.schema) |
| schema_parts.append("INFORMATION_SCHEMA") |
| schema = ".".join(schema_parts) |
| # Define a virtual table reference to INFORMATION_SCHEMA.PARTITIONS |
| partitions_table = sql_table( |
| "PARTITIONS", |
| sql_column("partition_id"), |
| sql_column("table_name"), |
| schema=schema, |
| ) |
| |
| # Build the query |
| query = select( |
| func.max(partitions_table.c.partition_id).label("max_partition_id") |
| ).where(partitions_table.c.table_name == table.table) |
| |
| # Compile to BigQuery SQL |
| compiled_query = query.compile( |
| dialect=database.get_dialect(), |
| compile_kwargs={"literal_binds": True}, |
| ) |
| |
| # Run the query and handle result |
| with database.get_raw_connection( |
| catalog=table.catalog, |
| schema=table.schema, |
| ) as conn: |
| cursor = conn.cursor() |
| cursor.execute(str(compiled_query)) |
| if row := cursor.fetchone(): |
| return row[0] |
| return None |
| |
| @classmethod |
| def get_time_partition_column( |
| cls, |
| database: Database, |
| table: Table, |
| ) -> str | None: |
| with cls.get_engine( |
| database, catalog=table.catalog, schema=table.schema |
| ) as engine: |
| client = cls._get_client(engine, database) |
| bq_table = client.get_table(f"{table.schema}.{table.table}") |
| |
| if bq_table.time_partitioning: |
| return bq_table.time_partitioning.field |
| return None |
| |
| @classmethod |
| def get_extra_table_metadata( |
| cls, |
| database: Database, |
| table: Table, |
| ) -> dict[str, Any]: |
| payload = {} |
| partition_column = cls.get_time_partition_column(database, table) |
| with cls.get_engine( |
| database, catalog=table.catalog, schema=table.schema |
| ) as engine: |
| if partition_column: |
| max_partition_id = cls.get_max_partition_id(database, table) |
| sql = cls.select_star( |
| database, |
| table, |
| engine, |
| indent=False, |
| show_cols=False, |
| latest_partition=True, |
| ) |
| payload.update( |
| { |
| "partitions": { |
| "cols": [partition_column], |
| "latest": {partition_column: max_partition_id}, |
| "partitionQuery": sql, |
| }, |
| "indexes": [ |
| { |
| "name": "partitioned", |
| "cols": [partition_column], |
| "type": "partitioned", |
| } |
| ], |
| } |
| ) |
| return payload |
| |
| @classmethod |
| def epoch_to_dttm(cls) -> str: |
| return "TIMESTAMP_SECONDS({col})" |
| |
| @classmethod |
| def epoch_ms_to_dttm(cls) -> str: |
| return "TIMESTAMP_MILLIS({col})" |
| |
| @classmethod |
| def df_to_sql( |
| cls, |
| database: Database, |
| table: Table, |
| df: pd.DataFrame, |
| to_sql_kwargs: dict[str, Any], |
| ) -> None: |
| """ |
| Upload data from a Pandas DataFrame to a database. |
| |
| Calls `pandas_gbq.DataFrame.to_gbq` which requires `pandas_gbq` to be installed. |
| |
| Note this method does not create metadata for the table. |
| |
| :param database: The database to upload the data to |
| :param table: The table to upload the data to |
| :param df: The dataframe with data to be uploaded |
| :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method |
| """ |
| if not can_upload: |
| raise SupersetException( |
| "Could not import libraries needed to upload data to BigQuery." |
| ) |
| |
| if not table.schema: |
| raise SupersetException("The table schema must be defined") |
| |
| to_gbq_kwargs = {} |
| with cls.get_engine( |
| database, |
| catalog=table.catalog, |
| schema=table.schema, |
| ) as engine: |
| to_gbq_kwargs = { |
| "destination_table": str(table), |
| "project_id": engine.url.host, |
| } |
| |
| # Add credentials if they are set on the SQLAlchemy dialect. |
| |
| if creds := engine.dialect.credentials_info: |
| to_gbq_kwargs["credentials"] = ( |
| service_account.Credentials.from_service_account_info(creds) |
| ) |
| |
| # Only pass through supported kwargs. |
| supported_kwarg_keys = {"if_exists"} |
| |
| for key in supported_kwarg_keys: |
| if key in to_sql_kwargs: |
| to_gbq_kwargs[key] = to_sql_kwargs[key] |
| |
| pandas_gbq.to_gbq(df, **to_gbq_kwargs) |
| |
| @classmethod |
| def _get_client( |
| cls, |
| engine: Engine, |
| database: Database, # pylint: disable=unused-argument |
| ) -> bigquery.Client: |
| """ |
| Return the BigQuery client associated with an engine. |
| """ |
| if not dependencies_installed: |
| raise SupersetException( |
| "Could not import libraries needed to connect to BigQuery." |
| ) |
| |
| if credentials_info := engine.dialect.credentials_info: |
| credentials = service_account.Credentials.from_service_account_info( |
| credentials_info |
| ) |
| return bigquery.Client(credentials=credentials) |
| |
| try: |
| credentials = google.auth.default()[0] |
| return bigquery.Client(credentials=credentials) |
| except google.auth.exceptions.DefaultCredentialsError as ex: |
| raise SupersetDBAPIConnectionError( |
| "The database credentials could not be found." |
| ) from ex |
| |
| @classmethod |
| def estimate_query_cost( # pylint: disable=too-many-arguments |
| cls, |
| database: Database, |
| catalog: str | None, |
| schema: str, |
| sql: str, |
| source: utils.QuerySource | None = None, |
| ) -> list[dict[str, Any]]: |
| """ |
| Estimate the cost of a multiple statement SQL query. |
| |
| :param database: Database instance |
| :param catalog: Database project |
| :param schema: Database schema |
| :param sql: SQL query with possibly multiple statements |
| :param source: Source of the query (eg, "sql_lab") |
| """ |
| extra = database.get_extra(source) or {} |
| if not cls.get_allow_cost_estimate(extra): |
| raise SupersetException("Database does not support cost estimation") |
| |
| parsed_script = SQLScript(sql, engine=cls.engine) |
| |
| with cls.get_engine( |
| database, |
| catalog=catalog, |
| schema=schema, |
| source=source, |
| ) as engine: |
| client = cls._get_client(engine, database) |
| return [ |
| cls.custom_estimate_statement_cost( |
| cls.process_statement(statement, database), |
| client, |
| ) |
| for statement in parsed_script.statements |
| ] |
| |
| @classmethod |
| def get_default_catalog(cls, database: Database) -> str: |
| """ |
| Get the default catalog. |
| """ |
| url = database.url_object |
| |
| # The SQLAlchemy driver accepts both `bigquery://project` (where the project is |
| # technically a host) and `bigquery:///project` (where it's a database). But |
| # both can be missing, and the project is inferred from the authentication |
| # credentials. |
| if project := url.host or url.database: |
| return project |
| |
| with database.get_sqla_engine() as engine: |
| client = cls._get_client(engine, database) |
| return client.project |
| |
| @classmethod |
| def get_catalog_names( |
| cls, |
| database: Database, |
| inspector: Inspector, |
| ) -> set[str]: |
| """ |
| Get all catalogs. |
| |
| In BigQuery, a catalog is called a "project". |
| """ |
| engine: Engine |
| with database.get_sqla_engine() as engine: |
| try: |
| client = cls._get_client(engine, database) |
| except SupersetDBAPIConnectionError: |
| logger.warning( |
| "Could not connect to database to get catalogs due to missing " |
| "credentials. This is normal in certain circustances, for example, " |
| "doing an import." |
| ) |
| # return {} here, since it will be repopulated when creds are added |
| return set() |
| |
| projects = client.list_projects() |
| |
| return {project.project_id for project in projects} |
| |
| @classmethod |
| def adjust_engine_params( |
| cls, |
| uri: URL, |
| connect_args: dict[str, Any], |
| catalog: str | None = None, |
| schema: str | None = None, |
| ) -> tuple[URL, dict[str, Any]]: |
| if catalog: |
| uri = uri.set(host=catalog, database="") |
| |
| return uri, connect_args |
| |
| @classmethod |
| def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: |
| return True |
| |
| @classmethod |
| def custom_estimate_statement_cost( |
| cls, |
| statement: str, |
| client: bigquery.Client, |
| ) -> dict[str, Any]: |
| """ |
| Custom version that receives a client instead of a cursor. |
| """ |
| job_config = bigquery.QueryJobConfig(dry_run=True) |
| query_job = client.query(statement, job_config=job_config) |
| |
| # Format Bytes. |
| # TODO: Humanize in case more db engine specs need to be added, |
| # this should be made a function outside this scope. |
| byte_division = 1024 |
| if hasattr(query_job, "total_bytes_processed"): |
| query_bytes_processed = query_job.total_bytes_processed |
| if query_bytes_processed // byte_division == 0: |
| byte_type = "B" |
| total_bytes_processed = query_bytes_processed |
| elif query_bytes_processed // (byte_division**2) == 0: |
| byte_type = "KB" |
| total_bytes_processed = round(query_bytes_processed / byte_division, 2) |
| elif query_bytes_processed // (byte_division**3) == 0: |
| byte_type = "MB" |
| total_bytes_processed = round( |
| query_bytes_processed / (byte_division**2), 2 |
| ) |
| else: |
| byte_type = "GB" |
| total_bytes_processed = round( |
| query_bytes_processed / (byte_division**3), 2 |
| ) |
| |
| return {f"{byte_type} Processed": total_bytes_processed} |
| return {} |
| |
| @classmethod |
| def query_cost_formatter( |
| cls, raw_cost: list[dict[str, Any]] |
| ) -> list[dict[str, str]]: |
| return [{k: str(v) for k, v in row.items()} for row in raw_cost] |
| |
| @classmethod |
| def build_sqlalchemy_uri( |
| cls, |
| parameters: BigQueryParametersType, |
| encrypted_extra: dict[str, Any] | None = None, |
| ) -> str: |
| query = parameters.get("query", {}) |
| query_params = urllib.parse.urlencode(query) |
| |
| if encrypted_extra: |
| credentials_info = encrypted_extra.get("credentials_info") |
| if isinstance(credentials_info, str): |
| credentials_info = json.loads(credentials_info) |
| project_id = credentials_info.get("project_id") |
| if not encrypted_extra: |
| raise ValidationError("Missing service credentials") |
| |
| if project_id: |
| return f"{cls.default_driver}://{project_id}/?{query_params}" |
| |
| raise ValidationError("Invalid service credentials") |
| |
| @classmethod |
| def get_parameters_from_uri( |
| cls, |
| uri: str, |
| encrypted_extra: dict[str, Any] | None = None, |
| ) -> Any: |
| value = make_url_safe(uri) |
| |
| # Building parameters from encrypted_extra and uri |
| if encrypted_extra: |
| # ``value.query`` needs to be explicitly converted into a dict (from an |
| # ``immutabledict``) so that it can be JSON serialized |
| return {**encrypted_extra, "query": dict(value.query)} |
| |
| raise ValidationError("Invalid service credentials") |
| |
| @classmethod |
| def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: |
| # pylint: disable=import-outside-toplevel |
| from google.auth.exceptions import DefaultCredentialsError |
| |
| return {DefaultCredentialsError: SupersetDBAPIConnectionError} |
| |
| @classmethod |
| def validate_parameters( |
| cls, |
| properties: BasicPropertiesType, # pylint: disable=unused-argument |
| ) -> list[SupersetError]: |
| return [] |
| |
| @classmethod |
| def parameters_json_schema(cls) -> Any: |
| """ |
| Return configuration parameters as OpenAPI. |
| """ |
| if not cls.parameters_schema: |
| return None |
| |
| spec = APISpec( |
| title="Database Parameters", |
| version="1.0.0", |
| openapi_version="3.0.0", |
| plugins=[ma_plugin], |
| ) |
| |
| ma_plugin.init_spec(spec) |
| ma_plugin.converter.add_attribute_function(encrypted_field_properties) |
| spec.components.schema(cls.__name__, schema=cls.parameters_schema) |
| return spec.to_dict()["components"]["schemas"][cls.__name__] |
| |
| @classmethod |
| def select_star( # pylint: disable=too-many-arguments |
| cls, |
| database: Database, |
| table: Table, |
| dialect: Dialect, |
| limit: int = 100, |
| show_cols: bool = False, |
| indent: bool = True, |
| latest_partition: bool = True, |
| cols: list[ResultSetColumnType] | None = None, |
| ) -> str: |
| """ |
| Remove array structures from `SELECT *`. |
| |
| BigQuery supports structures and arrays of structures, eg: |
| |
| author STRUCT<name STRING, email STRING> |
| trailer ARRAY<STRUCT<key STRING, value STRING>> |
| |
| When loading metadata for a table each key in the struct is displayed as a |
| separate pseudo-column, eg: |
| |
| - author |
| - author.name |
| - author.email |
| - trailer |
| - trailer.key |
| - trailer.value |
| |
| When generating the `SELECT *` statement we want to remove any keys from |
| structs inside an array, since selecting them results in an error. The correct |
| select statement should look like this: |
| |
| SELECT |
| `author`, |
| `author`.`name`, |
| `author`.`email`, |
| `trailer` |
| FROM |
| table |
| |
| Selecting `trailer.key` or `trailer.value` results in an error, as opposed to |
| selecting `author.name`, since they are keys in a structure inside an array. |
| |
| This method removes any array pseudo-columns. |
| """ |
| if cols: |
| # For arrays of structs, remove the child columns, otherwise the query |
| # will fail. |
| array_prefixes = { |
| col["column_name"] |
| for col in cols |
| if isinstance(col["type"], sqltypes.ARRAY) |
| } |
| cols = [ |
| col |
| for col in cols |
| if "." not in col["column_name"] |
| or col["column_name"].split(".")[0] not in array_prefixes |
| ] |
| |
| return super().select_star( |
| database, |
| table, |
| dialect, |
| limit, |
| show_cols, |
| indent, |
| latest_partition, |
| cols, |
| ) |
| |
| @classmethod |
| def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: |
| """ |
| Label columns using their fully qualified name. |
| |
| BigQuery supports columns of type `struct`, which are basically dictionaries. |
| When loading metadata for a table with struct columns, each key in the struct |
| is displayed as a separate pseudo-column, eg: |
| |
| author STRUCT<name STRING, email STRING> |
| |
| Will be shown as 3 columns: |
| |
| - author |
| - author.name |
| - author.email |
| |
| If we select those fields: |
| |
| SELECT `author`, `author`.`name`, `author`.`email` FROM table |
| |
| The resulting columns will be called "author", "name", and "email", This may |
| result in a clash with other columns. To prevent that, we explicitly label |
| the columns using their fully qualified name, so we end up with "author", |
| "author__name" and "author__email", respectively. |
| """ |
| return [ |
| column(c["column_name"]).label(c["column_name"].replace(".", "__")) |
| for c in cols |
| ] |
| |
| @classmethod |
| def parse_error_exception(cls, exception: Exception) -> Exception: |
| try: |
| return type(exception)(str(exception).splitlines()[0].strip()) |
| except Exception: # pylint: disable=broad-except |
| # If for some reason we get an exception, for example, no new line |
| # We will return the original exception |
| return exception |
| |
| @classmethod |
| def get_materialized_view_names( |
| cls, |
| database: Database, |
| inspector: Inspector, |
| schema: str | None, |
| ) -> set[str]: |
| """ |
| Get all materialized views from BigQuery. |
| |
| BigQuery materialized views are not returned by the standard |
| get_view_names() method, so we need to query INFORMATION_SCHEMA directly. |
| """ |
| if not schema: |
| return set() |
| |
| # Construct the query to get materialized views from INFORMATION_SCHEMA |
| if catalog := database.get_default_catalog(): |
| information_schema = f"`{catalog}.{schema}.INFORMATION_SCHEMA.TABLES`" |
| else: |
| information_schema = f"`{schema}.INFORMATION_SCHEMA.TABLES`" |
| |
| # Use string formatting for the table name since it's not user input |
| # The catalog and schema are from trusted sources (database configuration) |
| query = f""" |
| SELECT table_name |
| FROM {information_schema} |
| WHERE table_type = 'MATERIALIZED VIEW' |
| """ # noqa: S608 |
| |
| materialized_views = set() |
| try: |
| with database.get_raw_connection(catalog=catalog, schema=schema) as conn: |
| cursor = conn.cursor() |
| cursor.execute(query) |
| materialized_views = {row[0] for row in cursor.fetchall()} |
| except Exception: |
| # If we can't fetch materialized views, return empty set |
| logger.warning( |
| "Unable to fetch materialized views for schema %s", |
| schema, |
| exc_info=True, |
| ) |
| |
| return materialized_views |
| |
| @classmethod |
| def get_view_names( |
| cls, |
| database: Database, |
| inspector: Inspector, |
| schema: str | None, |
| ) -> set[str]: |
| """ |
| Get all views from BigQuery, excluding materialized views. |
| |
| BigQuery's standard view discovery includes materialized views, |
| but we want to separate them for proper categorization. |
| """ |
| if not schema: |
| return set() |
| |
| # Construct the query to get regular views from INFORMATION_SCHEMA |
| catalog = database.get_default_catalog() |
| if catalog: |
| information_schema = f"`{catalog}.{schema}.INFORMATION_SCHEMA.TABLES`" |
| else: |
| information_schema = f"`{schema}.INFORMATION_SCHEMA.TABLES`" |
| |
| # Use string formatting for the table name since it's not user input |
| # The catalog and schema are from trusted sources (database configuration) |
| query = f""" |
| SELECT table_name |
| FROM {information_schema} |
| WHERE table_type = 'VIEW' |
| """ # noqa: S608 |
| |
| views = set() |
| try: |
| with database.get_raw_connection(catalog=catalog, schema=schema) as conn: |
| cursor = conn.cursor() |
| cursor.execute(query) |
| views = {row[0] for row in cursor.fetchall()} |
| except Exception: |
| # If we can't fetch views, fall back to the default implementation |
| logger.warning( |
| "Unable to fetch views for schema %s, falling back to default", |
| schema, |
| exc_info=True, |
| ) |
| return super().get_view_names(database, inspector, schema) |
| |
| return views |