blob: 0f140ab11bede0e9c8f5fa39c56372642982f5e0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integration tests for datasource validate_expression API endpoint"""
from unittest.mock import patch
from superset.utils import json
from superset.utils.core import SqlExpressionType
from tests.integration_tests.base_tests import SupersetTestCase
# Note: Tests use mocked responses, so we don't need the actual energy table fixture
class TestDatasourceValidateExpressionApi(SupersetTestCase):
"""Test the datasource validate_expression API endpoint"""
@patch("superset.connectors.sqla.models.SqlaTable.validate_expression")
def test_validate_expression_column_success(self, mock_validate):
"""Test successful validation of a column expression"""
self.login("admin")
# Mock successful validation
mock_validate.return_value = {"valid": True, "errors": []}
# Use the birth_names dataset for testing
datasource = self.get_birth_names_dataset()
datasource_id = datasource.id
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={
"expression": "test_col",
"expression_type": SqlExpressionType.COLUMN.value,
},
)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "result" in data
assert data["result"] == [] # Empty array means success
@patch("superset.connectors.sqla.models.SqlaTable.validate_expression")
def test_validate_expression_metric_success(self, mock_validate):
"""Test successful validation of a metric expression"""
self.login("admin")
# Mock successful validation
mock_validate.return_value = {"valid": True, "errors": []}
# Use the birth_names dataset for testing
datasource = self.get_birth_names_dataset()
datasource_id = datasource.id
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={
"expression": "SUM(amount)",
"expression_type": SqlExpressionType.METRIC.value,
},
)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "result" in data
assert data["result"] == []
@patch("superset.connectors.sqla.models.SqlaTable.validate_expression")
def test_validate_expression_where_success(self, mock_validate):
"""Test successful validation of a WHERE clause expression"""
self.login("admin")
# Mock successful validation
mock_validate.return_value = {"valid": True, "errors": []}
# Use the birth_names dataset for testing
datasource = self.get_birth_names_dataset()
datasource_id = datasource.id
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={
"expression": "status = 'active'",
"expression_type": SqlExpressionType.WHERE.value,
},
)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "result" in data
assert data["result"] == []
@patch("superset.connectors.sqla.models.SqlaTable.validate_expression")
def test_validate_expression_having_success(self, mock_validate):
"""Test successful validation of a HAVING clause expression"""
self.login("admin")
# Mock successful validation
mock_validate.return_value = {"valid": True, "errors": []}
# Use the birth_names dataset for testing
datasource = self.get_birth_names_dataset()
datasource_id = datasource.id
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={
"expression": "SUM(amount) > 100",
"expression_type": SqlExpressionType.HAVING.value,
},
)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "result" in data
assert data["result"] == []
def test_validate_expression_invalid_sql(self):
"""Test validation of invalid SQL expression"""
self.login("admin")
# Use the birth_names dataset for testing
datasource = self.get_birth_names_dataset()
datasource_id = datasource.id
with patch(
"superset.connectors.sqla.models.SqlaTable.validate_expression"
) as mock_validate:
mock_validate.return_value = {
"valid": False,
"errors": [{"message": "Invalid SQL syntax"}],
}
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={
"expression": "INVALID SQL HERE",
"expression_type": SqlExpressionType.COLUMN.value,
},
)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "result" in data
assert len(data["result"]) == 1
assert data["result"][0]["message"] == "Invalid SQL syntax"
def test_validate_expression_having_with_non_aggregated_column(self):
"""Test that HAVING clause fails for non-aggregated columns"""
self.login("admin")
# Use the birth_names dataset for testing
datasource = self.get_birth_names_dataset()
datasource_id = datasource.id
with patch(
"superset.connectors.sqla.models.SqlaTable.validate_expression"
) as mock_validate:
mock_validate.return_value = {
"valid": False,
"errors": [
{
"message": (
"column 'source' must appear in the GROUP BY clause "
"or be used in an aggregate function"
)
}
],
}
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={
"expression": "source = 'energy_source1'",
"expression_type": SqlExpressionType.HAVING.value,
},
)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert "result" in data
assert len(data["result"]) == 1
assert "must appear in the GROUP BY clause" in data["result"][0]["message"]
def test_validate_expression_empty(self):
"""Test validation of empty expression"""
self.login("admin")
# Use the birth_names dataset for testing
datasource = self.get_birth_names_dataset()
datasource_id = datasource.id
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={
"expression": "",
"expression_type": SqlExpressionType.COLUMN.value,
},
)
assert rv.status_code == 400 # Bad request for empty expression
def test_validate_expression_missing_parameters(self):
"""Test validation with missing required parameters"""
self.login("admin")
# Use the birth_names dataset for testing
datasource = self.get_birth_names_dataset()
datasource_id = datasource.id
# Missing expression_type - defaults to "where"
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={"expression": "test_col"},
)
assert rv.status_code == 200 # Defaults to "where" type, so succeeds
# Missing expression - this should fail
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={"expression_type": SqlExpressionType.COLUMN.value},
)
assert rv.status_code == 400 # Missing expression is an error
def test_validate_expression_datasource_not_found(self):
"""Test validation with non-existent datasource"""
self.login("admin")
rv = self.client.post(
"/api/v1/datasource/table/99999/validate_expression/",
json={
"expression": "test_col",
"expression_type": SqlExpressionType.COLUMN.value,
},
)
assert rv.status_code == 404
def test_validate_expression_no_permission(self):
"""Test validation without permission to access datasource"""
# Create a user without admin privileges
self.login("gamma")
# Use the birth_names dataset for testing
datasource = self.get_birth_names_dataset()
datasource_id = datasource.id
rv = self.client.post(
f"/api/v1/datasource/table/{datasource_id}/validate_expression/",
json={
"expression": "test_col",
"expression_type": SqlExpressionType.COLUMN.value,
},
)
# Should get 403 Forbidden or 404 if datasource is hidden
assert rv.status_code in [403, 404]
def test_validate_expression_invalid_datasource_type(self):
"""Test validation with invalid datasource type"""
self.login("admin")
rv = self.client.post(
"/api/v1/datasource/invalid_type/1/validate_expression/",
json={
"expression": "test_col",
"expression_type": SqlExpressionType.COLUMN.value,
},
)
assert rv.status_code == 400 # Returns 400 for invalid datasource type