blob: 74c3715f28a9265ed00646a4c570c8fbd545ea7d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
from sqlalchemy import column, table
from sqlalchemy.dialects import mssql
from sqlalchemy.dialects.mssql import DATE, NTEXT, NVARCHAR, TEXT, VARCHAR
from sqlalchemy.sql import select
from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.utils.core import GenericDataType
from tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestMssqlEngineSpec(TestDbEngineSpec):
def test_mssql_column_types(self):
def assert_type(type_string, type_expected, generic_type_expected):
if type_expected is None:
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
self.assertIsNone(type_assigned)
else:
column_spec = MssqlEngineSpec.get_column_spec(type_string)
if column_spec != None:
self.assertIsInstance(column_spec.sqla_type, type_expected)
self.assertEquals(column_spec.generic_type, generic_type_expected)
assert_type("STRING", String, GenericDataType.STRING)
assert_type("CHAR(10)", String, GenericDataType.STRING)
assert_type("VARCHAR(10)", String, GenericDataType.STRING)
assert_type("TEXT", String, GenericDataType.STRING)
assert_type("NCHAR(10)", UnicodeText, GenericDataType.STRING)
assert_type("NVARCHAR(10)", UnicodeText, GenericDataType.STRING)
assert_type("NTEXT", UnicodeText, GenericDataType.STRING)
def test_where_clause_n_prefix(self):
dialect = mssql.dialect()
spec = MssqlEngineSpec
type_, _ = spec.get_sqla_column_type("VARCHAR(10)")
str_col = column("col", type_=type_)
type_, _ = spec.get_sqla_column_type("NTEXT")
unicode_col = column("unicode_col", type_=type_)
tbl = table("tbl")
sel = (
select([str_col, unicode_col])
.select_from(tbl)
.where(str_col == "abc")
.where(unicode_col == "abc")
)
query = str(
sel.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
)
query_expected = (
"SELECT col, unicode_col \n"
"FROM tbl \n"
"WHERE col = 'abc' AND unicode_col = N'abc'"
)
self.assertEqual(query, query_expected)
def test_time_exp_mixd_case_col_1y(self):
col = column("MixedCase")
expr = MssqlEngineSpec.get_timestamp_expr(col, None, "P1Y")
result = str(expr.compile(None, dialect=mssql.dialect()))
self.assertEqual(result, "DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)")
def test_convert_dttm(self):
dttm = self.get_dttm()
test_cases = (
(
MssqlEngineSpec.convert_dttm("DATE", dttm),
"CONVERT(DATE, '2019-01-02', 23)",
),
(
MssqlEngineSpec.convert_dttm("DATETIME", dttm),
"CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)",
),
(
MssqlEngineSpec.convert_dttm("SMALLDATETIME", dttm),
"CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)",
),
)
for actual, expected in test_cases:
self.assertEqual(actual, expected)
def test_extract_error_message(self):
test_mssql_exception = Exception(
"(8155, b\"No column name was specified for column 1 of 'inner_qry'."
"DB-Lib error message 20018, severity 16:\\nGeneral SQL Server error: "
'Check messages from the SQL Server\\n")'
)
error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
expected_message = (
"mssql error: All your SQL functions need to "
"have an alias on MSSQL. For example: SELECT COUNT(*) AS C1 FROM TABLE1"
)
self.assertEqual(expected_message, error_message)
test_mssql_exception = Exception(
'(8200, b"A correlated expression is invalid because it is not in a '
"GROUP BY clause.\\n\")'"
)
error_message = MssqlEngineSpec.extract_error_message(test_mssql_exception)
expected_message = "mssql error: " + MssqlEngineSpec._extract_error_message(
test_mssql_exception
)
self.assertEqual(expected_message, error_message)
@mock.patch.object(
MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted"
)
def test_fetch_data(self, mock_pyodbc_rows_to_tuples):
data = [(1, "foo")]
with mock.patch.object(
BaseEngineSpec, "fetch_data", return_value=data
) as mock_fetch:
result = MssqlEngineSpec.fetch_data(None, 0)
mock_pyodbc_rows_to_tuples.assert_called_once_with(data)
self.assertEqual(result, "converted")
def test_column_datatype_to_string(self):
test_cases = (
(DATE(), "DATE"),
(VARCHAR(length=255), "VARCHAR(255)"),
(VARCHAR(length=255, collation="utf8_general_ci"), "VARCHAR(255)"),
(NVARCHAR(length=128), "NVARCHAR(128)"),
(TEXT(), "TEXT"),
(NTEXT(collation="utf8_general_ci"), "NTEXT"),
)
for original, expected in test_cases:
actual = MssqlEngineSpec.column_datatype_to_string(
original, mssql.dialect()
)
self.assertEqual(actual, expected)