blob: a2fc02fe751fb85c85689aef96f21b15c01031a0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
from datetime import datetime, timedelta
import json
import random
import string
import pytest
import prison
from sqlalchemy.sql import func
import tests.test_app
from superset import db, security_manager
from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database, QueryStatus
from superset.models.sql_lab import Query
from tests.base_tests import SupersetTestCase
QUERIES_FIXTURE_COUNT = 10
class TestQueryApi(SupersetTestCase):
def insert_query(
self,
database_id: int,
user_id: int,
client_id: str,
sql: str = "",
select_sql: str = "",
executed_sql: str = "",
limit: int = 100,
progress: int = 100,
rows: int = 100,
tab_name: str = "",
status: str = "success",
) -> Query:
database = db.session.query(Database).get(database_id)
user = db.session.query(security_manager.user_model).get(user_id)
query = Query(
database=database,
user=user,
client_id=client_id,
sql=sql,
select_sql=select_sql,
executed_sql=executed_sql,
limit=limit,
progress=progress,
rows=rows,
tab_name=tab_name,
status=status,
changed_on=datetime(2020, 1, 1),
)
db.session.add(query)
db.session.commit()
return query
@pytest.fixture()
def create_queries(self):
with self.create_app().app_context():
queries = []
admin_id = self.get_user("admin").id
alpha_id = self.get_user("alpha").id
example_database_id = get_example_database().id
main_database_id = get_main_database().id
for cx in range(QUERIES_FIXTURE_COUNT - 1):
queries.append(
self.insert_query(
example_database_id,
admin_id,
self.get_random_string(),
sql=f"SELECT col1, col2 from table{cx}",
rows=cx,
status=QueryStatus.SUCCESS
if (cx % 2) == 0
else QueryStatus.RUNNING,
)
)
queries.append(
self.insert_query(
main_database_id,
alpha_id,
self.get_random_string(),
sql=f"SELECT col1, col2 from table{QUERIES_FIXTURE_COUNT}",
rows=QUERIES_FIXTURE_COUNT,
status=QueryStatus.SUCCESS,
)
)
yield queries
# rollback changes
for query in queries:
db.session.delete(query)
db.session.commit()
@staticmethod
def get_random_string(length: int = 10):
letters = string.ascii_letters
return "".join(random.choice(letters) for i in range(length))
def test_get_query(self):
"""
Query API: Test get query
"""
admin = self.get_user("admin")
client_id = self.get_random_string()
example_db = get_example_database()
query = self.insert_query(
example_db.id,
admin.id,
client_id,
sql="SELECT col1, col2 from table1",
select_sql="SELECT col1, col2 from table1",
executed_sql="SELECT col1, col2 from table1 LIMIT 100",
)
self.login(username="admin")
uri = f"api/v1/query/{query.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
expected_result = {
"database": {"id": example_db.id},
"client_id": client_id,
"end_result_backend_time": None,
"error_message": None,
"executed_sql": "SELECT col1, col2 from table1 LIMIT 100",
"limit": 100,
"progress": 100,
"results_key": None,
"rows": 100,
"schema": None,
"select_as_cta": None,
"select_as_cta_used": False,
"select_sql": "SELECT col1, col2 from table1",
"sql": "SELECT col1, col2 from table1",
"sql_editor_id": None,
"status": "success",
"tab_name": "",
"tmp_schema_name": None,
"tmp_table_name": None,
"tracking_url": None,
}
data = json.loads(rv.data.decode("utf-8"))
self.assertIn("changed_on", data["result"])
for key, value in data["result"].items():
# We can't assert timestamp
if key not in (
"changed_on",
"end_time",
"start_running_time",
"start_time",
"id",
):
self.assertEqual(value, expected_result[key])
# rollback changes
db.session.delete(query)
db.session.commit()
def test_get_query_not_found(self):
"""
Query API: Test get query not found
"""
admin = self.get_user("admin")
client_id = self.get_random_string()
query = self.insert_query(get_example_database().id, admin.id, client_id)
max_id = db.session.query(func.max(Query.id)).scalar()
self.login(username="admin")
uri = f"api/v1/query/{max_id + 1}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
db.session.delete(query)
db.session.commit()
def test_get_query_no_data_access(self):
"""
Query API: Test get query without data access
"""
gamma1 = self.create_user(
"gamma_1", "password", "Gamma", email="gamma1@superset.org"
)
gamma2 = self.create_user(
"gamma_2", "password", "Gamma", email="gamma2@superset.org"
)
gamma1_client_id = self.get_random_string()
gamma2_client_id = self.get_random_string()
query_gamma1 = self.insert_query(
get_example_database().id, gamma1.id, gamma1_client_id
)
query_gamma2 = self.insert_query(
get_example_database().id, gamma2.id, gamma2_client_id
)
# Gamma1 user, only sees his own queries
self.login(username="gamma_1", password="password")
uri = f"api/v1/query/{query_gamma2.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
uri = f"api/v1/query/{query_gamma1.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
# Gamma2 user, only sees his own queries
self.logout()
self.login(username="gamma_2", password="password")
uri = f"api/v1/query/{query_gamma1.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
uri = f"api/v1/query/{query_gamma2.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
# Admin's have the "all query access" permission
self.logout()
self.login(username="admin")
uri = f"api/v1/query/{query_gamma1.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
uri = f"api/v1/query/{query_gamma2.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
# rollback changes
db.session.delete(query_gamma1)
db.session.delete(query_gamma2)
db.session.delete(gamma1)
db.session.delete(gamma2)
db.session.commit()
@pytest.mark.usefixtures("create_queries")
def test_get_list_query(self):
"""
Query API: Test get list query
"""
self.login(username="admin")
uri = "api/v1/query/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == QUERIES_FIXTURE_COUNT
# check expected columns
assert sorted(list(data["result"][0].keys())) == [
"changed_on",
"database",
"end_time",
"executed_sql",
"id",
"rows",
"schema",
"sql",
"sql_tables",
"start_time",
"status",
"tab_name",
"tmp_table_name",
"tracking_url",
"user",
]
assert sorted(list(data["result"][0]["user"].keys())) == [
"first_name",
"id",
"last_name",
"username",
]
assert list(data["result"][0]["database"].keys()) == [
"database_name",
]
@pytest.mark.usefixtures("create_queries")
def test_get_list_query_filter_sql(self):
"""
Query API: Test get list query filter
"""
self.login(username="admin")
arguments = {"filters": [{"col": "sql", "opr": "ct", "value": "table2"}]}
uri = f"api/v1/query/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
@pytest.mark.usefixtures("create_queries")
def test_get_list_query_filter_database(self):
"""
Query API: Test get list query filter database
"""
self.login(username="admin")
database_id = get_main_database().id
arguments = {
"filters": [{"col": "database", "opr": "rel_o_m", "value": database_id}]
}
uri = f"api/v1/query/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
@pytest.mark.usefixtures("create_queries")
def test_get_list_query_filter_user(self):
"""
Query API: Test get list query filter user
"""
self.login(username="admin")
alpha_id = self.get_user("alpha").id
arguments = {"filters": [{"col": "user", "opr": "rel_o_m", "value": alpha_id}]}
uri = f"api/v1/query/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
@pytest.mark.usefixtures("create_queries")
def test_get_list_query_filter_changed_on(self):
"""
Query API: Test get list query filter changed_on
"""
self.login(username="admin")
arguments = {
"filters": [
{"col": "changed_on", "opr": "lt", "value": "2020-02-01T00:00:00Z"},
{"col": "changed_on", "opr": "gt", "value": "2019-12-30T00:00:00Z"},
]
}
uri = f"api/v1/query/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == QUERIES_FIXTURE_COUNT
@pytest.mark.usefixtures("create_queries")
def test_get_list_query_order(self):
"""
Query API: Test get list query filter changed_on
"""
self.login(username="admin")
order_columns = [
"changed_on",
"database.database_name",
"rows",
"schema",
"sql",
"tab_name",
"user.first_name",
]
for order_column in order_columns:
arguments = {"order_column": order_column, "order_direction": "asc"}
uri = f"api/v1/query/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
def test_get_list_query_no_data_access(self):
"""
Query API: Test get queries no data access
"""
admin = self.get_user("admin")
client_id = self.get_random_string()
query = self.insert_query(
get_example_database().id,
admin.id,
client_id,
sql="SELECT col1, col2 from table1",
)
self.login(username="gamma")
arguments = {"filters": [{"col": "sql", "opr": "sw", "value": "SELECT col1"}]}
uri = f"api/v1/query/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
# rollback changes
db.session.delete(query)
db.session.commit()