fix: url shortener invalid input (#13461)
* fix: url shortner invalid input
* fix lint
(cherry picked from commit c3c73763d0a64b06e18662c95937c00221c6afbd)
diff --git a/superset/views/redirects.py b/superset/views/redirects.py
index 02dc587..1be79b6 100644
--- a/superset/views/redirects.py
+++ b/superset/views/redirects.py
@@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import logging
+from typing import Optional
+
from flask import flash, request, Response
from flask_appbuilder import expose
from flask_appbuilder.security.decorators import has_access_api
@@ -24,11 +27,22 @@
from superset.typing import FlaskResponse
from superset.views.base import BaseSupersetView
+logger = logging.getLogger(__name__)
+
class R(BaseSupersetView): # pylint: disable=invalid-name
"""used for short urls"""
+ @staticmethod
+ def _validate_url(url: Optional[str] = None) -> bool:
+ if url and (
+ url.startswith("//superset/dashboard/")
+ or url.startswith("//superset/explore/")
+ ):
+ return True
+ return False
+
@event_logger.log_this
@expose("/<int:url_id>")
def index(self, url_id: int) -> FlaskResponse: # pylint: disable=no-self-use
@@ -38,8 +52,9 @@
if url.url.startswith(explore_url):
explore_url += f"r={url_id}"
return redirect(explore_url[1:])
-
- return redirect(url.url[1:])
+ if self._validate_url(url.url):
+ return redirect(url.url[1:])
+ return redirect("/")
flash("URL to nowhere...", "danger")
return redirect("/")
@@ -49,6 +64,9 @@
@expose("/shortner/", methods=["POST"])
def shortner(self) -> FlaskResponse: # pylint: disable=no-self-use
url = request.form.get("data")
+ if not self._validate_url(url):
+ logger.warning("Invalid URL: %s", url)
+ return Response(f"Invalid URL: {url}", 400)
obj = models.Url(url=url)
db.session.add(obj)
db.session.commit()
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 3bc230a..111964a 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -634,6 +634,28 @@
resp = self.client.post("/r/shortner/", data=dict(data=data))
assert re.search(r"\/r\/[0-9]+", resp.data.decode("utf-8"))
+ def test_shortner_invalid(self):
+ self.login(username="admin")
+ invalid_urls = [
+ "hhttp://invalid.com",
+ "hhttps://invalid.com",
+ "www.invalid.com",
+ ]
+ for invalid_url in invalid_urls:
+ resp = self.client.post("/r/shortner/", data=dict(data=invalid_url))
+ assert resp.status_code == 400
+
+ def test_redirect_invalid(self):
+ model_url = models.Url(url="hhttp://invalid.com")
+ db.session.add(model_url)
+ db.session.commit()
+
+ self.login(username="admin")
+ response = self.client.get(f"/r/{model_url.id}")
+ assert response.headers["Location"] == "http://localhost/"
+ db.session.delete(model_url)
+ db.session.commit()
+
@skipUnless(
(is_feature_enabled("KV_STORE")), "skipping as /kv/ endpoints are not enabled"
)