| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import re |
| import threading |
| from re import Pattern |
| from typing import Any, Callable, List, NamedTuple, Optional |
| |
| from flask_babel import gettext as __ |
| from sqlalchemy.engine.reflection import Inspector |
| from sqlalchemy.orm import Session |
| |
| # Need to try-catch here because pyocient may not be installed |
| try: |
| # Ensure pyocient inherits Superset's logging level |
| import geojson |
| import pyocient |
| from shapely import wkt |
| |
| from superset import app |
| |
| superset_log_level = app.config["LOG_LEVEL"] |
| pyocient.logger.setLevel(superset_log_level) |
| except (ImportError, RuntimeError): |
| pass |
| |
| from superset.constants import TimeGrain |
| from superset.db_engine_specs.base import BaseEngineSpec |
| from superset.errors import SupersetErrorType |
| from superset.models.core import Database |
| from superset.models.sql_lab import Query |
| |
| # Regular expressions to catch custom errors |
| |
| CONNECTION_INVALID_USERNAME_REGEX = re.compile( |
| r"The referenced user does not exist \(User '(?P<username>.*?)' not found\)" |
| ) |
| CONNECTION_INVALID_PASSWORD_REGEX = re.compile( |
| r"The userid/password combination was not valid \(Incorrect password for user\)" |
| ) |
| CONNECTION_INVALID_HOSTNAME_REGEX = re.compile( |
| r"Unable to connect to (?P<host>.*?):(?P<port>.*?)" |
| ) |
| CONNECTION_UNKNOWN_DATABASE_REGEX = re.compile( |
| r"No database named '(?P<database>.*?)' exists" |
| ) |
| CONNECTION_INVALID_PORT_ERROR = re.compile("Port out of range 0-65535") |
| INVALID_CONNECTION_STRING_REGEX = re.compile( |
| r"An invalid connection string attribute was specified" |
| r" \(failed to decrypt cipher text\)" |
| ) |
| SYNTAX_ERROR_REGEX = re.compile( |
| r"There is a syntax error in your statement \((?P<qualifier>.*?)" |
| r" input '(?P<input>.*?)' expecting (?P<expected>.*?)\)" |
| ) |
| TABLE_DOES_NOT_EXIST_REGEX = re.compile( |
| r"The referenced table or view '(?P<table>.*?)' does not exist" |
| ) |
| COLUMN_DOES_NOT_EXIST_REGEX = re.compile( |
| r"The reference to column '(?P<column>.*?)' is not valid" |
| ) |
| |
| |
| # Custom datatype conversion functions |
| |
| |
| def _to_hex(data: bytes) -> str: |
| """ |
| Converts the bytes object into a string of hexadecimal digits. |
| |
| :param data: the bytes object |
| :returns: string of hexadecimal digits representing the bytes |
| """ |
| return data.hex() |
| |
| |
| def _wkt_to_geo_json(geo_as_wkt: str) -> Any: |
| """ |
| Converts pyocient geometry objects to their geoJSON representation. |
| |
| :param geo_as_wkt: the GIS object in WKT format |
| :returns: the geoJSON encoding of `geo` |
| """ |
| # Need to try-catch here because these deps may not be installed |
| geo = wkt.loads(geo_as_wkt) |
| return geojson.Feature(geometry=geo, properties={}) |
| |
| |
| def _point_list_to_wkt( |
| points, # type: List[pyocient._STPoint] |
| ) -> str: |
| """ |
| Converts the list of pyocient._STPoint elements to a WKT LineString. |
| |
| :param points: the list of pyocient._STPoint objects |
| :returns: WKT LineString |
| """ |
| coords = [f"{p.long} {p.lat}" for p in points] |
| return f"LINESTRING({', '.join(coords)})" |
| |
| |
| def _point_to_geo_json( |
| point, # type: pyocient._STPoint |
| ) -> Any: |
| """ |
| Converts the pyocient._STPolygon object to the geoJSON format |
| |
| :param point: the pyocient._STPoint instance |
| :returns: the geoJSON encoding of this point |
| """ |
| wkt_point = str(point) |
| return _wkt_to_geo_json(wkt_point) |
| |
| |
| def _linestring_to_geo_json( |
| linestring, # type: pyocient._STLinestring |
| ) -> Any: |
| """ |
| Converts the pyocient._STLinestring object to a GIS format |
| compatible with the Superset visualization toolkit (powered |
| by Deck.gl). |
| |
| :param linestring: the pyocient._STLinestring instance |
| :returns: the geoJSON of this linestring |
| """ |
| if len(linestring.points) == 1: |
| # While technically an invalid linestring object, Ocient |
| # permits ST_LINESTRING containers to contain a single |
| # point. The flexibility allows the database to encode |
| # geometry collections as an array of the highest dimensional |
| # element in the collection (i.e. ST_LINESTRING[] or |
| # ST_POLYGON[]). |
| point = linestring.points[0] |
| return _point_to_geo_json(point) |
| |
| wkt_linestring = str(linestring) |
| return _wkt_to_geo_json(wkt_linestring) |
| |
| |
| def _polygon_to_geo_json( |
| polygon, # type: pyocient._STPolygon |
| ) -> Any: |
| """ |
| Converts the pyocient._STPolygon object to a GIS format |
| compatible with the Superset visualization toolkit (powered |
| by Deck.gl). |
| |
| :param polygon: the pyocient._STPolygon instance |
| :returns: the geoJSON encoding of this polygon |
| """ |
| if len(polygon.exterior) > 0 and len(polygon.holes) == 0: |
| if len(polygon.exterior) == 1: |
| # The exterior ring contains a single ST_POINT |
| point = polygon.exterior[0] |
| return _point_to_geo_json(point) |
| if polygon.exterior[0] != polygon.exterior[-1]: |
| # The exterior ring contains an open ST_LINESTRING |
| wkt_linestring = _point_list_to_wkt(polygon.exterior) |
| return _wkt_to_geo_json(wkt_linestring) |
| # else |
| # This is a valid ST_POLYGON |
| wkt_polygon = str(polygon) |
| return _wkt_to_geo_json(wkt_polygon) |
| |
| |
| # Sanitization function for column values |
| SanitizeFunc = Callable[[Any], Any] |
| |
| |
| # Represents a pair of a column index and the sanitization function |
| # to apply to its values. |
| class PlacedSanitizeFunc(NamedTuple): |
| column_index: int |
| sanitize_func: SanitizeFunc |
| |
| |
| # This map contains functions used to sanitize values for column types |
| # that cannot be processed natively by Superset. |
| # |
| # Superset serializes temporal objects using a custom serializer |
| # defined in superset/utils/core.py (#json_int_dttm_ser(...)). Other |
| # are serialized by the default JSON encoder. |
| # |
| # Need to try-catch here because pyocient may not be installed |
| try: |
| from pyocient import TypeCodes |
| |
| _sanitized_ocient_type_codes: dict[int, SanitizeFunc] = { |
| TypeCodes.BINARY: _to_hex, |
| TypeCodes.ST_POINT: _point_to_geo_json, |
| TypeCodes.IP: str, |
| TypeCodes.IPV4: str, |
| TypeCodes.ST_LINESTRING: _linestring_to_geo_json, |
| TypeCodes.ST_POLYGON: _polygon_to_geo_json, |
| } |
| except ImportError as e: |
| _sanitized_ocient_type_codes = {} |
| |
| |
| def _find_columns_to_sanitize(cursor: Any) -> list[PlacedSanitizeFunc]: |
| """ |
| Cleans the column value for consumption by Superset. |
| |
| :param cursor: the result set cursor |
| :returns: the list of tuples consisting of the column index and sanitization function |
| """ |
| return [ |
| PlacedSanitizeFunc(i, _sanitized_ocient_type_codes[cursor.description[i][1]]) |
| for i in range(len(cursor.description)) |
| if cursor.description[i][1] in _sanitized_ocient_type_codes |
| ] |
| |
| |
| class OcientEngineSpec(BaseEngineSpec): |
| engine = "ocient" |
| engine_name = "Ocient" |
| # limit_method = LimitMethod.WRAP_SQL |
| force_column_alias_quotes = True |
| max_column_name_length = 30 |
| |
| allows_cte_in_subquery = False |
| # Ocient does not support cte names starting with underscores |
| cte_alias = "cte__" |
| # Store mapping of superset Query id -> Ocient ID |
| # These are inserted into the cache when executing the query |
| # They are then removed, either upon cancellation or query completion |
| query_id_mapping: dict[str, str] = dict() |
| query_id_mapping_lock = threading.Lock() |
| |
| custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { |
| CONNECTION_INVALID_USERNAME_REGEX: ( |
| __('The username "%(username)s" does not exist.'), |
| SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR, |
| {}, |
| ), |
| CONNECTION_INVALID_PASSWORD_REGEX: ( |
| __( |
| "The user/password combination is not valid" |
| " (Incorrect password for user)." |
| ), |
| SupersetErrorType.CONNECTION_INVALID_PASSWORD_ERROR, |
| {}, |
| ), |
| CONNECTION_UNKNOWN_DATABASE_REGEX: ( |
| __('Could not connect to database: "%(database)s"'), |
| SupersetErrorType.CONNECTION_UNKNOWN_DATABASE_ERROR, |
| {}, |
| ), |
| CONNECTION_INVALID_HOSTNAME_REGEX: ( |
| __('Could not resolve hostname: "%(host)s".'), |
| SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR, |
| {}, |
| ), |
| CONNECTION_INVALID_PORT_ERROR: ( |
| __("Port out of range 0-65535"), |
| SupersetErrorType.CONNECTION_INVALID_PORT_ERROR, |
| {}, |
| ), |
| INVALID_CONNECTION_STRING_REGEX: ( |
| __( |
| "Invalid Connection String: Expecting String of" |
| " the form 'ocient://user:pass@host:port/database'." |
| ), |
| SupersetErrorType.GENERIC_DB_ENGINE_ERROR, |
| {}, |
| ), |
| SYNTAX_ERROR_REGEX: ( |
| __('Syntax Error: %(qualifier)s input "%(input)s" expecting "%(expected)s'), |
| SupersetErrorType.SYNTAX_ERROR, |
| {}, |
| ), |
| TABLE_DOES_NOT_EXIST_REGEX: ( |
| __('Table or View "%(table)s" does not exist.'), |
| SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, |
| {}, |
| ), |
| COLUMN_DOES_NOT_EXIST_REGEX: ( |
| __('Invalid reference to column: "%(column)s"'), |
| SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR, |
| {}, |
| ), |
| } |
| _time_grain_expressions = { |
| None: "{col}", |
| TimeGrain.SECOND: "ROUND({col}, 'SECOND')", |
| TimeGrain.MINUTE: "ROUND({col}, 'MINUTE')", |
| TimeGrain.HOUR: "ROUND({col}, 'HOUR')", |
| TimeGrain.DAY: "ROUND({col}, 'DAY')", |
| TimeGrain.WEEK: "ROUND({col}, 'WEEK')", |
| TimeGrain.MONTH: "ROUND({col}, 'MONTH')", |
| TimeGrain.QUARTER_YEAR: "ROUND({col}, 'QUARTER')", |
| TimeGrain.YEAR: "ROUND({col}, 'YEAR')", |
| } |
| |
| @classmethod |
| def get_table_names( |
| cls, database: Database, inspector: Inspector, schema: Optional[str] |
| ) -> set[str]: |
| return inspector.get_table_names(schema) |
| |
| @classmethod |
| def fetch_data( |
| cls, cursor: Any, limit: Optional[int] = None |
| ) -> list[tuple[Any, ...]]: |
| try: |
| rows: list[tuple[Any, ...]] = super().fetch_data(cursor, limit) |
| except Exception as exception: |
| with OcientEngineSpec.query_id_mapping_lock: |
| del OcientEngineSpec.query_id_mapping[ |
| getattr(cursor, "superset_query_id") |
| ] |
| raise exception |
| |
| # TODO: Unsure if we need to verify that we are receiving rows: |
| if len(rows) > 0 and type(rows[0]).__name__ == "Row": |
| # Peek at the schema to determine which column values, if any, |
| # require sanitization. |
| columns_to_sanitize: list[PlacedSanitizeFunc] = _find_columns_to_sanitize( |
| cursor |
| ) |
| |
| if columns_to_sanitize: |
| # At least 1 column has to be sanitized. |
| |
| def identity(x: Any) -> Any: |
| return x |
| |
| # Use the identity function if the column type doesn't need to be |
| # sanitized. |
| sanitization_functions: list[SanitizeFunc] = [ |
| identity for _ in range(len(cursor.description)) |
| ] |
| for info in columns_to_sanitize: |
| sanitization_functions[info.column_index] = info.sanitize_func |
| |
| # pyocient returns a list of NamedTuple objects which represent a |
| # single row. We have to do this copy because that data type is |
| # NamedTuple's are immutable. |
| rows = [ |
| tuple( |
| sanitize_func(val) |
| for sanitize_func, val in zip(sanitization_functions, row) |
| ) |
| for row in rows |
| ] |
| return rows |
| |
| @classmethod |
| def epoch_to_dttm(cls) -> str: |
| return "DATEADD(S, {col}, '1970-01-01')" |
| |
| @classmethod |
| def epoch_ms_to_dttm(cls) -> str: |
| return "DATEADD(MS, {col}, '1970-01-01')" |
| |
| @classmethod |
| def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]: |
| # Return a Non-None value |
| # If None is returned, Superset will not call cancel_query |
| return "DUMMY_VALUE" |
| |
| @classmethod |
| def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: |
| with OcientEngineSpec.query_id_mapping_lock: |
| OcientEngineSpec.query_id_mapping[query.id] = cursor.query_id |
| |
| # Add the query id to the cursor |
| setattr(cursor, "superset_query_id", query.id) |
| return super().handle_cursor(cursor, query, session) |
| |
| @classmethod |
| def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: |
| with OcientEngineSpec.query_id_mapping_lock: |
| if query.id in OcientEngineSpec.query_id_mapping: |
| cursor.execute(f"CANCEL {OcientEngineSpec.query_id_mapping[query.id]}") |
| # Query has been cancelled, so we can safely remove the cursor from |
| # the cache |
| del OcientEngineSpec.query_id_mapping[query.id] |
| return True |
| # If the query is not in the cache, it must have either been cancelled |
| # elsewhere or completed |
| return False |