| #!/usr/bin/env python3 |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """PyPubSub - a simple publisher/subscriber service written in Python 3""" |
| import asyncio |
| import aiohttp.web |
| import aiofile |
| import os |
| import time |
| import json |
| import yaml |
| import netaddr |
| import binascii |
| import base64 |
| import argparse |
| import collections |
| import plugins.ldap |
| import plugins.sqs |
| import typing |
| |
| # Some consts |
| PUBSUB_VERSION = '0.7.2' |
| PUBSUB_CONTENT_TYPE = 'application/vnd.pypubsub-stream' |
| PUBSUB_DEFAULT_PORT = 2069 |
| PUBSUB_DEFAULT_IP = '0.0.0.0' |
| PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE = 102400 |
| PUBSUB_DEFAULT_BACKLOG_SIZE = 0 |
| PUBSUB_DEFAULT_BACKLOG_AGE = 0 |
| PUBSUB_BAD_REQUEST = "I could not understand your request, sorry! Please see https://pubsub.apache.org/api.html \ |
| for usage documentation.\n" |
| PUBSUB_PAYLOAD_RECEIVED = "Payload received, thank you very much!\n" |
| PUBSUB_NOT_ALLOWED = "You are not authorized to deliver payloads!\n" |
| PUBSUB_BAD_PAYLOAD = "Bad payload type. Payloads must be JSON dictionary objects, {..}!\n" |
| PUBSUB_PAYLOAD_TOO_LARGE = "Payload is too large for me to serve, please make it shorter.\n" |
| PUBSUB_WRITE_TIMEOUT = 0.35 # If we can't deliver to a pipe within N seconds, drop it. |
| |
| |
| class ServerConfig(typing.NamedTuple): |
| ip: str |
| port: int |
| payload_limit: int |
| |
| |
| class BacklogConfig(typing.NamedTuple): |
| max_age: int |
| queue_size: int |
| storage: typing.Optional[str] |
| |
| |
| class Configuration: |
| server: ServerConfig |
| backlog: BacklogConfig |
| payloaders: typing.List[netaddr.ip.IPNetwork] |
| oldschoolers: typing.List[str] |
| |
| def __init__(self, yml: dict): |
| |
| # LDAP Settings |
| self.ldap = None |
| lyml = yml.get('clients', {}).get('ldap') |
| if isinstance(lyml, dict): |
| self.ldap = plugins.ldap.LDAPConnection(lyml) |
| |
| # SQS? |
| self.sqs = yml.get('sqs') |
| |
| # Main server config |
| server_ip = yml['server'].get('bind', PUBSUB_DEFAULT_IP) |
| server_port = int(yml['server'].get('port', PUBSUB_DEFAULT_PORT)) |
| server_payload_limit = int(yml['server'].get('max_payload_size', PUBSUB_DEFAULT_MAX_PAYLOAD_SIZE)) |
| self.server = ServerConfig(ip=server_ip, port=server_port, payload_limit=server_payload_limit) |
| |
| # Backlog settings |
| bma = yml['server'].get('backlog', {}).get('max_age', PUBSUB_DEFAULT_BACKLOG_AGE) |
| if isinstance(bma, str): |
| bma = bma.lower() |
| if bma.endswith('s'): |
| bma = int(bma.replace('s', '')) |
| elif bma.endswith('m'): |
| bma = int(bma.replace('m', '')) * 60 |
| elif bma.endswith('h'): |
| bma = int(bma.replace('h', '')) * 3600 |
| elif bma.endswith('d'): |
| bma = int(bma.replace('d', '')) * 86400 |
| bqs = yml['server'].get('backlog', {}).get('size', |
| PUBSUB_DEFAULT_BACKLOG_SIZE) |
| bst = yml['server'].get('backlog', {}).get('storage') |
| self.backlog = BacklogConfig(max_age=bma, queue_size=bqs, storage=bst) |
| |
| # Payloaders - clients that can post payloads |
| self.payloaders = [netaddr.IPNetwork(x) for x in yml['clients'].get('payloaders', [])] |
| |
| # Binary backwards compatibility |
| self.oldschoolers = yml['clients'].get('oldschoolers', []) |
| |
| |
| class Server: |
| """Main server class, responsible for handling requests and publishing events """ |
| yaml: dict |
| config: Configuration |
| subscribers: list |
| pending_events: asyncio.Queue |
| backlog: list |
| last_ping = typing.Type[float] |
| server: aiohttp.web.Server |
| |
| def __init__(self, args: argparse.Namespace): |
| self.yaml = yaml.safe_load(open(args.config)) |
| self.config = Configuration(self.yaml) |
| self.subscribers = [] |
| self.pending_events = asyncio.Queue() |
| self.backlog = [] |
| self.last_ping = time.time() |
| self.acl = {} |
| try: |
| self.acl = yaml.safe_load(open(args.acl)) |
| except FileNotFoundError: |
| print(f"ACL configuration file {args.acl} not found, private events will not be broadcast.") |
| |
| async def poll(self): |
| """Polls for new stuff to publish, and if found, publishes to whomever wants it.""" |
| while True: |
| payload: Payload = await self.pending_events.get() |
| bad_subs: list = await payload.publish(self.subscribers) |
| self.pending_events.task_done() |
| |
| # Cull subscribers we couldn't deliver payload to. |
| for bad_sub in bad_subs: |
| print("Culling %r due to connection errors" % bad_sub) |
| try: |
| self.subscribers.remove(bad_sub) |
| except ValueError: # Already removed elsewhere |
| pass |
| |
| async def handle_request(self, request: aiohttp.web.BaseRequest): |
| """Generic handler for all incoming HTTP requests""" |
| resp: typing.Union[aiohttp.web.Response, aiohttp.web.StreamResponse] |
| |
| # Define response headers first... |
| headers = { |
| 'Server': 'PyPubSub/%s' % PUBSUB_VERSION, |
| 'X-Subscribers': str(len(self.subscribers)), |
| 'X-Requests': str(self.server.requests_count), |
| } |
| |
| # Are we handling a publisher payload request? (PUT/POST) |
| if request.method in ['PUT', 'POST']: |
| ip = netaddr.IPAddress(request.remote) |
| allowed = False |
| for network in self.config.payloaders: |
| if ip in network: |
| allowed = True |
| break |
| if not allowed: |
| resp = aiohttp.web.Response(headers=headers, status=403, text=PUBSUB_NOT_ALLOWED) |
| return resp |
| if request.can_read_body: |
| try: |
| if request.content_length and request.content_length > self.config.server.payload_limit: |
| resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_PAYLOAD_TOO_LARGE) |
| return resp |
| body = await request.text() |
| as_json = json.loads(body) |
| assert isinstance(as_json, dict) # Payload MUST be an dictionary object, {...} |
| pl = Payload(request.path, as_json) |
| self.pending_events.put_nowait(pl) |
| # Add to backlog? |
| if self.config.backlog.queue_size > 0: |
| self.backlog.append(pl) |
| # If backlog has grown too large, delete the first (oldest) item in it. |
| while len(self.backlog) > self.config.backlog.queue_size: |
| del self.backlog[0] |
| |
| resp = aiohttp.web.Response(headers=headers, status=202, text=PUBSUB_PAYLOAD_RECEIVED) |
| return resp |
| except json.decoder.JSONDecodeError: |
| resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST) |
| return resp |
| except AssertionError: |
| resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_PAYLOAD) |
| return resp |
| # Is this a subscriber request? (GET) |
| elif request.method == 'GET': |
| resp = aiohttp.web.StreamResponse(headers=headers) |
| # We do not support HTTP 1.0 here... |
| if request.version.major == 1 and request.version.minor == 0: |
| return resp |
| subscriber = Subscriber(self, resp, request) |
| |
| # Is there a basic auth in this request? If so, set up ACL |
| auth = request.headers.get('Authorization') |
| if auth: |
| subscriber.acl = await subscriber.parse_acl(auth) |
| |
| # Subscribe the user before we deal with the potential backlog request and pings |
| self.subscribers.append(subscriber) |
| resp.content_type = PUBSUB_CONTENT_TYPE |
| try: |
| resp.enable_chunked_encoding() |
| await resp.prepare(request) |
| |
| # Is the client requesting a backlog of items? |
| backlog = request.headers.get('X-Fetch-Since') |
| if backlog: |
| try: |
| backlog_ts = int(backlog) |
| except ValueError: # Default to 0 if we can't parse the epoch |
| backlog_ts = 0 |
| # If max age is specified, force the TS to minimum that age |
| if self.config.backlog.max_age > 0: |
| backlog_ts = max(backlog_ts, int(time.time() - self.config.backlog.max_age)) |
| # For each item, publish to client if new enough. |
| for item in self.backlog: |
| if item.timestamp >= backlog_ts: |
| await item.publish([subscriber]) |
| |
| while True: |
| await subscriber.ping() |
| if subscriber not in self.subscribers: # If we got dislodged somehow, end session |
| break |
| await asyncio.sleep(5) |
| # We may get exception types we don't have imported, so grab ANY exception and kick out the subscriber |
| except: |
| pass |
| if subscriber in self.subscribers: |
| self.subscribers.remove(subscriber) |
| return resp |
| elif request.method == 'HEAD': |
| resp = aiohttp.web.Response(headers=headers, status=204, text="") |
| return resp |
| # I don't know this type of request :/ (DELETE, PATCH, etc) |
| else: |
| resp = aiohttp.web.Response(headers=headers, status=400, text=PUBSUB_BAD_REQUEST) |
| return resp |
| |
| async def write_backlog_storage(self): |
| previous_backlog = [] |
| while True: |
| if self.config.backlog.storage: |
| try: |
| backlog_list = self.backlog.copy() |
| if backlog_list != previous_backlog: |
| previous_backlog = backlog_list |
| async with aiofile.AIOFile(self.config.backlog.storage, 'w+') as afp: |
| offset = 0 |
| for item in backlog_list: |
| js =json.dumps({ |
| 'timestamp': item.timestamp, |
| 'topics': item.topics, |
| 'json': item.json, |
| 'private': item.private |
| }) + '\n' |
| await afp.write(js, offset=offset) |
| offset += len(js) |
| await afp.fsync() |
| except Exception as e: |
| print(f"Could not write to backlog file {self.config.backlog.storage}: {e}") |
| await asyncio.sleep(10) |
| |
| def read_backlog_storage(self): |
| if self.config.backlog.storage and os.path.exists(self.config.backlog.storage): |
| try: |
| readlines = 0 |
| with open(self.config.backlog.storage, 'r') as fp: |
| for line in fp.readlines(): |
| js = json.loads(line) |
| readlines += 1 |
| ppath = "/".join(js['topics']) |
| if js['private']: |
| ppath = '/private/' + ppath |
| payload = Payload(ppath, js['json'], js['timestamp']) |
| self.backlog.append(payload) |
| if self.config.backlog.queue_size < len(self.backlog): |
| self.backlog.pop(0) |
| except Exception as e: |
| print(f"Error while reading backlog: {e}") |
| |
| print(f"Read {readlines} objects from {self.config.backlog.storage}, applied {len(self.backlog)} to backlog.") |
| |
| async def server_loop(self, loop: asyncio.BaseEventLoop): |
| self.server = aiohttp.web.Server(self.handle_request) |
| runner = aiohttp.web.ServerRunner(self.server) |
| await runner.setup() |
| site = aiohttp.web.TCPSite(runner, self.config.server.ip, self.config.server.port) |
| await site.start() |
| print("==== PyPubSub v/%s starting... ====" % PUBSUB_VERSION) |
| print("==== Serving up PubSub goodness at %s:%s ====" % ( |
| self.config.server.ip, self.config.server.port)) |
| if self.config.sqs: |
| for key, config in self.config.sqs.items(): |
| loop.create_task(plugins.sqs.get_payloads(self, config)) |
| self.read_backlog_storage() |
| loop.create_task(self.write_backlog_storage()) |
| await self.poll() |
| |
| def run(self): |
| loop = asyncio.get_event_loop() |
| try: |
| loop.run_until_complete(self.server_loop(loop)) |
| except KeyboardInterrupt: |
| pass |
| loop.close() |
| |
| |
| class Subscriber: |
| """Basic subscriber (client) class. Holds information about the connection and ACL""" |
| acl: dict |
| topics: typing.List[typing.List[str]] |
| |
| def __init__(self, server: Server, connection: aiohttp.web.StreamResponse, request: aiohttp.web.BaseRequest): |
| self.connection = connection |
| self.acl = {} |
| self.server = server |
| self.lock = asyncio.Lock() |
| |
| # Set topics subscribed to |
| self.topics = [] |
| for topic_batch in request.path.split(','): |
| sub_to = [x for x in topic_batch.split('/') if x] |
| self.topics.append(sub_to) |
| |
| # Is the client old and expecting zero-terminators? |
| self.old_school = False |
| for ua in self.server.config.oldschoolers: |
| if ua in request.headers.get('User-Agent', ''): |
| self.old_school = True |
| break |
| |
| async def parse_acl(self, basic: str): |
| """Sets the ACL if possible, based on Basic Auth""" |
| try: |
| decoded = str(base64.decodebytes(bytes(basic.replace('Basic ', ''), 'ascii')), 'utf-8') |
| u, p = decoded.split(':', 1) |
| if u in self.server.acl: |
| acl_pass = self.server.acl[u].get('password') |
| if acl_pass and acl_pass == p: |
| acl = self.server.acl[u].get('acl', {}) |
| # Vet ACL for user |
| assert isinstance(acl, dict), f"ACL for user {u} " \ |
| f"must be a dictionary of sub-IDs and topics, but is not." |
| # Make sure each ACL segment is a list of topics |
| for k, v in acl.items(): |
| assert isinstance(v, list), f"ACL segment {k} for user {u} is not a list of topics!" |
| print(f"Client {u} successfully authenticated (and ACL is valid).") |
| return acl |
| elif self.server.config.ldap: |
| acl = {} |
| groups = await self.server.config.ldap.get_groups(u,p) |
| # Make sure each ACL segment is a list of topics |
| for k, v in self.server.config.ldap.acl.items(): |
| if k in groups: |
| assert isinstance(v, dict), f"ACL segment {k} for user {u} is not a dictionary of segments!" |
| for segment, topics in v.items(): |
| print(f"Enabling ACL segment {segment} for user {u}") |
| assert isinstance(topics, |
| list), f"ACL segment {segment} for user {u} is not a list of topics!" |
| acl[segment] = topics |
| return acl |
| except binascii.Error as e: |
| pass # Bad Basic Auth params, bail quietly |
| except AssertionError as e: |
| print(e) |
| print(f"ACL configuration error: ACL scheme for {u} contains errors, setting ACL to nothing.") |
| except Exception as e: |
| print(f"Basic unknown exception occurred: {e}") |
| return {} |
| |
| async def ping(self): |
| """Generic ping-back to the client""" |
| js = b"%s\n" % json.dumps({"stillalive": time.time()}).encode('utf-8') |
| if self.old_school: |
| js += b"\0" |
| async with self.lock: |
| await asyncio.wait_for(self.connection.write(js), timeout=PUBSUB_WRITE_TIMEOUT) |
| |
| |
| class Payload: |
| """A payload (event) object sent by a registered publisher.""" |
| |
| def __init__(self, path: str, data: dict, timestamp: typing.Optional[float] = None): |
| self.json = data |
| self.timestamp = timestamp or time.time() |
| self.topics = [x for x in path.split('/') if x] |
| self.private = False |
| |
| # Private payload? |
| if self.topics and self.topics[0] == 'private': |
| self.private = True |
| del self.topics[0] # Remove the private bit from topics now. |
| |
| self.json['pubsub_timestamp'] = self.timestamp |
| self.json['pubsub_topics'] = self.topics |
| self.json['pubsub_path'] = path |
| |
| async def publish(self, subscribers: typing.List[Subscriber]): |
| """Publishes an object to all subscribers using those topics (or a sub-set thereof)""" |
| js = b"%s\n" % json.dumps(self.json).encode('utf-8') |
| ojs = js + b"\0" |
| bad_subs = [] |
| for sub in subscribers: |
| # If a private payload, check ACL and bail if not a match |
| if self.private: |
| can_see = False |
| for key, private_topics in sub.acl.items(): |
| if all(el in self.topics for el in private_topics): |
| can_see = True |
| break |
| if not can_see: |
| continue |
| # If subscribed to all the topics, tell a subscriber about this |
| for topic_batch in sub.topics: |
| if all(el in self.topics for el in topic_batch): |
| try: |
| if sub.old_school: |
| async with sub.lock: |
| await asyncio.wait_for(sub.connection.write(ojs), timeout=PUBSUB_WRITE_TIMEOUT) |
| else: |
| async with sub.lock: |
| await asyncio.wait_for(sub.connection.write(js), timeout=PUBSUB_WRITE_TIMEOUT) |
| except Exception: |
| bad_subs.append(sub) |
| break |
| return bad_subs |
| |
| |
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", help="Configuration file to load (default: pypubsub.yaml)", default="pypubsub.yaml") |
| parser.add_argument("--acl", help="ACL Configuration file to load (default: pypubsub_acl.yaml)", |
| default="pypubsub_acl.yaml") |
| cliargs = parser.parse_args() |
| pubsub_server = Server(cliargs) |
| pubsub_server.run() |