| #!/usr/bin/env python3 |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Apache Pony Mail, Codename Foal - A Python variant of Pony Mail""" |
| import argparse |
| import asyncio |
| import importlib |
| import json |
| import os |
| import sys |
| from time import sleep |
| import traceback |
| import typing |
| |
| import aiohttp.web |
| import yaml |
| import uuid |
| |
| import plugins.background |
| import plugins.configuration |
| import plugins.database |
| import plugins.formdata |
| import plugins.offloader |
| import plugins.server |
| import plugins.session |
| |
| PONYMAIL_FOAL_VERSION = "0.1.0" |
| from server_version import PONYMAIL_SERVER_VERSION |
| |
| |
| # Certain environments such as MinGW-w64 will not register as a TTY and uses buffered output. |
| # In such cases, we need to force a flush of each print, or nothing will show. |
| if not sys.stdout.buffer.isatty(): |
| import functools |
| print = functools.partial(print, flush=True) |
| |
| |
| class Server(plugins.server.BaseServer): |
| """Main server class, responsible for handling requests and scheduling offloader threads """ |
| |
| def _load_endpoint(self, subdir): |
| for endpoint_file in sorted(os.listdir(subdir)): |
| if endpoint_file.endswith(".py"): |
| endpoint = endpoint_file[:-3] |
| m = importlib.import_module(f"{subdir}.{endpoint}") |
| if hasattr(m, "register"): |
| self.handlers[endpoint] = m.__getattribute__("register")(self) |
| print(f"Registered endpoint /api/{endpoint}") |
| else: |
| print( |
| f"Could not find entry point 'register()' in {endpoint_file}, skipping!" |
| ) |
| |
| def __init__(self, args: argparse.Namespace): |
| print( |
| "==== Apache Pony Mail (Foal v/%s ~%s) starting... ====" % (PONYMAIL_FOAL_VERSION, PONYMAIL_SERVER_VERSION) |
| ) |
| # Load configuration |
| yml = yaml.safe_load(open(args.config)) |
| self.config = plugins.configuration.Configuration(yml) |
| self.data = plugins.configuration.InterData() |
| self.handlers = dict() |
| self.dbpool = asyncio.Queue() |
| self.runners = plugins.offloader.ExecutorPool() |
| self.server = None |
| self.streamlock = asyncio.Lock() |
| self.api_logger = None |
| self.foal_version = PONYMAIL_FOAL_VERSION |
| self.server_version = PONYMAIL_SERVER_VERSION |
| self.stoppable = False # allow remote stop for tests |
| self.background_event = asyncio.Event() # for background task to wait on |
| |
| # Make a pool of database connections for async queries |
| pool_size = self.config.database.pool_size |
| if pool_size < 1: |
| raise ValueError(f"pool_size {pool_size} must be > 0") |
| for _ in range(0, pool_size): # stop value is exclusive |
| self.dbpool.put_nowait(plugins.database.Database(self.config.database)) |
| |
| # Load each URL endpoint |
| if args.testendpoints: |
| print("** Loading additional testing endpoints **") |
| self._load_endpoint("testendpoints") |
| print() |
| self._load_endpoint("endpoints") |
| |
| if args.logger: |
| import logging |
| es_logger = logging.getLogger('elasticsearch') |
| es_logger.setLevel(args.logger) |
| es_logger.addHandler(logging.StreamHandler()) |
| if args.trace: |
| import logging |
| es_trace_logger = logging.getLogger('elasticsearch.trace') |
| es_trace_logger.setLevel(args.trace) |
| es_trace_logger.addHandler(logging.StreamHandler()) |
| if args.apilog: |
| import logging |
| self.api_logger = logging.getLogger('ponymail.apilog') |
| self.api_logger.setLevel(args.apilog) |
| self.api_logger.addHandler(logging.StreamHandler()) |
| self.stoppable = args.stoppable |
| self.refreshable = args.refreshable |
| |
| async def handle_request( |
| self, request: aiohttp.web.BaseRequest |
| ) -> typing.Union[aiohttp.web.Response, aiohttp.web.StreamResponse]: |
| """Generic handler for all incoming HTTP requests""" |
| |
| # Define response headers first... |
| headers = { |
| "Server": "Apache Pony Mail (Foal/%s ~%s)" % (PONYMAIL_FOAL_VERSION, PONYMAIL_SERVER_VERSION), |
| } |
| |
| if self.api_logger: |
| self.api_logger.info(request.raw_path) |
| |
| # Figure out who is going to handle this request, if any |
| # We are backwards compatible with the old Lua interface URLs |
| body_type = "form" |
| # Support URLs of form /api/handler/extra?query |
| parts = request.path.split("/") |
| if len(parts) < 3: |
| return aiohttp.web.Response( |
| headers=headers, status=404, text="API Endpoint not found!" |
| ) |
| handler = parts[2] |
| # handle test requests |
| if self.stoppable and handler == 'stop': |
| self.background_event.set() |
| return aiohttp.web.Response(headers=headers, status=200, text='Stop requested\n') |
| if self.refreshable and handler == 'refresh': |
| await plugins.background.get_data(self) |
| return aiohttp.web.Response(headers=headers, status=200, text='Refresh performed\n') |
| |
| if handler.endswith(".lua"): |
| body_type = "form" |
| handler = handler[:-4] |
| if handler.endswith(".json"): |
| body_type = "json" |
| handler = handler[:-5] |
| |
| # Parse form data if any |
| try: |
| indata = await plugins.formdata.parse_formdata(body_type, request) |
| if self.api_logger: |
| self.api_logger.info(indata) |
| except ValueError as e: |
| return aiohttp.web.Response(headers=headers, status=400, text=str(e)) |
| |
| # Find a handler, or 404 |
| if handler in self.handlers: |
| session = await plugins.session.get_session(self, request) |
| try: |
| # Wait for endpoint response. This is typically JSON in case of success, |
| # but could be an exception (that needs a traceback) OR |
| # it could be a custom response, which we just pass along to the client. |
| xhandler = self.handlers[handler] |
| if isinstance(xhandler, plugins.server.StreamingEndpoint): |
| output = await xhandler.exec(self, request, session, indata) |
| elif isinstance(xhandler, plugins.server.Endpoint): |
| output = await xhandler.exec(self, session, indata) |
| if session.database: |
| self.dbpool.put_nowait(session.database) |
| self.dbpool.task_done() |
| session.database = None |
| if isinstance(output, aiohttp.web.Response) or isinstance(output, aiohttp.web.StreamResponse): |
| return output |
| if output: |
| jsout = await self.runners.run(json.dumps, output, indent=2) |
| headers["content-type"] = "application/json" |
| headers["Content-Length"] = str(len(jsout)) |
| return aiohttp.web.Response(headers=headers, status=200, text=jsout) |
| return aiohttp.web.Response( |
| headers=headers, status=404, text="Content not found" |
| ) |
| # If a handler hit an exception, we need to print that exception somewhere, |
| # either to the web client or stderr: |
| except: |
| if session.database: |
| self.dbpool.put_nowait(session.database) |
| self.dbpool.task_done() |
| session.database = None |
| exc_type, exc_value, exc_traceback = sys.exc_info() |
| err = "\n".join( |
| traceback.format_exception(exc_type, exc_value, exc_traceback) |
| ) |
| # By default, we print the traceback to the user, for easy debugging. |
| if self.config.ui.traceback: |
| return aiohttp.web.Response( |
| headers=headers, status=500, text="API error occurred: \n" + err |
| ) |
| # If client traceback is disabled, we print it to stderr instead, but leave an |
| # error ID for the client to report back to the admin. Every line of the traceback |
| # will have this error ID at the beginning of the line, for easy grepping. |
| # We only need a short ID here, let's pick 18 chars. |
| eid = str(uuid.uuid4())[:18] |
| sys.stderr.write("API Endpoint %s got into trouble (%s): \n" % (request.path, eid)) |
| for line in err.split("\n"): |
| sys.stderr.write("%s: %s\n" % (eid, line)) |
| return aiohttp.web.Response( |
| headers=headers, status=500, text="API error occurred. The application journal will have " |
| "information. Error ID: %s" % eid |
| ) |
| else: |
| return aiohttp.web.Response( |
| headers=headers, status=404, text="API Endpoint not found!" |
| ) |
| |
| async def server_loop(self): |
| self.server = aiohttp.web.Server(self.handle_request) |
| runner = aiohttp.web.ServerRunner(self.server) |
| await runner.setup() |
| site = aiohttp.web.TCPSite( |
| runner, self.config.server.ip, self.config.server.port |
| ) |
| await site.start() |
| print( |
| "==== Serving up Pony goodness at %s:%s ====" |
| % (self.config.server.ip, self.config.server.port) |
| ) |
| await plugins.background.run_tasks(self) |
| await self.cleanup() |
| await site.stop() # try to clean up |
| |
| async def cleanup(self): |
| while not self.dbpool.empty(): |
| await self.dbpool.get_nowait().client.close() |
| |
| def run(self): |
| # get_event_loop is deprecated in 3.10, but the replacment new_event_loop |
| # does not seem to work properly in earlier versions |
| if sys.version_info.minor < 10: |
| loop = asyncio.get_event_loop() |
| else: |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| try: |
| loop.run_until_complete(self.server_loop()) |
| except KeyboardInterrupt: |
| self.background_event.set() |
| loop.run_until_complete(self.cleanup()) |
| loop.close() |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--config", |
| help="Configuration file to load (default: ponymail.yaml)", |
| default="ponymail.yaml", |
| ) |
| parser.add_argument( |
| "--logger", |
| help="elasticsearch level (e.g. INFO or DEBUG)", |
| ) |
| parser.add_argument( |
| "--trace", |
| help="elasticsearch.trace level (e.g. INFO or DEBUG)", |
| ) |
| parser.add_argument( |
| "--apilog", |
| help="api log level (e.g. INFO or DEBUG)", |
| ) |
| parser.add_argument( |
| "--stoppable", |
| action='store_true', |
| help="Allow remote stop for testing", |
| ) |
| parser.add_argument( |
| "--refreshable", |
| action='store_true', |
| help="Allow remote refresh for testing", |
| ) |
| parser.add_argument( |
| "--testendpoints", |
| action='store_true', |
| help="Enable test endpoints", |
| ) |
| cliargs = parser.parse_args() |
| Server(cliargs).run() |