blob: a2db0909a83ca2c63f4168dea7523b85efeeadde [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import threading
import requests
import flask
import socket
import SocketServer
import ssl
from collections import defaultdict
from wsgiref.simple_server import make_server
# dict of testid -> {client_request, client_response}
REQUESTS = defaultdict(dict)
# TODO: some request/response class to load the various libary's implementations and allow for comparison
class TrackingRequests():
'''
This class gives you a "requests" like object that will return a dict of:
- client_request
- client_response
- server_request
- server_response
assuming the request is going to the instance of DynamicHTTPEndpoint this object
was created with
In general this is useful for a proxy testing framework beause you commonly
need to check that the proxy (for example) added a header to the request
before the origin got it.
'''
def __init__(self, endpoint):
self.endpoint = endpoint
def __getattr__(self, name):
def handlerFunction(*args,**kwargs):
func = getattr(requests, name)
# set some kwargs
# set the tracking header
if 'headers' not in kwargs:
kwargs['headers'] = {}
key = self.endpoint.get_tracking_key()
kwargs['headers'][self.endpoint.TRACKING_HEADER] = key
ret = {}
resp = func(*args, **kwargs)
server_resp = self.endpoint.get_tracking_by_key(key)
# TODO: create intermediate objects that you can compare
ret['client_request'] = resp.request
ret['client_response'] = resp
ret['server_request'] = server_resp['request']
ret['server_response'] = server_resp['response']
return ret
return handlerFunction
class DynamicHTTPEndpoint(threading.Thread):
'''
A threaded webserver which allows you to dynamically add/remove handlers.
This is implemented using flask (http://flask.pocoo.org/) primarily because
it is very common and (almost more importantly) *very* picky about http
semantics.
To use this in a TestCase you simply need to create the thread:
# create the thread object
http_endpoint = tsqa.endpoint.DynamicHTTPEndpoint(port=cls.endpoint_port)
# start the thread
http_endpoint.start()
# wait for the webserver to listen
http_endpoint.ready.wait()
At this point the webserver is listening and returning 404 for all requests.
To register an endpoint you must (1) define a request-handler function and
(2) add that handler to the http_endpoint.
(1): To define a request handler you must create a function which takes a single
argument which is the Request wrapper (http://werkzeug.pocoo.org/docs/0.10/wrappers/#werkzeug.wrappers.Request).
Flask support a variety or return types (http://flask.pocoo.org/docs/0.10/quickstart/#about-responses),
for this example we will simply return "hello world"
def handler_func(request):
return "hello world"
(2): Now that we have a function, we can add it as a handler to a context path
http_endpoint.add_handler('/hello', handler_func)
'''
TRACKING_HEADER = '__cool_test_header__' # TODO: better name?
@property
def address(self):
'''
Return a tuple of (ip, port) that this thread is listening on.
'''
return (self.server.server_address, self.server.server_port)
def __init__(self, port=0):
threading.Thread.__init__(self)
# dict to store request data in
self._tracked_requests = {}
self.daemon = True
self.port = port
self.ready = threading.Event()
# dict of pathname (no starting /) -> function
self._handlers = {}
self.app = flask.Flask(__name__)
self.app.debug = True
@self.app.before_request
def save_request():
'''
If the tracking header is set, save the request
'''
if flask.request.headers.get(self.TRACKING_HEADER):
self._tracked_requests[flask.request.headers[self.TRACKING_HEADER]] = {'request': flask.request}
@self.app.after_request
def save_response(response):
'''
If the tracking header is set, save the response
'''
if flask.request.headers.get(self.TRACKING_HEADER):
self._tracked_requests[flask.request.headers[self.TRACKING_HEADER]]['response'] = response
return response
@self.app.route('/', defaults={'path': ''})
@self.app.route('/<path:path>')
def catch_all(path=''):
# get path key
if path in self._handlers:
return self._handlers[path](flask.request)
# return a 404 since we didn't find it
return ('', 404)
# A little magic to make flask accept *all* methods on the catch_all path
for rule in self.app.url_map.iter_rules():
rule.methods = None
rule.refresh()
def get_tracking_key(self):
'''
Return a new key for tracking a request by key
'''
key = str(len(self._tracked_requests))
self._tracked_requests[key] = {}
return key
def get_tracking_by_key(self, key):
'''
Return tracking data by key
'''
if key not in self._tracked_requests:
raise Exception()
return self._tracked_requests[key]
def normalize_path(self, path):
'''
Normalize the path, since its common (and convenient) to start with / in your paths
'''
if path.startswith('/'):
return path[1:]
return path
def add_handler(self, path, func):
'''
Add a new handler attached to a specific path
'''
path = self.normalize_path(path)
if path in self._handlers:
raise Exception()
self._handlers[path] = func
def remove_handler(self, path):
'''
remove a handler attached to a specific path
'''
path = self.normalize_path(path)
if path not in self._handlers:
raise Exception()
del self._handlers[path]
def clear_handlers(self):
'''
Clear all handlers that have been registered
'''
self._handlers = {}
def url(self, path=''):
'''
Get the url for the given path in this endpoint
'''
if path and not path.startswith('/'):
path = '/' + path
return 'http://127.0.0.1:{0}{1}'.format(self.address[1], path)
def run(self):
self.server = make_server('',
self.port,
self.app.wsgi_app)
# mark the socket as SO_REUSEADDR
self.server.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# mark it as ready
self.ready.set()
# serve it
self.server.serve_forever()
class TrackingWSGIServer(threading.Thread):
'''
A threaded webserver which will wrap any wsgi app and track request/response
headers to the origin
# create the thread object
http_endpoint = tsqa.endpoint.TrackingWSGIServer(app)
# start the thread
http_endpoint.start()
# wait for the webserver to listen
http_endpoint.ready.wait()
'''
TRACKING_HEADER = '__cool_test_header__' # TODO: better name?
@property
def address(self):
'''
Return a tuple of (ip, port) that this thread is listening on.
'''
return (self.server.server_address, self.server.server_port)
def __init__(self, app, port=0):
threading.Thread.__init__(self)
# dict to store request data in
self._tracked_requests = {}
self.daemon = True
self.port = port
self.ready = threading.Event()
self.app = app
self.app.debug = True
@self.app.before_request
def save_request():
'''
If the tracking header is set, save the request
'''
if flask.request.headers.get(self.TRACKING_HEADER):
self._tracked_requests[flask.request.headers[self.TRACKING_HEADER]] = {'request': request.copy()}
@self.app.after_request
def save_response(response):
'''
If the tracking header is set, save the response
'''
if flask.request.headers.get(self.TRACKING_HEADER):
self._tracked_requests[flask.request.headers[self.TRACKING_HEADER]]['response'] = response
return response
def get_tracking_key(self):
'''
Return a new key for tracking a request by key
'''
key = str(len(self._tracked_requests))
self._tracked_requests[key] = {}
return key
def get_tracking_by_key(self, key):
'''
Return tracking data by key
'''
if key not in self._tracked_requests:
raise Exception()
return self._tracked_requests[key]
def run(self):
self.server = make_server('',
self.port,
self.app.wsgi_app)
# mark it as ready
self.ready.set()
# serve it
self.server.serve_forever()
class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
pass
class SocketServerDaemon(threading.Thread):
'''
A daemon thread to run a socketserver
'''
def __init__(self, handler, port=0):
threading.Thread.__init__(self)
self.port = port
self.handler = handler
self.ready = threading.Event()
self.daemon = True
def run(self):
self.server = ThreadedTCPServer(('0.0.0.0', self.port), self.handler)
self.server.allow_reuse_address = True
self.port = self.server.socket.getsockname()[1]
self.ready.set()
# Activate the server; this will keep running until you
# interrupt the program with Ctrl-C
self.server.serve_forever()
class ThreadedSSLTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
def __init__(self,
server_address,
RequestHandlerClass,
certfile,
keyfile,
ssl_version=ssl.PROTOCOL_TLSv1,
bind_and_activate=True):
SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate)
self.certfile = certfile
self.keyfile = keyfile
self.ssl_version = ssl_version
def get_request(self):
newsocket, fromaddr = self.socket.accept()
connstream = ssl.wrap_socket(newsocket,
server_side=True,
certfile=self.certfile,
keyfile=self.keyfile,
ssl_version=self.ssl_version,
)
return connstream, fromaddr
class SSLSocketServerDaemon(threading.Thread):
'''
A daemon thread to run a socketserver
This is just a thread wrapper to https://docs.python.org/2/library/socketserver.html
'''
def __init__(self, handler, cert, key, port=0):
'''
handler: instance of SocketServer.BaseRequestHandler
https://docs.python.org/2/library/socketserver.html#socketserver-tcpserver-example
cert: path to certificate file
key: path to key file
'''
# for testing it is *very* common to have self-signed certs, so we
# will disable warnings so we don't flood logs
requests.packages.urllib3.disable_warnings()
threading.Thread.__init__(self)
self.handler = handler
self.cert = cert
self.key = key
self.port = port
self.ready = threading.Event()
self.daemon = True
def run(self):
self.server = ThreadedSSLTCPServer(('0.0.0.0', self.port),
self.handler,
self.cert,
self.key,
)
self.server.allow_reuse_address = True
self.port = self.server.socket.getsockname()[1]
self.ready.set()
self.server.serve_forever()