blob: 660761c4b7feb5411cc8392fec6f1af08926717e [file] [log] [blame]
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Apache Pony Mail, Codename Foal - A Python variant of Pony Mail"""
import argparse
import asyncio
import importlib
import json
import os
import sys
from time import sleep
import traceback
import typing
import aiohttp.web
import yaml
import uuid
import plugins.background
import plugins.configuration
import plugins.database
import plugins.formdata
import plugins.offloader
import plugins.server
import plugins.session
PONYMAIL_FOAL_VERSION = "0.1.0"
from server_version import PONYMAIL_SERVER_VERSION
# Certain environments such as MinGW-w64 will not register as a TTY and uses buffered output.
# In such cases, we need to force a flush of each print, or nothing will show.
if not sys.stdout.buffer.isatty():
import functools
print = functools.partial(print, flush=True)
class Server(plugins.server.BaseServer):
"""Main server class, responsible for handling requests and scheduling offloader threads """
def _load_endpoint(self, subdir):
for endpoint_file in sorted(os.listdir(subdir)):
if endpoint_file.endswith(".py"):
endpoint = endpoint_file[:-3]
m = importlib.import_module(f"{subdir}.{endpoint}")
if hasattr(m, "register"):
self.handlers[endpoint] = m.__getattribute__("register")(self)
print(f"Registered endpoint /api/{endpoint}")
else:
print(
f"Could not find entry point 'register()' in {endpoint_file}, skipping!"
)
def __init__(self, args: argparse.Namespace):
print(
"==== Apache Pony Mail (Foal v/%s ~%s) starting... ====" % (PONYMAIL_FOAL_VERSION, PONYMAIL_SERVER_VERSION)
)
# Load configuration
yml = yaml.safe_load(open(args.config))
self.config = plugins.configuration.Configuration(yml)
self.data = plugins.configuration.InterData()
self.handlers = dict()
self.dbpool = asyncio.Queue()
self.runners = plugins.offloader.ExecutorPool()
self.server = None
self.streamlock = asyncio.Lock()
self.api_logger = None
self.foal_version = PONYMAIL_FOAL_VERSION
self.server_version = PONYMAIL_SERVER_VERSION
self.stoppable = False # allow remote stop for tests
self.background_event = asyncio.Event() # for background task to wait on
# Make a pool of database connections for async queries
pool_size = self.config.database.pool_size
if pool_size < 1:
raise ValueError(f"pool_size {pool_size} must be > 0")
for _ in range(0, pool_size): # stop value is exclusive
self.dbpool.put_nowait(plugins.database.Database(self.config.database))
# Load each URL endpoint
if args.testendpoints:
print("** Loading additional testing endpoints **")
self._load_endpoint("testendpoints")
print()
self._load_endpoint("endpoints")
if args.logger:
import logging
es_logger = logging.getLogger('elasticsearch')
es_logger.setLevel(args.logger)
es_logger.addHandler(logging.StreamHandler())
if args.trace:
import logging
es_trace_logger = logging.getLogger('elasticsearch.trace')
es_trace_logger.setLevel(args.trace)
es_trace_logger.addHandler(logging.StreamHandler())
if args.apilog:
import logging
self.api_logger = logging.getLogger('ponymail.apilog')
self.api_logger.setLevel(args.apilog)
self.api_logger.addHandler(logging.StreamHandler())
self.stoppable = args.stoppable
self.refreshable = args.refreshable
async def handle_request(
self, request: aiohttp.web.BaseRequest
) -> typing.Union[aiohttp.web.Response, aiohttp.web.StreamResponse]:
"""Generic handler for all incoming HTTP requests"""
# Define response headers first...
headers = {
"Server": "Apache Pony Mail (Foal/%s ~%s)" % (PONYMAIL_FOAL_VERSION, PONYMAIL_SERVER_VERSION),
}
if self.api_logger:
self.api_logger.info(request.raw_path)
# Figure out who is going to handle this request, if any
# We are backwards compatible with the old Lua interface URLs
body_type = "form"
# Support URLs of form /api/handler/extra?query
parts = request.path.split("/")
if len(parts) < 3:
return aiohttp.web.Response(
headers=headers, status=404, text="API Endpoint not found!"
)
handler = parts[2]
# handle test requests
if self.stoppable and handler == 'stop':
self.background_event.set()
return aiohttp.web.Response(headers=headers, status=200, text='Stop requested\n')
if self.refreshable and handler == 'refresh':
await plugins.background.get_data(self)
return aiohttp.web.Response(headers=headers, status=200, text='Refresh performed\n')
if handler.endswith(".lua"):
body_type = "form"
handler = handler[:-4]
if handler.endswith(".json"):
body_type = "json"
handler = handler[:-5]
# Parse form data if any
try:
indata = await plugins.formdata.parse_formdata(body_type, request)
if self.api_logger:
self.api_logger.info(indata)
except ValueError as e:
return aiohttp.web.Response(headers=headers, status=400, text=str(e))
# Find a handler, or 404
if handler in self.handlers:
session = await plugins.session.get_session(self, request)
try:
# Wait for endpoint response. This is typically JSON in case of success,
# but could be an exception (that needs a traceback) OR
# it could be a custom response, which we just pass along to the client.
xhandler = self.handlers[handler]
if isinstance(xhandler, plugins.server.StreamingEndpoint):
output = await xhandler.exec(self, request, session, indata)
elif isinstance(xhandler, plugins.server.Endpoint):
output = await xhandler.exec(self, session, indata)
if session.database:
self.dbpool.put_nowait(session.database)
self.dbpool.task_done()
session.database = None
if isinstance(output, aiohttp.web.Response) or isinstance(output, aiohttp.web.StreamResponse):
return output
if output:
jsout = await self.runners.run(json.dumps, output, indent=2)
headers["content-type"] = "application/json"
headers["Content-Length"] = str(len(jsout))
return aiohttp.web.Response(headers=headers, status=200, text=jsout)
return aiohttp.web.Response(
headers=headers, status=404, text="Content not found"
)
# If a handler hit an exception, we need to print that exception somewhere,
# either to the web client or stderr:
except:
if session.database:
self.dbpool.put_nowait(session.database)
self.dbpool.task_done()
session.database = None
exc_type, exc_value, exc_traceback = sys.exc_info()
err = "\n".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
# By default, we print the traceback to the user, for easy debugging.
if self.config.ui.traceback:
return aiohttp.web.Response(
headers=headers, status=500, text="API error occurred: \n" + err
)
# If client traceback is disabled, we print it to stderr instead, but leave an
# error ID for the client to report back to the admin. Every line of the traceback
# will have this error ID at the beginning of the line, for easy grepping.
# We only need a short ID here, let's pick 18 chars.
eid = str(uuid.uuid4())[:18]
sys.stderr.write("API Endpoint %s got into trouble (%s): \n" % (request.path, eid))
for line in err.split("\n"):
sys.stderr.write("%s: %s\n" % (eid, line))
return aiohttp.web.Response(
headers=headers, status=500, text="API error occurred. The application journal will have "
"information. Error ID: %s" % eid
)
else:
return aiohttp.web.Response(
headers=headers, status=404, text="API Endpoint not found!"
)
async def server_loop(self):
self.server = aiohttp.web.Server(self.handle_request)
runner = aiohttp.web.ServerRunner(self.server)
await runner.setup()
site = aiohttp.web.TCPSite(
runner, self.config.server.ip, self.config.server.port
)
await site.start()
print(
"==== Serving up Pony goodness at %s:%s ===="
% (self.config.server.ip, self.config.server.port)
)
await plugins.background.run_tasks(self)
await self.cleanup()
await site.stop() # try to clean up
async def cleanup(self):
while not self.dbpool.empty():
await self.dbpool.get_nowait().client.close()
def run(self):
# get_event_loop is deprecated in 3.10, but the replacment new_event_loop
# does not seem to work properly in earlier versions
if sys.version_info.minor < 10:
loop = asyncio.get_event_loop()
else:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.server_loop())
except KeyboardInterrupt:
self.background_event.set()
loop.run_until_complete(self.cleanup())
loop.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
help="Configuration file to load (default: ponymail.yaml)",
default="ponymail.yaml",
)
parser.add_argument(
"--logger",
help="elasticsearch level (e.g. INFO or DEBUG)",
)
parser.add_argument(
"--trace",
help="elasticsearch.trace level (e.g. INFO or DEBUG)",
)
parser.add_argument(
"--apilog",
help="api log level (e.g. INFO or DEBUG)",
)
parser.add_argument(
"--stoppable",
action='store_true',
help="Allow remote stop for testing",
)
parser.add_argument(
"--refreshable",
action='store_true',
help="Allow remote refresh for testing",
)
parser.add_argument(
"--testendpoints",
action='store_true',
help="Enable test endpoints",
)
cliargs = parser.parse_args()
Server(cliargs).run()