Webserver: Sanitize values passed to origin param (#10334)
(cherry-picked from 5c2bb7b0b0e717b11f093910b443243330ad93ca)
diff --git a/airflow/www/views.py b/airflow/www/views.py
index b496e72..6087356 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -54,7 +54,7 @@
from pygments import highlight, lexers
import six
from pygments.formatters.html import HtmlFormatter
-from six.moves.urllib.parse import quote, unquote
+from six.moves.urllib.parse import quote, unquote, urlparse
from sqlalchemy import or_, desc, and_, union_all
from wtforms import (
@@ -328,6 +328,23 @@
return 600 + len(dag.tasks) * 10
+def get_safe_url(url):
+ """Given a user-supplied URL, ensure it points to our web server"""
+ try:
+ valid_schemes = ['http', 'https', '']
+ valid_netlocs = [request.host, '']
+
+ parsed = urlparse(url)
+ if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
+ return url
+ except Exception as e: # pylint: disable=broad-except
+ log.debug("Error validating value in origin parameter passed to URL: %s", url)
+ log.debug("Error: %s", e)
+ pass
+
+ return "/admin/"
+
+
def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
dttm = request.args.get('execution_date')
if dttm:
@@ -1108,7 +1125,7 @@
def run(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)
task = dag.get_task(task_id)
@@ -1179,7 +1196,7 @@
from airflow.exceptions import DagNotFound, DagFileExists
dag_id = request.values.get('dag_id')
- origin = request.values.get('origin') or "/admin/"
+ origin = get_safe_url(request.values.get('origin'))
try:
delete_dag.delete_dag(dag_id)
@@ -1203,7 +1220,7 @@
@provide_session
def trigger(self, session=None):
dag_id = request.values.get('dag_id')
- origin = request.values.get('origin') or "/admin/"
+ origin = get_safe_url(request.values.get('origin'))
if request.method == 'GET':
return self.render(
@@ -1304,7 +1321,7 @@
def clear(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)
execution_date = request.form.get('execution_date')
@@ -1334,7 +1351,7 @@
@wwwutils.notify_owner
def dagrun_clear(self):
dag_id = request.form.get('dag_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
@@ -1437,7 +1454,7 @@
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_failed(dag_id, execution_date,
confirmed, origin)
@@ -1449,7 +1466,7 @@
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_success(dag_id, execution_date,
confirmed, origin)
@@ -1502,7 +1519,7 @@
def failed(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
@@ -1522,7 +1539,7 @@
def success(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py
index f098b25..9d46d03 100644
--- a/airflow/www_rbac/views.py
+++ b/airflow/www_rbac/views.py
@@ -31,7 +31,7 @@
from urllib.parse import unquote
import six
-from six.moves.urllib.parse import quote
+from six.moves.urllib.parse import quote, urlparse
import pendulum
import sqlalchemy as sqla
@@ -89,6 +89,23 @@
dagbag = models.DagBag(os.devnull, include_examples=False)
+def get_safe_url(url):
+ """Given a user-supplied URL, ensure it points to our web server"""
+ try:
+ valid_schemes = ['http', 'https', '']
+ valid_netlocs = [request.host, '']
+
+ parsed = urlparse(url)
+ if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
+ return url
+ except Exception as e: # pylint: disable=broad-except
+ logging.debug("Error validating value in origin parameter passed to URL: %s", url)
+ logging.debug("Error: %s", e)
+ pass
+
+ return url_for('Airflow.index')
+
+
def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
dttm = request.args.get('execution_date')
if dttm:
@@ -930,7 +947,7 @@
def run(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)
task = dag.get_task(task_id)
@@ -1000,7 +1017,7 @@
from airflow.exceptions import DagNotFound, DagFileExists
dag_id = request.values.get('dag_id')
- origin = request.values.get('origin') or url_for('Airflow.index')
+ origin = get_safe_url(request.values.get('origin'))
try:
delete_dag.delete_dag(dag_id)
@@ -1027,7 +1044,7 @@
def trigger(self, session=None):
dag_id = request.values.get('dag_id')
- origin = request.values.get('origin') or url_for('Airflow.index')
+ origin = get_safe_url(request.values.get('origin'))
if request.method == 'GET':
return self.render_template(
@@ -1128,7 +1145,7 @@
def clear(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)
execution_date = request.form.get('execution_date')
@@ -1158,7 +1175,7 @@
@action_logging
def dagrun_clear(self):
dag_id = request.form.get('dag_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
@@ -1280,7 +1297,7 @@
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_failed(dag_id, execution_date,
confirmed, origin)
@@ -1292,7 +1309,7 @@
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_success(dag_id, execution_date,
confirmed, origin)
@@ -1345,7 +1362,7 @@
def failed(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
@@ -1365,7 +1382,7 @@
def success(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index ac71ebb..438830c 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -37,6 +37,7 @@
from airflow.operators.bash_operator import BashOperator
from airflow.utils import timezone
from airflow.utils.db import create_session
+from parameterized import parameterized
from tests.compat import mock
from six.moves.urllib.parse import quote_plus
@@ -1115,6 +1116,28 @@
'Triggered example_bash_operator, it should start any moment now.',
response.data.decode('utf-8'))
+ @parameterized.expand([
+ ("javascript:alert(1)", "/admin/"),
+ ("http://google.com", "/admin/"),
+ (
+ "%2Fadmin%2Fairflow%2Ftree%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator",
+ "/admin/airflow/tree?dag_id=example_bash_operator"
+ ),
+ (
+ "%2Fadmin%2Fairflow%2Fgraph%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator",
+ "/admin/airflow/graph?dag_id=example_bash_operator"
+ ),
+ ("", ""),
+ ])
+ def test_trigger_dag_form_origin_url(self, test_origin, expected_origin):
+ test_dag_id = "example_bash_operator"
+ response = self.app.get(
+ '/admin/airflow/trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin))
+ self.assertIn(
+ '<button class="btn" onclick="location.href = \'{}\'; return false">'.format(
+ expected_origin),
+ response.data.decode('utf-8'))
+
class HelpersTest(unittest.TestCase):
@classmethod
diff --git a/tests/www_rbac/test_views.py b/tests/www_rbac/test_views.py
index 33a8338..4e06b57 100644
--- a/tests/www_rbac/test_views.py
+++ b/tests/www_rbac/test_views.py
@@ -2244,6 +2244,22 @@
self.check_content_in_response(
'Triggered example_bash_operator, it should start any moment now.', response)
+ @parameterized.expand([
+ ("javascript:alert(1)", "/home"),
+ ("http://google.com", "/home"),
+ ("%2Ftree%3Fdag_id%3Dexample_bash_operator", "/tree?dag_id=example_bash_operator"),
+ ("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "/graph?dag_id=example_bash_operator"),
+ ("", ""),
+ ])
+ def test_trigger_dag_form_origin_url(self, test_origin, expected_origin):
+ test_dag_id = "example_bash_operator"
+
+ resp = self.client.get('trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin))
+ self.check_content_in_response(
+ '<button class="btn" onclick="location.href = \'{}\'; return false">'.format(
+ expected_origin),
+ resp)
+
@mock.patch('airflow.www_rbac.views.dagbag.get_dag')
def test_trigger_endpoint_uses_existing_dagbag(self, mock_get_dag):
"""