| |
| try: |
| import Queue |
| except ImportError: |
| # pylint: disable=F0401 |
| # http is a Python3 module, replacing httplib. Ditto. |
| import queue as Queue |
| import threading |
| |
| try: |
| import httplib |
| except ImportError: |
| # pylint: disable=F0401 |
| from http import client as httplib |
| |
| try: |
| from urllib import urlencode |
| except ImportError: |
| # pylint: disable=F0401,E0611 |
| from urllib.parse import urlencode |
| |
| import datetime |
| import json |
| import logging |
| |
| # use generators for python2 and python3 |
| try: |
| xrange |
| except NameError: |
| xrange = range |
| |
| # some constants |
| MAX_RETRY = 1 # 0 means no retry |
| |
| |
| # logger |
| logger = None |
| DEBUG_LOG = False |
| |
| |
| def enable_log(filename=None): |
| global logger |
| global DEBUG_LOG |
| timestamp = datetime.datetime.today() |
| if not filename: |
| logfile = "./log/predictionio_%s.log" % timestamp.strftime( |
| "%Y-%m-%d_%H:%M:%S.%f") |
| else: |
| logfile = filename |
| logging.basicConfig(filename=logfile, |
| filemode='w', |
| level=logging.DEBUG, |
| format='[%(levelname)s] %(name)s (%(threadName)s) %(message)s') |
| logger = logging.getLogger(__name__) |
| DEBUG_LOG = True |
| |
| |
| class PredictionIOAPIError(Exception): |
| pass |
| |
| |
| class NotSupportMethodError(PredictionIOAPIError): |
| pass |
| |
| |
| class ProgramError(PredictionIOAPIError): |
| pass |
| |
| |
| class AsyncRequest(object): |
| |
| """AsyncRequest object |
| |
| """ |
| |
| def __init__(self, method, path, **params): |
| self.method = method # "GET" "POST" etc |
| # the sub path eg. POST /v1/users.json GET /v1/users/1.json |
| self.path = path |
| # dictionary format eg. {"appkey" : 123, "id" : 3} |
| self.params = params |
| # use queue to implement response, store AsyncResponse object |
| self.response_q = Queue.Queue(1) |
| self.qpath = "%s?%s" % (self.path, urlencode(self.params)) |
| self._response = None |
| # response function to be called to handle the response |
| self.rfunc = None |
| |
| def __str__(self): |
| return "%s %s %s %s" % (self.method, self.path, self.params, |
| self.qpath) |
| |
| def set_rfunc(self, func): |
| self.rfunc = func |
| |
| def set_response(self, response): |
| """ store the response |
| |
| NOTE: Must be only called once |
| """ |
| self.response_q.put(response) |
| |
| def get_response(self): |
| """Get the response. Blocking. |
| |
| :returns: self.rfunc's return type. |
| """ |
| if self._response is None: |
| tmp_response = self.response_q.get(True) # NOTE: blocking |
| if self.rfunc is None: |
| self._response = tmp_response |
| else: |
| self._response = self.rfunc(tmp_response) |
| |
| return self._response |
| |
| |
| class AsyncResponse(object): |
| """Store the response of asynchronous request |
| |
| When get the response, user should check if error is None (which means no |
| Exception happens). |
| If error is None, then should check if the status is expected. |
| """ |
| |
| def __init__(self): |
| #: exception object if any happens |
| self.error = None |
| |
| self.version = None |
| self.status = None |
| self.reason = None |
| #: Response header. str |
| self.headers = None |
| #: Response body. str |
| self.body = None |
| #: Jsonified response body. Remains None if conversion is unsuccessful. |
| self.json_body = None |
| #: Point back to the AsyncRequest object |
| self.request = None |
| |
| def __str__(self): |
| return "e:%s v:%s s:%s r:%s h:%s b:%s" % (self.error, self.version, |
| self.status, self.reason, |
| self.headers, self.body) |
| |
| def set_resp(self, version, status, reason, headers, body): |
| self.version = version |
| self.status = status |
| self.reason = reason |
| self.headers = headers |
| self.body = body |
| # Try to extract the json. |
| try: |
| self.json_body = json.loads(body.decode('utf8')) |
| except ValueError as ex: |
| self.json_body = None |
| |
| def set_error(self, error): |
| self.error = error |
| |
| def set_request(self, request): |
| self.request = request |
| |
| |
| class PredictionIOHttpConnection(object): |
| |
| def __init__(self, host, https=True, timeout=5): |
| if https: # https connection |
| self._connection = httplib.HTTPSConnection(host, timeout=timeout) |
| else: |
| self._connection = httplib.HTTPConnection(host, timeout=timeout) |
| |
| def connect(self): |
| self._connection.connect() |
| |
| def close(self): |
| self._connection.close() |
| |
| def request(self, method, url, body={}, headers={}): |
| """ |
| http request wrapper function, with retry capability in case of error. |
| catch error exception and store it in AsyncResponse object |
| return AsyncResponse object |
| |
| Args: |
| method: http method, type str |
| url: url path, type str |
| body: http request body content, type dict |
| header: http request header , type dict |
| """ |
| |
| response = AsyncResponse() |
| |
| try: |
| # number of retry in case of error (minimum 0 means no retry) |
| retry_limit = MAX_RETRY |
| mod_headers = dict(headers) # copy the headers |
| mod_headers["Connection"] = "keep-alive" |
| enc_body = None |
| if body: # if body is not empty |
| #enc_body = urlencode(body) |
| #mod_headers[ |
| # "Content-type"] = "application/x-www-form-urlencoded" |
| enc_body = json.dumps(body) |
| mod_headers[ |
| "Content-type"] = "application/json" |
| #mod_headers["Accept"] = "text/plain" |
| except Exception as e: |
| response.set_error(e) |
| return response |
| |
| if DEBUG_LOG: |
| logger.debug("Request m:%s u:%s h:%s b:%s", method, url, |
| mod_headers, enc_body) |
| # retry loop |
| for i in xrange(retry_limit + 1): |
| try: |
| if i != 0: |
| if DEBUG_LOG: |
| logger.debug("retry request %s times" % i) |
| if self._connection.sock is None: |
| self._connection.connect() |
| self._connection.request(method, url, enc_body, mod_headers) |
| except Exception as e: |
| self._connection.close() |
| if i == retry_limit: |
| # new copy of e created everytime?? |
| response.set_error(e) |
| else: # NOTE: this is try's else clause |
| # connect() and request() OK |
| try: |
| resp = self._connection.getresponse() |
| except Exception as e: |
| self._connection.close() |
| if i == retry_limit: |
| response.set_error(e) |
| else: # NOTE: this is try's else clause |
| # getresponse() OK |
| resp_version = resp.version # int |
| resp_status = resp.status # int |
| resp_reason = resp.reason # str |
| # resp.getheaders() returns list of tuples |
| # converted to dict format |
| resp_headers = dict(resp.getheaders()) |
| # NOTE: have to read the response before sending out next |
| # http request |
| resp_body = resp.read() # str |
| response.set_resp(version=resp_version, status=resp_status, |
| reason=resp_reason, headers=resp_headers, |
| body=resp_body) |
| break # exit retry loop |
| # end of retry loop |
| if DEBUG_LOG: |
| logger.debug("Response %s", response) |
| return response # AsyncResponse object |
| |
| |
| def connection_worker(host, request_queue, https=True, timeout=5, loop=True): |
| """worker function which establishes connection and wait for request jobs |
| from the request_queue |
| |
| Args: |
| request_queue: the request queue storing the AsyncRequest object |
| valid requests: |
| GET |
| POST |
| DELETE |
| KILL |
| https: HTTPS (True) or HTTP (False) |
| timeout: timeout for HTTP connection attempts and requests in seconds |
| loop: This worker function stays in a loop waiting for request |
| For testing purpose only. should always be set to True. |
| """ |
| |
| connect = PredictionIOHttpConnection(host, https, timeout) |
| |
| # loop waiting for job form request queue |
| killed = not loop |
| |
| while True: |
| # print "thread %s waiting for request" % thread.get_ident() |
| request = request_queue.get(True) # NOTE: blocking get |
| # print "get request %s" % request |
| method = request.method |
| if method == "GET": |
| path = request.qpath |
| d = connect.request("GET", path) |
| elif method == "POST": |
| path = request.path |
| body = request.params |
| d = connect.request("POST", path, body) |
| elif method == "DELETE": |
| path = request.qpath |
| d = connect.request("DELETE", path) |
| elif method == "KILL": |
| # tell the thread to kill the connection |
| killed = True |
| d = AsyncResponse() |
| else: |
| d = AsyncResponse() |
| d.set_error(NotSupportMethodError( |
| "Don't Support the method %s" % method)) |
| |
| d.set_request(request) |
| request.set_response(d) |
| request_queue.task_done() |
| if killed: |
| break |
| |
| # end of while loop |
| connect.close() |
| |
| |
| class Connection(object): |
| |
| """abstract object for connection with server |
| |
| spawn multiple connection_worker threads to handle jobs in the queue q |
| """ |
| |
| def __init__(self, host, threads=1, qsize=0, https=True, timeout=5): |
| """constructor |
| |
| Args: |
| host: host of the server. |
| threads: type int, number of threads to be spawn |
| qsize: size of the queue q |
| https: indicate it is httpS (True) or http connection (False) |
| timeout: timeout for HTTP connection attempts and requests in |
| seconds |
| """ |
| self.host = host |
| self.https = https |
| self.q = Queue.Queue(qsize) # if qsize=0, means infinite |
| self.threads = threads |
| self.timeout = timeout |
| # start thread based on threads number |
| self.tid = {} # dictionary of thread object |
| |
| for i in xrange(threads): |
| tname = "PredictionIOThread-%s" % i # thread name |
| self.tid[i] = threading.Thread( |
| target=connection_worker, name=tname, |
| kwargs={'host': self.host, 'request_queue': self.q, |
| 'https': self.https, 'timeout': self.timeout}) |
| self.tid[i].setDaemon(True) |
| self.tid[i].start() |
| |
| def make_request(self, request): |
| """put the request into the q |
| """ |
| self.q.put(request) |
| |
| def pending_requests(self): |
| """number of pending requests in the queue |
| """ |
| return self.q.qsize() |
| |
| def close(self): |
| """close this Connection. Call this when main program exits |
| """ |
| # set kill message to q |
| for i in xrange(self.threads): |
| self.make_request(AsyncRequest("KILL", "")) |
| |
| self.q.join() # wait for q empty |
| |
| for i in xrange(self.threads): # wait for all thread finish |
| self.tid[i].join() |