| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Unit tests for SqlaTable.validate_expression""" |
| |
| from unittest.mock import MagicMock, patch |
| |
| from superset.connectors.sqla.models import SqlaTable |
| from superset.utils.core import SqlExpressionType |
| |
| |
| class TestValidateExpression: |
| """Test validate_expression method""" |
| |
| def setup_method(self): |
| """Set up test fixtures""" |
| self.table = SqlaTable() |
| self.table.table_name = "test_table" |
| self.table.schema = "test_schema" |
| self.table.catalog = None |
| self.table.database = MagicMock() |
| self.table.database.db_engine_spec = MagicMock() |
| self.table.database.db_engine_spec.make_sqla_column_compatible = lambda x, _: x |
| self.table.columns = [] |
| |
| # Mock get_from_clause to return a simple table |
| self.table.get_from_clause = MagicMock(return_value=(MagicMock(), None)) |
| |
| # Mock get_sqla_row_level_filters |
| self.table.get_sqla_row_level_filters = MagicMock(return_value=[]) |
| |
| # Mock get_template_processor |
| self.table.get_template_processor = MagicMock(return_value=None) |
| |
| @patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query") |
| def test_validate_column_expression(self, mock_execute): |
| """Test validation of column expressions""" |
| # Mock _execute_validation_query to return success |
| mock_execute.return_value = {"valid": True, "errors": []} |
| |
| result = self.table.validate_expression( |
| expression="test_col", |
| expression_type=SqlExpressionType.COLUMN, |
| ) |
| |
| assert result["valid"] is True |
| assert result["errors"] == [] |
| mock_execute.assert_called_once() |
| |
| @patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query") |
| def test_validate_metric_expression(self, mock_execute): |
| """Test validation of metric expressions""" |
| # Mock _execute_validation_query to return success |
| mock_execute.return_value = {"valid": True, "errors": []} |
| |
| result = self.table.validate_expression( |
| expression="SUM(amount)", |
| expression_type=SqlExpressionType.METRIC, |
| ) |
| |
| assert result["valid"] is True |
| assert result["errors"] == [] |
| |
| @patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query") |
| def test_validate_where_expression(self, mock_execute): |
| """Test validation of WHERE clause expressions""" |
| # Mock _execute_validation_query to return success |
| mock_execute.return_value = {"valid": True, "errors": []} |
| |
| result = self.table.validate_expression( |
| expression="status = 'active'", |
| expression_type=SqlExpressionType.WHERE, |
| ) |
| |
| assert result["valid"] is True |
| assert result["errors"] == [] |
| |
| @patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query") |
| def test_validate_having_expression(self, mock_execute): |
| """Test validation of HAVING clause expressions""" |
| # Mock _execute_validation_query to return success |
| mock_execute.return_value = {"valid": True, "errors": []} |
| |
| result = self.table.validate_expression( |
| expression="SUM(amount) > 100", |
| expression_type=SqlExpressionType.HAVING, |
| ) |
| |
| assert result["valid"] is True |
| assert result["errors"] == [] |
| |
| @patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query") |
| def test_validate_invalid_expression(self, mock_execute): |
| """Test validation of invalid SQL expressions""" |
| # Mock _execute_validation_query to raise an exception |
| mock_execute.side_effect = Exception("Invalid SQL syntax") |
| |
| result = self.table.validate_expression( |
| expression="INVALID SQL HERE", |
| expression_type=SqlExpressionType.COLUMN, |
| ) |
| |
| assert result["valid"] is False |
| assert len(result["errors"]) == 1 |
| assert "Invalid SQL syntax" in result["errors"][0]["message"] |
| |
| @patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query") |
| def test_validate_having_with_non_aggregated_column(self, mock_execute): |
| """Test that HAVING clause properly detects non-aggregated columns""" |
| # Simulate database error for non-aggregated column in HAVING |
| mock_execute.side_effect = Exception( |
| "column 'region' must appear in the GROUP BY clause " |
| "or be used in an aggregate function" |
| ) |
| |
| result = self.table.validate_expression( |
| expression="region = 'US'", |
| expression_type=SqlExpressionType.HAVING, |
| ) |
| |
| assert result["valid"] is False |
| assert len(result["errors"]) == 1 |
| assert "must appear in the GROUP BY clause" in result["errors"][0]["message"] |
| |
| @patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query") |
| def test_validate_empty_expression(self, mock_execute): |
| """Test validation of empty expressions""" |
| # Mock _execute_validation_query to raise exception for empty expression |
| mock_execute.side_effect = Exception("Expression is empty") |
| |
| result = self.table.validate_expression( |
| expression="", |
| expression_type=SqlExpressionType.COLUMN, |
| ) |
| |
| assert result["valid"] is False |
| assert len(result["errors"]) == 1 |
| # The actual error message will come from the exception |
| assert "empty" in result["errors"][0]["message"].lower() |
| |
| @patch("superset.connectors.sqla.models.SqlaTable._execute_validation_query") |
| def test_validate_expression_with_rls(self, mock_execute): |
| """Test that RLS filters are applied during validation""" |
| # Mock _execute_validation_query to return success |
| mock_execute.return_value = {"valid": True, "errors": []} |
| |
| # Mock RLS filters |
| self.table.get_sqla_row_level_filters = MagicMock(return_value=[MagicMock()]) |
| |
| result = self.table.validate_expression( |
| expression="test_col", |
| expression_type=SqlExpressionType.COLUMN, |
| ) |
| |
| assert result["valid"] is True |