Adds a feature where applications can define which URLs are safe to redirect to.
Thanks to Christian Iacullo for reporting it.
Fixes #24
diff --git a/ChangeLog b/ChangeLog
index f83b147..06c706e 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -10,6 +10,8 @@
- Added support for optional sreg keys. [Shay Erlichmen, Patrick Uiterwijk]
+- Added the option to declare which URL roots are safe to redirect to [Patrick Uiterwijk]
+
Fix
~~~
diff --git a/example/example.py b/example/example.py
index b155998..1ec7bb6 100644
--- a/example/example.py
+++ b/example/example.py
@@ -27,7 +27,7 @@
)
# setup flask-openid
-oid = OpenID(app)
+oid = OpenID(app, safe_roots=[])
# setup sqlalchemy
engine = create_engine(app.config['DATABASE_URI'])
diff --git a/flask_openid.py b/flask_openid.py
index 15501d4..e54fbff 100644
--- a/flask_openid.py
+++ b/flask_openid.py
@@ -326,7 +326,7 @@
"""
def __init__(self, app=None, fs_store_path=None, store_factory=None,
- fallback_endpoint=None, extension_responses=[]):
+ fallback_endpoint=None, extension_responses=[], safe_roots=None):
# backwards compatibility support
if isinstance(app, basestring):
from warnings import warn
@@ -350,6 +350,10 @@
self.after_login_func = None
self.fallback_endpoint = fallback_endpoint
self.extension_responses = extension_responses
+ if isinstance(safe_roots, basestring):
+ self.safe_roots = [safe_roots]
+ else:
+ self.safe_roots = safe_roots
def init_app(self, app):
"""This callback can be used to initialize an application for the
@@ -394,12 +398,26 @@
always return a valid URL.
"""
return (
- request.values.get('next') or
- request.referrer or
- (self.fallback_endpoint and url_for(self.fallback_endpoint)) or
+ self.check_safe_root(request.values.get('next')) or
+ self.check_safe_root(request.referrer) or
+ (self.fallback_endpoint and self.check_safe_root(url_for(self.fallback_endpoint))) or
request.url_root
)
+ def check_safe_root(self, url):
+ if url is None:
+ return None
+ if self.safe_roots is None:
+ return url
+ if url.startswith(request.url_root) or url.startswith('/'):
+ # A URL inside the same app is deemed to always be safe
+ return url
+ for safe_root in self.safe_roots:
+ if url.startswith(safe_root):
+ return url
+ print 'Unsafe url: %s' % url
+ return None
+
def get_current_url(self):
"""the current URL + next."""
return request.base_url + '?next=' + url_quote(self.get_next_url())