blob: 0a2a606c75120a254913aa8893ed1002771a0155 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
``ctx`` proxy server implementation.
"""
import collections
import json
import re
import socket
import threading
import traceback
import Queue
import StringIO
import wsgiref.simple_server
import bottle
from aria import modeling
from .. import exceptions
class CtxProxy(object):
def __init__(self, ctx, ctx_patcher=(lambda *args, **kwargs: None)):
self.ctx = ctx
self._ctx_patcher = ctx_patcher
self.port = _get_unused_port()
self.socket_url = 'http://localhost:{0}'.format(self.port)
self.server = None
self._started = Queue.Queue(1)
self.thread = self._start_server()
self._started.get(timeout=5)
def _start_server(self):
class BottleServerAdapter(bottle.ServerAdapter):
proxy = self
def close_session(self):
self.proxy.ctx.model.log._session.remove()
def run(self, app):
class Server(wsgiref.simple_server.WSGIServer):
allow_reuse_address = True
bottle_server = self
def handle_error(self, request, client_address):
pass
def serve_forever(self, poll_interval=0.5):
try:
wsgiref.simple_server.WSGIServer.serve_forever(self, poll_interval)
finally:
# Once shutdown is called, we need to close the session.
# If the session is not closed properly, it might raise warnings,
# or even lock the database.
self.bottle_server.close_session()
class Handler(wsgiref.simple_server.WSGIRequestHandler):
def address_string(self):
return self.client_address[0]
def log_request(*args, **kwargs): # pylint: disable=no-method-argument
if not self.quiet:
return wsgiref.simple_server.WSGIRequestHandler.log_request(*args,
**kwargs)
server = wsgiref.simple_server.make_server(
host=self.host,
port=self.port,
app=app,
server_class=Server,
handler_class=Handler)
self.proxy.server = server
self.proxy._started.put(True)
server.serve_forever(poll_interval=0.1)
def serve():
# Since task is a thread_local object, we need to patch it inside the server thread.
self._ctx_patcher(self.ctx)
bottle_app = bottle.Bottle()
bottle_app.post('/', callback=self._request_handler)
bottle.run(
app=bottle_app,
host='localhost',
port=self.port,
quiet=True,
server=BottleServerAdapter)
thread = threading.Thread(target=serve)
thread.start()
return thread
def close(self):
if self.server:
self.server.shutdown()
self.server.server_close()
def _request_handler(self):
request = bottle.request.body.read() # pylint: disable=no-member
response = self._process(request)
return bottle.LocalResponse(
body=json.dumps(response, cls=modeling.utils.ModelJSONEncoder),
status=200,
headers={'content-type': 'application/json'}
)
def _process(self, request):
try:
with self.ctx.model.instrument(*self.ctx.INSTRUMENTATION_FIELDS):
typed_request = json.loads(request)
args = typed_request['args']
payload = _process_ctx_request(self.ctx, args)
result_type = 'result'
if isinstance(payload, exceptions.ScriptException):
payload = dict(message=str(payload))
result_type = 'stop_operation'
result = {'type': result_type, 'payload': payload}
except Exception as e:
traceback_out = StringIO.StringIO()
traceback.print_exc(file=traceback_out)
payload = {
'type': type(e).__name__,
'message': str(e),
'traceback': traceback_out.getvalue()
}
result = {'type': 'error', 'payload': payload}
return result
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
self.close()
def _process_ctx_request(ctx, args):
current = ctx
num_args = len(args)
index = 0
while index < num_args:
arg = args[index]
attr = _desugar_attr(current, arg)
if attr:
current = getattr(current, attr)
elif isinstance(current, collections.MutableMapping):
key = arg
path_dict = _PathDictAccess(current)
if index + 1 == num_args:
# read dict prop by path
value = path_dict.get(key)
current = value
elif index + 2 == num_args:
# set dict prop by path
value = args[index + 1]
path_dict.set(key, value)
current = None
else:
raise RuntimeError('Illegal argument while accessing dict')
break
elif callable(current):
kwargs = {}
remaining_args = args[index:]
if isinstance(remaining_args[-1], collections.MutableMapping):
kwargs = remaining_args[-1]
remaining_args = remaining_args[:-1]
current = current(*remaining_args, **kwargs)
break
else:
raise RuntimeError('{0} cannot be processed in {1}'.format(arg, args))
index += 1
if callable(current):
current = current()
return current
def _desugar_attr(obj, attr):
if not isinstance(attr, basestring):
return None
if hasattr(obj, attr):
return attr
attr = attr.replace('-', '_')
if hasattr(obj, attr):
return attr
return None
class _PathDictAccess(object):
pattern = re.compile(r"(.+)\[(\d+)\]")
def __init__(self, obj):
self.obj = obj
def set(self, prop_path, value):
obj, prop_name = self._get_parent_obj_prop_name_by_path(prop_path)
obj[prop_name] = value
def get(self, prop_path):
value = self._get_object_by_path(prop_path)
return value
def _get_object_by_path(self, prop_path, fail_on_missing=True):
# when setting a nested object, make sure to also set all the
# intermediate path objects
current = self.obj
for prop_segment in prop_path.split('.'):
match = self.pattern.match(prop_segment)
if match:
index = int(match.group(2))
property_name = match.group(1)
if property_name not in current:
self._raise_illegal(prop_path)
if not isinstance(current[property_name], list):
self._raise_illegal(prop_path)
current = current[property_name][index]
else:
if prop_segment not in current:
if fail_on_missing:
self._raise_illegal(prop_path)
else:
current[prop_segment] = {}
current = current[prop_segment]
return current
def _get_parent_obj_prop_name_by_path(self, prop_path):
split = prop_path.split('.')
if len(split) == 1:
return self.obj, prop_path
parent_path = '.'.join(split[:-1])
parent_obj = self._get_object_by_path(parent_path, fail_on_missing=False)
prop_name = split[-1]
return parent_obj, prop_name
@staticmethod
def _raise_illegal(prop_path):
raise RuntimeError('illegal path: {0}'.format(prop_path))
def _get_unused_port():
sock = socket.socket()
sock.bind(('127.0.0.1', 0))
_, port = sock.getsockname()
sock.close()
return port