feat(jinja): improve url parameter formatting (#16711)
* feat(jinja): improve url parameter formatting
* add UPDATING.md
* fix test
diff --git a/UPDATING.md b/UPDATING.md
index 711ccad..b319c19 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -25,6 +25,9 @@
## Next
### Breaking Changes
+
+- [16711](https://github.com/apache/incubator-superset/pull/16711): The `url_param` Jinja function will now by default escape the result. For instance, the value `O'Brien` will now be changed to `O''Brien`. To disable this behavior, call `url_param` with `escape_result` set to `False`: `url_param("my_key", "my default", escape_result=False)`.
+
### Potential Downtime
### Deprecations
### Other
diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index ffcf497..e6a4cab 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -34,6 +34,8 @@
from flask_babel import gettext as _
from jinja2 import DebugUndefined
from jinja2.sandbox import SandboxedEnvironment
+from sqlalchemy.engine.interfaces import Dialect
+from sqlalchemy.types import String
from typing_extensions import TypedDict
from superset.exceptions import SupersetTemplateException
@@ -95,9 +97,11 @@
self,
extra_cache_keys: Optional[List[Any]] = None,
removed_filters: Optional[List[str]] = None,
+ dialect: Optional[Dialect] = None,
):
self.extra_cache_keys = extra_cache_keys
self.removed_filters = removed_filters if removed_filters is not None else []
+ self.dialect = dialect
def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:
"""
@@ -145,7 +149,11 @@
return key
def url_param(
- self, param: str, default: Optional[str] = None, add_to_cache_keys: bool = True
+ self,
+ param: str,
+ default: Optional[str] = None,
+ add_to_cache_keys: bool = True,
+ escape_result: bool = True,
) -> Optional[str]:
"""
Read a url or post parameter and use it in your SQL Lab query.
@@ -166,6 +174,7 @@
:param param: the parameter to lookup
:param default: the value to return in the absence of the parameter
:param add_to_cache_keys: Whether the value should be included in the cache key
+ :param escape_result: Should special characters in the result be escaped
:returns: The URL parameters
"""
@@ -178,6 +187,11 @@
form_data, _ = get_form_data()
url_params = form_data.get("url_params") or {}
result = url_params.get(param, default)
+ if result and escape_result and self.dialect:
+ # use the dialect specific quoting logic to escape string
+ result = String().literal_processor(dialect=self.dialect)(value=result)[
+ 1:-1
+ ]
if add_to_cache_keys:
self.cache_key_wrapper(result)
return result
@@ -430,7 +444,11 @@
class JinjaTemplateProcessor(BaseTemplateProcessor):
def set_context(self, **kwargs: Any) -> None:
super().set_context(**kwargs)
- extra_cache = ExtraCache(self._extra_cache_keys, self._removed_filters)
+ extra_cache = ExtraCache(
+ extra_cache_keys=self._extra_cache_keys,
+ removed_filters=self._removed_filters,
+ dialect=self._database.get_dialect(),
+ )
self._context.update(
{
"url_param": partial(safe_proxy, extra_cache.url_param),
diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py
index 7e4ebfd..e808bad 100644
--- a/tests/integration_tests/base_tests.py
+++ b/tests/integration_tests/base_tests.py
@@ -28,9 +28,11 @@
from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
+from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.declarative.api import DeclarativeMeta
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
+from sqlalchemy.dialects.mysql import dialect
from tests.integration_tests.test_app import app
from superset.sql_parse import CtasMethod
@@ -422,7 +424,7 @@
self.login(username="admin")
database_name = "db_for_macros_testing"
db_id = 200
- return self.get_or_create(
+ database = self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
@@ -430,7 +432,14 @@
id=db_id,
)
- def delete_fake_db_for_macros(self):
+ def mock_get_dialect() -> Dialect:
+ return dialect()
+
+ database.get_dialect = mock_get_dialect
+ return database
+
+ @staticmethod
+ def delete_fake_db_for_macros():
database = (
db.session.query(Database)
.filter(Database.database_name == "db_for_macros_testing")
diff --git a/tests/integration_tests/jinja_context_tests.py b/tests/integration_tests/jinja_context_tests.py
index a990968..b82adfa 100644
--- a/tests/integration_tests/jinja_context_tests.py
+++ b/tests/integration_tests/jinja_context_tests.py
@@ -20,6 +20,7 @@
from unittest import mock
import pytest
+from sqlalchemy.dialects.postgresql import dialect
import tests.integration_tests.test_app
from superset import app
@@ -199,6 +200,36 @@
cache = ExtraCache()
self.assertEqual(cache.url_param("foo"), "bar")
+ def test_url_param_escaped_form_data(self) -> None:
+ with app.test_request_context(
+ query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
+ ):
+ cache = ExtraCache(dialect=dialect())
+ self.assertEqual(cache.url_param("foo"), "O''Brien")
+
+ def test_url_param_escaped_default_form_data(self) -> None:
+ with app.test_request_context(
+ query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
+ ):
+ cache = ExtraCache(dialect=dialect())
+ self.assertEqual(cache.url_param("bar", "O'Malley"), "O''Malley")
+
+ def test_url_param_unescaped_form_data(self) -> None:
+ with app.test_request_context(
+ query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
+ ):
+ cache = ExtraCache(dialect=dialect())
+ self.assertEqual(cache.url_param("foo", escape_result=False), "O'Brien")
+
+ def test_url_param_unescaped_default_form_data(self) -> None:
+ with app.test_request_context(
+ query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
+ ):
+ cache = ExtraCache(dialect=dialect())
+ self.assertEqual(
+ cache.url_param("bar", "O'Malley", escape_result=False), "O'Malley"
+ )
+
def test_safe_proxy_primitive(self) -> None:
def func(input: Any) -> Any:
return input