Add in secure topics feature
diff --git a/pypubsub.py b/pypubsub.py
index 9f63562..a0c562d 100644
--- a/pypubsub.py
+++ b/pypubsub.py
@@ -67,6 +67,7 @@
backlog: BacklogConfig
payloaders: typing.List[netaddr.ip.IPNetwork]
oldschoolers: typing.List[str]
+ secure_topics: typing.Optional[typing.List[str]]
def __init__(self, yml: dict):
@@ -108,6 +109,9 @@
# Binary backwards compatibility
self.oldschoolers = yml['clients'].get('oldschoolers', [])
+ # Secure topics, if any
+ self.secure_topics = set(yml['clients'].get('secure_topics', []) or [])
+
class Server:
"""Main server class, responsible for handling requests and publishing events """
@@ -158,6 +162,12 @@
'X-Requests': str(self.server.requests_count),
}
+ subscriber = Subscriber(self, request)
+ # Is there a basic auth in this request? If so, set up ACL
+ auth = request.headers.get('Authorization')
+ if auth:
+ await subscriber.parse_acl(auth)
+
# Are we handling a publisher payload request? (PUT/POST)
if request.method in ['PUT', 'POST']:
ip = netaddr.IPAddress(request.remote)
@@ -166,6 +176,16 @@
if ip in network:
allowed = True
break
+ # Check for secure topics
+ payload_topics = set(request.path.split("/"))
+ if any(x in self.config.secure_topics for x in payload_topics):
+ allowed = False
+ # Figure out which secure topics we need permission for:
+ which_secure = [x for x in self.config.secure_topics if x in payload_topics]
+ # Is the user allowed to post to all of these secure topics?
+ if subscriber.secure_topics and all(x in subscriber.secure_topics for x in which_secure):
+ allowed = True
+
if not allowed:
resp = aiohttp.web.Response(headers=headers, status=403, text=PUBSUB_NOT_ALLOWED)
return resp
@@ -200,14 +220,9 @@
# We do not support HTTP 1.0 here...
if request.version.major == 1 and request.version.minor == 0:
return resp
- subscriber = Subscriber(self, resp, request)
-
- # Is there a basic auth in this request? If so, set up ACL
- auth = request.headers.get('Authorization')
- if auth:
- subscriber.acl = await subscriber.parse_acl(auth)
# Subscribe the user before we deal with the potential backlog request and pings
+ subscriber.connection = resp
self.subscribers.append(subscriber)
resp.content_type = PUBSUB_CONTENT_TYPE
try:
@@ -322,11 +337,12 @@
acl: dict
topics: typing.List[typing.List[str]]
- def __init__(self, server: Server, connection: aiohttp.web.StreamResponse, request: aiohttp.web.BaseRequest):
- self.connection = connection
+ def __init__(self, server: Server, request: aiohttp.web.BaseRequest):
+ self.connection: typing.Optional[aiohttp.web.StreamResponse] = None
self.acl = {}
self.server = server
self.lock = asyncio.Lock()
+ self.secure_topics = []
# Set topics subscribed to
self.topics = []
@@ -357,7 +373,8 @@
for k, v in acl.items():
assert isinstance(v, list), f"ACL segment {k} for user {u} is not a list of topics!"
print(f"Client {u} successfully authenticated (and ACL is valid).")
- return acl
+ self.acl = acl
+ self.secure_topics = set(self.server.acl[u].get('topics', []) or [])
elif self.server.config.ldap:
acl = {}
groups = await self.server.config.ldap.get_groups(u,p)
@@ -370,7 +387,8 @@
assert isinstance(topics,
list), f"ACL segment {segment} for user {u} is not a list of topics!"
acl[segment] = topics
- return acl
+ self.acl = acl
+
except binascii.Error as e:
pass # Bad Basic Auth params, bail quietly
except AssertionError as e:
@@ -378,7 +396,7 @@
print(f"ACL configuration error: ACL scheme for {u} contains errors, setting ACL to nothing.")
except Exception as e:
print(f"Basic unknown exception occurred: {e}")
- return {}
+
async def ping(self):
"""Generic ping-back to the client"""